# U-Net (paper implementation)

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F


def double_conv(in_channels, out_channels):
    """
    This function applies two convolutional layers
    each followed by a ReLU activation function
    :param in_channels: number of input channels
    :param out_channels: number of output channels
    :return: a down-conv layer
    """
    conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3),
                         nn.ReLU(inplace=True),
                         nn.Conv2d(out_channels, out_channels, kernel_size=3),
                         nn.ReLU(inplace=True))
    return conv


def crop_tensor(tensor, target_tensor):
    """
    Center crops a tensor to size of a given target tensor size
    Please note that this function is applicable only to
    this implementation of unet. There are a few assumptions
    in this implementation that might not be applicable to all
    networks and all other use-cases.
    Both tensors are of shape (bs, c, h, w)
    :param tensor: a tensor that needs to be cropped
    :param target_tensor: target tensor of smaller size
    :return: cropped tensor
    """
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = tensor_size - target_size
    delta = delta // 2
    return tensor[:, :, delta:tensor_size - delta, delta:tensor_size - delta]


class UNet(nn.Module):

    def __init__(self):
        super(UNet, self).__init__()
        # we need only one max_pool as it is not learned
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = double_conv(1, 64)
        self.down_conv_2 = double_conv(64, 128)
        self.down_conv_3 = double_conv(128, 256)
        self.down_conv_4 = double_conv(256, 512)
        self.down_conv_5 = double_conv(512, 1024)
        self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024,
                                             out_channels=512,
                                             kernel_size=2,
                                             stride=2)
        self.up_conv_1 = double_conv(1024, 512)
        self.up_trans_2 = nn.ConvTranspose2d(in_channels=512,
                                             out_channels=256,
                                             kernel_size=2,
                                             stride=2)
        self.up_conv_2 = double_conv(512, 256)
        self.up_trans_3 = nn.ConvTranspose2d(in_channels=256,
                                             out_channels=128,
                                             kernel_size=2,
                                             stride=2)
        self.up_conv_3 = double_conv(256, 128)
        self.up_trans_4 = nn.ConvTranspose2d(in_channels=128,
                                             out_channels=64,
                                             kernel_size=2,
                                             stride=2)
        self.up_conv_4 = double_conv(128, 64)
        self.out = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)

    def forward(self, image):
        # encoder
        x1 = self.down_conv_1(image)
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv_2(x2)
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv_3(x4)
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv_4(x6)
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv_5(x8)

        # decoder
        x = self.up_trans_1(x9)
        y = crop_tensor(x7, x)
        x = self.up_conv_1(torch.cat([x, y], axis=1))
        x = self.up_trans_2(x)
        y = crop_tensor(x5, x)
        x = self.up_conv_2(torch.cat([x, y], axis=1))
        x = self.up_trans_3(x)
        y = crop_tensor(x3, x)
        x = self.up_conv_3(torch.cat([x, y], axis=1))
        x = self.up_trans_4(x)
        y = crop_tensor(x1, x)
        x = self.up_conv_4(torch.cat([x, y], axis=1))

        # output layer
        out = self.out(x)
        return out


In [2]:
# initialize the model and run it on some random data
model = UNet()
x = torch.randn(1, 1, 572, 572)
y = model(x)
print(y.shape)

torch.Size([1, 2, 388, 388])


# PredNet Implementation (Zhen Shen 2019) 256x256


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class EncoderBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv(x)
        x = self.max_pool(x)
        return x


class DecoderBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels,
                                     in_channels // 2,
                                     kernel_size=2,
                                     stride=2)
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))

    def forward(self, x, enc):
        x = self.up(x)
        diffY = enc.size()[2] - x.size()[2]
        diffX = enc.size()[3] - x.size()[3]
        x = F.pad(
            x,
            [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([enc, x], dim=1)
        x = self.conv(x)
        return x


class PredNet(nn.Module):

    def __init__(self):
        super(PredNet, self).__init__()
        self.enc1 = EncoderBlock(1, 64)
        self.enc2 = EncoderBlock(64, 128)
        self.enc3 = EncoderBlock(128, 256)
        self.enc4 = EncoderBlock(256, 512)
        self.dec1 = DecoderBlock(512, 256)
        self.dec2 = DecoderBlock(256, 128)
        self.dec3 = DecoderBlock(128, 64)
        self.out = nn.Sequential(nn.Conv2d(64, 2, kernel_size=1),
                                 nn.BatchNorm2d(2), nn.Sigmoid())

    def forward(self, x):
        # x = F.interpolate(x, size=(256, 256))
        print(x.shape)
        enc1 = self.enc1(x)
        print(enc1.shape)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        dec1 = self.dec1(enc4, enc3)
        dec2 = self.dec2(dec1, enc2)
        dec3 = self.dec3(dec2, enc1)
        print(dec3.shape)
        out = self.out(dec3)
        # out = F.interpolate(out, size=(256, 256))
        return out


In [4]:
# initialize the model and run it on some random data
model = PredNet()
x = torch.randn(1, 1, 256, 256)
y = model(x)
print(y.shape)


torch.Size([1, 1, 256, 256])
torch.Size([1, 64, 128, 128])
torch.Size([1, 64, 128, 128])
torch.Size([1, 2, 128, 128])


# PredNet Implementation (Zhen Shen 2019) 32x32x32

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class EncoderBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True))
        self.max_pool = nn.MaxPool3d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv(x)
        x = self.max_pool(x)
        return x


class DecoderBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.up = nn.ConvTranspose3d(in_channels,
                                     in_channels // 2,
                                     kernel_size=2,
                                     stride=2)
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True))

    def forward(self, x, enc):
        x = self.up(x)
        diffZ = enc.size()[2] - x.size()[2]
        diffY = enc.size()[3] - x.size()[3]
        diffX = enc.size()[4] - x.size()[4]
        x = F.pad(
            x,
            [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])
        x = torch.cat([enc, x], dim=1)
        x = self.conv(x)
        return x
class PredNet(nn.Module):

    def __init__(self):
        super(PredNet, self).__init__()
        self.enc1 = EncoderBlock(1, 64)
        self.enc2 = EncoderBlock(64, 128)
        self.enc3 = EncoderBlock(128, 256)
        self.enc4 = EncoderBlock(256, 512)
        self.dec1 = DecoderBlock(512, 256)
        self.dec2 = DecoderBlock(256, 128)
        self.dec3 = DecoderBlock(128, 64)
        self.out = nn.Sequential(nn.Conv3d(64, 1, kernel_size=1),
                                 nn.BatchNorm3d(1), nn.Sigmoid())

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        dec1 = self.dec1(enc4, enc3)
        dec2 = self.dec2(dec1, enc2)
        dec3 = self.dec3(dec2, enc1)
        out = self.out(dec3)
        out = F.interpolate(out,
                             size=x.shape[2:],
                             mode='trilinear',
                             align_corners=False)
        return out


In [6]:
# initialize the model and run it on some random data
model = PredNet()
x = torch.randn(1, 1, 32, 32, 32)
y = model(x)
print(y.shape)


torch.Size([1, 1, 32, 32, 32])


# Final PredNet Implementation (Zhen Shen 2019) 16x16x16

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class EncoderBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True))
        self.max_pool = nn.MaxPool3d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv(x)
        x = self.max_pool(x)
        return x


class SpatialBranch(nn.Module):
    def __init__(self):
        super(SpatialBranch, self).__init__()
        self.conv1 = nn.Conv3d(1, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv3d(128, 256, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv4 = nn.Conv3d(256, 512, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.pool3(x)
        x = self.conv4(x)
        x = self.pool4(x)
        return x


class PredNet(nn.Module):

    def __init__(self):
        super(PredNet, self).__init__()
        self.enc1 = EncoderBlock(1, 64)
        self.enc2 = EncoderBlock(64, 128)
        self.enc3 = EncoderBlock(128, 256)
        self.enc4 = EncoderBlock(256, 512)
        self.spatial = SpatialBranch()
        self.flatten = nn.Flatten()
        # step 4
        self.fc = nn.Sequential(nn.ReLU(inplace=True), nn.Linear(1024, 4096),
                                nn.Sigmoid())

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)

        # step 3
        enc4 = self.flatten(enc4)

        # output of the spatial
        spatial = self.spatial(x)
        spatial = self.flatten(spatial)

        # It is combined with the output of the spatial data branch
        x = torch.cat((enc4, spatial), dim=1)

        x = self.fc(x)
        x = x.view(-1, 16, 16, 16)

        return x


In [11]:
# initialize the model and run it on some random data
model = PredNet()
x = torch.randn(1, 1, 16, 16, 16)
y = model(x)
print('output:', y.shape)

output: torch.Size([1, 16, 16, 16])
