In [None]:
# U-Net from scratch with line-by-line comments explaining literally everything.

import torch                         # Core PyTorch tensor library
import torch.nn as nn                # Neural network layers and modules
import torch.nn.functional as F      # Functional ops (conv, relu, interpolate, etc.)

# ---------------------------
# Basic building block: DoubleConv
# ---------------------------
class DoubleConv(nn.Module):
    """
    Two consecutive 3x3 convolutions with ReLU activations (no change in H/W thanks to padding=1).
    This is the standard U-Net "conv -> ReLU -> conv -> ReLU" pattern used everywhere.
    """
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()                                  # Initialize nn.Module internals
        self.net = nn.Sequential(                           # Sequential container runs layers in order
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=True),  # 1st 3x3 conv (keeps H/W)
            nn.ReLU(inplace=True),                          # Nonlinearity (inplace saves a bit of memory)
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=True), # 2nd 3x3 conv
            nn.ReLU(inplace=True),                          # Nonlinearity
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: shape [N, C_in, H, W]
        return self.net(x)                                  # Output: [N, C_out, H, W]


# ---------------------------
# Down block: DoubleConv then downsample (MaxPool)
# ---------------------------
class Down(nn.Module):
    """
    Contracting-path block: extracts features (DoubleConv) then halves spatial size (MaxPool).
    Returns both the pooled output (for next stage) and the pre-pooled features (skip connection).
    """
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.encode = DoubleConv(in_ch, out_ch)            # Feature extractor at this scale
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  # Halve H and W

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        # x: [N, C_in, H, W]
        feat = self.encode(x)                              # feat: [N, out_ch, H, W] (skip tensor)
        x_down = self.pool(feat)                           # x_down: [N, out_ch, H/2, W/2]
        return x_down, feat                                # Return both for later use


# ---------------------------
# Up block: upsample (by transpose conv or bilinear+1x1) then DoubleConv after concatenating skip
# ---------------------------
class Up(nn.Module):
    """
    Expansive-path block: upsample decoder features to match the encoder scale,
    concatenate the corresponding skip features, then fuse with DoubleConv.
    """
    def __init__(self, in_ch: int, skip_ch: int, out_ch: int, up_mode: str = "transpose"):
        """
        in_ch:   channels of incoming decoder feature (before upsampling)
        skip_ch: channels of the encoder skip tensor to concatenate
        out_ch:  channels after fusion
        up_mode: "transpose" (learned ConvTranspose2d) or "bilinear" (Upsample + 1x1 Conv)
        """
        super().__init__()
        self.up_mode = up_mode                             # Store chosen upsampling method

        if up_mode == "transpose":
            # Learnable upsampling that doubles H/W when stride=2
            self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
            # After upsampling, decoder tensor has out_ch channels; after concat with skip (skip_ch),
            # the DoubleConv in_channels = out_ch + skip_ch (defined below).
            reduce_out = out_ch                            # No extra reduce layer needed
            self.reduce = None
        elif up_mode == "bilinear":
            # Fixed (non-learned) resizing to double H/W
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
            # 1x1 conv to adjust channel count after upsampling
            self.reduce = nn.Conv2d(in_ch, out_ch, kernel_size=1)
            reduce_out = out_ch
        else:
            raise ValueError("up_mode must be 'transpose' or 'bilinear'.")

        # After upsampling and (optional) reduce, we concatenate along channel dim with skip features.
        # Therefore, the DoubleConv takes (reduce_out + skip_ch) channels in, outputs out_ch channels.
        self.fuse = DoubleConv(reduce_out + skip_ch, out_ch)

    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
        """
        x:    decoder tensor to be upsampled, shape [N, C_in, H_dec, W_dec]
        skip: encoder skip tensor, shape [N, C_skip, H_skip, W_skip]
        returns fused feature at skip's resolution: [N, out_ch, H_skip, W_skip]
        """
        if self.up_mode == "transpose":
            x = self.up(x)                                 # Learned upsample: doubles H and W
        else:
            x = self.up(x)                                 # Bilinear resize (no channel change)
            x = self.reduce(x)                             # 1x1 conv to set channels to out_ch

        # At this point, due to odd sizes or padding choices, shapes may not match perfectly.
        # We center-pad the upsampled x to match skip's H and W so concatenation is legal.
        diff_h = skip.size(2) - x.size(2)                  # Height difference
        diff_w = skip.size(3) - x.size(3)                  # Width difference

        # F.pad pad order for 2D: (pad_left, pad_right, pad_top, pad_bottom)
        # We split the difference so x is centered relative to skip.
        x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2,
                      diff_h // 2, diff_h - diff_h // 2])

        # Concatenate along channels: decoder upsampled feature + encoder skip feature
        x = torch.cat([skip, x], dim=1)                    # Shape: [N, C_skip + C_dec_up, H_skip, W_skip]

        # Fuse the concatenated features with two 3x3 convs (DoubleConv)
        x = self.fuse(x)                                   # Shape: [N, out_ch, H_skip, W_skip]
        return x


# ---------------------------
# Full U-Net
# ---------------------------
class UNet(nn.Module):
    """
    Classic U-Net with 4 downsampling steps and 4 upsampling steps.
    - Encoder (Down blocks): extract context, shrink spatial size.
    - Bottleneck (DoubleConv): deepest features.
    - Decoder (Up blocks): restore spatial size, fuse with encoder skips for fine detail.
    - Final 1x1 conv: map to desired number of classes/channels.
    """
    def __init__(self, in_ch: int = 1, out_ch: int = 1, up_mode: str = "transpose"):
        """
        in_ch:  number of input channels (1 for grayscale, 3 for RGB)
        out_ch: number of output channels (1 for binary mask logits, K for K-class logits)
        up_mode: "transpose" or "bilinear" upsampling method
        """
        super().__init__()

        # Feature sizes in the encoder; double as we go down to increase capacity.
        f1, f2, f3, f4 = 64, 128, 256, 512

        # Initial high-resolution feature extraction (no pooling here).
        self.in_conv = DoubleConv(in_ch, f1)               # [N, f1, H, W]

        # Contracting path (each Down returns next tensor and skip features)
        self.down1 = Down(f1, f2)                          # -> [N, f2, H/2, W/2], skip f2@H,W
        self.down2 = Down(f2, f3)                          # -> [N, f3, H/4, W/4], skip f3@H/2,W/2
        self.down3 = Down(f3, f4)                          # -> [N, f4, H/8, W/8], skip f4@H/4,W/4

        # One more downsampling to the bottleneck scale (like original UNet's 4 pools total)
        self.pool = nn.MaxPool2d(2)                        # Simple pooling layer to reach bottleneck
        self.bottleneck = DoubleConv(f4, f4 * 2)           # Deepest features: channels doubled (1024)

        # Expansive path: upsample and fuse with corresponding skips
        # Up(in_ch from decoder, skip_ch from encoder, out_ch after fusion)
        self.up3 = Up(in_ch=f4 * 2, skip_ch=f4, out_ch=f4, up_mode=up_mode)  # match down3 skip
        self.up2 = Up(in_ch=f4,     skip_ch=f3, out_ch=f3, up_mode=up_mode)  # match down2 skip
        self.up1 = Up(in_ch=f3,     skip_ch=f2, out_ch=f2, up_mode=up_mode)  # match down1 skip
        self.up0 = Up(in_ch=f2,     skip_ch=f1, out_ch=f1, up_mode=up_mode)  # match in_conv output

        # Final 1x1 conv maps features to output channels (e.g., 1 for binary, K for classes)
        self.out_conv = nn.Conv2d(f1, out_ch, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: input image tensor [N, in_ch, H, W]
        returns: logits tensor [N, out_ch, H, W]
        """
        # Encoder
        x0 = self.in_conv(x)                               # Highest-res features, skip0
        x1, skip1 = self.down1(x0)                         # Down once, keep skip1 (features before pool)
        x2, skip2 = self.down2(x1)                         # Down twice, keep skip2
        x3, skip3 = self.down3(x2)                         # Down thrice, keep skip3

        # Bottleneck
        x4 = self.pool(x3)                                 # Final downsample to bottleneck scale
        x4 = self.bottleneck(x4)                           # Deep features (context)

        # Decoder (upsample + skip concat + fuse)
        x = self.up3(x4, skip3)                            # Up to match skip3 spatial size
        x = self.up2(x,  skip2)                            # Up to match skip2 spatial size
        x = self.up1(x,  skip1)                            # Up to match skip1 spatial size
        x = self.up0(x,  x0)                               # Up to match input spatial size

        # Project to desired output channels (logits, not probabilities)
        logits = self.out_conv(x)                          # [N, out_ch, H, W]
        return logits


