In [2]:
import torch
import torch.nn as nn

class EncoderBlock(nn.Module):
    def __init__(self, out_channels, num_convs=2, kernel_size=3):
        super().__init__()
        
        layers = []
        for _ in range(num_convs):
            layers.append(nn.LazyConv2d(out_channels, kernel_size=kernel_size))
            layers.append(nn.ReLU())
        
        self.convblock = nn.Sequential(*layers)
        self.poolblock = nn.MaxPool2d(kernel_size=2, stride=2)
        
    def forward(self, X):
        skip = self.convblock(X)
        Y = self.poolblock(skip)
        return Y, skip
        
class DecoderBlock(nn.Module):
    def __init__(self, out_channels, num_convs=2, kernel_size=3):
        super().__init__()
        
        self.unpoolblock = nn.LazyConvTranspose2d(out_channels, kernel_size=2, stride=2)
        
        layers = []
        for _ in range(num_convs):
            layers.append(nn.LazyConv2d(out_channels, kernel_size=kernel_size))
            layers.append(nn.ReLU())
        self.convblock = nn.Sequential(*layers)
        
    def forward(self, X, skip):
        X_up = self.unpoolblock(X)
        m, n = X_up.shape[-2:]
        d_m = (skip.shape[-2]-m) // 2
        d_n = (skip.shape[-1]-n) // 2
        # Center crop skip features to same height/width as X_up
        # then concatenate them along the output channel dimension
        X_up_cat = torch.cat((X_up,skip[:, :, d_m:m+d_m, d_n:n+d_n]), dim=1)
        Y = self.convblock(X_up_cat)
        return Y
        
class UNet(nn.Module):
    def __init__(self, channels_out, encoder_depth=4, num_classes=2):
        super().__init__()
        
        self.encoder = nn.ModuleList()
        for i in range(0,encoder_depth):
            self.encoder.append(EncoderBlock(channels_out*2**i))
            
        self.bridge = nn.Sequential(
            nn.LazyConv2d(channels_out*2**encoder_depth, kernel_size=3), nn.ReLU(),
            nn.LazyConv2d(channels_out*2**encoder_depth, kernel_size=3), nn.ReLU(),            
        )
        
        self.decoder = nn.ModuleList()
        for i in range(encoder_depth-1,-1,-1):
            self.decoder.append(DecoderBlock(channels_out*2**i))
            
        self.final = nn.LazyConv2d(num_classes, kernel_size=1)
        
    def forward(self, X):
        # H = hidden layer activations
        H = X
        skip = []
        
        for enc_block in self.encoder:
            H, S = enc_block(H)
            skip.append(S)
            
        H = self.bridge(H)
        
        for dec_block in self.decoder:
            H = dec_block(H, skip.pop())
            
        Y = self.final(H)
        return Y
    
    def init_weights(self, inputs):
        self.forward(inputs)
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')

In [3]:
unet = UNet(64)
I = torch.randn((1,1,572,572))
unet.init_weights(I)
print(unet(I))
print(list(unet.children()))



tensor([[[[ 1.7961e-01,  5.8380e-01,  7.3720e-01,  ...,  8.7482e-01,
            1.0781e+00,  1.4031e+00],
          [ 1.4914e-01,  1.8723e-01,  2.3184e-01,  ...,  1.2077e+00,
            3.9147e-01,  1.7680e-01],
          [ 5.6973e-01,  6.5522e-01,  9.0722e-01,  ...,  8.8139e-01,
           -7.5916e-03,  6.5021e-01],
          ...,
          [ 4.8358e-01,  6.8153e-01, -5.7561e-01,  ...,  2.9997e-01,
           -1.3754e-01,  5.0778e-02],
          [ 6.0585e-01,  3.9180e-01, -1.2003e-01,  ...,  7.4532e-01,
            1.0719e-01,  7.9113e-01],
          [ 7.2439e-01, -2.2501e-02,  5.5179e-01,  ...,  5.2347e-01,
            1.4869e+00,  1.4965e-01]],

         [[ 1.4314e+00, -2.7129e-01,  1.7169e+00,  ...,  3.1603e-01,
            1.6929e+00,  6.7742e-01],
          [ 2.1997e-01, -5.0266e-01, -2.5307e-01,  ...,  1.7772e-01,
            9.9475e-01, -4.2054e-01],
          [ 8.6702e-01,  4.9469e-01,  1.2595e+00,  ...,  8.2479e-01,
            1.5853e+00,  1.4610e+00],
          ...,
     