<a href="https://colab.research.google.com/github/devangi2000/Deep-Learning/blob/master/U_Net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import torch
import torch.nn as nn

def double_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True)
    )
    return conv

def crop_img(tensor, target_tensor):
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = tensor_size - target_size
    delta = delta//2
    return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.maxpool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = double_conv(1, 64)
        self.down_conv_2 = double_conv(64, 128)
        self.down_conv_3 = double_conv(128, 256)
        self.down_conv_4 = double_conv(256, 512)
        self.down_conv_5 = double_conv(512, 1024)


        self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024,
                                             out_channels=512,
                                             kernel_size=2,
                                             stride=2)
        self.up_conv_1 = double_conv(1024, 512)

        self.up_trans_2 = nn.ConvTranspose2d(in_channels=512,
                                             out_channels=256,
                                             kernel_size=2,
                                             stride=2)
        self.up_conv_2 = double_conv(512, 256)

        self.up_trans_3 = nn.ConvTranspose2d(in_channels=256,
                                             out_channels=128,
                                             kernel_size=2,
                                             stride=2)
        self.up_conv_3 = double_conv(256, 128)

        self.up_trans_4 = nn.ConvTranspose2d(in_channels=128,
                                             out_channels=64,
                                             kernel_size=2,
                                             stride=2)
        self.up_conv_4 = double_conv(128, 64)

        self.out = nn.Conv2d(
            in_channels = 64,
            out_channels=2,
            kernel_size=1
        )


    def forward(self, image):
        # encoder
        x1 = self.down_conv_1(image) #
        #print(x1.size())
        x2 = self.maxpool_2x2(x1)
        x3 = self.down_conv_2(x2) #
        x4 = self.maxpool_2x2(x3)
        x5 = self.down_conv_3(x4) #
        #print(x5.size())
        x6 = self.maxpool_2x2(x5)
        x7 = self.down_conv_4(x6) #
        #print(x7.size())
        x8 = self.maxpool_2x2(x7)
        x9 = self.down_conv_5(x8)
        #print(x9.size())

        # Decoder
        x = self.up_trans_1(x9)
        y = crop_img(x7, x)
        x = self.up_conv_1(torch.cat([x,y],1))

        x = self.up_trans_2(x)
        y = crop_img(x5, x)
        x = self.up_conv_2(torch.cat([x,y],1))

        x = self.up_trans_3(x)
        y = crop_img(x3, x)
        x = self.up_conv_3(torch.cat([x,y],1))

        x = self.up_trans_4(x)
        y = crop_img(x1, x)
        x = self.up_conv_4(torch.cat([x,y],1))

        
        x = self.out(x)
        #print(x.size())
        return x
        #print(x7.size())
        #print(y.size())

        


if __name__=="__main__":
    # bs, c, h, w
    image = torch.rand((1,1,572,572))
    model = UNet()
    print(model(image))

torch.Size([1, 2, 388, 388])
tensor([[[[0.0414, 0.0419, 0.0481,  ..., 0.0424, 0.0440, 0.0479],
          [0.0468, 0.0463, 0.0401,  ..., 0.0418, 0.0482, 0.0452],
          [0.0446, 0.0421, 0.0480,  ..., 0.0428, 0.0429, 0.0457],
          ...,
          [0.0446, 0.0433, 0.0477,  ..., 0.0451, 0.0430, 0.0470],
          [0.0442, 0.0445, 0.0440,  ..., 0.0421, 0.0453, 0.0441],
          [0.0450, 0.0453, 0.0425,  ..., 0.0487, 0.0436, 0.0464]],

         [[0.1355, 0.1375, 0.1363,  ..., 0.1354, 0.1376, 0.1349],
          [0.1344, 0.1359, 0.1358,  ..., 0.1378, 0.1334, 0.1387],
          [0.1340, 0.1364, 0.1359,  ..., 0.1359, 0.1364, 0.1354],
          ...,
          [0.1356, 0.1373, 0.1351,  ..., 0.1368, 0.1372, 0.1388],
          [0.1376, 0.1364, 0.1380,  ..., 0.1396, 0.1358, 0.1380],
          [0.1346, 0.1371, 0.1312,  ..., 0.1362, 0.1364, 0.1373]]]],
       grad_fn=<MkldnnConvolutionBackward>)
