In [None]:
%load_ext autoreload
%autoreload 2

%env PYDEVD_DISABLE_FILE_VALIDATION=1

In [None]:
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader,Subset
import math
#from diffusers import DDPMScheduler, UNet2DModel
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from matplotlib import pyplot as plt
from icecream import ic
import numpy as np
from datetime import datetime
from cop_diffusion.utils import save_model, load_model

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
train_dataset = torchvision.datasets.MNIST(
    root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor()
)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
oh_y = torch.nn.functional.one_hot(y)
ic("Input shape:", x.shape)
ic("Labels:", y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap="Greys")


In [None]:
class BasicUNet(nn.Module):
    def __init__(self, ctx_nb_feats:int, in_channels:int=1, out_channels:int=1):
        super().__init__()
        self.d1 = nn.Conv2d(in_channels, 64, kernel_size=5, padding=2)
        self.d2 = nn.Conv2d(64, 128, kernel_size=5, padding=2)
        self.d3 = nn.Conv2d(128, 256, kernel_size=5, padding=2)

        self.u1= nn.Conv2d(256, 128, kernel_size=5, padding=2)
        self.u2= nn.Conv2d(128, 64, kernel_size=5, padding=2)
        self.u3= nn.Conv2d(64, out_channels, kernel_size=5, padding=2)

        self.ce2 = nn.Embedding(ctx_nb_feats,128)
        self.te2 = nn.Embedding(1,128)

        self.ce3 = nn.Embedding(ctx_nb_feats,64)
        self.te3 = nn.Embedding(1,64)

        self.act = nn.SiLU()
        self.downsample = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2)


    def forward(self, x, t, c=None):
        xd1 = self.act(self.d1(x))
        xmd1 = self.downsample(xd1)
        xd2 = self.act(self.d2(xmd1))
        xmd2 =self.downsample(xd2)
        xd3 = self.act(self.d3(xmd2))

        xu1 = self.act(self.u1(xd3))
        xus2 = self.upsample(xu1)
        xus2 = xus2 + xd2
        h2_dim  = xus2.shape[1] # the dim of hidden
        t2_emb = self.te2(t).view(-1, h2_dim, 1,1)
        c2_emb = self.ce2(c).view(-1, h2_dim, 1,1)
        xus2 = c2_emb * xus2 + t2_emb
        xu2 = self.act(self.u2(xus2))

        xus3 = self.upsample(xu2)
        xus3 = xus3 + xd1
        h3_dim  = xus3.shape[1] # the dim of hidden
        t3_emb = self.te3(t).view(-1, h3_dim, 1,1)
        c3_emb = self.ce3(c).view(-1, h3_dim, 1,1)
        xus2 = c3_emb * xus3 + t3_emb
        xu3 = self.act(self.u3(xus3))

        return xu3

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
model = BasicUNet(ctx_nb_feats=9).to(device)
#x = torch.rand(8, 1, 28, 28)
t = torch.zeros(x.shape[0]).long().to(device)
#oh_y = oh_y.to(device)
x = x.to(device)
y = y.long().to(device)
ic(t.shape, y.shape)
ux = model(x,t, y)
ic(ux.shape, x.shape)
ic(torchvision.utils.make_grid(x.cpu())[0].shape)
plt.imshow(torchvision.utils.make_grid(x.cpu())[0], cmap="Greys")
plt.show()
plt.imshow(torchvision.utils.make_grid(ux.cpu())[0], cmap="Greys")
plt.show()

In [None]:
class BasicUNetOld(nn.Module):
    """A minimal UNet implementation."""
    #TODO: add time embedding and a class embedding

    def __init__(self, in_channels=1, out_channels=1, ctx_nb_feats:int=0):
        super().__init__()
        self.down_layers = torch.nn.ModuleList(
            [
                nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
                nn.Conv2d(32, 64, kernel_size=5, padding=2),
                nn.Conv2d(64, 64, kernel_size=5, padding=2),
            ]
        )
        self.time_embeds = nn.ModuleList([nn.Embedding(1,64), nn.Embedding(1,32)])
        self.ctx_embeds = nn.ModuleList([nn.Embedding(ctx_nb_feats,64), nn.Embedding(ctx_nb_feats,32)])
        self.up_layers = torch.nn.ModuleList(
            [
                nn.Conv2d(64, 64, kernel_size=5, padding=2),
                nn.Conv2d(64, 32, kernel_size=5, padding=2),
                nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
            ]
        )

        self.act = nn.SiLU()  # The activation function
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)

    def forward(self, x, t, c=None):
        h = []
        for i, l in enumerate(self.down_layers):
            x = self.act(l(x))  # Through the layer and the activation function
            if i < 2:  # For all but the third (final) down layer:
                h.append(x)  # Storing output for skip connection
                x = self.downscale(x)  # Downscale ready for the next layer
        ic("after downscale layers", x.shape)
        for i, l in enumerate(self.up_layers):

            if i > 0:  # For all except the first up layer
                ic("before upscale",i, x.shape)
                x = self.upscale(x)  # Upscale
                ic("after upscale",i,x.shape)

                x += h.pop()  # Fetching stored output (skip connection)
                h_dim  = x.shape[1] # the dim of hidden
                t_emb = self.time_embeds[i-1](t).view(-1, h_dim, 1,1)
                ctx_emb = self.ctx_embeds[i-1](c).view(-1, h_dim, 1,1)
                ic(x.shape, ctx_emb.shape, t_emb.shape)
                x = ctx_emb * x + t_emb
            ic("before uplayer + activation",i, x.shape)
            x = self.act(l(x))  # Through the layer and the activation function
            ic("after uplayer + activation",i, x.shape)
        return x


model = BasicUNetOld().to(device)
#x = torch.rand(8, 1, 28, 28)
t = torch.zeros(x.shape[0]).long().to(device)
#oh_y = oh_y.to(device)
x = x.to(device)
ic(t.shape, y.shape)
ux = model(x,t, y)
ic(ux.shape, x.shape)
plt.imshow(torchvision.utils.make_grid(x.cpu())[0], cmap="Greys")
plt.show()
plt.imshow(torchvision.utils.make_grid(ux.cpu())[0], cmap="Greys")
plt.show()
