In [4]:
import torch
import torch.nn as nn


In [5]:
def doub_conv(in_c,out_c):
  conv = nn.Sequential(nn.Conv2d(in_c,out_c,kernel_size = 3),
                       nn.ReLU(),
                       nn.Conv2d(out_c,out_c,kernel_size = 3),
                       nn.ReLU())
  return conv

In [15]:
def conv_transpose(in_c,out_c):
  conv_tr = nn.ConvTranspose2d(in_c,out_c,kernel_size = 2,stride = 2)

  return conv_tr

In [18]:
import torchvision.transforms as transforms

In [40]:
class UNetModel(nn.Module):
  def __init__(self,):
    super().__init__()

    self.max_pool = nn.MaxPool2d(kernel_size = 2,stride = 2)
    self.down_conv1 = doub_conv(1,64)
    self.down_conv2 = doub_conv(64,128)
    self.down_conv3 = doub_conv(128,256)
    self.down_conv4 = doub_conv(256,512)
    self.down_conv5 = doub_conv(512,1024)

    self.tr_conv1 = conv_transpose(1024,512)
    self.tr_conv2 = conv_transpose(512,256)
    self.tr_conv3 = conv_transpose(256,128)
    self.tr_conv4 = conv_transpose(128,64)

    self.crop1 = transforms.CenterCrop((392,392))
    self.crop2 = transforms.CenterCrop((200,200))
    self.crop3 = transforms.CenterCrop((104,104))
    self.crop4 = transforms.CenterCrop((56,56))

    self.up_conv1 = doub_conv(1024,512)
    self.up_conv2 = doub_conv(512,256)
    self.up_conv3 = doub_conv(256,128)
    self.up_conv4 = doub_conv(128,64)



  def forward(self,image):
    x1 = self.down_conv1(image)
    x2 = self.max_pool(x1)
    x3 = self.down_conv2(x2)
    x4 = self.max_pool(x3)
    x5 = self.down_conv3(x4)
    x6 = self.max_pool(x5)
    x7 = self.down_conv4(x6)
    x8 = self.max_pool(x7)
    x9 = self.down_conv5(x8)
    x10 = self.tr_conv1(x9)
    x11 = self.up_conv1(torch.cat([x10,self.crop4(x7)],dim=1))
    x12 = self.up_conv2(torch.cat([self.tr_conv2(x11),self.crop3(x5)],dim = 1))
    x13 = self.up_conv3(torch.cat([self.tr_conv3(x12),self.crop2(x3)],dim = 1))
    x14 = self.up_conv4(torch.cat([self.tr_conv4(x13),self.crop1(x1)],dim = 1))

    final_output = nn.Conv2d(64,2,kernel_size = 1)
    return final_output(x14)



In [41]:
if __name__ == "__main__":
  image = torch.rand((1,1,572,572))
  model = UNetModel()
  print(model(image)   )


tensor([[[[-0.0767, -0.0769, -0.0771,  ..., -0.0796, -0.0796, -0.0777],
          [-0.0756, -0.0776, -0.0775,  ..., -0.0776, -0.0783, -0.0773],
          [-0.0785, -0.0788, -0.0779,  ..., -0.0773, -0.0742, -0.0751],
          ...,
          [-0.0781, -0.0787, -0.0771,  ..., -0.0775, -0.0778, -0.0778],
          [-0.0784, -0.0774, -0.0784,  ..., -0.0765, -0.0746, -0.0778],
          [-0.0804, -0.0772, -0.0753,  ..., -0.0777, -0.0799, -0.0765]],

         [[ 0.0420,  0.0432,  0.0448,  ...,  0.0417,  0.0413,  0.0424],
          [ 0.0428,  0.0478,  0.0422,  ...,  0.0432,  0.0459,  0.0396],
          [ 0.0438,  0.0461,  0.0467,  ...,  0.0408,  0.0443,  0.0454],
          ...,
          [ 0.0444,  0.0400,  0.0428,  ...,  0.0434,  0.0422,  0.0424],
          [ 0.0412,  0.0434,  0.0420,  ...,  0.0418,  0.0410,  0.0446],
          [ 0.0440,  0.0415,  0.0434,  ...,  0.0454,  0.0438,  0.0422]]]],
       grad_fn=<ConvolutionBackward0>)
