In [None]:
import torch
import torch.nn as nn

In [None]:
class DoubleConv(nn.Module):
  def __init__(self, in_channel, out_channel):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(inplace=True)
    )
  def forward(self,x):
    x = self.conv(x)
    return x

In [None]:
class Down(nn.Module):
  def __init__(self, in_channel, out_channel):
    super().__init__()
    self.down_conv = nn.Sequential(
        nn.MaxPool2d(kernel_size=2, stride=2),
        DoubleConv(in_channel, out_channel)
    )
  def forward(self,x):
    x = self.down_conv(x)
    return x
def skip_connection(x,y):
  if x.shape[2:] != y.shape[2:]:
    x = torch.nn.functional.interpolate(x, size=y.shape[2:], mode='bilinear', align_corners=True)
  out  = torch.cat([x,y], dim=1)
  return out
class Up(nn.Module):
  def __init__(self, in_channel, out_channel, x_skip_channel):
    super().__init__()
    self.up_conv = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    self.dconv = DoubleConv(in_channel + x_skip_channel, out_channel)
  def forward(self,x,y):
    out = self.up_conv(x)
    out = skip_connection(out,y)
    out = self.dconv(out)
    return out

In [None]:
class Unet(nn.Module):
  def __init__(self, in_channel, out_channel):
    super().__init__()
    self.dconv1 = DoubleConv(in_channel, 64)
    self.down1 = Down(64, 128)
    self.down2 = Down(128,256)
    self.down3 = Down(256,512)
    self.down4 = Down(512,1024)
    self.up1 = Up(1024,512,512)
    self.up2 = Up(512,256,256)
    self.up3 = Up(256,128,128)
    self.up4 = Up(128,64,64)
    self.out_conv = nn.Conv2d(64, out_channel, kernel_size=1)

  def forward(self,x):
    x1 = self.dconv1(x)
    x2 = self.down1(x1)
    x3 = self.down2(x2)
    x4 = self.down3(x3)
    x5 = self.down4(x4)
    x = self.up1(x5,x4)
    x = self.up2(x,x3)
    x = self.up3(x,x2)
    x = self.up4(x,x1)
    x = self.out_conv(x)
    return x