In [1]:
import torch
import torch.nn as nn

class EncoderBlock(nn.Module):
    def __init__(self, out_channels, num_convs=2, kernel_size=3):
        super().__init__()
        
        layers = []
        for _ in range(num_convs):
            layers.append(nn.LazyConv2d(out_channels, kernel_size=kernel_size))
            layers.append(nn.ReLU())
        
        self.convblock = nn.Sequential(*layers)
        self.poolblock = nn.MaxPool2d(kernel_size=2, stride=2)
        
    def forward(self, X):
        skip = self.convblock(X)
        Y = self.poolblock(skip)
        return Y, skip
        
class DecoderBlock(nn.Module):
    def __init__(self, out_channels, num_convs=2, kernel_size=3):
        super().__init__()
        
        self.unpoolblock = nn.LazyConvTranspose2d(out_channels, kernel_size=2, stride=2)
        
        layers = []
        for _ in range(num_convs):
            layers.append(nn.LazyConv2d(out_channels, kernel_size=kernel_size))
            layers.append(nn.ReLU())
        self.convblock = nn.Sequential(*layers)
        
    def forward(self, X, skip):
        X_up = self.unpoolblock(X)
        m, n = X_up.shape[-2:]
        d_m = (skip.shape[-2]-m) // 2
        d_n = (skip.shape[-1]-n) // 2
        # Center crop skip features to same height/width as X_up
        # then concatenate them along the output channel dimension
        X_up_cat = torch.cat((X_up,skip[:, :, d_m:m+d_m, d_n:n+d_n]), dim=1)
        Y = self.convblock(X_up_cat)
        return Y
        
class UNet(nn.Module):
    def __init__(self, channels_out, encoder_depth=4, num_classes=2):
        super().__init__()
        
        self.encoder = nn.ModuleList()
        for i in range(0,encoder_depth):
            self.encoder.append(EncoderBlock(channels_out*2**i))
            
        self.bridge = nn.Sequential(
            nn.LazyConv2d(channels_out*2**encoder_depth, kernel_size=3), nn.ReLU(),
            nn.LazyConv2d(channels_out*2**encoder_depth, kernel_size=3), nn.ReLU(),            
        )
        
        self.decoder = nn.ModuleList()
        for i in range(encoder_depth-1,-1,-1):
            self.decoder.append(DecoderBlock(channels_out*2**i))
            
        self.final = nn.LazyConv2d(num_classes, kernel_size=1)
        
    def forward(self, X):
        # H = hidden layer activations
        H = X
        skip = []
        
        for enc_block in self.encoder:
            H, S = enc_block(H)
            skip.append(S)
            
        H = self.bridge(H)
        
        for dec_block in self.decoder:
            H = dec_block(H, skip.pop())
            
        Y = self.final(H)
        return Y
    
    def init_weights(self, inputs):
        self.forward(inputs)
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')

In [2]:
unet = UNet(64)
I = torch.randn((1,1,572,572))
unet.init_weights(I)
print(unet(I))
print(list(unet.children()))



tensor([[[[-0.1735,  1.3608, -0.5569,  ...,  0.4138, -1.2697,  1.4504],
          [ 0.9092,  1.4605, -1.5454,  ...,  0.1866,  0.7423,  0.7390],
          [-1.4064,  0.6028, -0.0883,  ...,  0.5749,  0.0245,  0.3559],
          ...,
          [ 0.2521, -0.0932,  0.7357,  ..., -0.5910,  0.1675, -0.2962],
          [-0.3453,  0.7124,  0.6962,  ...,  0.4802, -0.7110,  0.0048],
          [ 0.0724, -0.0707,  1.1107,  ...,  0.6663,  1.1248, -0.9702]],

         [[ 0.3904, -0.2253, -1.9845,  ..., -2.1738, -0.8560, -0.7408],
          [-1.4031, -0.8858, -1.8994,  ...,  0.0403, -1.2398, -0.7838],
          [-1.2248, -0.9576, -0.8792,  ..., -0.5345, -0.1558, -1.8179],
          ...,
          [-1.7406, -0.7554, -0.7882,  ..., -2.5429, -1.0109, -1.6083],
          [-0.0935, -0.4155, -0.4309,  ..., -0.7168, -1.9295, -1.0563],
          [-1.6012, -0.5487, -0.9493,  ..., -1.1098, -0.9146, -1.9235]]]],
       grad_fn=<ConvolutionBackward0>)
[ModuleList(
  (0): EncoderBlock(
    (convblock): Sequential(