# Conv2d Crash Course for Graph Optimization (10 minutes)

Learn exactly what you need to understand Conv2d operations for graph optimization.

In [1]:
import torch
import torch.nn as nn

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


## 1. WHAT IS A CONVOLUTION? (2D Image Example)

A convolution slides a small filter (kernel) over an image to detect patterns.

```
Example: 3x3 kernel sliding over a 5x5 image:

   Input Image (5x5)              3x3 Kernel           Output Feature Map

   [1 2 3 4 5]                    [1 0 1]
   [6 7 8 9 0]                    [0 1 0]              Each position =
   [1 2 3 4 5]        *           [1 0 1]       →      sum of element-wise
   [6 7 8 9 0]                                         multiplication
   [1 2 3 4 5]
```

**The kernel slides across the image:**
- Position 1: Kernel over top-left 3x3 → compute one output value
- Position 2: Slide right by stride → compute next output value
- Continue until entire image is covered

**Key concepts:**
- Each output value = weighted sum of local region (kernel acts as pattern detector)
- Multiple kernels = multiple output channels (detect different patterns)

## 2. KEY PARAMETERS

```python
nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0)
```

**Parameters:**
- `in_channels`: Number of input feature maps (e.g., 3 for RGB image)
- `out_channels`: Number of output feature maps (number of different kernels/filters)
- `kernel_size`: Size of the sliding window (e.g., 3 means 3x3 kernel)
- `stride`: How many pixels to move kernel each step (default=1)
- `padding`: Pixels added around border (padding=1 adds 1 pixel border of zeros)

**Example:**
```python
Conv2d(3, 16, kernel_size=3, padding=1)
```
- Takes 3-channel input (RGB)
- Applies 16 different 3x3 kernels
- Produces 16 output feature maps
- padding=1 keeps spatial dimensions the same

## 3. SHAPE CALCULATION FORMULAS ⭐

**Input shape:** `(batch, in_channels, H_in, W_in)`  
**Output shape:** `(batch, out_channels, H_out, W_out)`

### Height/Width Calculation:
```
H_out = floor((H_in + 2*padding - kernel_size) / stride) + 1
W_out = floor((W_in + 2*padding - kernel_size) / stride) + 1
```

### Common Cases:

**1. Keep same size:** `padding = (kernel_size - 1) / 2`
   - Example: kernel=3, padding=1 → same size

**2. Halve size:** `stride=2, padding=1, kernel=3`
   - Example: 32x32 → 16x16

**3. No padding:** `padding=0`
   - Example: 32x32 with kernel=3 → 30x30

### Other Layers:

**BatchNorm2d(num_features):**
- `num_features` must equal the number of channels
- Does NOT change shape: `(B, C, H, W) → (B, C, H, W)`

**MaxPool2d(kernel_size, stride):**
- Typically kernel_size = stride (non-overlapping)
- Output shape: `H_out = H_in / stride`, `W_out = W_in / stride`
- Example: MaxPool2d(2, 2) halves both dimensions

**ReLU():**
- Does NOT change shape: `(B, C, H, W) → (B, C, H, W)`

## 4. STEP-BY-STEP SHAPE TRACKING EXAMPLE ⭐⭐

Let's track shapes through a typical CNN block:

```
Starting input: (1, 3, 32, 32)
                 ↓ batch=1, channels=3 (RGB), 32x32 pixels

Layer 1: Conv2d(3, 16, kernel_size=3, padding=1)
  H_out = (32 + 2*1 - 3)/1 + 1 = 32
  W_out = (32 + 2*1 - 3)/1 + 1 = 32
  Shape: (1, 16, 32, 32)  ← 3 channels → 16 channels, size preserved
                           ↓

Layer 2: BatchNorm2d(16)
  Shape: (1, 16, 32, 32)  ← No change (normalizes per channel)
                           ↓

Layer 3: ReLU()
  Shape: (1, 16, 32, 32)  ← No change (elementwise activation)
                           ↓

Layer 4: MaxPool2d(2, 2)
  H_out = 32/2 = 16
  W_out = 32/2 = 16
  Shape: (1, 16, 16, 16)  ← Spatial dimensions halved
                           ↓

Final output: (1, 16, 16, 16)
```

### KEY INSIGHT FOR STACKING LAYERS:

**Next layer's in_channels MUST equal previous layer's out_channels!**

```python
✓ Conv2d(3, 16, ...) → Conv2d(16, 32, ...)  # 16 matches!
✗ Conv2d(3, 16, ...) → Conv2d(8, 32, ...)   # ERROR: 16 ≠ 8
```

## 5. WORKING CODE EXAMPLE

Let's build a 2-block CNN and track shapes through each layer.

In [2]:
class DetailedConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Block 1: 3 → 16 channels
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)  # Must match conv1 out_channels
        self.pool1 = nn.MaxPool2d(2, 2)

        # Block 2: 16 → 32 channels (in_channels must match previous out_channels!)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)  # Must match conv2 out_channels
        self.pool2 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        print(f"Input:               {x.shape}")

        # Block 1
        x = self.conv1(x)
        print(f"After Conv2d(3→16):  {x.shape}")

        x = self.bn1(x)
        print(f"After BatchNorm:     {x.shape}")

        x = torch.relu(x)
        print(f"After ReLU:          {x.shape}")

        x = self.pool1(x)
        print(f"After MaxPool(2x2):  {x.shape}")

        # Block 2
        x = self.conv2(x)
        print(f"After Conv2d(16→32): {x.shape}")

        x = self.bn2(x)
        print(f"After BatchNorm:     {x.shape}")

        x = torch.relu(x)
        print(f"After ReLU:          {x.shape}")

        x = self.pool2(x)
        print(f"After MaxPool(2x2):  {x.shape}")

        return x

# Test the network
model = DetailedConvNet().to(device)
test_input = torch.randn(1, 3, 32, 32, device=device)

print("Running forward pass:\n")
output = model(test_input)

Running forward pass:

Input:               torch.Size([1, 3, 32, 32])
After Conv2d(3→16):  torch.Size([1, 16, 32, 32])
After BatchNorm:     torch.Size([1, 16, 32, 32])
After ReLU:          torch.Size([1, 16, 32, 32])
After MaxPool(2x2):  torch.Size([1, 16, 16, 16])
After Conv2d(16→32): torch.Size([1, 32, 16, 16])
After BatchNorm:     torch.Size([1, 32, 16, 16])
After ReLU:          torch.Size([1, 32, 16, 16])
After MaxPool(2x2):  torch.Size([1, 32, 8, 8])


### Verify Shape Calculations

In [3]:
print("Expected shape progression:")
print("  (1, 3, 32, 32)   Input")
print("→ (1, 16, 32, 32)  Conv2d: channels 3→16, size preserved by padding=1")
print("→ (1, 16, 32, 32)  BatchNorm: no change")
print("→ (1, 16, 32, 32)  ReLU: no change")
print("→ (1, 16, 16, 16)  MaxPool: size halved (32→16)")
print("→ (1, 32, 16, 16)  Conv2d: channels 16→32, size preserved by padding=1")
print("→ (1, 32, 16, 16)  BatchNorm: no change")
print("→ (1, 32, 16, 16)  ReLU: no change")
print("→ (1, 32, 8, 8)    MaxPool: size halved (16→8)")
print(f"\nActual final shape: {output.shape}")

assert output.shape == torch.Size([1, 32, 8, 8]), "Shape mismatch!"
print("\n✓ All shape calculations correct!")

Expected shape progression:
  (1, 3, 32, 32)   Input
→ (1, 16, 32, 32)  Conv2d: channels 3→16, size preserved by padding=1
→ (1, 16, 32, 32)  BatchNorm: no change
→ (1, 16, 32, 32)  ReLU: no change
→ (1, 16, 16, 16)  MaxPool: size halved (32→16)
→ (1, 32, 16, 16)  Conv2d: channels 16→32, size preserved by padding=1
→ (1, 32, 16, 16)  BatchNorm: no change
→ (1, 32, 16, 16)  ReLU: no change
→ (1, 32, 8, 8)    MaxPool: size halved (16→8)

Actual final shape: torch.Size([1, 32, 8, 8])

✓ All shape calculations correct!


## QUICK REFERENCE FOR GRAPH OPTIMIZATION

When analyzing torch.fx graphs, you'll see:

### 1. Module Nodes
```python
# call_module nodes with target like "conv1", "bn1", "pool1"
# Use get_submodule(target) to get the actual module
# Check isinstance(module, nn.Conv2d) to identify conv layers
```

### 2. Channel Tracking
```python
# Conv2d has .in_channels and .out_channels attributes
# BatchNorm2d has .num_features (must match conv out_channels)
# These must align when stacking layers!
```

### 3. Spatial Size Tracking
```python
# Conv2d: use formula with padding, kernel_size, stride
# MaxPool2d: typically divides by stride
# ReLU/BatchNorm: preserve spatial dimensions
```

### 4. Common Optimizations
```python
# Fuse Conv2d + BatchNorm2d into single conv (inference only)
# Fuse Conv2d + ReLU into single operation
# Replace sequence of ops with optimized kernel
# These require matching input/output shapes!
```

## YOU'RE READY!

You now understand Conv2d well enough to:
- ✓ Read torch.fx graphs with conv operations
- ✓ Track tensor shapes through transformations
- ✓ Identify which layers can be fused/optimized
- ✓ Verify shape compatibility when modifying graphs

**Go tackle Exercise 1 in `00_essentials_only.ipynb`!**

---

## Practice: Calculate Shapes Manually

Given this architecture, calculate the output shape at each step:

```python
Input: (2, 3, 64, 64)  # batch=2, RGB, 64x64 image
Conv2d(3, 32, kernel_size=5, padding=2)
ReLU()
MaxPool2d(2, 2)
Conv2d(32, 64, kernel_size=3, padding=1)
ReLU()
MaxPool2d(2, 2)
```

Try calculating before running the cell below!

In [4]:
# Solution:
class PracticeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, padding=2)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
    
    def forward(self, x):
        print(f"Input:          {x.shape}")  # (2, 3, 64, 64)
        
        x = self.conv1(x)
        # H_out = (64 + 2*2 - 5)/1 + 1 = 64
        print(f"After Conv1:    {x.shape}")  # (2, 32, 64, 64)
        
        x = torch.relu(x)
        print(f"After ReLU:     {x.shape}")  # (2, 32, 64, 64)
        
        x = self.pool1(x)
        # H_out = 64/2 = 32
        print(f"After Pool1:    {x.shape}")  # (2, 32, 32, 32)
        
        x = self.conv2(x)
        # H_out = (32 + 2*1 - 3)/1 + 1 = 32
        print(f"After Conv2:    {x.shape}")  # (2, 64, 32, 32)
        
        x = torch.relu(x)
        print(f"After ReLU:     {x.shape}")  # (2, 64, 32, 32)
        
        x = self.pool2(x)
        # H_out = 32/2 = 16
        print(f"After Pool2:    {x.shape}")  # (2, 64, 16, 16)
        
        return x

practice_model = PracticeNet().to(device)
practice_input = torch.randn(2, 3, 64, 64, device=device)
practice_output = practice_model(practice_input)

print(f"\n✓ Final shape: {practice_output.shape}")
assert practice_output.shape == torch.Size([2, 64, 16, 16])

Input:          torch.Size([2, 3, 64, 64])
After Conv1:    torch.Size([2, 32, 64, 64])
After ReLU:     torch.Size([2, 32, 64, 64])
After Pool1:    torch.Size([2, 32, 32, 32])
After Conv2:    torch.Size([2, 64, 32, 32])
After ReLU:     torch.Size([2, 64, 32, 32])
After Pool2:    torch.Size([2, 64, 16, 16])

✓ Final shape: torch.Size([2, 64, 16, 16])
