In [3]:
import torch
import numpy as np
import math
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch import nn
import torch.nn.functional as F
import os

In [4]:
def f_cosine(t, T, s):
  return math.cos((t/T + s)/(1 + s))**2

def a_t(t,T,s):
  return f_cosine(t, T, s)/f_cosine(0, T, s)

def beta_t(t, T, s):
  return np.clip(1 - a_t(t, T, s)/a_t(t-1, T, s), a_min=-1, a_max=0.999)

def cosine_schedule(T, s=0.001):
  return [beta_t(t, T, s) for t in range(1,T+1)]

In [5]:
import random

class NoiseTransform:
  def __init__(self, schedule):
    self.schedule = schedule

  def __call__(self, img):
    T = len(self.schedule)
    t = np.random.randint(low=1, high=T)
    alphas_prod = np.prod((1 - np.array(self.schedule))[:t])
    #print(alphas_prod)
    e = torch.normal(mean=0, std=1, size=img.shape)
    img_with_noise = (img * math.sqrt(alphas_prod) + e * math.sqrt(1-alphas_prod)).clip(min=0, max=1)
    return (img_with_noise, t, e)

In [6]:
image_transform_with_noise = transforms.Compose([
    transforms.Resize(size=(28, 28)),
    transforms.ToTensor(),
    NoiseTransform(cosine_schedule(100))
])

In [7]:
train_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=image_transform_with_noise,
    target_transform=None
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=image_transform_with_noise,
    target_transform=None
)

train_data, test_data

(Dataset MNIST
     Number of datapoints: 60000
     Root location: data
     Split: Train
     StandardTransform
 Transform: Compose(
                Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=True)
                ToTensor()
                <__main__.NoiseTransform object at 0x000002313BF95E50>
            ),
 Dataset MNIST
     Number of datapoints: 10000
     Root location: data
     Split: Test
     StandardTransform
 Transform: Compose(
                Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=True)
                ToTensor()
                <__main__.NoiseTransform object at 0x000002313BF95E50>
            ))

In [8]:
BATCH_SIZE = 32
NUM_WORKERS = 0

train_dataloader = DataLoader(dataset=train_data,
                              batch_size=BATCH_SIZE,
                              num_workers=NUM_WORKERS,
                              shuffle=True)

test_dataloader = DataLoader(dataset=test_data,
                              batch_size=BATCH_SIZE,
                              num_workers=NUM_WORKERS,
                              shuffle=False)

train_dataloader, test_dataloader, NUM_WORKERS

(<torch.utils.data.dataloader.DataLoader at 0x2315cd1a690>,
 <torch.utils.data.dataloader.DataLoader at 0x2315c756390>,
 0)

In [9]:
class PositionalEncoding(nn.Module):
  def __init__(self, n_dim):
    super(PositionalEncoding, self).__init__()

    self.n_dim = n_dim

  def forward(self, x):
    x_shape = x.shape
    x_flat = x.reshape(-1)
    out = torch.zeros(size=(*x_flat.shape, self.n_dim))
    for idx, elem in enumerate(x_flat):
      out[idx, :] = self.__positional_encoding(elem, self.n_dim)

    out = out.reshape((*x_shape, self.n_dim))
    return out

  def __positional_encoding(self, pos, ndim):
    return torch.tensor([ (math.sin(pos/10000**(int(i/2) / ndim )) if i%2 == 0 else  math.cos(pos/10000**(int(i/2) / ndim )))  for i in range(ndim)]).type(torch.float)


In [10]:
class DoubleConv(nn.Module):
  def __init__(self, in_ch, out_ch):
    super(DoubleConv, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True),
      )

  def forward(self, x):
    x = self.conv(x)
    return x

In [11]:
class DownLayer(nn.Module):
  def __init__(self, in_ch, out_ch, img_height, img_width, n_dim):
    super(DownLayer, self).__init__()

    self.img_width = img_width
    self.img_height = img_height
    self.out_ch = out_ch

    self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    self.conv = DoubleConv(in_ch, out_ch)

    self.silu = nn.SiLU()
    self.lin_pe = nn.Linear(n_dim, out_ch * (img_height // 2) * (img_width // 2))

  def forward(self, x, pet_s):
    x = self.max_pool(x)
    x = self.conv(x)

    pet_s_1 = self.silu(pet_s)
    pet_s_2 = self.lin_pe(pet_s_1)
    pet_s_3 = pet_s_2.view(-1, self.out_ch, self.img_height // 2, self.img_width // 2)

    #print(x.shape, pet_s.shape)
    x = x + pet_s_3

    return x

In [12]:
class Up(nn.Module):
  def __init__(self, in_ch, out_ch):
    super(Up, self).__init__()

    self.up_scale = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)

  def forward(self, x1, x2):
    x2 = self.up_scale(x2)

    diffY = x1.size()[2] - x2.size()[2]
    diffX = x1.size()[3] - x2.size()[3]

    x2 = F.pad(x2, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2 ])

    x = torch.cat([x1,x2], dim=1)

    return x


In [13]:
class UpLayer(nn.Module):
  def __init__(self, in_ch, out_ch, img_height, img_width, n_dim):
    super(UpLayer, self).__init__()

    self.img_width = img_width
    self.img_height = img_height
    self.out_ch = out_ch

    self.up = Up(in_ch, out_ch)
    self.conv = DoubleConv(in_ch, out_ch)

    self.silu = nn.SiLU()
    self.lin_pe = nn.Linear(n_dim, out_ch * (img_height * 2) * (img_width * 2))

  def forward(self, x1, x2, pet_s):
    a = self.up(x1, x2)
    x = self.conv(a)

    pet_s_1 = self.silu(pet_s)
    pet_s_2 = self.lin_pe(pet_s_1)
    pet_s_3 = pet_s_2.view(-1, self.out_ch, self.img_height * 2, self.img_width * 2)

    #print(x.shape, pet_s.shape)
    x = x + pet_s_3

    return x

In [14]:
class SelfAttentionBlock(nn.Module):
  def __init__(self, in_ch, num_heads):
    super(SelfAttentionBlock, self).__init__()
    
    self.layer_norm_mha = nn.LayerNorm([in_ch])

    self.mha = nn.MultiheadAttention(embed_dim=in_ch, num_heads=num_heads)

    self.layer_norm = nn.LayerNorm([in_ch])
    self.linear_1 = nn.Linear(in_ch, in_ch)
    self.gelu = nn.GELU()
    self.linear_2 = nn.Linear(in_ch, in_ch)

  def forward(self, x: torch.Tensor):

    x_shape = x.shape

    x = x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1)

    x1 = self.layer_norm_mha(x)
    x2, _ = self.mha(x1, x1, x1) 

    x3 = x + x2

    x3 = self.layer_norm(x3)
    x3 = self.linear_1(x3)
    x3 = self.gelu(x3)
    x3 = self.linear_2(x3)

    x4 = x2 + x3

    x4 = x4.permute(0,2,1).reshape(-1, x_shape[-3],  x_shape[-2], x_shape[-1])

    return x4



In [34]:
class UNet(nn.Module):
  def __init__(self):
    super(UNet, self).__init__()

    self.pos_enc = PositionalEncoding(50)

    self.doubleconv = DoubleConv(1,32)
    self.down1 = DownLayer(32,64,28,28,50)
    self.down1_attention = SelfAttentionBlock(64, 4)
    self.down2 = DownLayer(64,128,14,14,50)
    self.down2_attention = SelfAttentionBlock(128, 4)
    self.middle_conv = DoubleConv(128, 128)
    self.up1 = UpLayer(128,64, 7, 7, 50)
    self.up1_attention = SelfAttentionBlock(64, 4)
    self.up2 = UpLayer(64,32, 14, 14, 50)
    self.up2_attention = SelfAttentionBlock(32, 4)
    self.last_conv = nn.Conv2d(32, 1, 1)

  def forward(self, x, t):
    t = self.pos_enc(t)

    x1 = self.doubleconv(x)
    x2 = self.down1(x1, t)
    x2 = self.down1_attention(x2)
    x3 = self.down2(x2, t)
    x3 = self.down2_attention(x3)
    x3 = self.middle_conv(x3)
    x1_up = self.up1(x2, x3, t)
    x1_up = self.up1_attention(x1_up)
    x2_up = self.up2(x1, x1_up, t)
    x2_up = self.up2_attention(x2_up)
    output = self.last_conv(x2_up)

    return output

In [35]:
model = UNet()

In [36]:
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(params=model.parameters(),
                            lr=0.1)

In [38]:
from tqdm.auto import tqdm

epochs = 10

for epoch in range(epochs):
  print(f"Epoch: {epoch}\n----------")

  model.train()
  train_loss = 0
  for (X, t, e), _ in tqdm(train_dataloader):

    e_pred = model(X, t)

    loss = loss_fn(e_pred,e)
    train_loss += loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
      
  train_loss /= len(train_dataloader)

  test_loss = 0
  model.eval()
  with torch.inference_mode():
    for (X_test, t_test, e_test), _ in tqdm(test_dataloader):
      test_pred = model(X_test, t_test)

      test_loss += loss_fn(test_pred, e_test)

    test_loss /= len(test_dataloader)

  print(f"\nTrain loss: {train_loss:.5f} | Test loss: {test_loss:.5f}\n")


Epoch: 0
----------


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [10:06<00:00,  3.09it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 313/313 [00:41<00:00,  7.49it/s]



Train loss: 0.22708 | Test loss: 0.22939

Epoch: 1
----------


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [10:34<00:00,  2.95it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 313/313 [00:41<00:00,  7.59it/s]



Train loss: 0.22614 | Test loss: 0.40983

Epoch: 2
----------


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [10:48<00:00,  2.89it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 313/313 [00:42<00:00,  7.39it/s]



Train loss: 0.22522 | Test loss: 0.33875

Epoch: 3
----------


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [10:11<00:00,  3.07it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 313/313 [00:39<00:00,  7.86it/s]



Train loss: 0.22484 | Test loss: 0.23752

Epoch: 4
----------


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [10:04<00:00,  3.10it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 313/313 [00:39<00:00,  7.93it/s]



Train loss: 0.22374 | Test loss: 0.22470

Epoch: 5
----------


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [10:04<00:00,  3.10it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 313/313 [00:39<00:00,  7.94it/s]



Train loss: 0.22321 | Test loss: 0.22470

Epoch: 6
----------


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [10:04<00:00,  3.10it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 313/313 [00:39<00:00,  7.93it/s]



Train loss: 0.22307 | Test loss: 0.22380

Epoch: 7
----------


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [10:04<00:00,  3.10it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 313/313 [00:39<00:00,  7.94it/s]



Train loss: 0.22255 | Test loss: 0.22226

Epoch: 8
----------


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [10:04<00:00,  3.10it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 313/313 [00:39<00:00,  7.92it/s]



Train loss: 0.22203 | Test loss: 0.22245

Epoch: 9
----------


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [10:15<00:00,  3.05it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 313/313 [00:44<00:00,  6.99it/s]


Train loss: 0.22129 | Test loss: 0.22536






In [23]:
model

UNet(
  (pos_enc): PositionalEncoding()
  (doubleconv): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): DownLayer(
    (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv): DoubleConv(
      (conv): Sequential(
        (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, aff