In [1]:
import torch
import torch.nn as nn

In [3]:
def double_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True)
    )
    return conv

def crop_tensor(tensor, target_tensor):
    '''
    Adjusting the shapes of layers to concatenate
    
    '''
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    # Assuming tensor size is larger than target size
    delta = tensor_size - target_size
    delta = delta // 2
    return tensor[:,:, delta:tensor_size-delta, delta:tensor_size-delta] # adjusting the height and width for concatenation

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
    
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = double_conv(1,64)
        self.down_conv_2 = double_conv(64,128)
        self.down_conv_3 = double_conv(128,256)
        self.down_conv_4 = double_conv(256,512)
        self.down_conv_5 = double_conv(512,1024)
    
        self.up_trans_1 = nn.ConvTranspose2d(in_channels = 1024,
                                             out_channels = 512,
                                             kernel_size =2, 
                                             stride=2)
        
        self.up_conv_1 = double_conv(1024, 512)
        
        
        self.up_trans_2 = nn.ConvTranspose2d(in_channels = 512,
                                             out_channels = 256,
                                             kernel_size =2, 
                                             stride=2)
        
        self.up_conv_2 = double_conv(512, 256)
        
        
        self.up_trans_3 = nn.ConvTranspose2d(in_channels = 256,
                                             out_channels = 128,
                                             kernel_size =2, 
                                             stride=2)
        
        self.up_conv_3 = double_conv(256, 128)
        
        
        self.up_trans_4 = nn.ConvTranspose2d(in_channels = 128,
                                             out_channels = 64,
                                             kernel_size =2, 
                                             stride=2)
        
        self.up_conv_4 = double_conv(128, 64)
        
        self.out = nn.Conv2d(
            in_channels = 64,
            out_channels = 2,
            kernel_size = 1
        )
        
    
    def forward(self, image):
        #batch_size, channels, height, width
        
        # encoder
        x1 = self.down_conv_1(image) 
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv_2(x2) 
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv_3(x4) 
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv_4(x6) 
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv_5(x8)
        
        #decoder
        x = self.up_trans_1(x9)
        #concatenate x with x7, before that we crop x7
        y = crop_tensor(x7, x)
        x = self.up_conv_1(torch.cat([x,y],1))
        
        x = self.up_trans_2(x)
        y = crop_tensor(x5, x)
        x = self.up_conv_2(torch.cat([x,y],1))
        
        x = self.up_trans_3(x)
        y = crop_tensor(x3, x)
        x = self.up_conv_3(torch.cat([x,y],1))
        
        x = self.up_trans_4(x)
        y = crop_tensor(x1, x)
        x = self.up_conv_4(torch.cat([x,y],1))
        
        x = self.out(x)
        print(x.size())
        return x
        
        
if __name__ == "__main__":
    image = torch.rand((1,1,572,572))
    model = UNet()
    print(model(image))

torch.Size([1, 2, 388, 388])
tensor([[[[0.0239, 0.0222, 0.0227,  ..., 0.0207, 0.0252, 0.0213],
          [0.0275, 0.0244, 0.0248,  ..., 0.0258, 0.0224, 0.0268],
          [0.0252, 0.0221, 0.0270,  ..., 0.0232, 0.0248, 0.0218],
          ...,
          [0.0213, 0.0199, 0.0225,  ..., 0.0232, 0.0235, 0.0250],
          [0.0259, 0.0253, 0.0254,  ..., 0.0211, 0.0239, 0.0215],
          [0.0215, 0.0270, 0.0251,  ..., 0.0224, 0.0217, 0.0225]],

         [[0.1149, 0.1200, 0.1220,  ..., 0.1213, 0.1184, 0.1205],
          [0.1148, 0.1144, 0.1103,  ..., 0.1134, 0.1155, 0.1146],
          [0.1161, 0.1175, 0.1100,  ..., 0.1196, 0.1161, 0.1182],
          ...,
          [0.1180, 0.1185, 0.1159,  ..., 0.1170, 0.1143, 0.1166],
          [0.1155, 0.1119, 0.1175,  ..., 0.1162, 0.1185, 0.1187],
          [0.1183, 0.1166, 0.1205,  ..., 0.1193, 0.1181, 0.1201]]]],
       grad_fn=<ThnnConv2DBackward>)
