In [1]:
import torch
from torch import nn

class double_conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(double_conv, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.05)
        )
        
    def forward(self, x):
        x = self.conv(x)
        return x

    
class encode(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(encode, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.mp_conv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_channels=self.in_channels, out_channels=self.out_channels)
        )
        
    def forward(self, x):
        x = self.mp_conv(x)
        return x

    
class decode(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(decode, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.bilinear = bilinear
        
        if self.bilinear == True:
            self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.upsample = nn.ConvTranspose2d(in_channels=self.in_channels, out_channels=self.out_channels,
                                               kernel_size=2, stride=2)
        
        self.conv = double_conv(in_channels=self.in_channels, out_channels=self.out_channels)
        
    def forward(self, x1, x2):
        x1 = F.relu(self.upsample(x1))
        x = torch.cat([x1, x2], dim=1)
        x = self.conv(x)
        return x