# **U-NET Implementation**

AUTHOR: Alejandro Meza Tudela

This notebook provides a complete, from-scratch implementation of U-Net, a foundational architecture for image segmentation.

In [17]:
import numpy
import torch
import torch.nn as nn
from torchsummary import summary

## Introduction

U-Net is a computer vision architecture that consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. At the time this architecture was introduced, it was praised for its huge performance, fast inference (often less than a second on a regular GPU), and its ability to perform well with a limited number of training samples.

Note: To better understand this architecture, we recommend that you review the architecture diagram while looking at this code.





## Code Implementation

The U-Net architecture will be built by creating several key building blocks, which will then be assembled to form the complete model.

### CONVOLUTION BLOCK
The convolution block in the U-Net architecture consists of the repeated application of two 3x3 convolutions, each followed by a Rectified Linear Unit (ReLU) activation function.

In [8]:
class convolution_block(nn.Module):
    """
    A basic convolutional block used in the U-Net architecture.
    Consists of two 3x3 convolutional layers, each followed by Batch Normalization.
    A ReLU activation function is applied at the end.
    """
    def __init__(self, input_channels, output_channels):
        super().__init__()
        #define the operations inside the block
        self.block = nn.Sequential(
          nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1,bias=False),
          nn.BatchNorm2d(output_channels),
          nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1,bias=False),
          nn.BatchNorm2d(output_channels),
          nn.ReLU(inplace=True)
        )

    def forward(self,x):
      x = self.block(x)
      return x

### CONTRACTING PATH DEFINITION

The contracting path consists of a repetition of the convolution blocks and a 2x2 max pooling operation with stride 2 for downsampling. At each downsampling step, the number of feature channels is doubled.

In [9]:
class downsampling_block(nn.Module):
  def __init__(self, input_channels, output_channels):
    """
        Initializes the downsampling block.

        Args:
            input_channels (int): The number of channels in the input feature map.
            output_channels (int): The number of channels after the convolution block.
    """
    super().__init__()
    self.conv_block = convolution_block(input_channels, output_channels)
    self.pool = nn.MaxPool2d((2,2))

  def forward(self,x):
    x = self.conv_block(x)
    p = self.pool(x)
    # Return both the pre-pooled feature map for the skip connection
    # and the pooled feature map for the next layer in the encoder.
    return x,p

### EXPANSIVE PATH DEFINITION

Every step in the expansive path consists of an upsampling of the feature map, followed by a 2x2 "up-convolution." This is followed by a concatenation with the correspondingly cropped feature map from the contracting path.

In [12]:
class upsampling_block(nn.Module):
  """
    An upsampling block for the U-Net expansive path.
    It combines upsampling with a skip connection from the encoder path.
  """
  def __init__(self, input_channels, output_channels):
    super().__init__()
    #define the up-convolution
    self.up_convolution = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=2, stride=2, padding=0)
    #define the convolution: current output + skip connection
    self.convolution_block = convolution_block(
                                               output_channels+output_channels, #current + skip connection
                                               output_channels
                                               )
  def forward(self,x,skip_connection):
    x = self.up_convolution(x)
    #concatenate the previous result with the skipp connection before applying the convolution block
    x = torch.cat([x,skip_connection],axis=1)
    #apply the convolution as always
    x = self.convolution_block(x)
    return x

## UNET Architecture

In summary, the U-Net architecture consists of two main components: a contracting path (encoder) and an expanding path (decoder), which together form a symmetric, U-shaped structure.

- The contracting path is composed of 4 downsampling blocks, which progressively reduce the spatial dimensions while increasing the number of feature channels.

- At the bottom of the U, there's a bottleneck layer that captures the most abstract representation of the input.

- The expanding path consists of 4 upsampling blocks, which gradually recover the spatial dimensions and refine the segmentation output.

![Alternative Text](https://idiotdeveloper.com/wp-content/uploads/2021/01/u-net-architecture.png)

In [13]:
class UNET_Architecture(nn.Module):
  def __init__(self):
    super().__init__()
    '''
      CONTRACTING PATH
    '''
    self.encoder_1 = downsampling_block(3,64)
    self.encoder_2 = downsampling_block(64,128)
    self.encoder_3 = downsampling_block(128,256)
    self.encoder_4 = downsampling_block(256,512)

    '''
      BOTTLENECK
    '''
    self.bottleneck = convolution_block(512,1024)

    '''
      EXPANSIVE PATH
    '''
    self.decoder_1 = upsampling_block(1024,512)
    self.decoder_2 = upsampling_block(512,256)
    self.decoder_3 = upsampling_block(256,128)
    self.decoder_4 = upsampling_block(128,64)

    '''
      OUTPUT LAYER
    '''
    self.output_layer = nn.Conv2d(64,1,kernel_size=1)

  def forward(self,x):
    #contracting path
    x1,p1 = self.encoder_1(x)
    x2,p2 = self.encoder_2(p1)
    x3,p3 = self.encoder_3(p2)
    x4,p4 = self.encoder_4(p3)

    #bottleneck
    b = self.bottleneck(p4)

    #expanding path
    d1 = self.decoder_1(b,x4) #use of skip connection
    d2 = self.decoder_2(d1,x3) #use of skip connection
    d3 = self.decoder_3(d2,x2) #use of skip connection
    d4 = self.decoder_4(d3,x1) #use of skip connection

    #output layer
    output = self.output_layer(d4)

    return output

## Try the implementation

In [22]:
model = UNET_Architecture() #create an instance of the model

In [19]:
#summary(model, input_size=(3, 256, 256)) #remove this comment to check some details of the model

After seeing some details of the model, let's create some random tensor to try it.

In [21]:
dummy_input = torch.randn(1, 3, 256, 256)
print(f"Input tensor shape: {dummy_input.shape}")

#Run the dummy input through the model
try:
  output = model(dummy_input)
  # Print the output shape to verify everything worked
  print(f"Output tensor shape: {output.shape}")

except RuntimeError as e:
  print(f"An error occurred during the forward pass: {e}")
  print("This usually indicates a dimension mismatch somewhere in the model.")

Input tensor shape: torch.Size([1, 3, 256, 256])
Output tensor shape: torch.Size([1, 1, 256, 256])


The output shape confirms that:

- The model is correctly assembled: The forward pass ran without any dimension mismatches or errors. This indicates that all the downsampling_blocks, bottleneck, and upsampling_blocks are correctly connected and that the skip connections are wired to the right places.

- The output layer is working as intended: The final output tensor has a channel count of 1 and the same spatial dimensions (256x256) as the input. This is exactly what is needed for a binary segmentation task, where the single channel represents the logit for each pixel.

So, all I can say, is that this U-NET model is ready to be trained! The implementation phase has been done succesfully.

References:
- The U-Net architecture, proposed by Ronneberger et al. in 2015, is widely used for image segmentation ([paper link](https://arxiv.org/abs/1505.04597)).
