## Introduction

In this notebook, we define the model used for lung tumor segmentation.

The architecture remains unchanged from the atrium segmentation task. We will use the U-Net, one of the most widely adopted architectures for medical image segmentation, originally proposed in [Ronneberger et al., 2015](https://arxiv.org/abs/1505.04597).

Its encoder-decoder structure with skip connections makes it particularly effective for detecting both large and small structures, which is essential for identifying tumors in CT scans.


In [1]:
import torch

## U-Net Architecture for Medical Image Segmentation

U-Net is a convolutional neural network architecture widely used for biomedical image segmentation. It was originally proposed in 2015 for segmenting cells in microscopy images, but has since become a standard baseline for many medical imaging tasks.

Our implementation is a custom 2D U-Net designed to segment lung tumors slice by slice.

---

### Structure of the U-Net

U-Net follows an **encoder-decoder** structure with skip connections. The idea is to **capture context** using a contracting path (encoder), and then **enable precise localization** using an expanding path (decoder) with high-resolution features from earlier layers.

#### 1. **Encoder (Downsampling Path)**

The encoder progressively reduces spatial dimensions while increasing the number of feature channels. It consists of:

- **DoubleConv blocks**: two `Conv2D + BatchNorm + ReLU` layers
- **MaxPooling layers**: halve the spatial resolution after each block

This part allows the model to extract increasingly abstract features from the image.

#### 2. **Bottleneck**

After the last downsampling step, a `Dropout` is applied to prevent overfitting. This is useful because this layer has a large receptive field and contains the most compressed representation of the input.

#### 3. **Decoder (Upsampling Path)**

The decoder gradually restores the spatial resolution of the feature maps:

- **Upsample + 1x1 conv**: used to double the spatial size and reduce the number of channels
- **Concatenation with skip connections**: merges low-level features from the encoder
- **DoubleConv**: refines the merged features

This part allows the network to combine context and detailed information to produce accurate segmentation masks.

#### 4. **Final Prediction Layer**

- A final block of `Conv + ReLU + Conv` is applied.
- The output is a single-channel **logit map** (no sigmoid), suitable for use with loss functions like `BCEWithLogitsLoss`.

---

### How the Data Flows Through the Model

Let’s assume the input is a single CT slice of shape `[1, 256, 256]`.

1. **Input** → `down1` → feature map `[64, 256, 256]`
2. ↓ `MaxPool`  
3. → `down2` → `[128, 128, 128]`
4. ↓ `MaxPool`  
5. → `down3` → `[256, 64, 64]`
6. ↓ `MaxPool`  
7. → `down4` → `[512, 32, 32]`  
8. → `Dropout`  
9. ↑ `Upsample` + `1x1 conv`  
10. → concat with `down3` → `up_conv1` → `[256, 64, 64]`  
11. ↑ and repeat...  
12. Final output → `[1, 256, 256]`

---

### Key Enhancements

- **Batch Normalization**: stabilizes learning and helps with convergence.
- **Dropout**: used at the bottleneck to reduce overfitting.
- **Xavier Initialization**: sets good initial weights for all convolutions.
- **No Sigmoid**: we output raw logits, which is more numerically stable when used with appropriate loss functions.

---

### Output

The model returns a 2D tensor of shape `[1, H, W]`, where each pixel contains the predicted probability (logit) of being part of the tumor. During inference, you can apply a sigmoid followed by thresholding to convert it into a binary mask.


In [1]:
import torch
import torch.nn as nn

class DoubleConv(nn.Module):
    """
    Applies two consecutive Conv2D + BatchNorm + ReLU blocks.
    Used for both downsampling and upsampling stages.
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    """
    U-Net encoder-decoder architecture for binary segmentation tasks.
    """
    def __init__(self):
        super().__init__()

        # Encoder (downsampling path)
        self.down1 = DoubleConv(1, 64)
        self.down2 = DoubleConv(64, 128)
        self.down3 = DoubleConv(128, 256)
        self.down4 = DoubleConv(256, 512)
        self.maxpool = nn.MaxPool2d(2)

        # Decoder (upsampling path)
        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(512, 256, kernel_size=1)  # Channel reduction before concat
        )
        self.up_conv1 = DoubleConv(512, 256)  # 256 from upsample + 256 from skip connection

        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, kernel_size=1)
        )
        self.up_conv2 = DoubleConv(256, 128)

        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, kernel_size=1)
        )
        self.up_conv3 = DoubleConv(128, 64)

        # Final prediction head (no activation for logits)
        self.final = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 1, kernel_size=1)  # Output logits
        )

        # Dropout at bottleneck to prevent overfitting
        self.dropout = nn.Dropout2d(0.2)

        # Xavier initialization for better convergence
        self._initialize_weights()

    def forward(self, x):
        # Encoder path
        x1 = self.down1(x)
        x2 = self.down2(self.maxpool(x1))
        x3 = self.down3(self.maxpool(x2))
        x4 = self.down4(self.maxpool(x3))
        x4 = self.dropout(x4)

        # Decoder path with skip connections
        xu1 = self.up1(x4)
        xu1 = torch.cat([xu1, x3], dim=1)
        xu1 = self.up_conv1(xu1)

        xu2 = self.up2(xu1)
        xu2 = torch.cat([xu2, x2], dim=1)
        xu2 = self.up_conv2(xu2)

        xu3 = self.up3(xu2)
        xu3 = torch.cat([xu3, x1], dim=1)
        xu3 = self.up_conv3(xu3)

        return self.final(xu3)

    def _initialize_weights(self):
        """
        Initialize convolutional layers using Xavier normal initialization.
        """
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight, gain=1.0)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
