# Build UNET architecture for image segmentation using Pytorch


![image.png](UNET.webp "UNET Architecture")

## 1. Import libraries

In [4]:
from torch import nn
import torch 

## 2. Define architecture

In [14]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_op = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv_op(x)

class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels);
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    def forward(self, x):
        down = self.conv(x)
        p = self.pool(down)
        return down, p

class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels);
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2,x1], dim=1)
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.down_convolution_1 = DownSample(in_channels,64)
        self.down_convolution_2 = DownSample(64,128)
        self.down_convolution_3 = DownSample(128,256)
        self.down_convolution_4 = DownSample(256,512)
        
        self.bottom_neck = DoubleConv(512,1024)
        
        self.up_convolution_1 = UpSample(1024,512)
        self.up_convolution_2 = UpSample(512,256)
        self.up_convolution_3 = UpSample(256,128)
        self.up_convolution_4 = UpSample(128,64)
        
        self.out = nn.Conv2d(64, num_classes, kernel_size=1)
        
    def forward(self, x):
        down1, p1 = self.down_convolution_1(x)
        down2, p2 = self.down_convolution_2(p1)
        down3, p3 = self.down_convolution_3(p2)
        down4, p4 = self.down_convolution_4(p3)
        
        bottom = self.bottom_neck(p4)

        up1 = self.up_convolution_1(bottom, down4)
        up2 = self.up_convolution_2(up1, down3)
        up3 = self.up_convolution_3(up2, down2)
        up4 = self.up_convolution_4(up3, down1)
        
        out = self.out(up4)
        return out


torch.Size([1, 3, 512, 512]) torch.Size([1, 10, 512, 512])
