# U-Nets with PyTorch

This notebook is used to test the implementation of various U-Net models from scratch using PyTorch. The model is saved under the `models` directory.

In [1]:
import sys
from path_utils import add_parent_path_to_sys_path
# add the parent directory to the sys path so that we can import the models
current_path = sys.path[0]
add_parent_path_to_sys_path(current_path, verbose=False)

import torch

# import the models
from models.unets_encoders import *
from models.unets_decoders import *
from models.unets import *

Path added to the sys path.


# ResUNet2D

In [2]:
# setting up the model
in_channels = 3
out_channels = 8
num_blocks = 4
base_channels = 64
channel_multiplier = 1

in_ksize = 3
in_stride = 1
in_padding = 1
out_ksize = 3
out_stride = 1
out_padding = 1
block_ksize = 3
block_stride = 1
block_padding = 1
padding_mode = 'zeros'
block_mid_channel_multiplier = 1
block_ksize_increment = 0
dropout = False

In [3]:
resunet2d_encoder = ResUNet2DEncoder(
    in_channels, 
    num_blocks, 
    base_channels, 
    channel_multiplier, 
    in_ksize, 
    in_stride, 
    in_padding, 
    block_ksize, 
    block_stride, 
    block_padding, 
    block_mid_channel_multiplier,
    block_ksize_increment,
    padding_mode,
    dropout
)

resunet2d_decoder = ResUNet2DDecoder(
    out_channels,
    num_blocks,
    base_channels,
    channel_multiplier,
    out_ksize,
    out_stride,
    out_padding,
    block_ksize,
    block_stride,
    block_padding,
    block_mid_channel_multiplier,
    block_ksize_increment,
    padding_mode,
    dropout
)

resunet2d = ResUNet2D(resunet2d_encoder,
                      resunet2d_decoder)

In [4]:
# generating some fake data
input_imgs = torch.rand(5, in_channels, 256, 256)

In [7]:
enc_out, enc_features = resunet2d_encoder(input_imgs)
print('Encoder output shape:', enc_out.shape)
for name, feature in enc_features.items():
    print(f'Encoder feature {name} shape:', feature.shape)

dec_out, dec_features = resunet2d_decoder(enc_out, enc_features)
print('Decoder output shape:', dec_out.shape)
for name, feature in dec_features.items():
    print(f'Decoder feature {name} shape:', feature.shape)
print('-'*50)

out, enc_features, dec_features = resunet2d(input_imgs)
print('Output shape:', out.shape)
for name, feature in enc_features.items():
    print(f'Encoder feature {name} shape:', feature.shape)
for name, feature in dec_features.items():
    print(f'Decoder feature {name} shape:', feature.shape)

Encoder output shape: torch.Size([5, 64, 256, 256])
Encoder feature encoder_block_0 shape: torch.Size([5, 64, 256, 256])
Encoder feature encoder_block_1 shape: torch.Size([5, 64, 256, 256])
Encoder feature encoder_block_2 shape: torch.Size([5, 64, 256, 256])
Encoder feature encoder_block_3 shape: torch.Size([5, 64, 256, 256])
Encoder feature encoder_block_4 shape: torch.Size([5, 64, 256, 256])
Decoder output shape: torch.Size([5, 8, 256, 256])
Decoder feature decoder_block_4 shape: torch.Size([5, 64, 256, 256])
Decoder feature decoder_block_3 shape: torch.Size([5, 64, 256, 256])
Decoder feature decoder_block_2 shape: torch.Size([5, 64, 256, 256])
Decoder feature decoder_block_1 shape: torch.Size([5, 64, 256, 256])
Decoder feature decoder_output shape: torch.Size([5, 8, 256, 256])
--------------------------------------------------
Output shape: torch.Size([5, 8, 256, 256])
Encoder feature encoder_block_0 shape: torch.Size([5, 64, 256, 256])
Encoder feature encoder_block_1 shape: torch.S

# ResUNet2D_Attention

In [3]:
# setting up the model
in_channels = 1
out_channels = 8
num_blocks = 4
base_channels = 32
channel_multiplier = 1

in_ksize = 3
in_stride = 1
in_padding = 1

out_ksize = 3
out_stride = 1
out_padding = 1

block_ksize = 3
block_stride = 1
block_padding = 1
padding_mode = 'zeros'
block_mid_channel_multiplier = 1
block_ksize_increment = 0
dropout = False

attn_channels = 8

In [4]:
resunet2d_encoder = ResUNet2DEncoder(
    in_channels, 
    num_blocks, 
    base_channels, 
    channel_multiplier, 
    in_ksize, 
    in_stride, 
    in_padding, 
    block_ksize, 
    block_stride, 
    block_padding, 
    block_mid_channel_multiplier,
    block_ksize_increment,
    padding_mode,
    dropout
)

resunet2d_decoder_attn = ResUNet2DDecoder_Attention(
    out_channels,
    num_blocks,
    base_channels,
    channel_multiplier,
    out_ksize,
    out_stride,
    out_padding,
    block_ksize,
    block_stride,
    block_padding,
    block_mid_channel_multiplier,
    block_ksize_increment,
    padding_mode,
    dropout,
    attn_channels
)

resunet2d_attn = ResUNet2D_Attention(resunet2d_encoder,
                                     resunet2d_decoder_attn)

In [5]:
# generating some fake data
input_imgs = torch.rand(3, in_channels, 64, 64)

In [7]:
enc_out, enc_features = resunet2d_encoder(input_imgs)
print('Encoder output shape:', enc_out.shape)
for name, feature in enc_features.items():
    print(f'Encoder feature {name} shape:', feature.shape)
print('-'*50)

dec_out, dec_features, attn_features = resunet2d_decoder_attn(enc_out, enc_features)
print('Decoder output shape:', dec_out.shape)
for name, feature in dec_features.items():
    print(f'Decoder feature {name} shape:', feature.shape)
print('-'*50)
for name, feature in attn_features.items():
    print(f'Attention feature {name} shape:', feature.shape)


print('='*50)

out, enc_features, dec_features, attn_features = resunet2d_attn(input_imgs)
print('Output shape:', out.shape)
for name, feature in enc_features.items():
    print(f'Encoder feature {name} shape:', feature.shape)
print('-'*50)
for name, feature in dec_features.items():
    print(f'Decoder feature {name} shape:', feature.shape)
print('-'*50)
for name, feature in attn_features.items():
    print(f'Attention feature {name} shape:', feature.shape)
print('-'*50)

Encoder output shape: torch.Size([3, 32, 64, 64])
Encoder feature encoder_block_0 shape: torch.Size([3, 32, 64, 64])
Encoder feature encoder_block_1 shape: torch.Size([3, 32, 64, 64])
Encoder feature encoder_block_2 shape: torch.Size([3, 32, 64, 64])
Encoder feature encoder_block_3 shape: torch.Size([3, 32, 64, 64])
Encoder feature encoder_block_4 shape: torch.Size([3, 32, 64, 64])
--------------------------------------------------
Decoder output shape: torch.Size([3, 8, 64, 64])
Decoder feature decoder_block_4 shape: torch.Size([3, 32, 64, 64])
Decoder feature decoder_block_3 shape: torch.Size([3, 32, 64, 64])
Decoder feature decoder_block_2 shape: torch.Size([3, 32, 64, 64])
Decoder feature decoder_block_1 shape: torch.Size([3, 32, 64, 64])
Decoder feature decoder_output shape: torch.Size([3, 8, 64, 64])
--------------------------------------------------
Attention feature decoder_attn_4 shape: torch.Size([3, 1, 64, 64])
Attention feature decoder_attn_3 shape: torch.Size([3, 1, 64, 64