In [57]:
import torch
import torch.nn

class ResnetBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, is_debug=False) -> None:
        super().__init__()
        self.conv_1 = torch.nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            padding=1
        )
        self.group_1 = torch.nn.GroupNorm(out_channels, out_channels)
        self.relu_1 = torch.nn.ReLU()
        self.conv_2 = torch.nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=3,
            padding=1
        )
        self.group_2 = torch.nn.GroupNorm(out_channels, out_channels)
        self.relu_2 = torch.nn.ReLU()
        self.is_debug = is_debug

    def forward(self, x, t = None, has_attn = False):
        x = self.conv_1(x)
        if self.is_debug:
            print("Conv 1 shape:", x.shape)
        x = self.group_1(x)
        if self.is_debug:
            print("Group 1 shape:", x.shape)
        x = self.relu_1(x)
        return self.relu_2(self.group_2(self.conv_2(x)))
    
class PositionalEmbedding(torch.nn.Module):
    def __init__(self, output_dim) -> None:
        super().__init__()

        self.ln = torch.nn.Linear(1, output_dim)

    def forward(self, x):
        return self.ln(x)

class UNet(torch.nn.Module):
    def __init__(self, in_channels, is_debug = False):
        super().__init__()
        self.is_debug = is_debug
        self.resnet_left_1 = ResnetBlock(in_channels, out_channels=3, is_debug=is_debug)
        self.down_1 = torch.nn.MaxPool2d(kernel_size=2, return_indices=True)
        self.resnet_left_2 = ResnetBlock(in_channels=3, out_channels=3)
        self.down_2 = torch.nn.MaxPool2d(kernel_size=2, return_indices=True)
        self.resnet_left_3 = ResnetBlock(in_channels=3, out_channels=3)
        self.down_3 = torch.nn.MaxPool2d(kernel_size=2, return_indices=True)

        self.time_embedding = PositionalEmbedding(3)
        self.backbone = ResnetBlock(in_channels=3, out_channels=3)

        self.up_1 = torch.nn.MaxUnpool2d(kernel_size=2)
        self.resnet_right_1 = ResnetBlock(in_channels, out_channels=3)
        self.up_2 = torch.nn.MaxUnpool2d(kernel_size=2)
        self.resnet_right_2 = ResnetBlock(in_channels, out_channels=3)
        self.up_3 = torch.nn.MaxUnpool2d(kernel_size=2)
        self.resnet_right_3 = ResnetBlock(in_channels, out_channels=3)

    def forward(self, x, t):
        x_1 = self.resnet_left_1(x)
        if self.is_debug:
            print("Resnet left 1 shape:", x_1.shape)
        x, ind_1 = self.down_1(x_1)
        if self.is_debug:
            print("Down 1 shape:", x.shape)
        x_2 = self.resnet_left_2(x)
        if self.is_debug:
            print("Resnet left 2 shape:", x_2.shape)
        x, ind_2 = self.down_2(x_2)
        if self.is_debug:
            print("Down 2 shape:", x.shape)
        x_3 = self.resnet_left_3(x)
        if self.is_debug:
            print("Resnet left 3 shape:", x_3.shape)
        x, ind_3 = self.down_3(x_3)
        if self.is_debug:
            print("Down 3 shape:", x.shape)
        
        if self.is_debug:
            print("Time:", t.shape)
        time_emb = self.time_embedding(t)
        batch_size, dim = time_emb.shape
        time_emb = time_emb.view(batch_size, dim, 1, 1)
        if self.is_debug:
            print("Time embedding:", time_emb.shape)
        x = self.backbone(x + time_emb)
        if self.is_debug:
            print("Backbone shape:", x.shape)
        
        x = self.up_1(x, indices = ind_3)
        if self.is_debug:
            print("Up 1 shape:", x.shape)
        x = self.resnet_right_1(x + x_3)
        if self.is_debug:
            print("Resnet right 1 shape:", x.shape)
        x = self.up_2(x, indices = ind_2)
        x = self.resnet_right_2(x + x_2)
        x = self.up_3(x, indices = ind_1)
        x = self.resnet_right_1(x + x_1)

        return x
    
unet = UNet(3, is_debug=True)
res = unet(tensor.unsqueeze(dim=0), torch.tensor([1.0]).unsqueeze(0))
res.shape

Conv 1 shape: torch.Size([1, 3, 32, 32])
Group 1 shape: torch.Size([1, 3, 32, 32])
Resnet left 1 shape: torch.Size([1, 3, 32, 32])
Down 1 shape: torch.Size([1, 3, 16, 16])
Resnet left 2 shape: torch.Size([1, 3, 16, 16])
Down 2 shape: torch.Size([1, 3, 8, 8])
Resnet left 3 shape: torch.Size([1, 3, 8, 8])
Down 3 shape: torch.Size([1, 3, 4, 4])
Time: torch.Size([1, 1])
Time embedding: torch.Size([1, 3, 1, 1])
Backbone shape: torch.Size([1, 3, 4, 4])
Up 1 shape: torch.Size([1, 3, 8, 8])
Resnet right 1 shape: torch.Size([1, 3, 8, 8])


torch.Size([1, 3, 32, 32])

In [2]:
import torchvision
import torchvision.transforms
import matplotlib.pyplot as plt

dataset = torchvision.datasets.cifar.CIFAR10(root="./datasets", download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./datasets/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:09<00:00, 17203218.42it/s]


Extracting ./datasets/cifar-10-python.tar.gz to ./datasets


In [3]:
dataset = torchvision.datasets.MNIST(root="./datasets", download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 8432401.54it/s] 


Extracting ./datasets/MNIST/raw/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 274654.22it/s]


Extracting ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:02<00:00, 718638.40it/s]


Extracting ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1480342.59it/s]

Extracting ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw






In [4]:
import torch

In [7]:
T = 1000
beta_schedule = torch.linspace(1e-4, 0.02, T)
alpha_t_schedule = 1 - beta_schedule

tensor([1.0000e-04, 1.1992e-04, 1.3984e-04, 1.5976e-04, 1.7968e-04, 1.9960e-04,
        2.1952e-04, 2.3944e-04, 2.5936e-04, 2.7928e-04, 2.9920e-04, 3.1912e-04,
        3.3904e-04, 3.5896e-04, 3.7888e-04, 3.9880e-04, 4.1872e-04, 4.3864e-04,
        4.5856e-04, 4.7848e-04, 4.9840e-04, 5.1832e-04, 5.3824e-04, 5.5816e-04,
        5.7808e-04, 5.9800e-04, 6.1792e-04, 6.3784e-04, 6.5776e-04, 6.7768e-04,
        6.9760e-04, 7.1752e-04, 7.3744e-04, 7.5736e-04, 7.7728e-04, 7.9720e-04,
        8.1712e-04, 8.3704e-04, 8.5696e-04, 8.7688e-04, 8.9680e-04, 9.1672e-04,
        9.3664e-04, 9.5656e-04, 9.7648e-04, 9.9640e-04, 1.0163e-03, 1.0362e-03,
        1.0562e-03, 1.0761e-03, 1.0960e-03, 1.1159e-03, 1.1358e-03, 1.1558e-03,
        1.1757e-03, 1.1956e-03, 1.2155e-03, 1.2354e-03, 1.2554e-03, 1.2753e-03,
        1.2952e-03, 1.3151e-03, 1.3350e-03, 1.3550e-03, 1.3749e-03, 1.3948e-03,
        1.4147e-03, 1.4346e-03, 1.4546e-03, 1.4745e-03, 1.4944e-03, 1.5143e-03,
        1.5342e-03, 1.5542e-03, 1.5741e-

In [16]:
img, index = dataset[20]

# print(img)
tensor = torchvision.transforms.ToTensor()(img)
print(tensor.shape)

<PIL.Image.Image image mode=RGB size=32x32 at 0x17F6FC1F0>
torch.Size([3, 32, 32])


In [29]:
unet = UNet(3)
unet(tensor)

tensor([[[0.0414, 0.0646, 0.0654,  ..., 0.0673, 0.0790, 0.0936],
         [0.0016, 0.0395, 0.0480,  ..., 0.0462, 0.0413, 0.0592],
         [0.0000, 0.0065, 0.0149,  ..., 0.0233, 0.0448, 0.0649],
         ...,
         [0.0000, 0.0436, 0.0680,  ..., 0.0873, 0.1083, 0.1078],
         [0.0000, 0.0427, 0.0355,  ..., 0.0894, 0.0883, 0.0782],
         [0.0167, 0.0470, 0.0867,  ..., 0.0940, 0.1172, 0.1028]],

        [[0.1162, 0.1424, 0.1424,  ..., 0.1365, 0.1500, 0.1468],
         [0.1038, 0.1404, 0.1309,  ..., 0.1375, 0.1328, 0.1392],
         [0.1060, 0.1398, 0.1427,  ..., 0.1390, 0.1452, 0.1395],
         ...,
         [0.1132, 0.1573, 0.1538,  ..., 0.1366, 0.1425, 0.1224],
         [0.1042, 0.1586, 0.1506,  ..., 0.1706, 0.1463, 0.1282],
         [0.1144, 0.1560, 0.1735,  ..., 0.1776, 0.1863, 0.1508]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.