In [92]:
import numpy as np
import torch
import torch.nn as nn

In [93]:
X = torch.zeros(2, 1, 256, 256)

In [94]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.maxpool = nn.MaxPool2d(2)
        self.activation = nn.ReLU()
        
        self.build_encoder()
        self.build_decoder()
        
        self.final_layer = nn.Conv2d(32, 1, 1)
        
    def build_conv_layer(self, in_channels, out_channels):
        layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1),
            self.activation,
            nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1),
            self.activation
        )
        return layer
        
    def build_encoder(self):
        self.E_layer_1 = self.build_conv_layer(1, 32)
        self.E_layer_2 = self.build_conv_layer(32, 64)
        self.E_layer_3 = self.build_conv_layer(64, 128)
        self.E_layer_4 = self.build_conv_layer(128, 256)
        
    def build_decoder(self):
        self.up_1 = nn.ConvTranspose2d(256, 256, 2, 2, 0)
        self.D_layer_1 = self.build_conv_layer(256+128, 128)
        self.up_2 = nn.ConvTranspose2d(128, 128, 2, 2, 0)
        self.D_layer_2 = self.build_conv_layer(128+64, 64)
        self.up_3 = nn.ConvTranspose2d(64, 64, 2, 2, 0)
        self.D_layer_3 = self.build_conv_layer(64+32, 32)
        
    def forward(self, x):
        # Out dim = (32, 256, 256)
        E1_out = self.E_layer_1(x)
        print(f'E1_out: {E1_out.shape}')
        
        out = self.maxpool(E1_out)
        print(f'out: {out.shape}')
        # Out dim = (64, 128, 128)
        E2_out = self.E_layer_2(out)
        print(f'E2_out: {E2_out.shape}')
        
        out = self.maxpool(E2_out)
        print(f'out: {out.shape}')
        # Out dim = (128, 64, 64)
        E3_out = self.E_layer_3(out)
        print(f'E3_out: {E3_out.shape}')
        
        # bottleneck
        out = self.maxpool(E3_out)
        print(f'out: {out.shape}')
        # Out dim = (256, 32, 32)
        bn = self.E_layer_4(out)
        print(f'bn: {bn.shape}')
        
        # Out dim = (256, 64, 64)
        out = self.up_1(bn)
        print(f'out: {out.shape}')
        # Out dim = (256+128, 64, 64)
        out = torch.cat([out, E3_out], dim=1)
        print(f'out + E3_out: {out.shape}')
        out = self.D_layer_1(out)
        print(f'D1 out: {out.shape}')
        
        # Out dim = (128, 128, 128)
        out = self.up_2(out)
        print(f'out: {out.shape}')
        # Out dim = (128+64, 128, 128)
        out = torch.cat([out, E2_out], dim=1)
        print(f'out + E2_out: {out.shape}')

        out = self.D_layer_2(out)
        print(f'out: {out.shape}')
        
        # Out dim = (64, 256, 256)
        out = self.up_3(out)
        print(f'out: {out.shape}')
        # Out dim = (64+32, 256, 256)
        out = torch.cat([out, E1_out], dim=1)
        print(f'out: {out.shape}')
        
        out = self.D_layer_3(out)
        print(f'out: {out.shape}')
        
        out = self.final_layer(out)
    
        return out

In [96]:
net = UNet()

In [97]:
out = net(X)
out.shape

E1_out: torch.Size([2, 32, 256, 256])
out: torch.Size([2, 32, 128, 128])
E2_out: torch.Size([2, 64, 128, 128])
out: torch.Size([2, 64, 64, 64])
E3_out: torch.Size([2, 128, 64, 64])
out: torch.Size([2, 128, 32, 32])
bn: torch.Size([2, 256, 32, 32])
out: torch.Size([2, 256, 64, 64])
out + E3_out: torch.Size([2, 384, 64, 64])
D1 out: torch.Size([2, 128, 64, 64])
out: torch.Size([2, 128, 128, 128])
out + E2_out: torch.Size([2, 192, 128, 128])
out: torch.Size([2, 64, 128, 128])
out: torch.Size([2, 64, 256, 256])
out: torch.Size([2, 96, 256, 256])
out: torch.Size([2, 32, 256, 256])


torch.Size([2, 1, 256, 256])