<a href="https://colab.research.google.com/github/matheus-piah/google_colabs/blob/main/UNET_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Create the UNet Class and Convolutions

In [4]:
import torch
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
class DoubleConv(torch.nn.Module):
  def __init__(self, in_channels,out_channels):
      super().__init__()
      self.step = torch.nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, 3, padding=1),
                                      torch.nn.ReLU(),
                                      torch.nn.Conv2d(out_channels, out_channels,3, padding=1),
                                      torch.nn.ReLU())
      def forward(self, X):
        return self.step(X)
        

In [14]:
class UNet(torch.nn.Module):
# This class implements a UNet for the Segmentation
# We use 3 down- and 3 UpConvolutions and 2 Convolutions in each step

  def __init__(self):
    # Sets up the U-Net Structure
    super().__init__()
    
    ######## Down
    self.layer1 = DoubleConv(1,64)
    self.layer2 = DoubleConv(64,128)
    self.layer3 = DoubleConv(128,256)
    self.layer4 = DoubleConv(256,512)

    ######## Up
    self.layer5 = DoubleConv(512+256,256)
    self.layer6 = DoubleConv(256+128,128)
    self.layer7 = DoubleConv(128+64,64)
    self.layer8 = torch.nn.Conv2d(64,1,1)

    self.maxpool = torch.nn.MaxPool2d(2)
  
  def forward(self, x):
    
    # DownConv 1
    x1 = self.layer1(x)
    x1m = self.maxpool(x1)

    # DownConv 2
    x2 = self.layer2(x1m)
    x2m = self.maxpool(x2)

    # DownConv 3
    x3 = self.layer3(x2m)
    x3m = self.maxpool(x3)

    # Intermediate Layer
    x4 = self.layer4(x3m)

    # UpConv1
    x5 = torch.nn.Upsample(scale_factor=2, mode='bilinear')(x4)
    x5 = torch.cat([x5,x3], dim=1)
    x5 = self.layer5(x5)

    # UpConv 2
    x6 = torch.nn.Upsample(scale_factor=2, mode='bilinear')(x5)
    x6 = torch.cat([x6,x2], dim=1)
    x6 = self.layer6(x6)

    # UpConv3
    x7 = torch.nn.Upsample(scale_factor=2, mode='bilinear')(x6)
    x7 = torch.cat([x7,x1], dim=1)
    x7 = self.layer7(x7)

    # Predicted Segmentation
    ret = self.layer8(x7)
    return ret