# Notebook 4: Going Deeper - Architecting ResNet-18 from Scratch

In the previous notebooks, we've built increasingly sophisticated models—from simple MLPs to CNNs. But as we try to make networks deeper (with more layers), we run into a fundamental problem: **the vanishing gradient problem**.

As networks get deeper, gradients become exponentially smaller as they propagate backward through the layers. This makes it extremely difficult for early layers to learn anything useful—they receive gradients that are essentially zero. The result? Deeper networks often perform worse than shallower ones, which defies intuition.

## The Solution: Residual Connections

ResNet (Residual Network) solves this problem with a brilliantly simple idea: **residual connections** (also called "shortcut" or "skip" connections). 

Think of it like an express lane on a highway. Instead of forcing all traffic (data) to go through every single layer, ResNet gives the data a direct path to skip ahead. This means:

1. **If a layer is useful**: The network learns to use it and modify the data flowing through it.
2. **If a layer is harmful**: The network can simply learn to make that layer output zero, effectively "turning it off" and letting the shortcut connection handle everything.

The key insight: By adding the original input `x` directly to the output of a convolutional stack (a "residual block"), the network can easily learn an identity mapping if that's optimal. This makes training very deep networks not just possible, but actually beneficial.

Mathematically, instead of learning `H(x)`, the network learns `F(x) = H(x) - x`, which is often easier. Then the output becomes `H(x) = F(x) + x`, where `x` is the shortcut connection.


## A More Challenging Dataset: CIFAR-10

So far, we've been working with MNIST—grayscale images of handwritten digits. Now we're stepping up to **CIFAR-10**, a significantly more challenging dataset.

**CIFAR-10 Characteristics:**
- **10 classes**: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
- **32×32 color images**: Much smaller than MNIST's 28×28, but now with **3 color channels** (RGB)
- **50,000 training images** and **10,000 test images**
- **More complex**: Unlike MNIST's simple black-and-white digits, CIFAR-10 contains natural images with varied backgrounds, lighting, and perspectives

**Input Shape: (Batch, 3, 32, 32)**
- Batch: The number of images in a batch
- 3: The three color channels (Red, Green, Blue)
- 32×32: The spatial dimensions of each image

This is a step up in difficulty because:
1. **Color channels**: We now have 3 input channels instead of 1
2. **Natural images**: More complex patterns, textures, and variations
3. **Smaller resolution**: Less pixel information per object


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define transforms for CIFAR-10
# Normalize to match ImageNet statistics (common practice)
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Download and create datasets
train_dataset = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform_train
)

test_dataset = datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform_test
)

# Create DataLoaders
batch_size = 128
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Batch size: {batch_size}")
print(f"Number of classes: {len(train_dataset.classes)}")
print(f"Classes: {train_dataset.classes}")

# Pull one batch and inspect its shape
X, y = next(iter(train_dataloader))
print(f"\nImage tensor shape: {X.shape}")
print(f"Label tensor shape: {y.shape}")
print(f"Expected image shape: (batch_size, 3, 32, 32)")
print(f"Image dtype: {X.dtype}, Label dtype: {y.dtype}")


## The BasicBlock: The Building Unit of ResNet

ResNets are built from a repeating unit called the **BasicBlock**. This is the fundamental building block that makes ResNet so powerful.

### Internal Structure

Each BasicBlock contains:
1. **Two convolutional layers** stacked together
2. **Batch normalization** after each convolution (for stable training)
3. **ReLU activations** between layers
4. **A shortcut connection** that adds the input directly to the output

### The Residual Connection Logic

The magic happens in the forward pass. The input `x` flows through the convolutional stack (conv → bn → relu → conv → bn), producing some output `out`. Then, instead of just returning `out`, we do:

```python
out = out + x  # The residual connection!
```

This simple addition is what makes ResNet work. If the layers learn that doing nothing is optimal, they can output zero, and `out + x = x` (identity mapping). If they learn something useful, they can modify `x` appropriately.

### Handling Dimension Mismatches

When dimensions or channels change (like when we downsample), the shortcut connection needs to match. In these cases, we use a 1×1 convolution with the appropriate stride to resize the shortcut path to match the output dimensions.

### The Stride Parameter

The first convolution in a BasicBlock can use `stride=2` for downsampling. When `stride=2`, the spatial dimensions are halved (e.g., 32×32 → 16×16), which is how ResNet progressively reduces image size while increasing channel depth.


In [None]:
class BasicBlock(nn.Module):
    """
    BasicBlock for ResNet-18/34.
    Consists of two 3x3 convolutions with batch normalization and ReLU.
    """
    expansion = 1  # For ResNet-18/34, expansion is 1 (no channel expansion)
    
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        
        # First convolutional layer
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        # Second convolutional layer
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, 
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Shortcut connection
        # If stride != 1 or channels change, we need to project the shortcut
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, 
                          stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        # Save the input for the shortcut connection
        identity = x
        
        # Forward through the convolutional layers
        out = self.conv1(x)
        out = self.bn1(out)
        out = nn.functional.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        # Add the shortcut connection
        out += self.shortcut(identity)
        
        # Apply ReLU after the addition
        out = nn.functional.relu(out)
        
        return out

# Test the BasicBlock
print("Testing BasicBlock with matching dimensions:")
block1 = BasicBlock(in_channels=64, out_channels=64, stride=1)
test_input = torch.randn(2, 64, 32, 32)
test_output = block1(test_input)
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")

print("\nTesting BasicBlock with downsampling:")
block2 = BasicBlock(in_channels=64, out_channels=128, stride=2)
test_input2 = torch.randn(2, 64, 32, 32)
test_output2 = block2(test_input2)
print(f"Input shape: {test_input2.shape}")
print(f"Output shape: {test_output2.shape}")


## Assembling ResNet-18

Now that we understand the BasicBlock, let's see how ResNet-18 is assembled. The architecture follows a clear pattern:

### The Architecture Structure

**1. Stem (Initial Convolution)**
- `conv1`: A single 7×7 convolution (or 3×3 for CIFAR-10) that converts the 3 input channels to 64 feature maps

**2. Four Main Layers**
Each layer is a stack of BasicBlocks:
- **layer1**: 2 blocks, 64 channels (no downsampling)
- **layer2**: 2 blocks, 128 channels (downsamples here: 32×32 → 16×16)
- **layer3**: 2 blocks, 256 channels (downsamples here: 16×16 → 8×8)
- **layer4**: 2 blocks, 512 channels (downsamples here: 8×8 → 4×4)

**3. Classification Head**
- `AdaptiveAvgPool2d(1)`: Global average pooling that reduces spatial dimensions to 1×1
- `Linear(512, num_classes)`: Final fully-connected layer for classification

### Why "18"?

ResNet-18 gets its name from having **18 learnable layers**: 
- 1 initial conv layer
- 4 layers × 2 blocks × 2 conv layers per block = 16 layers
- 1 final linear layer
- Total: 1 + 16 + 1 = 18 layers

### The _make_layer Helper Method

Instead of manually creating each layer, we use a helper method `_make_layer` that:
- Creates the first block with the specified stride (for downsampling)
- Creates the remaining blocks with stride=1
- Returns all blocks as a `nn.Sequential`


In [None]:
class ResNet18(nn.Module):
    """
    ResNet-18 architecture for CIFAR-10.
    Adapted for 32x32 input images (smaller initial kernel size).
    """
    def __init__(self, num_classes=10):
        super(ResNet18, self).__init__()
        
        # Initial convolution (stem)
        # For CIFAR-10 (32x32), we use 3x3 conv instead of 7x7
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        # Four main layers
        self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1)
        self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2)
        self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2)
        self.layer4 = self._make_layer(256, 512, num_blocks=2, stride=2)
        
        # Classification head
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
    
    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        """
        Helper method to create a layer of BasicBlocks.
        
        Args:
            in_channels: Number of input channels
            out_channels: Number of output channels
            num_blocks: Number of BasicBlocks in this layer
            stride: Stride for the first block (used for downsampling)
        """
        layers = []
        
        # First block: may have stride > 1 for downsampling
        layers.append(BasicBlock(in_channels, out_channels, stride=stride))
        
        # Remaining blocks: stride is always 1
        for _ in range(1, num_blocks):
            layers.append(BasicBlock(out_channels, out_channels, stride=1))
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        # Stem
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.functional.relu(x)
        
        # Main layers
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # Classification head
        x = self.avgpool(x)
        x = torch.flatten(x, 1)  # Flatten: (batch, 512, 1, 1) -> (batch, 512)
        x = self.fc(x)
        
        return x

# Instantiate the model
model = ResNet18(num_classes=10)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("ResNet-18 Architecture:")
print(model)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Test forward pass
test_input = torch.randn(2, 3, 32, 32)
test_output = model(test_input)
print(f"\nTest forward pass:")
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")


## Tracing Shapes Through ResNet-18

Let's trace how the tensor shapes transform as data flows through ResNet-18. This is crucial for understanding the architecture and debugging shape mismatches.

**Input**: `(Batch, 3, 32, 32)`
- 3 RGB channels, 32×32 spatial dimensions

**After `conv1` (3→64, kernel=3, stride=1, padding=1)**: 
- Channels: 3 → 64
- Spatial dimensions: unchanged (32×32) due to padding and stride=1
- Shape: `(Batch, 64, 32, 32)`

**After `bn1` + ReLU**: 
- No shape change (element-wise operations)
- Shape: `(Batch, 64, 32, 32)`

**After `layer1` (2 blocks, 64 channels, stride=1)**: 
- No downsampling (stride=1 in first block)
- Channels remain 64
- Shape: `(Batch, 64, 32, 32)`

**After `layer2` (2 blocks, 64→128 channels, stride=2)**: 
- First block downsamples: stride=2 halves spatial dimensions
- Channels: 64 → 128
- Shape: `(Batch, 128, 16, 16)`

**After `layer3` (2 blocks, 128→256 channels, stride=2)**: 
- First block downsamples again
- Channels: 128 → 256
- Shape: `(Batch, 256, 8, 8)`

**After `layer4` (2 blocks, 256→512 channels, stride=2)**: 
- First block downsamples again
- Channels: 256 → 512
- Shape: `(Batch, 512, 4, 4)`

**After `AdaptiveAvgPool2d(1)`**: 
- Reduces spatial dimensions to 1×1
- Channels remain 512
- Shape: `(Batch, 512, 1, 1)`

**After `flatten`**: 
- Flattens all dimensions except batch
- Shape: `(Batch, 512)`

**After `Linear(512, 10)`**: 
- Final classification layer
- Shape: `(Batch, 10)` - one score per class

### Key Observations

1. **Progressive downsampling**: Each layer reduces spatial size by half (32→16→8→4)
2. **Progressive channel expansion**: Channels double at each layer (64→128→256→512)
3. **Residual connections**: The BasicBlocks maintain spatial dimensions within each layer (except for the first block when stride=2)
4. **Global pooling**: AdaptiveAvgPool2d eliminates the need to calculate exact spatial dimensions before the final linear layer


## Training the ResNet

Now we'll train ResNet-18 on CIFAR-10. This is significantly more computationally intensive than our previous models—ResNet-18 has ~11 million parameters compared to the simple CNN's few thousand.

**Note**: Training a ResNet from scratch on CIFAR-10 will take much longer than training simpler models. You may want to reduce the number of epochs for initial testing, or use GPU acceleration if available.

The training process follows the same pattern as before:
1. Forward pass through the model
2. Calculate loss
3. Backpropagate gradients
4. Update weights
5. Repeat

However, with ResNet's depth and complexity, you'll notice:
- **Slower training**: More parameters means more computation per batch
- **Better convergence**: Residual connections help gradients flow better, enabling effective training of deeper networks
- **Better performance**: ResNet should achieve significantly higher accuracy than simpler CNNs on CIFAR-10


In [None]:
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Instantiate model, loss, and optimizer
model = ResNet18(num_classes=10).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training function (same structure as before)
def train(dataloader, model, loss_fn, optimizer, epochs=10):
    """
    Train the ResNet model for a specified number of epochs.
    """
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0.0
        num_batches = 0
        
        for batch_idx, (X, y) in enumerate(dataloader):
            # Move data to device
            X, y = X.to(device), y.to(device)
            
            # Forward pass
            pred = model(X)
            
            # Calculate loss
            loss = loss_fn(pred, y)
            
            # Backpropagation
            loss.backward()
            
            # Update weights
            optimizer.step()
            
            # Zero gradients
            optimizer.zero_grad()
            
            total_loss += loss.item()
            num_batches += 1
            
            # Print progress every 100 batches
            if (batch_idx + 1) % 100 == 0:
                avg_loss = total_loss / num_batches
                print(f'Epoch {epoch + 1}/{epochs}, Batch {batch_idx + 1}/{len(dataloader)}, Loss: {avg_loss:.4f}')
        
        # Print average loss for the epoch
        avg_loss = total_loss / num_batches
        print(f'Epoch {epoch + 1}/{epochs} completed. Average Loss: {avg_loss:.4f}\n')

# Start training
print("Starting ResNet-18 training on CIFAR-10...")
print("Note: This will take significantly longer than previous models.\n")
train(train_dataloader, model, loss_fn, optimizer, epochs=10)
print("Training completed!")
