In [1]:
import torch
from torch import nn

In [2]:
image = torch.randn((1,1,572,572))  # B,H,W,H


In [12]:
def double_convs(in_channels, out_channels, kernel_size = 3):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size),
        nn.ReLU(inplace=True),
    )

In [51]:
def crop_image(tensor, target_tensor):
    target_size = target_tensor.size()[2]  # 56
    tensor_size = tensor.size()[2] # 64
    #delta [1, 512, 64 - 4 , 64 - 4 ]
    delta = tensor_size - target_size
    delta = delta // 2
    return tensor[:, :, delta:tensor_size - delta, delta:tensor_size - delta]

In [76]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        #building blocks for UNet
        ##encoder part -> first (  right  ) half 
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size= 2, stride= 2)
        self.down_conv_1 = double_convs(1,64)
        self.down_conv_2 = double_convs(64,128)
        self.down_conv_3 = double_convs(128,256)
        self.down_conv_4 = double_convs(256,512)
        self.down_conv_5 = double_convs(512,1024)

        ##decoder part -> second (  left  ) half 
        self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024,
                                             out_channels=512,
                                             kernel_size=2,
                                             stride=2)      #green arrow
        self.up_conv_1 = double_convs(1024, 512)            # blue arrow
        self.up_trans_2 = nn.ConvTranspose2d(in_channels=512,
                                             out_channels=256,
                                             kernel_size=2,
                                             stride=2)
        self.up_conv_2 = double_convs(512, 256)
        self.up_trans_3 = nn.ConvTranspose2d(in_channels=256,
                                             out_channels=128,
                                             kernel_size=2,
                                             stride=2)
        self.up_conv_3 = double_convs(256,128)
        self.up_trans_4 = nn.ConvTranspose2d(in_channels=128,
                                             out_channels=64,
                                             kernel_size=2,
                                             stride=2)
        self.up_conv_4 = double_convs(128,64)
        self.out_conv = nn.Conv2d(64,2,kernel_size=1)


    def forward(self, x:torch.Tensor) -> torch.Tensor:  #return torch.Tensor
        #specify how the tensor move through the convolution architecture
        #encoder part -> first (  right  ) half #input x -> image
        X1 = self.down_conv_1(x)    #--> will be concatenated
        X2 = self.max_pool_2x2(X1)
        X3 = self.down_conv_2(X2)   #--> will be concatenated
        X4 = self.max_pool_2x2(X3)
        X5 = self.down_conv_3(X4)   #--> will be concatenated
        X6 = self.max_pool_2x2(X5)
        X7 = self.down_conv_4(X6)   #--> will be concatenated
        X8 = self.max_pool_2x2(X7)
        X9 = self.down_conv_5(X8)


        ##decoder part -> second (  left  ) half
        X10 = self.up_trans_1(X9)
        crop = crop_image(X7, X10)
        X11 = self.up_conv_1(torch.concat((X10, crop), dim=1)) #concatenate x10 with x7
        print(X11.shape) # -> torch.Size([1, 512, 52, 52])

        # print(f"this is X7{X7.shape},\n \t X10{X10.shape},\n")
        # print(f"this is crop shape {crop.shape}")

        X12 = self.up_trans_2(X11)
        crop = crop_image(X5, X12)
        X13 = self.up_conv_2(torch.concat((X12, crop), dim=1)) #concatenate X12 with X5
        print(X13.shape) # -> torch.Size([1, 256, 100, 100])


        X14 = self.up_trans_3(X13)
        crop = crop_image(X3, X14)
        X15 = self.up_conv_3(torch.concat((X14, crop), dim=1)) #concatenate X14 with X3
        print(X15.shape) # -> torch.Size([1, 128, 196, 196])



        X16 = self.up_trans_4(X15)
        crop = crop_image(X1, X16)
        X17 = self.up_conv_4(torch.concat((X16, crop), dim=1)) #concatenate X16 with X1
        print(X16.shape) # -> torch.Size([1, 64, 392, 392])

        X18 = self.out_conv(X17)
        print(X18.shape) # -> torch.Size([1, 2, 388, 388])




            

In [77]:
medel_1 = UNet()
t = medel_1(image)
t

torch.Size([1, 512, 52, 52])
torch.Size([1, 256, 100, 100])
torch.Size([1, 128, 196, 196])
torch.Size([1, 64, 392, 392])
torch.Size([1, 2, 388, 388])
