# ResNets and Skip Connections
## Module 3.2, Lesson 2 — Implement a ResNet Block and Train on CIFAR-10

In this notebook you will:
1. Load CIFAR-10 data
2. See a **plain deep CNN** baseline (no skip connections) — provided
3. **Implement a ResidualBlock** class (your code)
4. **Assemble a small ResNet** from your blocks (your code)
5. Train both networks and compare

The key observation: the ResNet trains successfully at depths where the plain network degrades.

---

**Prerequisites:** Architecture Evolution lesson, nn.Module lesson, MNIST CNN Project

## 0. Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Load CIFAR-10

CIFAR-10 has 60,000 32x32 RGB images in 10 classes (airplane, car, bird, cat, deer, dog, frog, horse, ship, truck).

We apply basic normalization — no data augmentation (keeping it simple for comparison).

In [None]:
# CIFAR-10 channel-wise mean and std (standard values)
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2)

print(f'Training samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')
print(f'Input shape: {train_dataset[0][0].shape}')  # [3, 32, 32]

## 2. Plain Deep CNN (Baseline — No Skip Connections)

This is a deep CNN that stacks many conv layers *without* skip connections. It follows the VGG philosophy: repeated Conv-BN-ReLU blocks. We use 20 conv layers total — deep enough to see degradation effects.

**This code is provided.** Read it to understand the structure.

In [None]:
class PlainBlock(nn.Module):
    """A plain block: two 3x3 convs with BN and ReLU. No skip connection."""
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        return out


class PlainDeepCNN(nn.Module):
    """A plain deep CNN with no skip connections.
    
    Architecture:
    - Initial conv: 3 -> 16 channels
    - Stage 1: n blocks at 16 channels (32x32 spatial)
    - Stage 2: n blocks at 32 channels (16x16 spatial) 
    - Stage 3: n blocks at 64 channels (8x8 spatial)
    - Global average pooling -> FC(10)
    
    Total conv layers = 1 + 2*n*3 = 1 + 6n
    With n=3: 19 conv layers. With n=5: 31 conv layers.
    """
    def __init__(self, n_blocks=3):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        
        # Build stages
        self.stage1 = self._make_stage(16, 16, n_blocks, stride=1)
        self.stage2 = self._make_stage(16, 32, n_blocks, stride=2)
        self.stage3 = self._make_stage(32, 64, n_blocks, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(64, 10)

    def _make_stage(self, in_channels, out_channels, n_blocks, stride):
        blocks = [PlainBlock(in_channels, out_channels, stride=stride)]
        for _ in range(1, n_blocks):
            blocks.append(PlainBlock(out_channels, out_channels, stride=1))
        return nn.Sequential(*blocks)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.stage1(out)
        out = self.stage2(out)
        out = self.stage3(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


# Test it
plain_net = PlainDeepCNN(n_blocks=3)
test_input = torch.randn(2, 3, 32, 32)
test_output = plain_net(test_input)
print(f'PlainDeepCNN output shape: {test_output.shape}')  # [2, 10]
print(f'Total parameters: {sum(p.numel() for p in plain_net.parameters()):,}')

## 3. YOUR TURN: Implement a ResidualBlock

Now implement the residual version. The structure is almost identical to `PlainBlock`, but you add a **skip connection** that adds the input to the output.

Remember the pattern from the lesson:
1. Conv(3x3) → BN → ReLU
2. Conv(3x3) → BN
3. **Add the shortcut (input x)**
4. ReLU

**Important:** When `stride > 1` or `in_channels != out_channels`, you need a **projection shortcut** (1x1 conv) to match dimensions. Otherwise, use the identity shortcut (just add x directly).

Fill in the `TODO` sections below.

In [None]:
class ResidualBlock(nn.Module):
    """A residual block: two 3x3 convs with BN, ReLU, and a skip connection.
    
    When stride > 1 or in_channels != out_channels, uses a projection shortcut
    (1x1 conv + BN) to match dimensions.
    """
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        
        # TODO 1: Define the two conv layers with batch norm
        # conv1: 3x3 conv from in_channels to out_channels with the given stride
        # bn1: BatchNorm2d for out_channels
        # conv2: 3x3 conv from out_channels to out_channels with stride=1
        # bn2: BatchNorm2d for out_channels
        # Hint: use bias=False when using batch norm (BN has its own bias)
        # Hint: use padding=1 for 3x3 convs to preserve spatial size
        
        self.conv1 = ...  # YOUR CODE HERE
        self.bn1 = ...    # YOUR CODE HERE
        self.conv2 = ...  # YOUR CODE HERE
        self.bn2 = ...    # YOUR CODE HERE
        
        # TODO 2: Define the shortcut connection
        # If in_channels == out_channels AND stride == 1: identity (no-op)
        # Otherwise: 1x1 conv + BN to match dimensions
        # Hint: nn.Sequential() creates an empty no-op module
        # Hint: nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
        
        self.shortcut = ...  # YOUR CODE HERE
    
    def forward(self, x):
        # TODO 3: Implement the forward pass
        # 1. Save the identity (input) for the shortcut
        # 2. Pass through conv1 -> bn1 -> relu
        # 3. Pass through conv2 -> bn2 (NO relu yet!)
        # 4. Add the shortcut to the output
        # 5. Apply relu AFTER the addition
        
        ...  # YOUR CODE HERE
        
        return out

### Test your ResidualBlock

Run these tests to verify your implementation. All three should print the expected output shapes.

In [None]:
# Test 1: Identity shortcut (same channels, stride=1)
block1 = ResidualBlock(16, 16, stride=1)
x1 = torch.randn(2, 16, 32, 32)
out1 = block1(x1)
print(f'Identity shortcut: {x1.shape} -> {out1.shape}')  # Expected: [2, 16, 32, 32]
assert out1.shape == (2, 16, 32, 32), f'Expected (2, 16, 32, 32), got {out1.shape}'

# Test 2: Projection shortcut (different channels)
block2 = ResidualBlock(16, 32, stride=1)
x2 = torch.randn(2, 16, 32, 32)
out2 = block2(x2)
print(f'Channel change: {x2.shape} -> {out2.shape}')  # Expected: [2, 32, 32, 32]
assert out2.shape == (2, 32, 32, 32), f'Expected (2, 32, 32, 32), got {out2.shape}'

# Test 3: Projection shortcut (different channels + stride)
block3 = ResidualBlock(16, 32, stride=2)
x3 = torch.randn(2, 16, 32, 32)
out3 = block3(x3)
print(f'Stride + channels: {x3.shape} -> {out3.shape}')  # Expected: [2, 32, 16, 16]
assert out3.shape == (2, 32, 16, 16), f'Expected (2, 32, 16, 16), got {out3.shape}'

print('\nAll tests passed!')

## 4. YOUR TURN: Assemble a Small ResNet

Now build a full ResNet using your `ResidualBlock`. The architecture mirrors `PlainDeepCNN` but uses residual blocks instead of plain blocks:

- Initial conv: 3 → 16 channels
- Stage 1: `n_blocks` residual blocks at 16 channels (32x32)
- Stage 2: `n_blocks` residual blocks at 32 channels (16x16)
- Stage 3: `n_blocks` residual blocks at 64 channels (8x8)
- Global average pooling → FC(10)

Fill in the `TODO` sections.

In [None]:
class SimpleResNet(nn.Module):
    """A small ResNet for CIFAR-10.
    
    Same overall structure as PlainDeepCNN, but uses ResidualBlock instead of PlainBlock.
    """
    def __init__(self, n_blocks=3):
        super().__init__()
        
        # TODO 4: Define the initial conv layer and batch norm
        # conv1: 3x3 conv from 3 channels (RGB) to 16 channels
        # bn1: BatchNorm2d for 16 channels
        
        self.conv1 = ...  # YOUR CODE HERE
        self.bn1 = ...    # YOUR CODE HERE
        
        # TODO 5: Build the three stages using _make_stage
        # Stage 1: 16 -> 16 channels, stride=1 (spatial stays 32x32)
        # Stage 2: 16 -> 32 channels, stride=2 (spatial becomes 16x16)
        # Stage 3: 32 -> 64 channels, stride=2 (spatial becomes 8x8)
        
        self.stage1 = ...  # YOUR CODE HERE
        self.stage2 = ...  # YOUR CODE HERE
        self.stage3 = ...  # YOUR CODE HERE
        
        # Global average pooling and classifier
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(64, 10)
    
    def _make_stage(self, in_channels, out_channels, n_blocks, stride):
        """Create a stage of residual blocks.
        
        The first block uses the given stride (may downsample).
        Remaining blocks use stride=1.
        """
        # TODO 6: Build a list of ResidualBlock instances
        # First block: ResidualBlock(in_channels, out_channels, stride=stride)
        # Remaining blocks: ResidualBlock(out_channels, out_channels, stride=1)
        # Return as nn.Sequential
        
        ...  # YOUR CODE HERE
    
    def forward(self, x):
        # TODO 7: Implement forward pass
        # 1. Initial conv -> bn -> relu
        # 2. Stage 1 -> Stage 2 -> Stage 3
        # 3. Global average pooling
        # 4. Flatten and FC
        
        ...  # YOUR CODE HERE
        
        return out

### Test your ResNet

In [None]:
resnet = SimpleResNet(n_blocks=3)
test_input = torch.randn(2, 3, 32, 32)
test_output = resnet(test_input)
print(f'SimpleResNet output shape: {test_output.shape}')  # Expected: [2, 10]
assert test_output.shape == (2, 10), f'Expected (2, 10), got {test_output.shape}'
print(f'Total parameters: {sum(p.numel() for p in resnet.parameters()):,}')
print('\nResNet test passed!')

## 5. Training Loop (Provided)

This training function works for both the plain network and the ResNet. It tracks training loss, training accuracy, and test accuracy per epoch.

**Note:** We use `nn.CrossEntropyLoss`, which takes raw logits (no softmax needed) and computes the cross-entropy loss for multi-class classification. It is the standard loss for classification tasks in PyTorch.

In [None]:
def train_model(model, train_loader, test_loader, epochs=15, lr=0.1):
    """Train a model and return training history."""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    
    # Learning rate schedule: reduce by 10x at epoch 8 and 12
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[8, 12], gamma=0.1)
    
    history = {
        'train_loss': [],
        'train_acc': [],
        'test_acc': [],
    }
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        
        scheduler.step()
        
        train_loss = running_loss / total
        train_acc = 100.0 * correct / total
        
        # Evaluation phase
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        test_acc = 100.0 * correct / total
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_acc'].append(test_acc)
        
        print(f'Epoch {epoch+1:2d}/{epochs} | '
              f'Train Loss: {train_loss:.4f} | '
              f'Train Acc: {train_acc:.1f}% | '
              f'Test Acc: {test_acc:.1f}%')
    
    return history

## 6. Train Both Models

We train the plain network and the ResNet with the same hyperparameters (learning rate, optimizer, epochs) so the comparison is fair.

**Expected runtime:** ~2–3 minutes per model on a Colab GPU.

In [None]:
# Use n_blocks=5 for a deeper comparison (31 conv layers)
N_BLOCKS = 5
EPOCHS = 15

print('=' * 60)
print(f'Training Plain Deep CNN ({1 + 6*N_BLOCKS} conv layers, no skip connections)')
print('=' * 60)
plain_model = PlainDeepCNN(n_blocks=N_BLOCKS)
plain_params = sum(p.numel() for p in plain_model.parameters())
print(f'Parameters: {plain_params:,}\n')

start = time.time()
plain_history = train_model(plain_model, train_loader, test_loader, epochs=EPOCHS)
plain_time = time.time() - start
print(f'\nTraining time: {plain_time:.1f}s')

In [None]:
print('\n' + '=' * 60)
print(f'Training ResNet ({1 + 6*N_BLOCKS} conv layers, WITH skip connections)')
print('=' * 60)
resnet_model = SimpleResNet(n_blocks=N_BLOCKS)
resnet_params = sum(p.numel() for p in resnet_model.parameters())
print(f'Parameters: {resnet_params:,}\n')

start = time.time()
resnet_history = train_model(resnet_model, train_loader, test_loader, epochs=EPOCHS)
resnet_time = time.time() - start
print(f'\nTraining time: {resnet_time:.1f}s')

## 7. Compare Results

Plot training loss and accuracy curves side by side. The key observation is in the **training accuracy** — if the plain network degrades, its training accuracy will plateau lower than the ResNet's.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
epochs_range = range(1, EPOCHS + 1)

# Training loss
axes[0].plot(epochs_range, plain_history['train_loss'], 'o-', label='Plain CNN', color='#f59e0b')
axes[0].plot(epochs_range, resnet_history['train_loss'], 's-', label='ResNet', color='#22c55e')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Training Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Training accuracy
axes[1].plot(epochs_range, plain_history['train_acc'], 'o-', label='Plain CNN', color='#f59e0b')
axes[1].plot(epochs_range, resnet_history['train_acc'], 's-', label='ResNet', color='#22c55e')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Test accuracy
axes[2].plot(epochs_range, plain_history['test_acc'], 'o-', label='Plain CNN', color='#f59e0b')
axes[2].plot(epochs_range, resnet_history['test_acc'], 's-', label='ResNet', color='#22c55e')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Accuracy (%)')
axes[2].set_title('Test Accuracy')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Summary comparison
print('\n' + '=' * 60)
print('COMPARISON SUMMARY')
print('=' * 60)
print(f'{"":20s} {"Plain CNN":>15s} {"ResNet":>15s}')
print('-' * 60)
print(f'{"Conv layers":20s} {1 + 6*N_BLOCKS:>15d} {1 + 6*N_BLOCKS:>15d}')
print(f'{"Parameters":20s} {plain_params:>15,} {resnet_params:>15,}')
print(f'{"Final train acc":20s} {plain_history["train_acc"][-1]:>14.1f}% {resnet_history["train_acc"][-1]:>14.1f}%')
print(f'{"Final test acc":20s} {plain_history["test_acc"][-1]:>14.1f}% {resnet_history["test_acc"][-1]:>14.1f}%')
print(f'{"Training time":20s} {plain_time:>14.1f}s {resnet_time:>14.1f}s')
print('-' * 60)

acc_diff = resnet_history['test_acc'][-1] - plain_history['test_acc'][-1]
print(f'\nResNet advantage: {acc_diff:+.1f}% test accuracy')
if acc_diff > 0:
    print('The skip connections helped! The ResNet trained better at this depth.')
print()
print('Look at the training accuracy curves above.')
print('If the plain network plateaus lower, that is the degradation problem in action.')
print('The ResNet avoids degradation because each block defaults to identity.')

## 8. Reflection

Before moving on, consider:

1. **Did the ResNet achieve higher training accuracy than the plain network?** If so, that is the degradation problem — the plain network could not even fit the training data as well, despite having the same number of parameters.

2. **The only difference between PlainBlock and ResidualBlock is one line:** `out = out + identity`. One line of code, and the network can train at much greater depth.

3. **Think about the mental model:** A residual block starts from identity and learns to deviate. Making "do nothing" the easiest path, not the hardest. This is why every major architecture after 2015 uses skip connections.

---

**Next lesson:** Transfer Learning — use a pretrained ResNet and adapt it to new tasks in minutes.