<a href="https://colab.research.google.com/github/goyalpramod/paper_implementations/blob/main/UNET_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Implementing UNET from scratch. Read the original [paper here](https://arxiv.org/pdf/1505.04597)

This is mostly a scrappy quickly written code with validation from AI, I haven't run it yet.

I wanted it to be more guided approach, but I will be adding this in my diffusion model blog, so I will breakdown the ideas and components better there.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# Hints for DoubleConv class
class DoubleConv(nn.Module):
    """
    Implement:
    1. Two 3x3 convolutions each followed by
    2. Batch normalization (optional but recommended)
    3. ReLU activation
    Use nn.Sequential for clean implementation
    """

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

In [None]:
class Down:
    """
    Implement:
    1. Max pooling operation (2x2 window)
    2. Double convolution
    Track the spatial dimension changes!
    """

In [None]:
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

In [None]:
class Up:
    """
    Implement:
    1. Upsampling (either ConvTranspose2d or Upsample)
    2. Concatenation with skip connection
    3. Double convolution
    Remember to handle different input sizes!
    """

In [None]:
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # Handle size differences
        diff_y = x2.size()[2] - x1.size()[2]
        diff_x = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
                       diff_y // 2, diff_y - diff_y // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

In [None]:
class UNet:
    """
    Implement:
    1. Initial double convolution
    2. Downsampling path (typically 4 down steps)
    3. Bottleneck
    4. Upsampling path with skip connections
    5. Final 1x1 convolution to map to desired number of classes

    Key points to consider:
    - Feature map sizes at each level
    - Channel numbers doubling/halving
    - Skip connections management
    """

In [8]:
class UNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        # Current issues:
        # 1. Inconsistent channel progression
        # 2. Redundant DoubleConv layers
        # 3. Bottleneck implementation is incomplete
        # 4. Skip connection handling is missing

        # Correct channel progression:
        self.inc = DoubleConv(3, 64)  # Initial convolution

        # Encoder path (feature maps halve, channels double)
        self.down1 = Down(64, 128)    # Output: 128 channels
        self.down2 = Down(128, 256)   # Output: 256 channels
        self.down3 = Down(256, 512)   # Output: 512 channels
        self.down4 = Down(512, 1024)  # Output: 1024 channels

        # Decoder path (feature maps double, channels halve)
        self.up1 = Up(1024, 512)      # Input: 1024 + 512 = 1536 channels
        self.up2 = Up(512, 256)       # Input: 512 + 256 = 768 channels
        self.up3 = Up(256, 128)       # Input: 256 + 128 = 384 channels
        self.up4 = Up(128, 64)        # Input: 128 + 64 = 192 channels

        # Final convolution
        self.outc = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        # Store encoder outputs for skip connections
        x1 = self.inc(x)         # [B, 64, H, W]
        x2 = self.down1(x1)      # [B, 128, H/2, W/2]
        x3 = self.down2(x2)      # [B, 256, H/4, W/4]
        x4 = self.down3(x3)      # [B, 512, H/8, W/8]
        x5 = self.down4(x4)      # [B, 1024, H/16, W/16]

        # Decoder path with skip connections
        x = self.up1(x5, x4)     # Use skip connection from x4
        x = self.up2(x, x3)      # Use skip connection from x3
        x = self.up3(x, x2)      # Use skip connection from x2
        x = self.up4(x, x1)      # Use skip connection from x1

        # Final 1x1 convolution
        logits = self.outc(x)    # [B, num_classes, H, W]

        return logits



In [9]:
# Input: [B, 3, H, W]
# Encoder:
# - Level 1: [B, 64, H, W]
# - Level 2: [B, 128, H/2, W/2]
# - Level 3: [B, 256, H/4, W/4]
# - Level 4: [B, 512, H/8, W/8]
# - Bottleneck: [B, 1024, H/16, W/16]

# Decoder (with skip connections):
# - Level 4: [B, 512, H/8, W/8]
# - Level 3: [B, 256, H/4, W/4]
# - Level 2: [B, 128, H/2, W/2]
# - Level 1: [B, 64, H, W]
# Output: [B, n_classes, H, W]

In [None]:
class UNetTrainer:
    def __init__(self, model, device, n_classes):
        """
        Initialize trainer with essential components

        Args:
            model: UNet model
            device: torch.device
            n_classes: number of segmentation classes
        """
        self.model = model
        self.device = device
        self.n_classes = n_classes

        # 1. Loss Function
        # For binary segmentation:
        self.criterion = nn.BCEWithLogitsLoss() if n_classes == 2 else \
                        nn.CrossEntropyLoss()

        # 2. Optimizer (Adam with learning rate scheduling)
        self.optimizer = optim.Adam(model.parameters(), lr=1e-4)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, 'min', patience=3)

    def train_epoch(self, train_loader):
        """
        Single training epoch

        Args:
            train_loader: DataLoader for training data
        """
        self.model.train()
        epoch_loss = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            # Move data to device
            data, target = data.to(self.device), target.to(self.device)

            # Zero gradients
            self.optimizer.zero_grad()

            # Forward pass
            output = self.model(data)

            # Compute loss
            if self.n_classes == 2:
                # Binary segmentation
                loss = self.criterion(output, target.float())
            else:
                # Multi-class segmentation
                loss = self.criterion(output, target.long())

            # Backward pass
            loss.backward()
            self.optimizer.step()

            epoch_loss += loss.item()

        return epoch_loss / len(train_loader)

    def validate(self, val_loader):
        """
        Validation step
        """
        self.model.eval()
        val_loss = 0
        dice_score = 0

        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)

                # Compute validation metrics
                val_loss += self.criterion(output,
                    target.float() if self.n_classes == 2 else target.long()).item()

                # Compute Dice score
                dice_score += self.compute_dice(output, target)

        return val_loss / len(val_loader), dice_score / len(val_loader)

    @staticmethod
    def compute_dice(pred, target):
        """
        Compute Dice coefficient

        Args:
            pred: model predictions
            target: ground truth masks
        """
        smooth = 1e-5
        pred = torch.sigmoid(pred) if pred.shape[1] == 1 else \
               F.softmax(pred, dim=1)

        intersection = (pred * target).sum()
        return (2. * intersection + smooth) / \
               (pred.sum() + target.sum() + smooth)

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=100):
    trainer = UNetTrainer(model, device='cuda', n_classes=YOUR_N_CLASSES)

    best_val_score = float('inf')
    for epoch in range(num_epochs):
        # Training
        train_loss = trainer.train_epoch(train_loader)

        # Validation
        val_loss, dice_score = trainer.validate(val_loader)

        # Learning rate scheduling
        trainer.scheduler.step(val_loss)

        # Save best model
        if val_loss < best_val_score:
            best_val_score = val_loss
            torch.save(model.state_dict(), 'best_model.pth')

        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Training Loss: {train_loss:.4f}')
        print(f'Validation Loss: {val_loss:.4f}')
        print(f'Dice Score: {dice_score:.4f}')

In [None]:
# transforms = A.Compose([
#     A.RandomRotate90(),
#     A.Flip(),
#     A.ElasticTransform(),
#     A.GridDistortion(),
#     A.Normalize()
# ])

In [None]:
# def compute_metrics(pred, target):
#     """
#     Compute multiple segmentation metrics:
#     - IoU (Intersection over Union)
#     - Dice Coefficient
#     - Pixel Accuracy
#     """
#     pass