In [1]:
# Understand U-Net first

import torch
import torch.nn as nn

torch.manual_seed(0)

# make up some data
x = torch.rand(1, 3, 128, 192)

In [3]:
model_dim = 16
input_conv = nn.Conv2d(3, model_dim, kernel_size=1)

input_conv(x).shape

torch.Size([1, 16, 128, 192])

In [4]:
class UBlockD(nn.Module):
    def __init__(self, model_dim, kernel_size=3):
        super(UBlockD, self).__init__()
        conv1 = nn.Conv2d(model_dim, model_dim, kernel_size, padding=1)
        conv2 = nn.Conv2d(model_dim, model_dim*2, kernel_size, padding=1)
        self.block = nn.Sequential(conv1, nn.ReLU(), conv2, nn.ReLU())

    def forward(self, x):
        return self.block(x)

In [5]:
class UBlockU(nn.Module):
    def __init__(self, model_dim, kernel_size=3):
        super(UBlockU, self).__init__()
        conv1 = nn.Conv2d(model_dim*2, model_dim, kernel_size, padding=1)
        conv2 = nn.Conv2d(model_dim, model_dim//2, kernel_size, padding=1)
        self.upsample_conv = nn.ConvTranspose2d(model_dim*2, model_dim, kernel_size=2, stride=2)
        self.block = nn.Sequential(conv1, nn.ReLU(), conv2, nn.ReLU())

    def forward(self, x_same, x_up):
        '''
        x_same: skipped link from the same level of contraction
        x_up: up-sampled
        '''
        x = torch.cat([x_same, self.upsample_conv(x_up)], dim=1)
        return self.block(x)

In [7]:
current_model_dim = model_dim
b1 = UBlockD(current_model_dim)
current_model_dim *= 2 
pool = nn.MaxPool2d(2)
b2 = UBlockD(current_model_dim)
current_model_dim *= 2
# reach bottom of "U"
current_model_dim //= 2
u1 = UBlockU(current_model_dim)

In [9]:
h1 = b1(input_conv(x))
h2 = pool(h1)
h2 = b2(h2)
print("Gating signal x_up to next up-sampling layer", h2.shape)

gs = h2
# gating signal

print("Reference of attention x_sample", h1.shape)
# attention is computed on h1 and gs

g1 = u1(h1, h2)
print(g1.shape)

Gating signal x_up to next up-sampling layer torch.Size([1, 64, 64, 96])
Reference of attention x_sample torch.Size([1, 32, 128, 192])
torch.Size([1, 16, 128, 192])
