In [1]:
!python3 -m pip install --upgrade pip



In [2]:
!pip install torch torchvision 
!pip install opencv-python scipy matplotlib 
!pip install scikit-image scikit-learn pandas jupyterlab numpy



In [3]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import sys
from torch.optim import Adam
import math

In [70]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch, up=False):
        super().__init__()
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding='same')
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1) 
            #nn.ConvTranspose2d: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
            # kernel = 4 stride =2 padding =1 
            #(W−F+2P)/S+1 = (28 - 4 + 2)/2 +1 = 14 
            #(W−F+2P)/S+1 = (14 - 4 + 2)/2 +1 = 7 .. 
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding='same')
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 1, padding='same')
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()

    def forward(self, x ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)

In [24]:
Block(32,32)

Block(
  (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (transform): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), padding=same)
  (bnorm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bnorm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
)

In [25]:
Block(32,32, up=True)

Block(
  (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (transform): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), padding=same)
  (bnorm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bnorm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
)

In [27]:
down_channels = (32,32,32)
up_channels = (32,32,32)
downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1]) for i in range(len(down_channels)-1)])
ups   = nn.ModuleList([Block(up_channels[i], up_channels[i+1], up=True) for i in range(len(up_channels)-1)])

In [28]:
downs

ModuleList(
  (0-1): 2 x Block(
    (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (transform): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv2): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), padding=same)
    (bnorm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bnorm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
)

In [29]:
ups

ModuleList(
  (0-1): 2 x Block(
    (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (transform): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv2): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), padding=same)
    (bnorm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bnorm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
)

In [71]:
class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels = 1
        down_channels = (32,32,32)
        up_channels = (32,32,32)
        out_dim = 1
        # Time embedding
        self.time_mlp = nn.Sequential(
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )

        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], image_channels, padding='same')
        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1]) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], up=True) \
                    for i in range(len(up_channels)-1)])
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x):

        # Initial conv
        #print("x", x.shape)
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x)
            #print("x",x.shape)
            residual_inputs.append(x)
        #print("residual_inputs", len(residual_inputs), [ residual_inputs[i].shape for i in range(len(residual_inputs))])
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            #print("x_up 1", x.shape, "residual_pop", residual_x.shape)
            x = torch.cat((x, residual_x), dim=1)
            #print("x_up 2", x.shape)
            x = up(x)
            #print("x_up 3 after deconv", x.shape)
        return self.output(x)

In [72]:
model = SimpleUnet()
print("Num params: ", sum(p.numel() for p in model.parameters()))
print(model)

Num params:  126977
SimpleUnet(
  (time_mlp): Sequential(
    (0): Linear(in_features=32, out_features=32, bias=True)
    (1): ReLU()
  )
  (conv0): Conv2d(1, 32, kernel_size=(1, 1), stride=(1, 1), padding=same)
  (downs): ModuleList(
    (0-1): 2 x Block(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (transform): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), padding=same)
      (bnorm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bnorm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
  )
  (ups): ModuleList(
    (0-1): 2 x Block(
      (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (transform): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 

In [73]:
x=torch.rand([10,1,28,28])
print(x.shape)

torch.Size([10, 1, 28, 28])


In [74]:
y = model(x)
print(y.shape)

torch.Size([10, 1, 28, 28])
