In [3]:
import numpy as np
import torch
import torch.nn as nn

In [4]:
X = torch.Tensor(2, 1, 256, 256)

In [76]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.maxpool = nn.MaxPool2d(2)
        self.activation = nn.ReLU()
        
        self.final_layer = nn.Conv2d(32, 1, 3, 1, 1)
        
        self.build_encoder()
        self.build_decoder()
        
    def build_conv_layer(self, in_channels, out_channels):
        conv_layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            self.activation,
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            self.activation,
        )
        return conv_layer
        
    def build_encoder(self):
        self.E1_layer = self.build_conv_layer(1, 32)
        self.E2_layer = self.build_conv_layer(32, 64)
        self.E3_layer = self.build_conv_layer(64, 128)
        self.E4_layer = self.build_conv_layer(128, 256)
        
    def build_decoder(self):
        self.up1 = nn.ConvTranspose2d(256, 256, 2, 2, 0)
        self.D1_layer = self.build_conv_layer(256+128, 128)
        self.up2 = nn.ConvTranspose2d(128, 128, 2, 2, 0)
        self.D2_layer = self.build_conv_layer(128+64, 64)
        self.up3 = nn.ConvTranspose2d(64, 64, 2, 2, 0)
        self.D3_layer = self.build_conv_layer(64+32, 32) 
        
    def forward(self, x):
        out = x
        
        # Encoder
        print(out.shape)
        E1_out = self.E1_layer(out)
        print(E1_out.shape)
        
        out = self.maxpool(E1_out)
        E2_out = self.E2_layer(out)
        print(E2_out.shape)
        
        out = self.maxpool(E2_out)
        E3_out = self.E3_layer(out)
        print(E3_out.shape)
        
        out = self.maxpool(E3_out)
        bn = self.E4_layer(out)
        print(bn.shape)

        # Decoder
        out = self.up1(bn)
        out = torch.cat([out, E3_out], dim=1)
        D1_out = self.D1_layer(out)
        print(D1_out.shape)
        
        out = self.up2(D1_out)
        out = torch.cat([out, E2_out], dim=1)
        D2_out = self.D2_layer(out)
        print(D2_out.shape)
        
        out = self.up3(D2_out)
        out = torch.cat([out, E1_out], dim=1)
        D3_out = self.D3_layer(out)
        print(D3_out.shape)

        out = self.final_layer(D3_out)
        
        return out

In [77]:
net = UNet()

In [78]:
net(X);

torch.Size([2, 1, 256, 256])
torch.Size([2, 32, 256, 256])
torch.Size([2, 64, 128, 128])
torch.Size([2, 128, 64, 64])
torch.Size([2, 256, 32, 32])
torch.Size([2, 128, 64, 64])
torch.Size([2, 64, 128, 128])
torch.Size([2, 32, 256, 256])
