In [1]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F

# UNET

In [2]:
def conv_block(in_channels, out_channels):
    layers = [
        nn.Conv2d(in_channels, out_channels, 3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3),
        nn.ReLU(inplace=True),
    ]

    return nn.Sequential(*layers)

In [3]:
class upconv_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.ct = nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
        self.enc = conv_block(in_channels, out_channels)

    def forward(self, x, y):
        x = self.ct(x)

        # cropping y to match size of x
        xs = x.shape[2]
        ys = y.shape[2]
        ds = (ys - xs) // 2
        y = y[:, :, ds : ys - ds, ds : ys - ds]

        z = torch.cat([y, x], dim=1)

        return self.enc(z)

In [4]:
class Unet(pl.LightningModule):
    def __init__(self, num_classes):
        super().__init__()

        self.pool = nn.MaxPool2d(2)

        self.enc1 = conv_block(1, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)
        self.enc5 = conv_block(512, 1024)

        self.dec1 = upconv_block(1024, 512)
        self.dec2 = upconv_block(512, 256)
        self.dec3 = upconv_block(256, 128)
        self.dec4 = upconv_block(128, 64)

        self.out = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        # encoding
        dc1 = self.enc1(x)
        pooled_dc1 = self.pool(dc1)

        dc2 = self.enc2(pooled_dc1)
        pooled_dc2 = self.pool(dc2)

        dc3 = self.enc3(pooled_dc2)
        pooled_dc3 = self.pool(dc3)

        dc4 = self.enc4(pooled_dc3)
        pooled_dc4 = self.pool(dc4)

        dc5 = self.enc5(pooled_dc4)

        # decoding
        uc1 = self.dec1(dc5, dc4)
        uc2 = self.dec2(uc1, dc3)
        uc3 = self.dec3(uc2, dc2)
        uc4 = self.dec4(uc3, dc1)

        # output
        out = self.out(uc4)

        return out

# Execution

In [5]:
model = Unet(2)

sample = torch.randn((1, 1, 572, 572))

print("sample.shape: ", sample.shape)
print("model(sample).shape: ", model(sample).shape)

sample.shape:  torch.Size([1, 1, 572, 572])
model(sample).shape:  torch.Size([1, 2, 388, 388])
