## Introduction
In this notebook we will create the model for the atrium segmentation! <br />
We will use the most famous architecture for this task, the U-NET (https://arxiv.org/abs/1505.04597).

## Imports:
1. torch for model creation

In [1]:
import torch

## What is U-Net?

U-Net is a type of deep learning model that was specifically designed for biomedical image segmentation. It allows the model to predict a detailed mask of an object (like an organ) in an image.

The structure of U-Net is based on two main ideas: **an Encoder-Decoder architecture** and **skip connections** between the encoder and decoder.

- **Encoder**:  
  The encoder is the first part of the network. It takes the input image and processes it through several layers. Each time, it uses operations called **convolutions** and **downsampling** to reduce the spatial size (width and height) of the image but to capture more abstract and important features.  
  Think of the encoder as compressing the image, keeping only the most essential information.

- **Decoder**:  
  The decoder is the second part of the network. Its job is the opposite of the encoder: it **reconstructs** the spatial structure by gradually increasing the width and height of the feature maps through **upsampling**. The final output has the same size as the input image, but instead of being an image, it is a **segmentation mask** that shows which pixels belong to the object of interest.

- **Skip Connections**:  
  A key feature of U-Net are the skip connections. At each level where the encoder compresses the image, a copy of the features is sent directly to the corresponding level of the decoder.  
  This helps the decoder recover fine details that would otherwise be lost during downsampling. It also makes the training process much easier and leads to higher-quality segmentation results.

In simple terms:
- The encoder captures **what** is in the image.
- The decoder reconstructs **where** it is in the image.
- Skip connections help keep **all the important details**.

Thanks to this design, U-Net can produce very accurate segmentation masks even when the input data is limited, which makes it particularly popular for medical imaging tasks where high precision is crucial.

![title](../images/unet.png)


## Convolutions
At first, we define a single Convolution block.
Two convolutions are used between each down- or upconvolution step

In [2]:
class DoubleConv(torch.nn.Module):
    """
    Helper Class which implements the intermediate Convolutions
    """
    def __init__(self, in_channels, out_channels):
        
        super().__init__()
        self.step = torch.nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, 3, padding=1),
                                        torch.nn.ReLU(),
                                        torch.nn.Conv2d(out_channels, out_channels, 3, padding=1),
                                        torch.nn.ReLU())
        
    def forward(self, X):
        return self.step(X)
    

## UNET
We now define the U-Net model used for left atrium segmentation.  
The architecture follows the classic encoder–decoder structure:

- The **encoder** consists of repeated `DoubleConv` blocks followed by `MaxPool2d` for downsampling. Each step reduces spatial resolution while increasing the number of feature channels.
- The **decoder** mirrors the encoder: it uses `Upsample` to increase spatial resolution and concatenates the corresponding feature maps from the encoder (skip connections) before applying another `DoubleConv`.

This structure allows the network to capture both global context and fine details, making it ideal for segmentation tasks.

> Note: `Upsample` is used for simplicity, but you may replace it with `ConvTranspose2d` for learnable upsampling.

In [6]:
class UNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        # Encoder
        self.layer1 = DoubleConv(1, 64)
        self.layer2 = DoubleConv(64, 128)
        self.layer3 = DoubleConv(128, 256)
        self.layer4 = DoubleConv(256, 512)
        
        # Decoder (with skip connections)
        self.upconv5 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.layer5 = DoubleConv(512, 256)

        self.upconv6 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.layer6 = DoubleConv(256, 128)

        self.upconv7 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.layer7 = DoubleConv(128, 64)

        # Final output layer
        self.layer8 = torch.nn.Conv2d(64, 1, kernel_size=1)
        
        self.maxpool = torch.nn.MaxPool2d(2)

    def forward(self, x):
        # Encoder path
        x1 = self.layer1(x)
        x2 = self.layer2(self.maxpool(x1))
        x3 = self.layer3(self.maxpool(x2))
        x4 = self.layer4(self.maxpool(x3))

        # Decoder path with skip connections
        x5 = self.upconv5(x4)
        x5 = torch.cat([x5, x3], dim=1)
        x5 = self.layer5(x5)

        x6 = self.upconv6(x5)
        x6 = torch.cat([x6, x2], dim=1)
        x6 = self.layer6(x6)

        x7 = self.upconv7(x6)
        x7 = torch.cat([x7, x1], dim=1)
        x7 = self.layer7(x7)

        # Final output
        return self.layer8(x7)

## Testing
Before training, we perform a simple test to verify that the U-Net architecture works as expected.  
We generate a random input tensor of shape `(1, 1, 256, 256)`, simulating a single grayscale image, and pass it through the model.

The output is expected to have the same spatial dimensions as the input, but with one output channel representing the segmentation mask.

If the output shape matches, this confirms that:
- The encoder and decoder paths are correctly balanced.
- The `ConvTranspose2d` operations restore the resolution properly.

In [7]:
model = UNet()

In [8]:
random_input = torch.randn(1, 1, 256, 256)
output = model(random_input)
assert output.shape == torch.Size([1, 1, 256, 256])