Paper link: https://arxiv.org/pdf/1409.4842

Problem: Can we make networks more sparse while still being expressive and learnable? Fully connected layers are expensive and prone to overfitting.

Sparse connections: Each neuron only connects to a small subset of neurons in the previous layer

If two neurons are frequently active together, they are likely:
- Responding to the same underlying cause
- Inputs to the same higher-level feature

1. Compute correlations between neurons in layer L
2. Cluster neurons based on correlation
3. Connect each neuron in layer L+1 to a cluster of neurons in layer L.

This creates a sparse connectivity pattern that reflects the data distribution.
Shared dependencies induce correlations between neurons.

A lot of clusters can be concentrated in a small region, can be convered by 1x1 conv in the L + 1 layer.

However, one can also expect that there will be a smaller number of more
spatially spread out clusters that can be covered by convolutions over larger patches, and there
will be a decreasing number of patches over larger and larger regions

The reason why we use larger kernels as we move down is because we want larger receptive fields 
and reduced need for precise localization. As features become more abstract, we want each neuron to 
cover a larger area. (The ratio of 3×3 and 5×5 convolutions should increase as we move to higher layers.)

Insert 1x1 conv layers to reduce the channel dimension. This preserves expressive power, dramatically reduces the number of parameters.

- All the convolutions, including those inside the Inception modules, use rectified linear activation.

“the strong performance of relatively shallower networks … suggests that the features produced by the layers in the middle of the network should be very discriminative”

- implies that mid-level features (edges → textures → parts) are already quite useful for classification.

**Outputs of all branches inside an Inception module must be concatenated.**
All must output tensors with the same spatial dimensions, so that they can be concatenated along the channel dimension. 


In [8]:
import torch
from torch import nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset

class Path1(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
            nn.ReLU()
        )

    def forward(self, x):
        return self.layers(x)

class Path2(nn.Module):
    def __init__(self, in_channels, reduce_channel, out_channels): # 128 for out_channel
        super().__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, reduce_channel, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(reduce_channel, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.layers(x)
        return x

class Path3(nn.Module):
    def __init__(self, in_channels, reduce_channel, out_channels):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, reduce_channel, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(reduce_channel, out_channels, kernel_size=5, stride=1, padding=2),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.layers(x)
        return x   

class Path4(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.layers = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1, stride=1),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.layers(x)
        return x            

class InceptionBranches(nn.Module):
    def __init__(self, in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_pool):
        super().__init__()
        self.path1 = Path1(in_channels, out_1x1)
        self.path2 = Path2(in_channels, red_3x3, out_3x3)
        self.path3 = Path3(in_channels, red_5x5, out_5x5)
        self.path4 = Path4(in_channels, out_pool)
    
    def forward(self, x):
        p1 = self.path1(x)
        p2 = self.path2(x)
        p3 = self.path3(x)
        p4 = self.path4(x)

        return torch.cat([p1, p2, p3, p4], dim=1)

In [9]:
class Inception(nn.Module):
    def __init__(self, num_classes=100):
        super().__init__()

        self.features = nn.Sequential(
            # Stem: lighter initial layers for small images
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),  # Keep 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            # Inception 3a, 3b (reduced channels)
            InceptionBranches(in_channels=64, out_1x1=32, red_3x3=48, out_3x3=64, red_5x5=8, out_5x5=16, out_pool=16),  # 128 out
            InceptionBranches(in_channels=128, out_1x1=64, red_3x3=64, out_3x3=96, red_5x5=16, out_5x5=32, out_pool=32),  # 224 out
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),  # 16x16

            # Inception 4a, 4b (fewer modules)
            InceptionBranches(in_channels=224, out_1x1=96, red_3x3=64, out_3x3=128, red_5x5=16, out_5x5=32, out_pool=32),  # 288 out
            InceptionBranches(in_channels=288, out_1x1=128, red_3x3=96, out_3x3=160, red_5x5=24, out_5x5=48, out_pool=48),  # 384 out
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),  # 8x8

            # Inception 5a (final module)
            InceptionBranches(in_channels=384, out_1x1=160, red_3x3=128, out_3x3=192, red_5x5=32, out_5x5=64, out_pool=64),  # 480 out
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(480, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

In [10]:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    torch.cuda.empty_cache()

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Mild to avoid over-distortion
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761]),
    transforms.RandomErasing(p=0.25)  # Apply after normalization for consistency
])

test_transform = transforms.Compose([
    transforms.ToTensor(), # Moved ToTensor before Normalize (good practice)
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

# Load raw datasets
cifar_train_raw = datasets.CIFAR100(root="./data", train=True, download=True, transform=None)

train_size = int(0.9 * len(cifar_train_raw))  # 48,000

train_indices = list(range(0, train_size))
val_indices = list(range(train_size, len(cifar_train_raw)))

# Create datasets with appropriate transforms
cifar_train = Subset(
    datasets.CIFAR100(root="./data", train=True, transform=train_transform),
    train_indices
)
cifar_val = Subset(
    datasets.CIFAR100(root="./data", train=True, transform=test_transform),
    val_indices
)
# Use original test set (10,000 samples) - close to 10% of 60,000
cifar_test = datasets.CIFAR100(root="./data", train=False, transform=test_transform)

train_loader = DataLoader(
    cifar_train,
    batch_size=1024,  # Changed from 1024
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6
)

val_loader = DataLoader(
    cifar_val,  # Use directly
    batch_size=1024,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6
)

test_loader = DataLoader(
    cifar_test,
    batch_size=1024,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6
)

num_classes = 100

model = Inception().to(device)

num_epochs = 40
loss_function = nn.CrossEntropyLoss(label_smoothing=0.1)
base_lr = 4e-3

batch_scale = 1024 / 256  # 4x larger batches
scaled_lr = base_lr * batch_scale**0.5  # Square root scaling

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-3,  # Changed from 1e-3
    weight_decay=1e-4
)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=1e-2,  # Changed from 3e-3
    epochs=num_epochs,
    steps_per_epoch=len(train_loader),
    pct_start=0.3,
    anneal_strategy='cos',
    div_factor=25.0,
    final_div_factor=1000.0
)

best_val_loss = float('inf')

for epoch in range(num_epochs):
    print(f'Starting Epoch {epoch+1}')
    model.train()

    current_loss = 0.0
    num_batches = 0

    for i, data in enumerate(train_loader):
        inputs, targets = data
        inputs, targets = inputs.to(device), targets.to(device)
            
        outputs = model(inputs)
        loss = loss_function(outputs, targets)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)  # Changed from 1.0
        
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()

        current_loss += loss.item()
        num_batches += 1

        if i % 50 == 0:
            print(f'Batch {i}/{len(train_loader)}, Loss: {loss.item():.4f}')


    avg_train_loss = current_loss / num_batches
    print(f'Epoch {epoch+1} finished')
    print(f'Training - Loss: {avg_train_loss:.4f}')

    if (epoch + 1) % 2 == 0:
        model.eval()
        val_loss = 0.0
        val_batches = 0

        print(f'Epoch {epoch+1} finished')
        print(f'average training loss is {avg_train_loss:.4f}')

        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_targets = val_data
                val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)  # Convert inputs to FP16

                val_outputs = model(val_inputs)
                val_batch_loss = loss_function(val_outputs, val_targets)

                val_loss += val_batch_loss.item()
                val_batches += 1


        avg_val_loss = val_loss / val_batches

        print(f'Epoch {epoch+1} finished')
        print(f'Training - Loss: {avg_train_loss:.4f}')
        print(f'Validation - Loss: {avg_val_loss:.4f}')

if torch.cuda.is_available():
    torch.cuda.empty_cache()


Using device: cuda
GPU Memory: 15.8 GB
Starting Epoch 1
Batch 0/44, Loss: 4.6045
Epoch 1 finished
Training - Loss: 4.5589
Starting Epoch 2
Batch 0/44, Loss: 4.4399
Epoch 2 finished
Training - Loss: 4.3840
Epoch 2 finished
average training loss is 4.3840
Epoch 2 finished
Training - Loss: 4.3840
Validation - Loss: 4.3235
Starting Epoch 3
Batch 0/44, Loss: 4.3180


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d976ff89f80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d976ff89f80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 3 finished
Training - Loss: 4.2739
Starting Epoch 4
Batch 0/44, Loss: 4.1906
Epoch 4 finished
Training - Loss: 4.1512
Epoch 4 finished
average training loss is 4.1512
Epoch 4 finished
Training - Loss: 4.1512
Validation - Loss: 4.0511
Starting Epoch 5
Batch 0/44, Loss: 4.0453
Epoch 5 finished
Training - Loss: 4.0740
Starting Epoch 6
Batch 0/44, Loss: 3.9457
Epoch 6 finished
Training - Loss: 3.9947
Epoch 6 finished
average training loss is 3.9947
Epoch 6 finished
Training - Loss: 3.9947
Validation - Loss: 3.8473
Starting Epoch 7
Batch 0/44, Loss: 3.9125
Epoch 7 finished
Training - Loss: 3.8956
Starting Epoch 8
Batch 0/44, Loss: 3.8057
Epoch 8 finished
Training - Loss: 3.8273
Epoch 8 finished
average training loss is 3.8273
Epoch 8 finished
Training - Loss: 3.8273
Validation - Loss: 3.9594
Starting Epoch 9
Batch 0/44, Loss: 4.0689
Epoch 9 finished
Training - Loss: 3.8178
Starting Epoch 10
Batch 0/44, Loss: 3.6722
Epoch 10 finished
Training - Loss: 3.6678
Epoch 10 finished
average tr

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d976ff89f80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d976ff89f80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 34 finished
Training - Loss: 2.3629
Validation - Loss: 2.2965
Starting Epoch 35
Batch 0/44, Loss: 2.3481
Epoch 35 finished
Training - Loss: 2.3509
Starting Epoch 36
Batch 0/44, Loss: 2.3651
Epoch 36 finished
Training - Loss: 2.3342
Epoch 36 finished
average training loss is 2.3342
Epoch 36 finished
Training - Loss: 2.3342
Validation - Loss: 2.2859
Starting Epoch 37
Batch 0/44, Loss: 2.3193
Epoch 37 finished
Training - Loss: 2.3220
Starting Epoch 38
Batch 0/44, Loss: 2.3066
Epoch 38 finished
Training - Loss: 2.3118
Epoch 38 finished
average training loss is 2.3118
Epoch 38 finished
Training - Loss: 2.3118
Validation - Loss: 2.2726
Starting Epoch 39
Batch 0/44, Loss: 2.3092
Epoch 39 finished
Training - Loss: 2.3118
Starting Epoch 40
Batch 0/44, Loss: 2.3418
Epoch 40 finished
Training - Loss: 2.3025
Epoch 40 finished
average training loss is 2.3025
Epoch 40 finished
Training - Loss: 2.3025
Validation - Loss: 2.2714


In [11]:
def evaluate_test_set(model):
    model.eval()
    correct = 0
    total = 0
    
    print("Starting evaluation...")
    
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            
            # Use autocast for consistency if you trained with it
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the network on the test images: {accuracy:.2f}%')
    return accuracy

print("\n=== Running standard evaluation ===")
standard_accuracy = evaluate_test_set(model)   
print(f'Standard Test Accuracy: {standard_accuracy:.4f} ({standard_accuracy*100:.2f}%)')


=== Running standard evaluation ===
Starting evaluation...
Accuracy of the network on the test images: 56.68%
Standard Test Accuracy: 56.6800 (5668.00%)
