# UNet Architecture Implementation in PyTorch

## Overview
UNet is a convolutional neural network architecture designed for **semantic segmentation** tasks. It was originally developed for biomedical image segmentation but is now widely used across various domains including medical imaging, satellite imagery, and object detection.

### Key Features:
- **Encoder-Decoder Architecture**: Captures context (encoder) and enables precise localization (decoder)
- **Skip Connections**: Concatenates features from encoder to decoder, preserving spatial information
- **Symmetric Design**: Creates a U-shaped architecture (hence the name "UNet")

### Architecture Flow:
1. **Contracting Path (Encoder)**: Captures context through downsampling
2. **Bottleneck**: Lowest resolution with highest number of feature channels
3. **Expanding Path (Decoder)**: Enables precise localization through upsampling
4. **Skip Connections**: Bridges encoder and decoder for better gradient flow

In [1]:
import os
import random
from pathlib import Path

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt


## Import Required Libraries

**Purpose**: Import all necessary libraries for building and testing the UNet model.

### Library Breakdown:
- **os, random, Path**: File system operations and randomization
- **numpy**: Numerical operations
- **PIL (Image)**: Image loading and processing
- **torch & torch.nn**: Core PyTorch functionality for building neural networks
- **torch.nn.functional (F)**: Functional operations like padding
- **Dataset, DataLoader**: Data loading utilities
- **matplotlib.pyplot**: Visualization of results

In [2]:
torch.__version__

'2.9.1+cpu'

## Building Block 1: DoubleConv

### Purpose:
The **DoubleConv** is the fundamental building block of UNet. It applies two consecutive convolution operations, each followed by batch normalization and ReLU activation.

### Architecture Pattern:
```
Input → Conv2d → BatchNorm → ReLU → Conv2d → BatchNorm → ReLU → Output
```

### Parameter Explanation:

#### 1. **in_channels** (int)
- **What**: Number of input feature channels/maps
- **Why**: Tells the network how many feature maps to expect as input
- **Scenarios**:
  - `in_channels=3`: RGB image input (Red, Green, Blue channels)
  - `in_channels=1`: Grayscale image or single feature map
  - `in_channels=64`: Feature maps from previous layer
  - `in_channels=512`: Deep layer with many learned features

#### 2. **out_channels** (int)
- **What**: Number of output feature channels/maps to produce
- **Why**: Determines how many different features/patterns the layer will learn
- **Scenarios**:
  - `out_channels=64`: Early layers learn basic features (edges, textures)
  - `out_channels=512`: Deep layers learn complex semantic features (shapes, objects)
  - **More channels = More capacity** but also more parameters and memory
- **Trade-off**: 
  - Higher values → Better feature representation but slower training
  - Lower values → Faster but may miss important patterns

#### 3. **mid_channels** (int, optional, default=None)
- **What**: Number of channels after the first convolution
- **Why**: Allows asymmetric channel progression for flexibility
- **Default behavior**: If `None`, uses `out_channels` (symmetric)
- **Scenarios**:
  ```python
  # Symmetric (default): 64 → 128 → 128
  DoubleConv(64, 128)  # mid_channels automatically = 128
  
  # Asymmetric: 64 → 96 → 128
  DoubleConv(64, 128, mid_channels=96)  # Gradual channel increase
  ```
- **Use case**: Memory optimization or gradual feature expansion

### Conv2d Parameters Deep Dive:

#### **kernel_size=3**
- **What**: 3×3 filter/receptive field
- **Why**: 
  - Standard in deep networks (proven effective)
  - Small enough to stack many layers
  - Large enough to capture local patterns
  - 3×3 is optimal balance (2 stacked 3×3 = 5×5 receptive field with fewer parameters)
- **Alternative scenarios**:
  - `kernel_size=1`: Point-wise convolution (like OutConv)
  - `kernel_size=5`: Larger receptive field but 2.7× more parameters
  - `kernel_size=7`: Common in first layer for capturing large patterns

#### **padding=1**
- **What**: Adds 1 pixel border around input before convolution
- **Why**: **Preserves spatial dimensions**
- **Math**: Output size = (Input size - Kernel size + 2×Padding) / Stride + 1
  ```
  With padding=1:  (H - 3 + 2×1) / 1 + 1 = H  ✓ Same size
  Without padding:  (H - 3 + 0) / 1 + 1 = H-2  ✗ Shrinks
  ```
- **Scenarios**:
  - `padding=1` with `kernel_size=3`: Maintains size (used in UNet)
  - `padding=0`: Reduces size (valid convolution)
  - `padding=2` with `kernel_size=5`: Maintains size
- **Why it matters**: UNet needs consistent sizes for skip connections

#### **bias=False**
- **What**: Removes learnable bias term from convolution
- **Why**: BatchNorm2d already includes bias (shift parameter β)
- **Math**: 
  ```
  With bias:     Conv(x) = W*x + b, then BN(Conv(x)) = γ·norm(W*x + b) + β
  Without bias:  Conv(x) = W*x,     then BN(Conv(x)) = γ·norm(W*x) + β
  ```
- **Benefit**: 
  - Saves parameters (no redundant bias)
  - BatchNorm's β serves as the bias term
  - Slightly faster computation
- **When to use bias=True**: When NOT using BatchNorm

### BatchNorm2d Parameters:

- **What**: Normalizes activations to have mean≈0 and variance≈1
- **Formula**: `output = γ × (input - μ) / √(σ² + ε) + β`
  - μ: batch mean, σ²: batch variance
  - γ: learnable scale, β: learnable shift
- **Why**:
  - **Stabilizes training**: Prevents internal covariate shift
  - **Allows higher learning rates**: More stable gradients
  - **Regularization effect**: Acts like dropout during training
- **Scenarios**:
  - Training: Uses batch statistics (mean/var of current batch)
  - Inference: Uses running statistics (accumulated during training)

### ReLU Parameters:

#### **inplace=True**
- **What**: Modifies input tensor directly instead of creating new tensor
- **Why**: **Saves memory**
- **Scenarios**:
  ```python
  # inplace=False (default): x_output = ReLU(x_input)  → 2 tensors in memory
  # inplace=True:            x = ReLU(x)               → 1 tensor in memory
  ```
- **Memory saving**: With 100 layers, saves 100× intermediate tensors
- **Trade-off**: Can't access original values (not needed in forward pass)
- **When NOT to use**: If you need gradients of input (handled automatically by PyTorch)

### Image Processing Flow:
1. **First Convolution**: 
   - Input: [B, in_channels, H, W]
   - Applies 3×3 filters with padding=1
   - Extracts initial features (edges, colors, textures)
   - Output: [B, mid_channels, H, W]
   
2. **Batch Normalization**: 
   - Normalizes across batch dimension
   - Ensures stable distribution of activations
   - Output: [B, mid_channels, H, W]
   
3. **ReLU Activation**: 
   - Applies: f(x) = max(0, x)
   - Introduces non-linearity (allows learning complex patterns)
   - Zeros out negative values
   - Output: [B, mid_channels, H, W]
   
4. **Second Convolution**: 
   - Further refines the extracted features
   - Combines patterns from first conv
   - Output: [B, out_channels, H, W]
   
5. **Final BatchNorm + ReLU**: 
   - Normalizes and activates the final output
   - Output: [B, out_channels, H, W]

### Practical Example:
```python
# Scenario: Converting RGB image to 64 feature maps
conv_block = DoubleConv(in_channels=3, out_channels=64)

# Input: RGB image [1, 3, 256, 256]
# After 1st Conv2d: [1, 64, 256, 256]  - spatial size preserved by padding
# After 1st BN+ReLU: [1, 64, 256, 256] - normalized and activated
# After 2nd Conv2d: [1, 64, 256, 256]  - refined features
# Output after 2nd BN+ReLU: [1, 64, 256, 256]

# Total learnable parameters:
# Conv1: (3×3×3×64) + BatchNorm(64×2) = 1,728 + 128 = 1,856
# Conv2: (3×3×64×64) + BatchNorm(64×2) = 36,864 + 128 = 36,992
# Total: 38,848 parameters
```

In [None]:
class DoubleConv(nn.Module):
    """DoubleConv is a block that combines Convolution + Batch Normalization + ReLU,
    and this sequence is applied twice."""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


## Building Block 2: Down (Encoder Block)

### Purpose:
The **Down** block is responsible for the **contracting path** (encoder) in UNet. It reduces spatial dimensions while increasing the number of feature channels.

### Architecture:
```
Input → MaxPool2d (2×2) → DoubleConv → Output
```

### Parameter Explanation:

#### 1. **in_channels** (int)
- **What**: Number of input feature channels
- **Why**: Must match output channels from previous layer
- **Scenarios**:
  ```python
  down1 = Down(64, 128)   # After initial 64-channel layer
  down2 = Down(128, 256)  # After down1 outputs 128 channels
  down3 = Down(256, 512)  # Progressive deepening
  ```
- **Pattern**: Each down block typically doubles channels

#### 2. **out_channels** (int)
- **What**: Number of output feature channels
- **Why**: Increases model capacity as spatial size decreases
- **Principle**: **Spatial resolution ↓ ⇒ Channel count ↑**
- **Reasoning**:
  - Less spatial information → Need more features to compensate
  - Smaller feature maps → Can afford more channels (memory-wise)
  - Higher-level semantics require richer representations

### MaxPool2d(2) Parameters Deep Dive:

#### **kernel_size=2**
- **What**: Size of pooling window (2×2 region)
- **Why**: Standard for downsampling by factor of 2
- **Operation**: 
  ```
  Input 4×4:           After MaxPool2d(2):
  [1  2  3  4]         [6  8]
  [5  6  7  8]    →    [14 16]
  [9  10 11 12]
  [13 14 15 16]
  
  Takes maximum from each 2×2 region
  ```
- **Scenarios**:
  - `MaxPool2d(2)`: Halves dimensions (256→128)
  - `MaxPool2d(3)`: Reduces by 3× (256→85) - less common
  - `MaxPool2d(4)`: Quarters dimensions (256→64) - aggressive

#### **Stride (default=kernel_size=2)**
- **What**: Step size when sliding the pooling window
- **Default**: Same as kernel_size (non-overlapping windows)
- **Why stride=2**:
  - Non-overlapping pooling (most common)
  - Clean 50% reduction in each dimension
  - No redundant computations
- **Alternative**: `stride=1` would create overlapping pools (rarely used in UNet)

### Why MaxPooling?

#### **Advantages**:
1. **Translation Invariance**: Small shifts in input don't change output much
2. **Computational Efficiency**: Reduces feature map size → faster processing
3. **Receptive Field Growth**: Each neuron "sees" larger area of original image
4. **Feature Concentration**: Keeps strongest activations

#### **Alternatives and Trade-offs**:
```python
# Option 1: MaxPool (used in UNet)
nn.MaxPool2d(2)  
# + Simple, no parameters
# + Preserves strong activations
# - Not learnable

# Option 2: Strided Convolution
nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=2, padding=1)
# + Learnable downsampling
# + Smoother reduction
# - More parameters, slower

# Option 3: Average Pooling
nn.AvgPool2d(2)
# + Smoother, less aggressive
# - May lose sharp features
# - Not common in segmentation
```

### Image Processing Flow:

#### Step-by-Step Transformation:
```python
Input:  [Batch, 64, 256, 256]   # 64 feature channels, 256×256 spatial
        ↓
MaxPool2d(2):
- Slides 2×2 window with stride=2
- Takes maximum value in each window
- Reduces spatial dimensions by half
Result: [Batch, 64, 128, 128]   # Channels unchanged, size halved
        ↓
DoubleConv(64, 128):
- Two 3×3 convolutions with padding
- Increases channels from 64 to 128
- Spatial size preserved by padding
Result: [Batch, 128, 128, 128]  # More features, same spatial size
```

### Receptive Field Explanation:

**What is Receptive Field?**
- Area of the input image that affects a single output neuron

**How Down blocks increase it**:
```
Layer 0 (input):     Receptive field = 1×1 pixel
After Conv 3×3:      Receptive field = 3×3 pixels
After MaxPool:       Receptive field = 6×6 pixels (doubled)
After another Conv:  Receptive field = 10×10 pixels
After another MaxPool: Receptive field = 20×20 pixels
```

**Why it matters**:
- Larger receptive field → See more context
- Early layers: Small receptive field (local patterns: edges, textures)
- Deep layers: Large receptive field (global patterns: objects, scenes)

### Practical Example with Numbers:

```python
# Creating a Down block
down_block = Down(in_channels=64, out_channels=128)

# Input: Feature maps from previous layer
input_tensor = torch.randn(4, 64, 256, 256)  # Batch=4, Channels=64, Size=256×256

# Process:
output = down_block(input_tensor)

# Output shape: torch.Size([4, 128, 128, 128])
# Memory: 
#   Input:  4 × 64 × 256 × 256 = 16,777,216 values (64 MB float32)
#   Output: 4 × 128 × 128 × 128 = 8,388,608 values (32 MB float32)
# Memory saved: 50% (due to spatial reduction)

# Parameters in DoubleConv(64, 128):
# Conv1: 3×3×64×128 = 73,728
# Conv2: 3×3×128×128 = 147,456
# Total: ~221K parameters + BatchNorm
```

### Real-World Scenario:

**Medical Image Segmentation (Cell Detection)**:
```python
# down1: Initial features (64→128)
# Learns: Cell boundaries, membrane edges
# Receptive field: ~10×10 pixels

# down2: Mid-level features (128→256)
# Learns: Cell shapes, nucleus patterns
# Receptive field: ~40×40 pixels

# down3: High-level features (256→512)
# Learns: Cell clusters, tissue structures
# Receptive field: ~160×160 pixels

# down4: Semantic features (512→1024)
# Learns: Organ regions, global context
# Receptive field: ~640×640 pixels (whole tissue sample)
```

### Why Downsample?

1. **Computational Efficiency**: 
   - 256×256 → 128×128 = 75% fewer computations
   - Enables deeper networks

2. **Hierarchical Features**:
   - Early: Low-level (edges, colors)
   - Middle: Mid-level (textures, patterns)
   - Deep: High-level (objects, semantics)

3. **Context Aggregation**:
   - Combines information from larger regions
   - Understands "what" is in the image

4. **Memory Management**:
   - Smaller feature maps use less GPU memory
   - Allows higher batch sizes or more channels

In [4]:
class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

## Building Block 3: Up (Decoder Block)

### Purpose:
The **Up** block is responsible for the **expanding path** (decoder) in UNet. It increases spatial dimensions while reducing feature channels, and **crucially**, it incorporates skip connections from the encoder.

### Architecture (Two Modes):

#### Mode 1: Bilinear Upsampling
```
Input (from below) → Bilinear Upsample (2×) → Concatenate with Skip Connection → DoubleConv → Output
```

#### Mode 2: Transposed Convolution
```
Input (from below) → ConvTranspose2d → Concatenate with Skip Connection → DoubleConv → Output
```

### Parameter Explanation:

#### 1. **in_channels** (int)
- **What**: Total number of channels AFTER concatenation
- **Why**: Must account for both upsampled features AND skip connection
- **Critical Understanding**:
  ```python
  # Example: up1 = Up(1024, 512, bilinear=False)
  # Input from below (x1): 1024 channels (bottleneck)
  # After upsampling: 512 channels (halved by upsampling)
  # Skip connection (x2): 512 channels (from encoder)
  # After concatenation: 512 + 512 = 1024 channels
  # Therefore, in_channels = 1024 ✓
  ```
- **Scenarios**:
  ```python
  # UNet progression:
  Up(1024, 512)  # in=1024 because 512(upsampled) + 512(skip)
  Up(512, 256)   # in=512 because 256(upsampled) + 256(skip)
  Up(256, 128)   # in=256 because 128(upsampled) + 128(skip)
  Up(128, 64)    # in=128 because 64(upsampled) + 64(skip)
  ```

#### 2. **out_channels** (int)
- **What**: Number of output feature channels
- **Why**: Progressively reduces channels as we go up (inverse of encoder)
- **Pattern**: Typically halves at each level (1024→512→256→128→64)
- **Reasoning**: 
  - Upsampling increases spatial detail
  - Less need for many channels at high resolution
  - Mirrors encoder structure symmetrically

#### 3. **bilinear** (bool, default=True)
- **What**: Chooses upsampling method
- **Why**: Trade-off between speed/memory vs quality

##### **bilinear=True (Interpolation-based)**:
- **Method**: Mathematical interpolation (no learning)
- **Process**: 
  ```
  Original pixels:    Interpolated (bilinear):
  [1  2]             [1    1.5   2  ]
  [3  4]       →     [2    2.5   3  ]
                     [3    3.5   4  ]
  
  New pixels are weighted averages of neighbors
  ```
- **Formula**: `f(x,y) = w1·p1 + w2·p2 + w3·p3 + w4·p4` (weighted average)
- **Advantages**:
  - ✓ No parameters (0 MB)
  - ✓ Faster forward pass
  - ✓ Less GPU memory
  - ✓ Good for limited hardware
  - ✓ No risk of checkerboard artifacts
- **Disadvantages**:
  - ✗ Not learnable (fixed interpolation)
  - ✗ May miss fine details
  - ✗ Slightly lower segmentation quality
- **When to use**: 
  - Limited GPU memory
  - Real-time applications
  - Initial prototyping
  - Medical imaging with clear boundaries

##### **bilinear=False (Transposed Convolution)**:
- **Method**: Learned upsampling using ConvTranspose2d
- **Process**: 
  ```
  Learnable 2×2 kernel:
  [w1  w2]
  [w3  w4]
  
  Input pixel "x" gets multiplied by kernel and spread:
  x → [x·w1  x·w2]
      [x·w3  x·w4]
  
  Overlapping regions get summed
  ```
- **Advantages**:
  - ✓ Learnable (adapts to data)
  - ✓ Better fine detail recovery
  - ✓ Higher segmentation accuracy
  - ✓ Can learn task-specific upsampling
- **Disadvantages**:
  - ✗ More parameters (~2-4x more)
  - ✗ More GPU memory needed
  - ✗ Risk of checkerboard artifacts
  - ✗ Slower training
- **When to use**:
  - Sufficient GPU memory available
  - Accuracy is priority
  - Fine-grained segmentation needed
  - Satellite/aerial imagery

### Upsampling Methods Comparison:

```python
# Example: Upsample from 64×64 to 128×128

# Method 1: Bilinear (bilinear=True)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
# Parameters: 0
# Output: Smooth interpolation
# Speed: Fast (no learning)

# Method 2: Transposed Conv (bilinear=False)
self.up = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
# Parameters: 2×2×512×256 = 524,288 (512K params!)
# Output: Learned upsampling
# Speed: Slower (backprop through convolution)
```

### Handling Dimension Mismatch:

**The Problem**:
```
Encoder (x2): [B, 512, 134, 134]  # May have odd dimensions from pooling
Upsampled (x1): [B, 512, 128, 128]  # Clean power of 2
# Cannot concatenate! Dimension mismatch!
```

**The Solution - Padding**:
```python
diffY = x2.size()[2] - x1.size()[2]  # 134 - 128 = 6
diffX = x2.size()[3] - x1.size()[3]  # 134 - 128 = 6

# Pad x1 to match x2:
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                diffY // 2, diffY - diffY // 2])
# Pads: [left=3, right=3, top=3, bottom=3]
# Result: x1 is now [B, 512, 134, 134] ✓
```

**Padding Parameters**:
- **Format**: `[left, right, top, bottom]`
- **Why split diffX//2 and diffX-diffX//2?**
  - Handles odd differences (e.g., diff=7 → left=3, right=4)
  - Centers the tensor symmetrically
- **align_corners=True in Upsample**:
  - Aligns corner pixels during interpolation
  - Reduces dimension mismatch issues
  - Makes skip connection fusion more accurate

### Skip Connection Deep Dive:

**What happens**:
```python
x = torch.cat([x2, x1], dim=1)
```

**Parameters**:
- **[x2, x1]**: Order matters! Usually encoder features first
- **dim=1**: Concatenate along channel dimension
  ```
  Before:
  x2 (skip): [B, 512, 64, 64]  # High-res features from encoder
  x1 (up):   [B, 512, 64, 64]  # Low-res features upsampled
  
  After concatenation:
  x: [B, 1024, 64, 64]  # Combined features
  ```

**Why concatenate (not add)?**:
```python
# Option 1: Concatenation (used in UNet)
x = torch.cat([x2, x1], dim=1)  # [B, 512+512, H, W]
# + Preserves both feature sets completely
# + Network learns how to combine them
# - Doubles channel count (more parameters)

# Option 2: Addition (used in ResNet)
x = x2 + x1  # [B, 512, H, W]
# + Fewer parameters
# + Forces channel counts to match
# - Information loss (forced merge)
# - Less flexible

# UNet uses concatenation for maximum information preservation!
```

### Image Processing Flow:

```python
# Example: up1 = Up(1024, 512, bilinear=False)

# Inputs:
x1 = [B, 1024, 32, 32]  # From bottleneck (deep, low-res)
x2 = [B, 512, 64, 64]   # From encoder (skip connection, high-res)

# Step 1: Upsample x1
# ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
x1 = self.up(x1)
# Result: [B, 512, 64, 64]  # Doubled spatial, halved channels

# Step 2: Check dimensions
diffY = 64 - 64 = 0  # No difference
diffX = 64 - 64 = 0  # No difference
# No padding needed ✓

# Step 3: Concatenate
x = torch.cat([x2, x1], dim=1)
# Result: [B, 1024, 64, 64]  # 512 + 512 channels

# Step 4: DoubleConv
x = self.conv(x)  # DoubleConv(1024, 512)
# Result: [B, 512, 64, 64]  # Final output
```

### Real-World Scenario Example:

**Satellite Image Segmentation (Building Detection)**:

```python
# Encoder captured:
# x1 (64×64): "There's a rectangular structure here" (semantic)
# x2 (256×256): "Exact pixel-level roof edges" (spatial detail)

# Problem: x1 knows WHAT (building) but not exactly WHERE
#          x2 knows WHERE (edges) but lost during downsampling

# Up block solution:
# 1. Upsample x1 to 128×128: Bring semantic info to higher resolution
# 2. Concatenate with x2 (128×128 skip): Add precise spatial detail
# 3. DoubleConv: Fuse "WHAT" + "WHERE" → Precise building boundaries

# Result: Accurate pixel-level building segmentation!
```

### Parameter Count Comparison:

```python
# bilinear=True
up_layer_bilinear = Up(1024, 512, bilinear=True)
# Upsample: 0 parameters
# DoubleConv(1024, 512): ~4.7M parameters
# Total: ~4.7M

# bilinear=False  
up_layer_learned = Up(1024, 512, bilinear=False)
# ConvTranspose2d: 2×2×1024×512 = 2.1M parameters
# DoubleConv(1024, 512): ~4.7M parameters
# Total: ~6.8M (45% more parameters!)

# For full UNet with 4 Up blocks:
# bilinear=True: ~31M total parameters
# bilinear=False: ~38M total parameters
# Difference: 7M parameters (22% increase)
```

### Checkerboard Artifacts (bilinear=False issue):

**What**: Checkerboard patterns in output
**Cause**: Overlapping regions in ConvTranspose2d with certain kernel/stride combinations
**Solution**: 
- Use kernel_size divisible by stride (2÷2=1 ✓)
- Or use bilinear=True
- Or use resize + convolution instead

### When to Choose Which:

| Scenario | bilinear=True | bilinear=False |
|----------|---------------|----------------|
| Limited GPU memory (< 8GB) | ✓ | ✗ |
| Real-time inference required | ✓ | ✗ |
| Maximum accuracy needed | ✗ | ✓ |
| Medical imaging (clear boundaries) | ✓ | Either |
| Natural images (complex textures) | ✗ | ✓ |
| Prototyping/testing | ✓ | ✗ |
| Production model | Either | ✓ |

In [5]:
class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

## Building Block 4: OutConv (Output Layer)

### Purpose:
The **OutConv** is the final layer that produces the segmentation output. It uses a 1×1 convolution to map feature channels to the number of classes.

### Architecture:
```
Input → Conv2d (1×1 kernel) → Output (Logits)
```

### Parameter Explanation:

#### 1. **in_channels** (int)
- **What**: Number of input feature channels from last Up block
- **Why**: Must match the output of up4 (final decoder layer)
- **Typical value**: 64 (standard UNet architecture)
- **Scenario**: 
  ```python
  up4_output = [B, 64, 572, 572]  # Last decoder output
  outc = OutConv(in_channels=64, out_channels=n_classes)
  ```

#### 2. **out_channels** (int) = n_classes
- **What**: Number of segmentation classes
- **Why**: Each channel represents probability/logit for one class
- **Determines**: Type of segmentation task

##### **Scenarios**:

**Binary Segmentation (n_classes=1)**:
```python
OutConv(64, 1)  # Single channel output
# Examples:
# - Background vs Foreground
# - Tumor vs Healthy tissue
# - Road vs Non-road
# - Cell vs Background

# Output: [B, 1, H, W]
# Each pixel has ONE value (logit)
# Apply sigmoid: probability of positive class
# Decision: if sigmoid(logit) > 0.5 → Class 1, else → Class 0
```

**Multi-class Segmentation (n_classes>1)**:
```python
OutConv(64, 3)  # Three classes
# Example: Cityscapes - Road, Vehicle, Pedestrian

# Output: [B, 3, H, W]
# Each pixel has 3 values (logits for each class)
# Apply softmax: probabilities sum to 1
# Decision: argmax(softmax(logits)) → Predicted class

# Another example:
OutConv(64, 21)  # Pascal VOC - 20 object classes + background
# Output: [B, 21, H, W]
```

**Medical Multi-organ Segmentation**:
```python
OutConv(64, 5)  # n_classes=5
# Classes: Background, Heart, Lung-Left, Lung-Right, Liver
# Output: [B, 5, H, W]

# Pixel [100, 200] logits: [-2.3, 0.5, 3.2, 1.1, -0.8]
# After softmax: [0.01, 0.15, 0.72, 0.09, 0.03]
# Prediction: Class 2 (Lung-Left) with 72% confidence
```

### kernel_size=1 Deep Dive:

**What is 1×1 Convolution?**:
```
Regular 3×3 Conv:          1×1 Conv:
Uses 3×3 neighborhood      Uses only center pixel
╔═══╗                      ╔═══╗
║ X ║                      ║   ║
╚═══╝                      ║ X ║
                           ║   ║
                           ╚═══╝
```

**Mathematical Operation**:
```python
# Input: [B, 64, H, W]
# Kernel: [1, 1, 64, n_classes]

# For each spatial position (i, j):
# For each output channel c:
output[b, c, i, j] = Σ(k=0 to 63) input[b, k, i, j] × weight[c, k]

# It's a LINEAR COMBINATION of input channels at each pixel
# No spatial mixing - purely channel transformation
```

**Why 1×1 Instead of 3×3?**:

1. **Efficiency**:
   ```
   3×3 Conv: 3×3×64×n_classes parameters
   1×1 Conv: 1×1×64×n_classes parameters
   # 9× fewer parameters!
   
   Example (n_classes=1):
   3×3: 3×3×64×1 = 576 parameters
   1×1: 1×1×64×1 = 64 parameters
   ```

2. **Spatial Independence**:
   - Each pixel classified independently
   - Spatial context already captured by encoder/decoder
   - Final layer just needs to map features → classes

3. **No Spatial Distortion**:
   - Padding not needed (1×1 doesn't change spatial size)
   - Output size exactly matches input size

4. **Cleaner Semantics**:
   - "For this pixel, given its 64 features, which class is it?"
   - Not mixing with neighbor information at final stage

**Example Visualization**:
```python
# Input from up4: [1, 64, 256, 256]
# 64 feature channels per pixel

# Pixel at (100, 150) has features:
features = [0.2, -0.5, 1.3, ..., 0.8]  # 64 values

# For binary segmentation (n_classes=1):
# Weight vector: w = [w0, w1, ..., w63]
# Output logit = Σ(features[i] × w[i])
output[0, 0, 100, 150] = features @ weights  # Dot product

# If output = 2.5 → sigmoid(2.5) = 0.92 → 92% probability of class 1
```

### No Activation Function?

**Important**: OutConv outputs **raw logits**, not probabilities

**Why no activation here?**:
```python
# Training (in loss function):
# Binary: Uses BCEWithLogitsLoss (combines sigmoid + BCE)
loss = nn.BCEWithLogitsLoss()(output, target)
# Numerically more stable than separate sigmoid + BCE

# Multi-class: Uses CrossEntropyLoss (combines softmax + NLL)
loss = nn.CrossEntropyLoss()(output, target)
# More stable than separate softmax + NLL

# Inference (after training):
# Binary:
probabilities = torch.sigmoid(output)
predictions = (probabilities > 0.5).float()

# Multi-class:
probabilities = torch.softmax(output, dim=1)
predictions = torch.argmax(probabilities, dim=1)
```

### Image Processing Flow:

```python
# Binary Segmentation Example:
input = [B, 64, 572, 572]  # 64 rich feature channels
        ↓
OutConv(64, 1)
        ↓
output = [B, 1, 572, 572]  # 1 channel with logits

# Each pixel value is a logit:
# Positive logit → More likely class 1
# Negative logit → More likely class 0

# Multi-class Example:
input = [B, 64, 256, 256]
        ↓
OutConv(64, 10)
        ↓
output = [B, 10, 256, 256]  # 10 classes

# For pixel (50, 50):
# output[0, :, 50, 50] = [1.2, -0.5, 3.1, ..., 0.8]
# After softmax: [0.12, 0.02, 0.78, ..., 0.08]
# Prediction: Class 2 (index with max value)
```

### Complete Pipeline Example:

```python
# Medical Image: Tumor Segmentation
# n_classes = 3: Background, Benign, Malignant

# 1. Network learns features through encoder/decoder
final_features = [1, 64, 512, 512]

# 2. OutConv maps to class space
outc = OutConv(64, 3)
logits = outc(final_features)  # [1, 3, 512, 512]

# 3. Convert to probabilities (inference)
probs = torch.softmax(logits, dim=1)
# probs[0, :, 100, 100] might be:
# [0.05, 0.85, 0.10]  # 85% benign, 10% malignant, 5% background

# 4. Get final prediction
pred_mask = torch.argmax(probs, dim=1)
# pred_mask[0, 100, 100] = 1  # Benign class

# 5. Visualize
# Color code: Background=Black, Benign=Green, Malignant=Red
```

### Parameter Count:

```python
# OutConv(64, 1) - Binary
# Conv2d(64, 1, kernel_size=1)
# Parameters: 1×1×64×1 + 1(bias) = 65 parameters
# Minimal! (< 0.01% of total network)

# OutConv(64, 21) - Multi-class (Pascal VOC)
# Parameters: 1×1×64×21 + 21(bias) = 1,365 parameters
# Still minimal compared to 31M total network parameters

# This is intentional design:
# - Heavy computation in encoder/decoder (feature learning)
# - Light final layer (just class mapping)
```

### Comparison with Alternatives:

```python
# UNet Standard (1×1 Conv)
OutConv(64, n_classes)
# ✓ Efficient
# ✓ Per-pixel classification
# ✓ Standard practice

# Alternative 1: 3×3 Conv
nn.Conv2d(64, n_classes, kernel_size=3, padding=1)
# ✗ 9× more parameters
# ✗ Unnecessary spatial mixing at final stage
# ~ Might capture very fine detail (rarely needed)

# Alternative 2: Fully Connected
nn.Linear(64, n_classes)  # Applied per pixel
# ✓ Same as 1×1 Conv mathematically
# ✗ Less common in segmentation literature
# ✗ Harder to visualize as spatial operation
```

### Real-World Tip:

**When n_classes is large** (e.g., 150 classes in ADE20K dataset):
```python
# Option 1: Direct
OutConv(64, 150)  # 64×150 = 9,600 parameters

# Option 2: Bottleneck (if needed)
nn.Sequential(
    nn.Conv2d(64, 256, 1),  # Expand
    nn.ReLU(),
    nn.Conv2d(256, 150, 1)  # Project to classes
)
# Sometimes better for very large n_classes
# Adds intermediate representation capacity
```

For standard UNet, simple 1×1 is sufficient and preferred!

In [6]:
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

## Complete UNet Architecture

### Purpose:
Assembles all building blocks into the complete UNet model for semantic segmentation.

### Architecture Overview:

```
                          Input Image (e.g., 572×572×3)
                                    ↓
        ┌──────────────── DoubleConv (inc) ────────────────┐
        │                 64 channels                       │ Skip
        │                      ↓                            │ Connection
        │              Down1 (MaxPool + Conv)               │ (x1)
        │                 128 channels                      │
        │                      ↓                            ↓
        │              Down2 (MaxPool + Conv)          Up4 (Upsample + Concat + Conv)
        │                 256 channels                 64 channels
        │                      ↓                            ↑
        │              Down3 (MaxPool + Conv)          Up3 (Upsample + Concat + Conv)
        │                 512 channels                 128 channels
        │                      ↓                            ↑
        └──────→          Down4 (MaxPool + Conv)      Up2 (Upsample + Concat + Conv)
                          1024/512 channels            256 channels
                     (Bottleneck - deepest point)           ↑
                                ↓                      Up1 (Upsample + Concat + Conv)
                          (features x5)                512 channels
                                                            ↑
                                                       OutConv (1×1)
                                                            ↓
                                                    Output Segmentation Map
```

### Main Parameters Explanation:

#### 1. **n_channels** (int)
- **What**: Number of channels in input image
- **Why**: Defines input format that network expects
- **Scenarios**:
  ```python
  n_channels=1:  # Grayscale
  # - Medical X-rays
  # - Satellite SAR imagery  
  # - Depth maps
  # Input shape: [B, 1, H, W]
  
  n_channels=3:  # RGB Color
  # - Natural images
  # - Medical histology (stained tissue)
  # - Aerial photography
  # Input shape: [B, 3, H, W]
  
  n_channels=4:  # RGBA or Multi-spectral
  # - Images with alpha channel
  # - Satellite (RGB + NIR)
  # Input shape: [B, 4, H, W]
  
  n_channels=7:  # Satellite Multispectral
  # - Sentinel-2 satellite bands
  # - Each channel captures different wavelength
  # Input shape: [B, 7, H, W]
  ```

- **Impact on Model**:
  ```python
  # First convolution parameters:
  n_channels=1: Conv2d(1, 64)   → 3×3×1×64 = 576 params
  n_channels=3: Conv2d(3, 64)   → 3×3×3×64 = 1,728 params
  n_channels=7: Conv2d(7, 64)   → 3×3×7×64 = 4,032 params
  # Rest of network unchanged!
  ```

#### 2. **n_classes** (int)
- **What**: Number of segmentation output classes
- **Why**: Determines what the model predicts
- **Scenarios**:

  **n_classes=1: Binary Segmentation**
  ```python
  # Single foreground class
  UNet(n_channels=3, n_classes=1)
  
  # Examples:
  # Medical: Tumor present or not
  # Satellite: Building or no building
  # Industrial: Defect or no defect
  
  # Output: [B, 1, H, W]
  # Loss: BCEWithLogitsLoss
  # Inference: sigmoid → threshold at 0.5
  ```

  **n_classes=2: Binary (Alternative Format)**
  ```python
  # Two classes: Background & Foreground
  UNet(n_channels=3, n_classes=2)
  
  # Output: [B, 2, H, W]
  # Channel 0: Background probability
  # Channel 1: Foreground probability
  # Loss: CrossEntropyLoss
  # Inference: softmax → argmax
  
  # Note: n_classes=1 vs n_classes=2 give same result
  # but n_classes=1 uses less memory
  ```

  **n_classes=3-10: Few Classes**
  ```python
  UNet(n_channels=3, n_classes=5)
  
  # Example: Autonomous driving (simplified)
  # Class 0: Background
  # Class 1: Road
  # Class 2: Vehicle
  # Class 3: Pedestrian
  # Class 4: Traffic sign
  
  # Output: [B, 5, H, W]
  # Each pixel gets 5 logits
  ```

  **n_classes=20+: Many Classes**
  ```python
  UNet(n_channels=3, n_classes=21)
  
  # Example: Pascal VOC dataset
  # 20 object classes + 1 background
  # Classes: Person, car, dog, chair, etc.
  
  # More challenging:
  UNet(n_channels=3, n_classes=150)
  # ADE20K dataset - scene parsing
  # Wall, building, sky, floor, tree, etc.
  ```

- **Memory Impact**:
  ```python
  # For 512×512 output:
  n_classes=1:   512×512×1 = 262,144 values (1 MB)
  n_classes=10:  512×512×10 = 2,621,440 values (10 MB)
  n_classes=150: 512×512×150 = 39,321,600 values (150 MB)
  # Per batch per forward pass!
  ```

#### 3. **bilinear** (bool, default=False)
- **What**: Upsampling method for ALL Up blocks
- **Why**: Global architecture decision affecting entire decoder

##### **bilinear=False (Default - Learned Upsampling)**:
```python
model = UNet(n_channels=3, n_classes=1, bilinear=False)

# Architecture changes:
# Bottleneck: Uses 1024 channels
self.down4 = Down(512, 1024)  # factor = 1

# Up blocks use ConvTranspose2d:
self.up1 = Up(1024, 512, bilinear=False)
# Up.up = ConvTranspose2d(1024, 512, kernel_size=2, stride=2)

# Total parameters: ~31M
# GPU memory: ~8-12 GB for training (batch size 4-8)
# Training time: ~100% (baseline)
# Accuracy: Best possible (learnable upsampling)
```

##### **bilinear=True (Memory-Efficient)**:
```python
model = UNet(n_channels=3, n_classes=1, bilinear=True)

# Architecture changes:
# Bottleneck: Uses 512 channels only!
self.down4 = Down(512, 512)  # factor = 2

# Up blocks use interpolation:
self.up1 = Up(1024, 256, bilinear=True)
# Up.up = Upsample(scale_factor=2, mode='bilinear')

# Total parameters: ~24M (23% fewer!)
# GPU memory: ~5-8 GB for training (same batch size)
# Training time: ~80% (20% faster)
# Accuracy: ~1-2% lower mIoU typically
```

##### **When to Choose**:

| Factor | bilinear=False | bilinear=True |
|--------|----------------|---------------|
| GPU memory available | > 8 GB | < 8 GB |
| Training time priority | Low | High |
| Inference speed priority | Low | High |
| Accuracy priority | Critical | Nice-to-have |
| Dataset size | Large | Small |
| Image complexity | High | Moderate |
| Production deployment | Server | Edge device |

**Hybrid Approach** (Advanced):
```python
# Can customize per layer:
self.up1 = Up(1024, 512, bilinear=False)  # Learn first upsampling
self.up2 = Up(512, 256, bilinear=True)    # Interpolate middle
self.up3 = Up(256, 128, bilinear=True)    # Interpolate middle  
self.up4 = Up(128, 64, bilinear=False)    # Learn final upsampling
# Balance quality and efficiency!
```

### Channel Progression Detailed:

#### **Encoder (Contracting Path)**:
```python
# Progressive downsampling with channel increase
# Spatial size ÷2 at each step, channels ×2

# Starting point:
inc: [B, n_channels, H, W] → [B, 64, H, W]
# Initial feature extraction
# Example: [2, 3, 572, 572] → [2, 64, 572, 572]

# Level 1:
down1: [B, 64, H, W] → [B, 128, H/2, W/2]
# MaxPool: 572 → 286
# Example: [2, 64, 572, 572] → [2, 128, 286, 286]

# Level 2:
down2: [B, 128, H/2, W/2] → [B, 256, H/4, W/4]
# MaxPool: 286 → 143
# Example: [2, 128, 286, 286] → [2, 256, 143, 143]

# Level 3:
down3: [B, 256, H/4, W/4] → [B, 512, H/8, W/8]
# MaxPool: 143 → 71
# Example: [2, 256, 143, 143] → [2, 512, 71, 71]

# Level 4 (Bottleneck):
# bilinear=False:
down4: [B, 512, H/8, W/8] → [B, 1024, H/16, W/16]
# Example: [2, 512, 71, 71] → [2, 1024, 35, 35]

# bilinear=True:
down4: [B, 512, H/8, W/8] → [B, 512, H/16, W/16]
# Example: [2, 512, 71, 71] → [2, 512, 35, 35]
# Half the channels! (factor=2)
```

#### **Decoder (Expanding Path)**:
```python
# Progressive upsampling with channel decrease
# Spatial size ×2 at each step, channels ÷2

# bilinear=False:
up1: [B, 1024, H/16, W/16] + skip[B, 512, H/8, W/8]
     → [B, 512, H/8, W/8]
# [2, 1024, 35, 35] + [2, 512, 71, 71] → [2, 512, 71, 71]

up2: [B, 512, H/8, W/8] + skip[B, 256, H/4, W/4]
     → [B, 256, H/4, W/4]
# [2, 512, 71, 71] + [2, 256, 143, 143] → [2, 256, 143, 143]

up3: [B, 256, H/4, W/4] + skip[B, 128, H/2, W/2]
     → [B, 128, H/2, W/2]
# [2, 256, 143, 143] + [2, 128, 286, 286] → [2, 128, 286, 286]

up4: [B, 128, H/2, W/2] + skip[B, 64, H, W]
     → [B, 64, H, W]
# [2, 128, 286, 286] + [2, 64, 572, 572] → [2, 64, 572, 572]

# Final:
outc: [B, 64, H, W] → [B, n_classes, H, W]
# [2, 64, 572, 572] → [2, 1, 572, 572]
```

### Forward Pass Explanation:

```python
def forward(self, x):
    # x: Input image [B, n_channels, H, W]
    
    # ENCODER (contracting path)
    x1 = self.inc(x)      # [B, 64, H, W] - Initial features
    x2 = self.down1(x1)   # [B, 128, H/2, W/2] - Low-level features
    x3 = self.down2(x2)   # [B, 256, H/4, W/4] - Mid-level features
    x4 = self.down3(x3)   # [B, 512, H/8, W/8] - High-level features
    x5 = self.down4(x4)   # [B, 1024, H/16, W/16] - Semantic features
    
    # x1, x2, x3, x4 saved for skip connections!
    
    # DECODER (expanding path with skip connections)
    x = self.up1(x5, x4)  # Takes x5 (from below) + x4 (skip) → [B, 512, H/8, W/8]
    x = self.up2(x, x3)   # Takes x (from below) + x3 (skip) → [B, 256, H/4, W/4]
    x = self.up3(x, x2)   # Takes x (from below) + x2 (skip) → [B, 128, H/2, W/2]
    x = self.up4(x, x1)   # Takes x (from below) + x1 (skip) → [B, 64, H, W]
    
    # OUTPUT
    logits = self.outc(x) # [B, n_classes, H, W] - Final segmentation
    return logits
```

### Skip Connection Flow:
```
Encoder saves:        Decoder uses:
x1 (64 ch, full res)  ────────────→ up4
x2 (128 ch, 1/2 res)  ──────────→ up3
x3 (256 ch, 1/4 res)  ────────→ up2
x4 (512 ch, 1/8 res)  ──────→ up1

Purpose:
- Combine WHERE (high-res encoder) with WHAT (low-res decoder)
- Preserve fine details lost during downsampling
- Enable precise localization
```

### use_checkpointing() Method:

**What it does**:
```python
model.use_checkpointing()
# Wraps each major block with gradient checkpointing
```

**How it works**:
- **Normal**: Stores all intermediate activations (memory-heavy)
  ```
  Forward: Compute + Store all activations
  Backward: Use stored activations for gradients
  Memory: High, Speed: Fast
  ```

- **With Checkpointing**: Stores only some activations
  ```
  Forward: Compute + Store only checkpoints
  Backward: Recompute activations on-the-fly
  Memory: Low (~50% reduction), Speed: Slower (~20% slower)
  ```

**When to use**:
```python
# Scenario 1: Large images
input_size = (3, 1024, 1024)  # 1K resolution
model.use_checkpointing()  # Reduces memory by ~5-8 GB

# Scenario 2: Limited GPU
# GPU: 8GB (e.g., RTX 2080)
# Without checkpointing: Batch size = 2
# With checkpointing: Batch size = 4 (2× more!)

# Scenario 3: Huge batch sizes
# For better batch norm statistics
# Trade 20% speed for 100% more batch size
```

**Implementation detail**:
```python
self.inc = torch.utils.checkpoint(self.inc)
# This is WRONG in actual PyTorch!
# Should be: torch.utils.checkpoint.checkpoint()

# Correct usage would be:
def forward(self, x):
    x1 = torch.utils.checkpoint.checkpoint(self.inc, x)
    # etc...
```

### Model Statistics:

```python
# Example: UNet(n_channels=3, n_classes=1, bilinear=False)

# Total parameters: ~31M
# Breakdown:
# - inc: 38K
# - down1-4: ~8M
# - up1-4: ~23M (most parameters!)
# - outc: ~65

# Memory usage (training, batch=4, size=512×512):
# - Model parameters: ~124 MB
# - Activations: ~6 GB
# - Gradients: ~124 MB
# - Optimizer states (Adam): ~248 MB
# - Total: ~7 GB GPU memory

# Inference (batch=1):
# - Model: ~124 MB
# - Activations: ~1 GB
# - Total: ~1.5 GB GPU memory
```

### Real-World Configuration Examples:

```python
# Medical Imaging (CT Scans)
model = UNet(n_channels=1, n_classes=3, bilinear=True)
# 1 channel (grayscale), 3 organs, memory-efficient
# Input: [B, 1, 512, 512]
# Output: [B, 3, 512, 512]

# Satellite Imagery (Building Detection)
model = UNet(n_channels=3, n_classes=1, bilinear=False)
# RGB, binary (building/not), best accuracy
# Input: [B, 3, 1024, 1024]
# Output: [B, 1, 1024, 1024]

# Autonomous Driving (Scene Parsing)
model = UNet(n_channels=3, n_classes=19, bilinear=False)
# RGB, 19 Cityscapes classes, accuracy critical
# Input: [B, 3, 1024, 2048]
# Output: [B, 19, 1024, 2048]

# Mobile Deployment (Portrait Segmentation)
model = UNet(n_channels=3, n_classes=1, bilinear=True)
# RGB, binary person mask, speed critical
# Possible: Reduce channels (32 instead of 64)
# Input: [B, 3, 256, 256]
# Output: [B, 1, 256, 256]
```

In [7]:
""" Full assembly of the parts to form the complete network """

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

## Model Summary and Visualization

### Purpose:
Display the complete architecture with layer-by-layer parameter counts and output shapes.

This will show:
- **Layer names** and types
- **Output shapes** at each layer
- **Parameter counts** (trainable weights)
- **Total parameters** in the model
- **Model size** in memory

### What to Expect:
For input size (3, 572, 572):
- Initial layers progressively reduce spatial dimensions
- Channel count increases in encoder (64→128→256→512→1024)
- Channel count decreases in decoder (1024→512→256→128→64)
- Final output: (1, 572, 572) for single-class segmentation

### Understanding the Summary:
- **Input Shape**: (batch_size, channels, height, width)
- **Output Shape**: Shape after each layer
- **Params**: Number of learnable parameters in that layer
- **Total Params**: Sum of all parameters in the network

In [8]:
from torchsummary import summary
summary(UNet(n_channels=3, n_classes=1).to('cpu'), (3, 572, 572))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 572, 572]           1,728
       BatchNorm2d-2         [-1, 64, 572, 572]             128
              ReLU-3         [-1, 64, 572, 572]               0
            Conv2d-4         [-1, 64, 572, 572]          36,864
       BatchNorm2d-5         [-1, 64, 572, 572]             128
              ReLU-6         [-1, 64, 572, 572]               0
        DoubleConv-7         [-1, 64, 572, 572]               0
         MaxPool2d-8         [-1, 64, 286, 286]               0
            Conv2d-9        [-1, 128, 286, 286]          73,728
      BatchNorm2d-10        [-1, 128, 286, 286]             256
             ReLU-11        [-1, 128, 286, 286]               0
           Conv2d-12        [-1, 128, 286, 286]         147,456
      BatchNorm2d-13        [-1, 128, 286, 286]             256
             ReLU-14        [-1, 128, 2