In [6]:
import torch 
import torch.nn as nn

class u_net(nn.Module):
    def double_conv(self, in_channels, out_channels):
        """
            It consists of the repeated application of two 3x3 convolutions (unpadded convolutions),
            each followed by a rectified linear unit (ReLU) 
        """
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3),
            nn.ReLU(inplace=True),
        )

    def __init__(self,):
        super().__init__()

        self.dconv_down1 = self.double_conv(3, 64) 
        self.maxpool_2x2 = nn.MaxPool2d(kernel_size= 2, stride=2) # 2x2 max pooling operation with stride 2 for downsampling. At each downsampling step we double the number of feature channels

        self.dconv_down2 = self.double_conv(64, 128)
        self.maxpool_2x2 = nn.MaxPool2d(kernel_size =2, stride=2)

        self.dconv_down3 = self.double_conv(128, 256)
        self.maxpool_2x2 = nn.MaxPool2d(kernel_size =2, stride=2)

        self.dconv_down4= self.double_conv(256, 512)
        self.maxpool_2x2 = nn.MaxPool2d(kernel_size =2, stride=2)

        self.dconv_down5 = self.double_conv(512, 1024)

        

u_net()

u_net(
  (dconv_down1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
  (maxpool_2x2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dconv_down2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
  (dconv_down3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
  (dconv_down4): Sequential(
    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
)