# Squeezenet Paper : Implementing the 2016 Paper using Pytorch with Flexibility
There's a paper here, talking about Squeezenet, which based on my light reading is basically a compressed version of CNN. It's lightweight, with less than 0.5 MB of size, so it must be very good for mobile application and utilized in Edge AI Applications. Before we go any further, let's discuss the paper abstract first

>Recent research on deep convolutional neural networks (CNNs) has focused primarily on improving accuracy. For a given accuracy level, it is typically possible to identify multiple CNN architectures that achieve that accuracy level. With equivalent accuracy, smaller CNN architectures offer at least three advantages: (1)Smaller CNNs require less communication across servers during distributed training. (2) Smaller CNNs require less bandwidth to export a new model from the cloud to an autonomous car. (3) Smaller CNNs are more feasible to deploy on FPGAs and other hardware with limited memory. To provide all of these advantages, we propose a small CNN architecture called SqueezeNet. SqueezeNet achieves AlexNet-level accuracy on ImageNet with 50x fewer parameters. Additionally, with model compression techniques, we are able to compress SqueezeNet to lessthan 0.5MB (510× smaller than AlexNet).
>

Ok. The abstract is quite straightforward. It's talking about the issues of current CNN (in 2016 when this paper was released) which has their forcus mostly on the accuracy. But these CNN architecture are bloated with high parameter counts. So they're proposing this SqueezeNet which reduced the parameter count of their CNN and comparing to a famous architecture, the AlexNet (50X smaller). I know a little bit about AlexNet which is the first proof of how effective Deep Learning is against traditional computer vision method during that time. Seeing that they have the same performance as AlexNet with 50x smaller architectures highlights the potential of SqueezeNet as an effective edgeAI implementation.

## How did they achieve it?
Without going too deep (pun intended) into the paper, let's just go into how they might achieve this. I don't fully read the paper in  it's entirety, just skim and read the key points. And to be frank I use ChatGPT to summarize the points. Hope ChatGPT is not scamming me. From what I read, the key major factor in Squeezenet success is how they managed the filter and the pooling.

I've extensively covered about CNN in my previous articles, you can read about it here

Part 1:
https://medium.com/@maercaestro/siri-belajar-ai-jaringan-neural-berlingkar-convoluted-neural-network-bahagian-1-ef517726609f

Part 2:
https://medium.com/@maercaestro/siri-belajar-ai-jaringan-neural-berlingkar-convolutional-neural-network-bahagian-2-f19956754288

From my article, we know that CNN is adding two process in front of tradtional neural network. The first one is the convoluttional process which allows the network to detect edges and pattern from the image provided, akin to how human and animals process images. The next layer is the pooling layer which reduces the dimensions of the input thus effectively reduce the parameter count coming from the filter. 

From here, there's actually two issues that I kind of see when I first learn about this

**1. The filter size**
Although CNN is introduced as a way to effectively process the images compared to traditional Neural Network, the subsequent works on it tends to bloat the filter size. AlexNet itself has 11X11 filter size for it's first convolution layers. This is understandable, since it is a good way to actually learns the spatial/relationship between each pixels and the adjacent pixel. But, high number of filter size will lead to high number of parameter counts, thus leading to inefficiency.

**2. Always Pooling after Filter**
This is what I question during my first implementation of CNN. They always go to pooling after the convolution. Why do they feel the need to do that? Why not just use only one pooling at the end of the process? That itself seems better

So, that's what the author of SqueezeNet did. First, they reduce the size of their filter size, and they remove the unnecessary pooling after each convolution layer. But, to ensure that the spatial information does not loss, they stacked smaller filter size on top of each other. By doing that, they ensure that they can capture all the necessary spatial information, while reducing the parameter size. They called this, their fire module.


## Let's build this fire module
So, in the paper, they present 3 main strategies for their fire module

**1. Replace all 3x3 filter (for the first convolutional layer) to 1x1 filter**

**2. Replace all input channels to 3x3 filter**

**3. Delayed Downsampling (ensure pooling only happened at the last layer**

So, that's basically how we want to implement our fire module



In [1]:
#we will use pytorch for our costum implementation 
import torch
import torch.nn as nn

class FireModule(nn.Module):
    """
    My own implementation of fire module as explain in the Squeezenet paper
    Since FireModule is designed in modular form, we can tune it to meet different architecture 
    that we want depending on the task. In the paper, the firemodel is comes with 3 tunable 
    parameters

    1. The 1x1 squeeze filter (we can determine how many of this we need), we called this sq
    2. The 1x1 expand filter (we call this exp1)
    3. The 3x3 expand filter (we call this exp3)

    Please bear in mind that 1x1 squeeze layer should be smaller than the expand filter

    """

    def __init__(self, in_channels, sq, exp1, exp3):
        super(FireModule,self).__init__()

        #we have to enfore constraint to ensure sq1 is always less than exp1+exp3
        if sq >= (exp1+exp3):
            raise ValueError(f"squeeze layer {sq} should always be smaller that {exp1+exp3}")

        # Squeeze Layer (1×1 convolution)
        self.squeeze = nn.Conv2d(in_channels, sq, kernel_size=1)
        self.squeeze_activation = nn.ReLU() #always follow with ReLU activation functon

        # Expand Layer (1×1 and 3×3 convolutions)
        self.expand1x1 = nn.Conv2d(sq, exp1, kernel_size=1)
        self.expand3x3 = nn.Conv2d(sq, exp3, kernel_size=3, padding=1)
        self.expand_activation = nn.ReLU() #always follow with ReLU activation functon

    def forward(self, x):
        x = self.squeeze_activation(self.squeeze(x))  # Squeeze
        x1 = self.expand_activation(self.expand1x1(x))  # Expand (1×1)
        x3 = self.expand_activation(self.expand3x3(x))  # Expand (3×3)
        return torch.cat([x1, x3], dim=1)


## Let's Test SqueezeNet Theoretically via Receptive Field
So there's a new term that I just learn. We know that AlexNet has a large number of filters (11x11) due to the fact that it wants to capture the spatial information and relationship between it's pixel. But it resulted in high parameter counts. SqueezeNet aims to reduce this parameter counts by stacking multiple 3x3 filters on top of each other to capture that spatial information instead of using a larger filter early on. So how do we know SqueezeNet was able to match AlexNet performance with that strategy? Let's test it's **receptive field**.

According to Wikipedia, the receptive field (of a biological neuron) is **“the portion of the sensory space that can elicit neuronal responses, when stimulated”**

We can denote the equation of the receptive field as below

$$
r(i-1) = s_i \times r_i + (k_i - s_i)
$$

where:
- **$ r(i-1) $** = Receptive field of the previous layer.
- **$ r_i $** = Receptive field of the current layer.
- **$ s_i $** = Stride of the current layer.
- **$ k_i $** = Kernel size of the current layer.


So, with this equation, we can use it to compare the receptive field of both AlexNet vs the SqueezeNet and see whether SqueezeNet strategy has its merits. So let's test this

In [2]:
#first we define the equation to calculate the receptive field
def compute_receptive_field(layers):
    """
    Compute the receptive field using the backward approach.
    
    layers: list of (layer_name, kernel_size, stride),
            ordered from the FIRST layer to the LAST layer.
            
    We'll reverse this list internally and apply:
        r_(i-1) = s_i * r_i + (k_i - s_i)
    
    Returns:
        final_rf: The receptive field seen at the input for the LAST layer.
        rf_per_layer: A list of (layer_name, receptive_field) from last to first.
    """
    # Start with receptive field of size 1 at the last layer
    r = 1
    rf_per_layer = []
    
    # Process from last to first
    for (layer_name, k, s) in reversed(layers):
        r = s * r + (k - s)
        rf_per_layer.append((layer_name, r))
    
    # Reverse rf_per_layer so it's in forward order
    rf_per_layer.reverse()
    return r, rf_per_layer


### Let's test this on AlexNet
We will test this equation on AlexNet, but we're only talking on the actitecture itself, not building the entire model in pytorch

In [3]:
alexnet_layers = [
    ("conv1", 11, 4),   # Conv1
    ("pool1", 3, 2),    # MaxPool1
    ("conv2", 5, 1),    # Conv2
    ("pool2", 3, 2),    # MaxPool2
    ("conv3", 3, 1),    # Conv3
    ("conv4", 3, 1),    # Conv4
    ("conv5", 3, 1),    # Conv5
    ("pool3", 3, 2)     # MaxPool3
]


In [4]:
alex_rf, alex_rf_layers = compute_receptive_field(alexnet_layers)
print("AlexNet Receptive Field (Ignoring Padding):", alex_rf)
print("Layer-by-Layer Growth (From Last to First):")
for layer_name, rf_val in alex_rf_layers:
    print(f"  {layer_name} -> RF = {rf_val}")


AlexNet Receptive Field (Ignoring Padding): 195
Layer-by-Layer Growth (From Last to First):
  conv1 -> RF = 195
  pool1 -> RF = 47
  conv2 -> RF = 23
  pool2 -> RF = 19
  conv3 -> RF = 9
  conv4 -> RF = 7
  conv5 -> RF = 5
  pool3 -> RF = 3


### Now, let's compare with SqueezeNet

In [5]:
squeezenet_layers = [
    ("conv1", 7, 2),       # Conv1
    ("pool1", 3, 2),       # MaxPool (after Conv1)
    
    # Fire2 - expand3x3
    ("fire2", 3, 1),
    # Fire3 - expand3x3
    ("fire3", 3, 1),
    # Fire4 - expand3x3
    ("fire4", 3, 1),
    ("pool4", 3, 2),       # MaxPool (after Fire4)
    
    # Fire5, Fire6, Fire7, Fire8 - expand3x3
    ("fire5", 3, 1),
    ("fire6", 3, 1),
    ("fire7", 3, 1),
    ("fire8", 3, 1),
    ("pool8", 3, 2),       # MaxPool (after Fire8)
    
    # Fire9 - expand3x3
    ("fire9", 3, 1),
    
    # Final conv layer
    ("conv10", 1, 1)       # This 1×1 kernel doesn't expand RF, but let's include it for completeness
]


In [6]:
squeeze_rf, squeeze_rf_layers = compute_receptive_field(squeezenet_layers)
print("\n SqueezeNet Receptive Field (Ignoring Padding):", squeeze_rf)
print(" Layer-by-Layer Growth (From Last to First):")
for layer_name, rf_val in squeeze_rf_layers:
    print(f"  {layer_name} -> RF = {rf_val}")



 SqueezeNet Receptive Field (Ignoring Padding): 155
 Layer-by-Layer Growth (From Last to First):
  conv1 -> RF = 155
  pool1 -> RF = 75
  fire2 -> RF = 37
  fire3 -> RF = 35
  fire4 -> RF = 33
  pool4 -> RF = 31
  fire5 -> RF = 15
  fire6 -> RF = 13
  fire7 -> RF = 11
  fire8 -> RF = 9
  pool8 -> RF = 7
  fire9 -> RF = 3
  conv10 -> RF = 1


Based on what we see above, it seems like SqueezeNet has 45 less receptive field compared to AlexNet. That doesn't seems good seems it might impact the performance of the model. But seeing that SqueezeNet has 50x less parameter counts but only around 20% less receptive seems a good tradeoff. But this is just on theoretical level, let's test it on real data

## Full Implementation

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


In [8]:
class FireModule(nn.Module):
    """
    Optimized Fire module with enforced constraint:
    s1x1 < e1x1 + e3x3 (Squeeze filters must be fewer than Expand filters).
    """
    def __init__(self, in_channels, s1x1, e1x1, e3x3):
        super(FireModule, self).__init__()

        # Enforce constraint
        if s1x1 >= (e1x1 + e3x3):
            raise ValueError(f"Invalid Fire Module: s1x1 ({s1x1}) must be smaller than e1x1 + e3x3 ({e1x1 + e3x3}).")

        # Squeeze Layer (1×1)
        self.squeeze = nn.Conv2d(in_channels, s1x1, kernel_size=1)
        self.squeeze_activation = nn.ReLU()

        # Expand Layer (1×1 and 3×3)
        self.expand1x1 = nn.Conv2d(s1x1, e1x1, kernel_size=1)
        self.expand3x3 = nn.Conv2d(s1x1, e3x3, kernel_size=3, padding=1)
        self.expand_activation = nn.ReLU()

    def forward(self, x):
        x = self.squeeze_activation(self.squeeze(x))  # Squeeze
        x1 = self.expand_activation(self.expand1x1(x))  # Expand (1×1)
        x3 = self.expand_activation(self.expand3x3(x))  # Expand (3×3)
        return torch.cat([x1, x3], dim=1)  # Concatenate across channels


class FlexibleSqueezeNet(nn.Module):
    """
    Flexible SqueezeNet implementation with Fire Modules and optional pooling layers.
    """
    def __init__(self, num_classes=10, fire_configs=None, pooling_layers=None):
        super(FlexibleSqueezeNet, self).__init__()
        
        # Default config: Similar to SqueezeNet v1.0 (but smaller to handle CIFAR-10 quickly)
        if fire_configs is None:
            fire_configs = [
                (16, 32, 32),  # Fire2
                (16, 32, 32),  # Fire3
                (32, 64, 64),  # Fire4
                (32, 64, 64),  # Fire5
                (48, 96, 96),  # Fire6
                (48, 96, 96),  # Fire7
                (64, 128, 128),# Fire8
                (64, 128, 128) # Fire9
            ]

        # Where to apply max pooling (delayed pooling approach)
        if pooling_layers is None:
            # This list must match length of fire_configs
            pooling_layers = [False, True, False, True, False, True, False, True]

        # Enforce constraints for each Fire module
        for s1x1, e1x1, e3x3 in fire_configs:
            if s1x1 >= (e1x1 + e3x3):
                raise ValueError(f"Invalid Fire Module: s1x1 ({s1x1}) >= e1x1 + e3x3 ({e1x1 + e3x3})")

        # Initial layer: smaller for CIFAR-10 (32x32)
        # We'll use kernel_size=3, stride=1 to keep more spatial info
        self.conv1 = nn.Conv2d(3, 96, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        
        # Instead of a big maxpool, let's just do a small one here
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)

        # Build Fire modules
        self.fire_layers = nn.ModuleList()
        in_channels = 96
        for i, (s1x1, e1x1, e3x3) in enumerate(fire_configs):
            self.fire_layers.append(FireModule(in_channels, s1x1, e1x1, e3x3))
            in_channels = e1x1 + e3x3  # update for next Fire module
            if pooling_layers[i]:
                # Another small pooling to reduce resolution
                self.fire_layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True))

        # Final conv to match num_classes
        self.conv10 = nn.Conv2d(in_channels, num_classes, kernel_size=1)
        
        # Global average pooling
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        
        x = self.relu(self.conv1(x))
        for layer in self.fire_layers:
            x = layer(x)


    
        x = self.avgpool(x)

    
        x = torch.flatten(x, 1)

    
        return x


In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


########################################
# 1. Data Loading for CIFAR-10
########################################

# Recommended CIFAR-10 data augmentation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2470, 0.2435, 0.2616)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2470, 0.2435, 0.2616)),
])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform_train
)
test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform_test
)

train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=2
)
test_loader = DataLoader(
    test_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=2
)


########################################
# 2. Define the Fire Module
########################################

class Fire(nn.Module):
    """
    A slightly modified 'Fire' module (same concept from official SqueezeNet).
    Squeeze -> Expand(1x1) + Expand(3x3)
    """
    def __init__(self, inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes):
        super().__init__()
        self.inplanes = inplanes
        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
        self.squeeze_activation = nn.ReLU(inplace=True)

        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
        self.expand1x1_activation = nn.ReLU(inplace=True)

        self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)
        self.expand3x3_activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.squeeze_activation(self.squeeze(x))
        out1x1 = self.expand1x1_activation(self.expand1x1(x))
        out3x3 = self.expand3x3_activation(self.expand3x3(x))
        return torch.cat([out1x1, out3x3], dim=1)


########################################
# 3. Define a SqueezeNet Variant for CIFAR-10
########################################

class SqueezeNetCIFAR(nn.Module):
    """
    A SqueezeNet-like architecture tuned for 32x32 inputs:
    - Smaller initial kernel, stride=1
    - Fewer aggressive max pools
    - Final conv to 10 classes
    - Global avg pool
    """
    def __init__(self, num_classes=10):
        super().__init__()

        # Initial conv: smaller kernel and stride=1 for CIFAR-10
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),  # was 7x7, stride=2 in original
            nn.ReLU(inplace=True),
            
            # Optional small pooling to reduce dimension a bit
            nn.MaxPool2d(kernel_size=2, stride=2),  # from 32x32 -> 16x16
            
            # Fire modules
            Fire(64, 16, 64, 64),   # out = 128
            Fire(128, 16, 64, 64),  # out = 128
            
            # Another max pool (optional)
            nn.MaxPool2d(kernel_size=2, stride=2),  # 16x16 -> 8x8
            
            Fire(128, 32, 128, 128), # out = 256
            Fire(256, 32, 128, 128), # out = 256
            
            # Another max pool (optional)
            nn.MaxPool2d(kernel_size=2, stride=2),  # 8x8 -> 4x4
            
            Fire(256, 48, 192, 192), # out = 384
            Fire(384, 48, 192, 192), # out = 384
            
            # Another max pool (optional)
            nn.MaxPool2d(kernel_size=2, stride=2),  # 4x4 -> 2x2
            
            Fire(384, 64, 256, 256), # out = 512
        )
        
        # Final classification layer
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Conv2d(512, num_classes, kernel_size=1),  # 512 -> 10
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        
        # Initialization similar to official squeezenet
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m is self.classifier[1]:  # final conv
                    nn.init.normal_(m.weight, mean=0.0, std=0.01)
                else:
                    nn.init.kaiming_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        # Flatten to [N, num_classes]
        return torch.flatten(x, 1)


########################################
# 4. Instantiate Model, Loss, Optimizer
########################################

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
model = SqueezeNetCIFAR(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)


########################################
# 5. Training and Evaluation Functions
########################################

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        total_correct += preds.eq(labels).sum().item()
        total_samples += labels.size(0)
    
    avg_loss = total_loss / total_samples
    accuracy = 100.0 * total_correct / total_samples
    return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item() * images.size(0)
            _, preds = outputs.max(1)
            total_correct += preds.eq(labels).sum().item()
            total_samples += labels.size(0)
    
    avg_loss = total_loss / total_samples
    accuracy = 100.0 * total_correct / total_samples
    return avg_loss, accuracy


########################################
# 6. Main Training Loop
########################################

num_epochs = 50  # Increase if you want better convergence
for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    
    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"  Test  Loss: {test_loss:.4f}  | Test  Acc: {test_acc:.2f}%")
    print('-'*50)


Files already downloaded and verified
Files already downloaded and verified
Using device: mps
Epoch [1/50]
  Train Loss: 2.2435 | Train Acc: 13.89%
  Test  Loss: 2.1771  | Test  Acc: 16.70%
--------------------------------------------------
Epoch [2/50]
  Train Loss: 2.1593 | Train Acc: 18.96%
  Test  Loss: 2.0693  | Test  Acc: 23.65%
--------------------------------------------------
Epoch [3/50]
  Train Loss: 1.9965 | Train Acc: 28.19%
  Test  Loss: 1.9075  | Test  Acc: 32.04%
--------------------------------------------------
Epoch [4/50]
  Train Loss: 1.8803 | Train Acc: 33.35%
  Test  Loss: 1.8882  | Test  Acc: 36.14%
--------------------------------------------------
Epoch [5/50]
  Train Loss: 1.8038 | Train Acc: 36.56%
  Test  Loss: 1.7551  | Test  Acc: 39.20%
--------------------------------------------------
Epoch [6/50]
  Train Loss: 1.7437 | Train Acc: 38.91%
  Test  Loss: 1.6376  | Test  Acc: 41.74%
--------------------------------------------------
Epoch [7/50]
  Train Los

In [10]:
import torch

# Check if MPS is available
print(f"MPS Available: {torch.backends.mps.is_available()}")
print(f"MPS Built: {torch.backends.mps.is_built()}")


MPS Available: True
MPS Built: True


In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Use MPS if available
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Data loading with augmentation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2470, 0.2435, 0.2616)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2470, 0.2435, 0.2616)),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

# Define SqueezeNet-based model for CIFAR-10
from torchvision.models import squeezenet1_0

model = squeezenet1_0(weights=None, num_classes=10)
model.to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

# Training Function
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss, total_correct, total_samples = 0.0, 0, 0
    
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        total_correct += preds.eq(labels).sum().item()
        total_samples += labels.size(0)
    
    return total_loss / total_samples, 100.0 * total_correct / total_samples

# Evaluation Function
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss, total_correct, total_samples = 0.0, 0, 0
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * images.size(0)
            _, preds = outputs.max(1)
            total_correct += preds.eq(labels).sum().item()
            total_samples += labels.size(0)
    
    return total_loss / total_samples, 100.0 * total_correct / total_samples

# Main Training Loop
num_epochs = 50
for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)

    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"  Test  Loss: {test_loss:.4f}  | Test  Acc: {test_acc:.2f}%")
    print('-' * 50)


Using device: mps
Files already downloaded and verified
Files already downloaded and verified
Epoch [1/50]
  Train Loss: 2.2097 | Train Acc: 16.48%
  Test  Loss: 2.2218  | Test  Acc: 19.76%
--------------------------------------------------
Epoch [2/50]
  Train Loss: 2.1458 | Train Acc: 21.46%
  Test  Loss: 2.0936  | Test  Acc: 24.64%
--------------------------------------------------
Epoch [3/50]
  Train Loss: 2.0462 | Train Acc: 24.54%
  Test  Loss: 1.9165  | Test  Acc: 29.94%
--------------------------------------------------
Epoch [4/50]
  Train Loss: 1.9320 | Train Acc: 28.54%
  Test  Loss: 1.7919  | Test  Acc: 34.17%
--------------------------------------------------
Epoch [5/50]
  Train Loss: 1.8205 | Train Acc: 33.15%
  Test  Loss: 1.7011  | Test  Acc: 39.12%
--------------------------------------------------
Epoch [6/50]
  Train Loss: 1.7394 | Train Acc: 37.02%
  Test  Loss: 1.6273  | Test  Acc: 41.64%
--------------------------------------------------
Epoch [7/50]
  Train Los

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

########################################
# 1. Data Loading for CIFAR-10
########################################

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2470, 0.2435, 0.2616)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2470, 0.2435, 0.2616)),
])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform_train
)
test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform_test
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

########################################
# 2. Define the Fire Module (Your Implementation)
########################################

class FireModule(nn.Module):
    def __init__(self, in_channels, sq, exp1, exp3):
        super(FireModule, self).__init__()
        if sq >= (exp1 + exp3):
            raise ValueError(f"Invalid FireModule: squeeze ({sq}) must be smaller than expand ({exp1 + exp3}).")
        
        self.squeeze = nn.Conv2d(in_channels, sq, kernel_size=1, bias=False)
        self.squeeze_bn = nn.BatchNorm2d(sq)
        self.squeeze_activation = nn.ReLU(inplace=True)

        self.expand1x1 = nn.Conv2d(sq, exp1, kernel_size=1, bias=False)
        self.expand3x3 = nn.Conv2d(sq, exp3, kernel_size=3, padding=1, bias=False)
        self.expand_bn1 = nn.BatchNorm2d(exp1)
        self.expand_bn3 = nn.BatchNorm2d(exp3)
        self.expand_activation = nn.ReLU(inplace=True)

        nn.init.kaiming_normal_(self.squeeze.weight, mode='fan_out', nonlinearity='relu')
        nn.init.kaiming_normal_(self.expand1x1.weight, mode='fan_out', nonlinearity='relu')
        nn.init.kaiming_normal_(self.expand3x3.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        x = self.squeeze_activation(self.squeeze_bn(self.squeeze(x)))
        x1 = self.expand_activation(self.expand_bn1(self.expand1x1(x)))
        x3 = self.expand_activation(self.expand_bn3(self.expand3x3(x)))
        return torch.cat([x1, x3], dim=1)

########################################
# 3. Define a SqueezeNet Variant for CIFAR-10
########################################

class SqueezeNetCIFAR(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            FireModule(64, 16, 64, 64),
            FireModule(128, 16, 64, 64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            FireModule(128, 32, 128, 128),
            FireModule(256, 32, 128, 128),
            nn.MaxPool2d(kernel_size=2, stride=2),
            FireModule(256, 48, 192, 192),
            FireModule(384, 48, 192, 192),
            nn.MaxPool2d(kernel_size=2, stride=2),
            FireModule(384, 64, 256, 256),
        )
        
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Conv2d(512, num_classes, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return torch.flatten(x, 1)

########################################
# 4. Instantiate Model, Loss, Optimizer
########################################

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = SqueezeNetCIFAR(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

########################################
# 5. Training and Evaluation Functions
########################################

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        total_correct += preds.eq(labels).sum().item()
        total_samples += labels.size(0)
    
    avg_loss = total_loss / total_samples
    accuracy = 100.0 * total_correct / total_samples
    return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            _, preds = outputs.max(1)
            total_correct += preds.eq(labels).sum().item()
            total_samples += labels.size(0)
    
    avg_loss = total_loss / total_samples
    accuracy = 100.0 * total_correct / total_samples
    return avg_loss, accuracy

########################################
# 6. Main Training Loop
########################################

num_epochs = 50
for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    
    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"  Test  Loss: {test_loss:.4f}  | Test  Acc: {test_acc:.2f}%")
    print('-'*50)


Files already downloaded and verified
Files already downloaded and verified
Using device: mps
Epoch [1/50]
  Train Loss: 1.7831 | Train Acc: 33.38%
  Test  Loss: 1.4696  | Test  Acc: 46.00%
--------------------------------------------------
Epoch [2/50]
  Train Loss: 1.3720 | Train Acc: 50.03%
  Test  Loss: 1.2785  | Test  Acc: 54.73%
--------------------------------------------------
Epoch [3/50]
  Train Loss: 1.2015 | Train Acc: 56.73%
  Test  Loss: 1.1287  | Test  Acc: 58.98%
--------------------------------------------------
Epoch [4/50]
  Train Loss: 1.0716 | Train Acc: 61.86%
  Test  Loss: 1.0290  | Test  Acc: 63.53%
--------------------------------------------------
Epoch [5/50]
  Train Loss: 0.9769 | Train Acc: 65.16%
  Test  Loss: 0.9376  | Test  Acc: 66.05%
--------------------------------------------------
Epoch [6/50]
  Train Loss: 0.9116 | Train Acc: 67.61%
  Test  Loss: 0.8910  | Test  Acc: 68.60%
--------------------------------------------------
Epoch [7/50]
  Train Los

KeyboardInterrupt: 