U-NEts are used for classification and segmentation.

* Need to know: What are convolutions, strides, padding, max pooling, transposed convolution, ReLU. 

* Single channel image 572x572 then you do a convolution and the image sixe reduces (there is no padding). Repeat this pattern.

* Double_conv function should do two convs, then max pools, then another 2 convolutions.. repeat that. 

* At the middle we start to apply up-convolutions. 




#### Lookup

+ bi-linear upsampling (lookup)

+ up-convolutions (lookup)

## More on the architecture:

It consistes of the repeated application of two 3x3 convolutions (unpadded convolutions) each followed by a rectifued linear unit, and a 2x2 max pooling opertaion with stride = 2 for downsampling. 


Transpose convolutions will increase the size. Convolutions without padding decrease the size. 

You can use ConvTransposed2d from PyTorch: From an image of size 28x28 we need to get an image of size 56. 

In [None]:
## Unet.py
import torch
import torch.nn as nn

In [None]:
def double_conv(in_c, out_c):
  #Make a helper function that performs a double convolution with Relu
  conv = nn.Sequential(
     nn.Conv2d(in_c, out_c, kernel_size = 3),
     nn.ReLU(inplace = True),
     nn.Conv2d(out_c,out_c, kernel_size = 3),
     nn.ReLU(inplace = True)
     )
  return conv



def crop_image(orig_tensor, target_sensor):
  target_size = target_sensor.size()[2]
  orig_tensor_size = orig_tensor.size()[2]
  delta = orig_tensor_size - target_size
  delta = delta // 2
  return orig_tensor[:,:, delta:orig_tensor_size-delta, delta:orig_tensor_size-delta]

In [None]:
class UNet(nn.Module):
    def __init__(self):
        # first thing we have is a input of 572x572 with 2 convolutions
        # one channel image is converted to 64 channels and then another 64 channels.
        super(UNet, self).__init__()
        # we need max pooling
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size = 2, stride = 2) #2x2 max pooling with stride of 2
        # we also need the double convolution
        self.down_conv1 = double_conv(1, 64)
        self.down_conv2 = double_conv(64, 128)
        self.down_conv3 = double_conv(128, 256)
        self.down_conv4 = double_conv(256, 512)
        self.down_conv5 = double_conv(512, 1024)

        self.up_trans_1 = nn.ConvTranspose2d(
                                  in_channels = 1024, 
                                  out_channels = 512, 
                                  kernel_size = 2,
                                  stride = 2)# need to increase the size by two. 

        self.up_conv_1 =  double_conv(1024, 512)# the output channels were 512, but you combined it with the concatenation so it turns into 1024
        self.up_trans_2 = nn.ConvTranspose2d(
                                  in_channels = 512, 
                                  out_channels = 256, 
                                  kernel_size = 2,
                                  stride = 2)# need to increase the size by two. 

        self.up_conv_2 =  double_conv(512, 256)

        self.up_trans_3 = nn.ConvTranspose2d(
                                  in_channels = 256, 
                                  out_channels = 128, 
                                  kernel_size = 2,
                                  stride = 2)# need to increase the size by two. 

        self.up_conv_3 =  double_conv(256, 128)

        self.up_trans_4 = nn.ConvTranspose2d(
                                  in_channels = 128, 
                                  out_channels = 64, 
                                  kernel_size = 2,
                                  stride = 2)# need to increase the size by two. 

        self.up_conv_4 =  double_conv(128, 64)


        self.out =  nn.Conv2d(in_channels = 64, out_channels = 2,#out channels here is two since it is two class
                              kernel_size = 1)
        
 
    
    def forward(self, image):
        #encoder
        #bs, c, h, w
        x1 = self.down_conv1(image)# we need to pass this to the last up-conv layer
        print("X1 after 1st conv", x1.size())
        x1_pooled = self.max_pool_2x2(x1) 
        x2 = self.down_conv2(x1_pooled) # this need to pass to the second to last up-conv layer. 
        x2_pooled = self.max_pool_2x2(x2)
        x3 = self.down_conv3(x2_pooled) # this goes to third to last
        x3_pooled = self.max_pool_2x2(x3)
        x4 = self.down_conv4(x3_pooled) # this goes to fourth to last
        print("x4 after conv", x4.size())
        x4_pooled = self.max_pool_2x2(x4)
        x5 = self.down_conv5(x4_pooled)
        print("X5 after last down convolution", x5.size())


        # Decoder
        x = self.up_trans_1(x5) # we need to concatenate x4 [1, 512, 64, 64] to x which is [1, 512, 56, 56]
        y = crop_image(x4, x)
        x = self.up_conv_1(torch.cat([x,y], 1))
        print("result of up_conv", x.shape)
        x = self.up_trans_2(x) # we need to concatenate x4 [1, 512, 64, 64] to x which is [1, 512, 56, 56]
        y = crop_image(x3, x)
        x = self.up_conv_2(torch.cat([x,y], 1))

        x = self.up_trans_3(x) # we need to concatenate x4 [1, 512, 64, 64] to x which is [1, 512, 56, 56]
        y = crop_image(x2, x)
        x = self.up_conv_3(torch.cat([x,y], 1))

        x = self.up_trans_4(x) # we need to concatenate x4 [1, 512, 64, 64] to x which is [1, 512, 56, 56]
        y = crop_image(x1, x)
        x = self.up_conv_4(torch.cat([x,y], 1))

        x = self.out(x)
        print("results", x.shape)
        return x



In [None]:
# if __name__ == "__main__":

image = torch.rand(1, 1, 572, 572)
model = UNet()
model(image)

X1 after 1st conv torch.Size([1, 64, 568, 568])
x4 after conv torch.Size([1, 512, 64, 64])
X5 after last down convolution torch.Size([1, 1024, 28, 28])
result of up_conv torch.Size([1, 512, 52, 52])
results torch.Size([1, 2, 388, 388])


tensor([[[[ 1.9717e-03,  2.8767e-03,  4.3008e-03,  ...,  3.7842e-03,
            2.4997e-03,  2.9996e-03],
          [ 2.7762e-03,  3.3202e-03,  2.5759e-03,  ...,  3.1935e-03,
            9.7582e-05,  3.6280e-03],
          [ 6.1720e-03,  3.6901e-03,  7.1466e-03,  ...,  2.0450e-03,
            2.9575e-03,  2.7717e-03],
          ...,
          [ 2.5501e-03,  2.6572e-03,  3.0521e-03,  ...,  3.3637e-03,
            7.6788e-04,  4.2437e-03],
          [ 5.7757e-04,  2.2286e-04,  2.8272e-03,  ...,  5.5916e-03,
            3.1097e-03,  4.5494e-03],
          [ 1.6284e-03,  3.2798e-03,  5.1131e-03,  ...,  2.2878e-03,
           -2.9861e-04,  2.3731e-03]],

         [[-1.0958e-01, -1.1289e-01, -1.1193e-01,  ..., -1.0959e-01,
           -1.0844e-01, -1.1241e-01],
          [-1.0923e-01, -1.0782e-01, -1.1300e-01,  ..., -1.0924e-01,
           -1.0926e-01, -1.0921e-01],
          [-1.1016e-01, -1.0771e-01, -1.1001e-01,  ..., -1.1016e-01,
           -1.0844e-01, -1.0752e-01],
          ...,
     