In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision
import os
from PIL import Image

In [None]:
BATCH_SIZE = 32
IMAGE_SIZE = 24
CHANNELS_IMG = 4
NUM_EPOCHS = 20
TIME_STEPS = 200


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
class PunkDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.transform = transform
        self.image_folder = image_folder
        files = os.listdir(image_folder)
        self.n = len(files)

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_folder, str(idx) + '.png')
        img = Image.open(img_path)
        img = self.transform(img)
        return img

In [None]:
class DiffusionModel(nn.Module):
    def __init__(self, in_channels=4, out_channels=4, features=[64, 128]):
        super(DiffusionModel, self).__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()

        # Downsampling path
        for feature in features:
            self.downs.append(nn.Conv2d(in_channels, feature, 3, padding=1))
            in_channels = feature

        # Upsampling path
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(feature * 2, feature, kernel_size=4, stride=2, padding=1)
            )

        self.bottleneck = nn.Conv2d(features[-1], features[-1] * 2, 3, padding=1, stride=2)
        self.final = nn.Conv2d(features[0] * 2, out_channels, kernel_size=3, padding=1)

    def forward(self, x, t):
        skip_connections = []

        for down in self.downs:
            x = F.relu(down(x))
            skip_connections.append(x)

        x = F.relu(self.bottleneck(x))
        skip_connections = skip_connections[::-1]

        for idx, up in enumerate(self.ups):
            x = up(x)
            skip = skip_connections[idx]
            # Concatenate along channel dimension
            concat_skip = torch.cat((skip, x), dim=1)
            x = F.relu(concat_skip)

        return self.final(x)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5 for i in range(CHANNELS_IMG)], [0.5 for i in range(CHANNELS_IMG)])
])
dataset = PunkDataset('drive/MyDrive/data', transform=transform)
trainloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

model = DiffusionModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [None]:
for epoch in range(NUM_EPOCHS):
    for images in trainloader:
        images = images.to(device)
        # Uniformly sample timesteps t for each element in the batch
        t = torch.randint(1, TIME_STEPS + 1, (images.shape[0],), device=device)

        # Sample standard gaussian noise
        epsilon = torch.randn_like(images).to(device)

        noise_images =
        # Future training steps:
        # 1. Add noise to images based on t
        # 2. Model predicts epsilon
        # 3. Compute loss and optimize