In [1]:
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
import os
from PIL import Image
from enum import Enum

In [25]:
def build_conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1),
        nn.BatchNorm2d(num_features=out_channels),
        nn.ReLU(),
        
        nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1),
        nn.BatchNorm2d(num_features=out_channels),
        nn.ReLU(),
    )

In [23]:
class ConvBlock(nn.Module):
    skip_connections = []

    def __init__(self, in_channels, out_channels, encode=True):
        super(ConvBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.encode = encode
        self.conv = build_conv_block(in_channels=self.in_channels, out_channels=self.out_channels)
    
    def forward(self, X):
        if self.encode:
            X = self.conv(X)
            self.skip_connections.append(X)
            return nn.MaxPool2d(kernel_size=2, stride=2)(X)
        else:
            X = nn.ConvTranspose2d(in_channels=self.in_channels, out_channels=self.in_channels//2, kernel_size=2, stride=2)(X)
            X = torch.cat((X, self.skip_connections.pop()), dim=1)
            return self.conv(X)

In [24]:
class YouNet(nn.Module):
    def __init__(self, in_channels=64, out_channels=64):
        super(YouNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        # The down-sampling layers
        self.contractive_path = nn.ModuleDict({
            'encode0': ConvBlock(in_channels=self.in_channels, out_channels=64, encode=True),
            'encode1': ConvBlock(in_channels=64, out_channels=128, encode=True),
            'encode2': ConvBlock(in_channels=128, out_channels=256, encode=True),
            'encode3': ConvBlock(in_channels=256, out_channels=512, encode=True),
        })

        # The bottleneck
        self.trough = build_conv_block(in_channels=512, out_channels=1024)

        # The up-sampling layers
        # in_channels takes input from previous layer and skip connections
        self.expansive_path = nn.ModuleDict({
            'decode3': ConvBlock(in_channels=512*2, out_channels=512, encode=False),
            'decode2': ConvBlock(in_channels=256*2, out_channels=256, encode=False),
            'decode1': ConvBlock(in_channels=128*2, out_channels=128, encode=False),
            'decode0': ConvBlock(in_channels=64*2, out_channels=64, encode=False)
        })

        # The prediction layer
        self.final = nn.Conv2d(in_channels=64, out_channels=self.out_channels, kernel_size=1)
    
    def forward(self, X):
        # Train the contractive path
        for conv_block in self.contractive_path:
            X = self.contractive_path[conv_block](X)
        
        # Train the trough
        X = self.trough(X)
        
        # Train the expansive path
        for conv_block in self.expansive_path:
            X = self.expansive_path[conv_block](X)
        
        return self.final(X)
    
    def print_hook_shape(self, module, input, output):
        '''Prints the input and output tensor of a given layer. Used by YouNet.print_forward_hooks'''
        print(f'{module.__class__.__name__}(input shape: {input[0].shape}, output shape: {output.shape})')

    def print_forward_hooks(self):
        '''Prints the input and output tensors of each layer.'''
        for name, layer in self.contractive_path.items():
            layer.register_forward_hook(self.print_hook_shape)
        
        self.trough.register_forward_hook(self.print_hook_shape)
        
        for name, layer in self.expansive_path.items():
            layer.register_forward_hook(self.print_hook_shape)
        
        self.final.register_forward_hook(self.print_hook_shape)
    
    def print_network(self):
        '''Prints the entire network architecture.'''
        for name, conv_block in self.contractive_path.items():
            print('Layer:', name)
            print(conv_block)

        print('Layer: bottleneck')
        print(self.trough)

        for name, conv_block in self.expansive_path.items():
            print('Layer:', name)
            print(conv_block)
        
        print('Layer: final\n', self.final, sep='')

net = YouNet(8, 8)
net.print_network()
X = torch.randn((32, 8, 160, 160))
preds = net(X)
print()
print(X.shape)
print(preds.shape)
assert preds.shape == X.shape

Layer: encode0
ConvBlock(
  (conv): Sequential(
    (0): Conv2d(8, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
)
Layer: encode1
ConvBlock(
  (conv): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
)
Layer: encode2
ConvBlock(
  (conv): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, moment