In [27]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

In [28]:
def double_conv(in_channels,out_channels):
    conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding =1, bias ='False'),
        #nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding = 1, bias = 'False'),
        #nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True))
    return conv

In [30]:
class UNet(nn.Module):
    
    def __init__(self, in_channels = 1, out_channels = 1):
        super(UNet, self).__init__()
        
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.down_conv = double_conv(in_channels,64)
        self.down_conv2 = double_conv(64,128)
        self.down_conv3 = double_conv(128,256)
        self.down_conv4 = double_conv(256,512)
        self.down_conv5 = double_conv(512,1024)
        
        self.up_trans1 = nn.ConvTranspose2d(in_channels = 1024, out_channels = 512, kernel_size=2, stride=2)
        self.up_conv1 = double_conv(1024,512)
        
        self.up_trans2 = nn.ConvTranspose2d(in_channels = 512, out_channels = 256, kernel_size=2, stride=2)
        self.up_conv2 = double_conv(512,256)
        
        self.up_trans3 = nn.ConvTranspose2d(in_channels = 256, out_channels = 128, kernel_size=2, stride=2)
        self.up_conv3 = double_conv(256,128)
        
        self.up_trans4 = nn.ConvTranspose2d(in_channels = 128, out_channels = 64, kernel_size=2, stride=2)
        self.up_conv4 = double_conv(128,64)
        
        self.out = nn.Conv2d(64,out_channels, kernel_size=1)
        
    def forward(self, image):
        # encoder
        
        x1_conv = self.down_conv(image)
        x1_pool = self.max_pool(x1_conv)
        
        x2_conv = self.down_conv2(x1_pool)
        x2_pool = self.max_pool(x2_conv)
        
        x3_conv = self.down_conv3(x2_pool)
        x3_pool = self.max_pool(x3_conv)
        
        x4_conv = self.down_conv4(x3_pool)
        x4_pool = self.max_pool(x4_conv)
        
        x5_conv = self.down_conv5(x4_pool)
        
        # Decoder
        x = self.up_trans1(x5_conv)
        x = TF.resize(x, size=x4_conv.shape[2:])
        x = torch.concat((x4_conv,x),axis =1)
        x  = self.up_conv1(x)
        
        x = self.up_trans2(x)
        x = TF.resize(x, size=x3_conv.shape[2:])
        x = torch.concat((x3_conv,x),axis =1)
        x  = self.up_conv2(x)
        
        x = self.up_trans3(x)
        x = TF.resize(x, size=x2_conv.shape[2:])
        x = torch.concat((x2_conv,x),axis =1)
        x  = self.up_conv3(x)
        
        x = self.up_trans4(x)
        x = TF.resize(x, size=x1_conv.shape[2:])
        x = torch.concat((x1_conv,x),axis =1)
        x  = self.up_conv4(x)
        
        x = self.out(x)
        return x
        

In [31]:
def test():    
    model = UNet()  
    image = torch.randn(1, 1, 572, 572)
    print(f"Input shape: {image.shape}")
    output = model(image)
    print("Output shape:", output.shape)

In [32]:
if __name__ == "__main__":
    test()

Input shape: torch.Size([1, 1, 572, 572])
Output shape: torch.Size([1, 1, 572, 572])
