In [1]:
from pathlib import Path
import torch
from torch import nn
from torchvision import datasets, transforms
from diffusion.noise_scheduler import NoiseScheduler

In [2]:
DATA_PATH = Path("../data")
test_data = datasets.CelebA(
    root=DATA_PATH,
    split="test",
    transform=transforms.ToTensor(),
    download=True
)
data = test_data[0][0].reshape(1, 3, 218, 178)

Files already downloaded and verified


In [3]:
class ResBlock(nn.Module):
    def __init__(self, shape, in_c, out_c, activation, mode, kernel_size=3, stride=1, padding=1) -> None:
        super().__init__()
        self.normalize = nn.LayerNorm((in_c, *shape))
        self.Conv = nn.Conv2d if mode == "down" else nn.ConvTranspose2d

        self.conv0 = self.Conv(in_c, out_c, kernel_size, stride, padding)
        self.conv1 = self.Conv(out_c, out_c, kernel_size, stride, padding)
        self.conv2 = self.Conv(out_c, out_c, kernel_size, stride, padding)
        self.activation = activation

        self.shortcut = self.Conv(in_c, out_c, kernel_size, stride, padding)

    def forward(self, xs):
        xs = self.normalize(xs)

        out = self.conv0(xs)
        out = self.activation(out)
        out = self.conv1(out)
        out = self.activation(out)
        out = self.conv2(out)
        out = self.activation(out)

        out += self.shortcut(xs)

        return out

In [4]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("embedding_w", torch.randn(256//2))
        self.noise_scheduler = NoiseScheduler(0, 0.02, 1000)

        self.channels = [3, 16, 32, 64, 128]
        self.activation = nn.SiLU()

        self.l0 = nn.Linear(256, self.channels[0])
        self.b0 = nn.Sequential(
            ResBlock((218, 178), self.channels[0], self.channels[1], self.activation, "down"),
            ResBlock((218, 178), self.channels[1], self.channels[1], self.activation, "down"),
            ResBlock((218, 178), self.channels[1], self.channels[1], self.activation, "down"),
        )
        self.down0 = nn.Conv2d(self.channels[1], self.channels[1], kernel_size=3, stride=2, padding=1)

        self.l1 = nn.Linear(256, self.channels[1])
        self.b1 = nn.Sequential(
            ResBlock((109, 89), self.channels[1], self.channels[2], self.activation, "down"),
            ResBlock((109, 89), self.channels[2], self.channels[2], self.activation, "down"),
            ResBlock((109, 89), self.channels[2], self.channels[2], self.activation, "down"),
        )
        self.down1 = nn.Conv2d(self.channels[2], self.channels[2], kernel_size=3, stride=2, padding=1)

        self.l2 = nn.Linear(256, self.channels[2])
        self.b2 = nn.Sequential(
            ResBlock((55, 45), self.channels[2], self.channels[3], self.activation, "down"),
            ResBlock((55, 45), self.channels[3], self.channels[3], self.activation, "down"),
            ResBlock((55, 45), self.channels[3], self.channels[3], self.activation, "down"),
        )
        self.down2 = nn.Conv2d(self.channels[3], self.channels[3], kernel_size=3, stride=2, padding=1)

        self.l3 = nn.Linear(256, self.channels[3])
        self.b3 = nn.Sequential(
            ResBlock((28, 23), self.channels[3], self.channels[4], self.activation, "down"),
            ResBlock((28, 23), self.channels[4], self.channels[4], self.activation, "down"),
            ResBlock((28, 23), self.channels[4], self.channels[4], self.activation, "down"),
        )
        self.down3 = nn.Conv2d(self.channels[4], self.channels[4], kernel_size=3, stride=2, padding=1)


        self.lt3 = nn.Linear(256, self.channels[4])
        self.bt3 = nn.Sequential(
            ResBlock((14, 12), self.channels[4], self.channels[3], self.activation, "up"),
            ResBlock((14, 12), self.channels[3], self.channels[3], self.activation, "up"),
            ResBlock((14, 12), self.channels[3], self.channels[3], self.activation, "up"),
        )
        self.up3 = nn.ConvTranspose2d(self.channels[3], self.channels[3], kernel_size=(4, 3), stride=2, padding=1)

        self.lt2 = nn.Linear(256, self.channels[3])
        self.bt2 = nn.Sequential(
            ResBlock((28, 23), 2*self.channels[3], self.channels[2], self.activation, "up"),
            ResBlock((28, 23), self.channels[2], self.channels[2], self.activation, "up"),
            ResBlock((28, 23), self.channels[2], self.channels[2], self.activation, "up"),
        )
        self.up2 = nn.ConvTranspose2d(self.channels[2], self.channels[2], kernel_size=3, stride=2, padding=1)

        self.lt1 = nn.Linear(256, self.channels[2])
        self.bt1 = nn.Sequential(
            ResBlock((55, 45), 2*self.channels[2], self.channels[1], self.activation, "up"),
            ResBlock((55, 45), self.channels[1], self.channels[1], self.activation, "up"),
            ResBlock((55, 45), self.channels[1], self.channels[1], self.activation, "up"),
        )
        self.up1 = nn.ConvTranspose2d(self.channels[1], self.channels[1], kernel_size=3, stride=2, padding=1)

        self.lt0 = nn.Linear(256, self.channels[1])
        self.bt0 = nn.Sequential(
            ResBlock((109, 89), 2*self.channels[1], self.channels[0], self.activation, "up"),
            ResBlock((109, 89), self.channels[0], self.channels[0], self.activation, "up"),
            ResBlock((109, 89), self.channels[0], self.channels[0], self.activation, "up"),
        )
        self.up0 = nn.ConvTranspose2d(self.channels[0], self.channels[0], kernel_size=4, stride=2, padding=1)

    def preprocess(self, xs):
        ts = torch.randint(1, self.noise_scheduler.steps, (len(xs),))
        xs, ys = self.noise_scheduler.forward_process(xs, ts)

        return xs, ts, ys

    def embed(self, ts, linear):
        embedding = 30 * torch.outer(ts, self.embedding_w)
        embedding = torch.cat([torch.sin(embedding), torch.cos(embedding)], dim=1)
        embedding = self.activation(linear(embedding))
        embedding = embedding.reshape(*embedding.shape, 1, 1)

        return embedding

    def forward(self, xs, ts):
        print(f"xs = {xs.shape}")

        h0 = self.down0(self.b0(xs + self.embed(ts, self.l0)))
        print(f"h0 = {h0.shape}")
        h1 = self.down1(self.b1(h0 + self.embed(ts, self.l1)))
        print(f"h1 = {h1.shape}")
        h2 = self.down2(self.b2(h1 + self.embed(ts, self.l2)))
        print(f"h2 = {h2.shape}")
        h3 = self.down3(self.b3(h2 + self.embed(ts, self.l3)))
        print(f"h3 = {h3.shape}")

        h = self.up3(self.bt3(h3 + self.embed(ts, self.lt3)))
        print(f"ht1 = {h.shape}")
        h = self.up2(self.bt2(torch.cat((h + self.embed(ts, self.lt2), h2), dim=1)))
        print(f"ht1 = {h.shape}")
        h = self.up1(self.bt1(torch.cat((h + self.embed(ts, self.lt1), h1), dim=1)))
        print(f"ht1 = {h.shape}")
        h = self.up0(self.bt0(torch.cat((h + self.embed(ts, self.lt0), h0), dim=1)))
        print(f"ht0 = {h.shape}")

        return h


model = Model()
xs, ts, _ = model.preprocess(data)
model(xs, ts)

xs = torch.Size([1, 3, 218, 178])
h0 = torch.Size([1, 16, 109, 89])
h1 = torch.Size([1, 32, 55, 45])
h2 = torch.Size([1, 64, 28, 23])
h3 = torch.Size([1, 128, 14, 12])
ht1 = torch.Size([1, 64, 28, 23])
ht1 = torch.Size([1, 32, 55, 45])
ht1 = torch.Size([1, 16, 109, 89])
ht0 = torch.Size([1, 3, 218, 178])


tensor([[[[-0.0822,  0.0922,  0.0397,  ...,  0.0521,  0.0818,  0.0111],
          [ 0.0144,  0.1985, -0.1695,  ..., -0.0273,  0.0137, -0.0807],
          [ 0.0523,  0.1237,  0.0165,  ..., -0.0982,  0.1035,  0.0761],
          ...,
          [-0.0008, -0.0903,  0.0863,  ...,  0.0272, -0.0009, -0.0244],
          [ 0.0195,  0.1126, -0.0852,  ..., -0.0110, -0.2137, -0.0883],
          [ 0.0169, -0.0595, -0.0274,  ..., -0.0188, -0.0211,  0.0199]],

         [[ 0.0749,  0.0782,  0.0335,  ...,  0.0597,  0.2029,  0.0909],
          [ 0.0694,  0.2537,  0.4160,  ...,  0.1461,  0.0308,  0.0354],
          [ 0.0579,  0.0284, -0.1273,  ..., -0.0588, -0.0159,  0.0448],
          ...,
          [ 0.0301,  0.0847, -0.0307,  ..., -0.0933,  0.3191,  0.0439],
          [-0.0017,  0.1677,  0.0938,  ..., -0.0585,  0.1483,  0.0453],
          [ 0.0287,  0.0421, -0.0141,  ...,  0.0241,  0.1371, -0.0094]],

         [[ 0.1474,  0.0933,  0.1095,  ..., -0.0047,  0.0128,  0.0774],
          [-0.0521, -0.0583, -