In [1]:
import torch
import torch.nn as nn
from torchsummaryX import summary as summaryX
from torchsummary import summary

In [2]:
def convLayer(in_channels, out_channels):
    conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size = 3),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size = 3),
        nn.ReLU()
    )
    return conv

In [3]:
def maxPoolLayer(kernel_size, stride = None, padding = 0):
    maxPool = nn.Sequential(
        nn.MaxPool2d(kernel_size, stride, padding)
    )
    return maxPool

In [4]:
def upConvLayer(in_channels, out_channels):
    upConv = nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size= 2, stride=2)
    )
    return upConv

In [5]:
def crop_tensor(tensor, target_tensor):
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = tensor_size - target_size
    delta //= 2
    return tensor[:, :, delta:tensor_size - delta, delta:tensor_size - delta]

In [6]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.conv1 = convLayer(1, 64)
        self.maxPoolLayer = maxPoolLayer(2, 2)
        self.conv2 = convLayer(64, 128)
        self.conv3 = convLayer(128, 256) 
        self.conv4 = convLayer(256, 512)
        self.conv5 = convLayer(512, 1024)
        self.conv6 = convLayer(1024, 512)
        self.conv7 = convLayer(512, 256)
        self.conv8 = convLayer(256, 128) 
        self.conv9 = convLayer(128, 64)
        self.upConv1 = upConvLayer(1024, 512)
        self.upConv2 = upConvLayer(512, 256)
        self.upConv3 = upConvLayer(256, 128)
        self.upConv4 = upConvLayer(128, 64)
        self.upConv5 = upConvLayer(64, 1)
        self.upConv6 = upConvLayer(64, 2)
    
    def forward(self, image):
        x1 = self.conv1(image) 
        x2 = self.maxPoolLayer(x1)
        x3 = self.conv2(x2) 
        x4 = self.maxPoolLayer(x3)
        x5 = self.conv3(x4) 
        x6 = self.maxPoolLayer(x5)
        x7 = self.conv4(x6) 
        x8 = self.maxPoolLayer(x7)
        x9 = self.conv5(x8)
        x = self.upConv1(x9)
        y = crop_tensor(x7, x)
        x = self.conv6(y)
        x = self.UpConv2(x)
        y = crop_tensor(x5, x)
        x = self.conv7(y)
        x = self.upConv3(x)
        y = crop_tensor(x3, x)
        x = self.conv8(y)
        x = self.upConv4(x)
        y = crop_tensor(x1, x)
        x = self.conv9(y)
        x = self.upConv6(x)

        return x

In [7]:
#Checking the Model

x = torch.randn(2, 1, 572, 572)
model = UNet()
output = model(x)
print(model)
summary(model, (1, 572, 572))
print("Output Shape:", output.shape)

RuntimeError: Given groups=1, weight of size [512, 1024, 3, 3], expected input[2, 512, 56, 56] to have 1024 channels, but got 512 channels instead