In [1]:
import torch 
import torch.nn as nn

class u_net(nn.Module):
    def double_conv(self, in_channels, out_channels):
        """
            It consists of the repeated application of two 3x3 convolutions (unpadded convolutions),
            each followed by a rectified linear unit (ReLU) 
        """
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def __init__(self,):
        super().__init__()

        self.down_dconv_1 = self.double_conv(3, 64)  # input chanel을 했나아?)
        # 2x2 max pooling operation with stride 2 for downsampling.
        # At each downsampling step we double the number of feature channels
        self.maxpool_2x2_1 = nn.MaxPool2d(kernel_size= 2, stride=2)
        
        self.down_dconv_2 = self.double_conv(64, 128)
        self.maxpool_2x2_2 = nn.MaxPool2d(kernel_size =2, stride=2)

        self.down_dconv_3 = self.double_conv(128, 256)
        self.maxpool_2x2_3 = nn.MaxPool2d(kernel_size =2, stride=2)

        self.down_dconv_4= self.double_conv(256, 512)
        self.maxpool_2x2_4 = nn.MaxPool2d(kernel_size =2, stride=2)

        self.down_dconv_5 = self.double_conv(512, 1024)

        # Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution
        self.up_trans_4 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2) 
        self.up_conv_4 = self.double_conv(1024,512)
        
        self.up_trans_3 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
        self.up_conv_3 = self.double_conv(512,256)
        
        self.up_trans_2 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
        self.up_conv_2 = self.double_conv(256,128)

        self.up_trans_1 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
        self.up_conv_1 = self.double_conv(128,64)

        self.out = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1) # 1x1 convolution

    def forward(self, x): 
        down_dconv_1 = self.down_dconv_1(x) # check down_dconv_1
        maxpool_1 = self.maxpool_2x2_1(down_dconv_1) 
        
        down_dconv_2 = self.down_dconv_2(maxpool_1) # check  down_dconv_2
        maxpool_2 = self.maxpool_2x2_2(down_dconv_2)

        down_dconv_3 = self.down_dconv_3(maxpool_2) # check  down_dconv_3
        maxpool_3 = self.maxpool_2x2_3(down_dconv_3)

        down_dconv_4 = self.down_dconv_4(maxpool_3)  #check  down_dconv_4
        maxpool_4 = self.maxpool_2x2_4(down_dconv_4)

        down_dconv_5 = self.down_dconv_5(maxpool_4) 
        # end constracting path 

        # start expansive path 
        up_trans_4 = self.up_trans_4(down_dconv_5)
        cat_4 = torch.cat([down_dconv_4, up_trans_4],dim=1)
        up_conv_4 = self.up_conv_4(cat_4)

        up_trans_3 = self.up_trans_3(up_conv_4)
        cat_3 = torch.cat([down_dconv_3, up_trans_3],dim=1)  
        up_conv_3 = self.up_conv_3(cat_3)

        up_trans_2 = self.up_trans_2(up_conv_3)
        cat_2 = torch.cat([down_dconv_2, up_trans_2],dim=1)  
        up_conv_2 = self.up_conv_2(cat_2)

        up_trans_1 = self.up_trans_1(up_conv_2)
        cat_1 = torch.cat([down_dconv_1, up_trans_1],dim=1)  
        up_conv_1 = self.up_conv_1(cat_1)

        out = self.out(up_conv_1)
        return out
u_net()

u_net(
  (down_dconv_1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (maxpool_2x2_1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (down_dconv_2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (maxpool_2x2_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (down_dconv_3): Sequential(