# SqueezeNet

In this file we train the SqueezeNet model as described in the paper found [here](https://arxiv.org/abs/1602.07360).
This implementation uses the CIFAR10 dataset.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.quantization import QuantStub, DeQuantStub
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from collections import OrderedDict

## Implementation

The class in this cell below defines our architecture and defines our forward pass. We insert quantization stub for later Quantization Aware Training. We also define helper functions to save and load the model.

In [None]:
class SqueezeNet(nn.Module):
    def __init__(self, num_classes=10, input_channels=3):
        super(SqueezeNet, self).__init__()

        self.relu = nn.ReLU(inplace=True)

        # quantization stubs (identity during FP32 training)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

        # conv1
        self.conv1 = nn.Conv2d(
            input_channels, 96, kernel_size=7, stride=2, padding=3
        )
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)

        # fire modules
        self.fire2_squeeze = nn.Conv2d(96, 16, kernel_size=1)
        self.fire2_expand1x1 = nn.Conv2d(16, 64, kernel_size=1)
        self.fire2_expand3x3 = nn.Conv2d(16, 64, kernel_size=3, padding=1)

        self.fire3_squeeze = nn.Conv2d(128, 16, kernel_size=1)
        self.fire3_expand1x1 = nn.Conv2d(16, 64, kernel_size=1)
        self.fire3_expand3x3 = nn.Conv2d(16, 64, kernel_size=3, padding=1)

        self.fire4_squeeze = nn.Conv2d(128, 32, kernel_size=1)
        self.fire4_expand1x1 = nn.Conv2d(32, 128, kernel_size=1)
        self.fire4_expand3x3 = nn.Conv2d(32, 128, kernel_size=3, padding=1)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2)

        self.fire5_squeeze = nn.Conv2d(256, 32, kernel_size=1)
        self.fire5_expand1x1 = nn.Conv2d(32, 128, kernel_size=1)
        self.fire5_expand3x3 = nn.Conv2d(32, 128, kernel_size=3, padding=1)

        self.fire6_squeeze = nn.Conv2d(256, 48, kernel_size=1)
        self.fire6_expand1x1 = nn.Conv2d(48, 192, kernel_size=1)
        self.fire6_expand3x3 = nn.Conv2d(48, 192, kernel_size=3, padding=1)

        self.fire7_squeeze = nn.Conv2d(384, 48, kernel_size=1)
        self.fire7_expand1x1 = nn.Conv2d(48, 192, kernel_size=1)
        self.fire7_expand3x3 = nn.Conv2d(48, 192, kernel_size=3, padding=1)

        self.fire8_squeeze = nn.Conv2d(384, 64, kernel_size=1)
        self.fire8_expand1x1 = nn.Conv2d(64, 256, kernel_size=1)
        self.fire8_expand3x3 = nn.Conv2d(64, 256, kernel_size=3, padding=1)
        self.maxpool8 = nn.MaxPool2d(kernel_size=3, stride=2)

        self.fire9_squeeze = nn.Conv2d(512, 64, kernel_size=1)
        self.fire9_expand1x1 = nn.Conv2d(64, 256, kernel_size=1)
        self.fire9_expand3x3 = nn.Conv2d(64, 256, kernel_size=3, padding=1)

        self.dropout = nn.Dropout(p=0.5)

        # conv10
        self.conv10 = nn.Conv2d(512, num_classes, kernel_size=1)

    def _fire_forward(self, x, squeeze, expand1, expand3):
        x = self.relu(squeeze(x))
        e1 = self.relu(expand1(x))
        e3 = self.relu(expand3(x))
        return torch.cat([e1, e3], dim=1)

    def forward(self, x):

        # quant input (for task 2)
        x = self.quant(x)

        x = self.relu(self.conv1(x))
        x = self.maxpool1(x)

        x = self._fire_forward(x, self.fire2_squeeze, self.fire2_expand1x1, self.fire2_expand3x3)
        x = self._fire_forward(x, self.fire3_squeeze, self.fire3_expand1x1, self.fire3_expand3x3)

        x = self._fire_forward(x, self.fire4_squeeze, self.fire4_expand1x1, self.fire4_expand3x3)
        x = self.maxpool4(x)

        x = self._fire_forward(x, self.fire5_squeeze, self.fire5_expand1x1, self.fire5_expand3x3)
        x = self._fire_forward(x, self.fire6_squeeze, self.fire6_expand1x1, self.fire6_expand3x3)
        x = self._fire_forward(x, self.fire7_squeeze, self.fire7_expand1x1, self.fire7_expand3x3)

        x = self._fire_forward(x, self.fire8_squeeze, self.fire8_expand1x1, self.fire8_expand3x3)
        x = self.maxpool8(x)

        x = self._fire_forward(x, self.fire9_squeeze, self.fire9_expand1x1, self.fire9_expand3x3)
        x = self.dropout(x)

        x = self.conv10(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)

        # dequant before logits output (for task 2)
        x = self.dequant(x)

        return x
    
    def load_model(self, path='squeezenet_cifar10.pth',device='cpu'):
        state_dict = torch.load(path,map_location=device)

        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            if k.startswith('module.'):
                k = k[len('module.'):]
            new_state_dict[k] = v

        self.load_state_dict(new_state_dict)
        self.to(device)
        self.eval()

        print(f"Model loaded from {path}")
        # print(self)

    def save_model(self, path='squeezenet_cifar10.pth'):
        torch.save(self.state_dict(), path)
        print(f"Model saved to {path}")


## Load the Dataset

In this cell we define a function to load our dataset.

In [None]:
def load_dataset(path='./data', batch_size=64):
  print("Loading the CIFAR10 dataset")

  transform = transforms.Compose([
      transforms.ToTensor(), # scale RGB 0-255 to 0-1
      # normalize with known mean and std deviation of the CIFAR10 dataset
      transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010))
      ])

  # get training data
  train_dataset = datasets.CIFAR10(root=path, train=True, download=True, transform=transform)
  # get test data
  test_dataset = datasets.CIFAR10(root=path, train=False, download=True, transform=transform)
  # load the training data
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  # load the test data
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

  print(f"Loaded train data: {len(train_loader.dataset)} total samples, {len(train_loader)} batches\n"
      f"Loaded test data: {len(test_loader.dataset)} total samples, {len(test_loader)} batches")

  return train_loader, test_loader

In [None]:
train_loader, test_loader = load_dataset()

## Train the model

In the cells below we define a function to visualize our training and train our model

In [None]:
import matplotlib.pyplot as plt

def plot_metrics(metrics):
  train_losses = metrics.get('train_loss',None)
  test_losses = metrics.get('test_loss',None)
  train_accs = metrics.get('train_acc',None)
  test_accs = metrics.get('test_acc',None)

  epochs = range(1, len(train_losses) + 1)

  plt.figure(figsize=(12, 5))

  # Loss Graph
  plt.subplot(1, 2, 1)
  if train_losses:
    plt.plot(epochs, train_losses, label='Train Loss', marker='o')
  if test_losses:
    plt.plot(epochs, test_losses, label='Test Loss', marker='s')
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.title('Training vs Test Loss')
  plt.legend()
  plt.grid(True, linestyle='--', alpha=0.6)

  # Accuracy Graph
  plt.subplot(1, 2, 2)
  if train_accs:
    plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')
  if test_accs:
    plt.plot(epochs, test_accs, label='Test Accuracy', marker='s')
  plt.xlabel('Epoch')
  plt.ylabel('Accuracy (%)')
  plt.title('Training vs Test Accuracy')
  plt.legend()
  plt.grid(True, linestyle='--', alpha=0.6)

  plt.tight_layout()
  plt.show()

In [None]:
def train_model(model,train_loader,test_loader,train=True,test=True,device='cpu',epochs=10,lr=1e-3):
  model.to(device)
  metrics = {
        "train_loss": [],
        "train_acc": [],
        "test_loss": [],
        "test_acc": []
    }

  # TRAINING LOOP
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

  for e in range(epochs):
    print(f"Epoch [{e+1}/{epochs}] ",end='')
    if train:
      model.train()
      train_loss, total_examples, correct = 0.0, 0, 0

      for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad() # zero gradients
        outputs = model(inputs) # forward pass
        loss = criterion(outputs,labels) # get loss from cost function
        loss.backward() # backward propagation
        optimizer.step() # update gradients

        # train_loss += loss.item() # track total loss up to this point
        train_loss += loss.item() * labels.size(0)
        _, pred_ind = outputs.max(1) # get index of prediction (highest value)
        total_examples += labels.size(0) # update count for this epoch with batch size
        correct += pred_ind.eq(labels).sum().item() # return count of correct predictions

    #   train_loss /= len(train_loader) # get average per batch
      train_loss /= total_examples # get average per example
      train_acc = 100.0 * correct / total_examples

      metrics["train_loss"].append(train_loss)
      metrics["train_acc"].append(train_acc)

      print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% ",end='')

      # VALIDATION/TEST
    if test:
      model.eval()
      test_loss, total_examples, correct = 0.0, 0, 0

      with torch.no_grad():
        for inputs, labels in test_loader:
          inputs, labels = inputs.to(device), labels.to(device)
          outputs = model(inputs) # forward pass
          loss = criterion(outputs,labels) # get loss from cost function
          test_loss += loss.item() # update loss
          _, pred_ind = outputs.max(1) # get index of prediction (highest value)
          total_examples += labels.size(0) # update count for this epoch with batch size
          correct += pred_ind.eq(labels).sum().item() # return count of correct predictions

      test_loss /= len(test_loader)
      test_acc = 100.0 * correct / total_examples

      metrics["test_loss"].append(test_loss)
      metrics["test_acc"].append(test_acc)

      print(f"Test/Val Loss: {test_loss:.4f}, Test/Val Acc: {test_acc:.2f}%")

  return metrics

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device={device}")

model_fp32 = SqueezeNet()

# model_fp32.load_model()

In [None]:
train, test = True, True
epochs = 100
fp32_metrics = train_model(model=model_fp32,train_loader=train_loader,test_loader=test_loader,train=train,test=test,device=device,epochs=epochs)

In [None]:
model_fp32.save_model("squeezenet_cifar10_fp32.ipynb")

In [None]:
plot_metrics(fp32_metrics)