In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [68]:
class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3)
        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        return x

From the original paper
The contracting path follows the typical architecture of a convolutional network. It consists of the repeated application of two 3×3 convolutions (unpadded convolutions), each followed by a rectified linear unit (ReLU) and a 2×2 max pooling operation with stride 2 for downsampling. At each downsampling step we double the number of feature channels.

In [73]:
class Contracting_path(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Contracting_path, self).__init__()
        self.conv1 = nn.Conv2d(in_channel, out_channel, 3)
        self.conv2 = nn.Conv2d(out_channel, out_channel, 3)
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool2d(2, 2)
        
    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        p = self.max_pool(x)
        #print(type(x))
        return x, p

From the original paper
Every step in the expansive path consists of an upsampling of the feature map followed by a 2×2 convolution (“up-convolution” aka transpose convolution) that halves the number of feature channels, a concatenation with the correspondingly cropped feature map from the contracting path, and two 3×3 convolutions, each followed by a ReLU

In [79]:
class Expansive_path(torch.nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Expansive_path, self).__init__()
        self.up_conv = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2, padding=0)
        self.conv1 = torch.nn.Conv2d(out_channel + out_channel, out_channel, 3, padding=1)
        self.conv2 = torch.nn.Conv2d(out_channel + out_channel, out_channel, 3, padding=1)
        self.relu = torch.nn.ReLU()

    def forward(self, prev_feature_map, inputs):
        x = self.up_conv(inputs)
        x = torch.cat((x, prev_feature_map), dim=1)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        return x

In [80]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        ### We have 5 layers are on the left side# ##
        self.c1 = Contracting_path(1, 64)
        self.c2 = Contracting_path(64, 128)
        self.c3 = Contracting_path(128, 256)
        self.c4 = Contracting_path(256, 512)
        
        self.c5 = conv_block(512, 1024)

        #self.c5 = Contracting_path(512, 1024)
        #self.max_pool = nn.MaxPool2d(2,2)

        
        ## And 4 layers are on the right side ###
        self.e1 = Expansive_path(1024, 512)
        self.e2 = Expansive_path(512, 256)
        self.e3 = Expansive_path(256, 128)
        self.e4 = Expansive_path(128, 64)

        #output layer
        self.out = nn.Conv2d(64, 2, kernel_size=1)
       

    def forward(self, inputs):
   
        #forward pass for Left side
        x1 = self.c1(inputs)
        
        #x2 = self.max_pool(x1)
        x2 = self.c2(x1)
        #x4 = self.max_pool(x3)
        x3 = self.c3(x2)
        #x6 = self.max_pool(x5)
        x4 = self.c4(x3)
        #x8 = self.max_pool(x7)
        x5 = self.c5(x4)

         #forward pass for Right side
        x = self.e1(x4, x5)
        x = self.e2(x3, x)
        x = self.e3(x2, x)
        x = self.e4(x1, x)
        x = self.out(x)
        return x

In [82]:
image = torch.rand((1, 1, 572, 572))
model = UNet()
model(image)


TypeError: ignored