In [60]:
import numpy as np
import torch
import torch.nn as nn

In [61]:
# batch, channels, dimensions
X = torch.zeros(2, 1, 256, 256)

In [62]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.maxpool = nn.MaxPool2d(2)
        self.activation = nn.ReLU()
        
        # Or kernel = 1
        self.final_layer = nn.Conv2d(32, 1, 3, stride=1, padding=1)
        
        self.build_encoder()
        self.build_decoder()
        
    def build_conv_layer(self, in_channels, out_channels):
        layer = nn.Sequential(
            # in_channels, out_channels, kernel_size, stride=1, padding=0
            nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1),
            self.activation,
            nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1),
            self.activation
        )
        return layer
        
    def build_encoder(self):
        self.E_layer_1 = self.build_conv_layer(1, 32)
        self.E_layer_2 = self.build_conv_layer(32, 64)
        self.E_layer_3 = self.build_conv_layer(64, 128)
        self.E_layer_4 = self.build_conv_layer(128, 256)
        
    def build_decoder(self):
        self.up_1 = nn.ConvTranspose2d(256, 256, 2, 2, 0)
        self.D_layer_1 = self.build_conv_layer(256+128, 128)
        self.up_2 = nn.ConvTranspose2d(128, 128, 2, 2, 0)
        self.D_layer_2 = self.build_conv_layer(128+64, 64)
        self.up_3 = nn.ConvTranspose2d(64, 64, 2, 2, 0)
        self.D_layer_3 = self.build_conv_layer(64+32, 32)
        
    def forward(self, x):
        # Encoder
        E1_out = self.E_layer_1(x)  # conv affects num channels
        
        out = self.maxpool(E1_out)  # max pool affects dimensions, not num channels
        E2_out = self.E_layer_2(out)
        
        out = self.maxpool(E2_out)
        E3_out = self.E_layer_3(out)
        
        # Bottleneck
        out = self.maxpool(E3_out)
        bn = self.E_layer_4(out)
        
        # Decoder
        out = self.up_1(bn)
        out = torch.cat([out, E3_out], dim=1)
        D1_out = self.D_layer_1(out)
        
        out = self.up_2(D1_out)
        out = torch.cat([out, E2_out], dim=1)
        D2_out = self.D_layer_2(out)
        
        out = self.up_3(D2_out)
        out = torch.cat([out, E1_out], dim=1)
        out = self.D_layer_3(out)
        
        out = self.final_layer(out)
        
        return out

In [63]:
net = UNet()

In [64]:
net(X).shape

torch.Size([2, 1, 256, 256])

In [65]:
def compute_dims(i, k, p, s):
    o = np.floor((i - k + 2*p)/s) + 1
    return o