In [1]:
import cv2
import numpy as np
from numpy import clip
import matplotlib.pyplot as plt
import os

import torch
from torch.utils.data import TensorDataset, DataLoader, Dataset
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.optimization import get_cosine_schedule_with_warmup

from tqdm import tqdm

In [2]:
GLOBAL_MEAN = np.load("data/global_mean.npy")/255
GLOBAL_STD = np.load("data/global_std.npy")/255

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
resize_size = 64

transform = transforms.Compose([
    transforms.Resize((resize_size, resize_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=GLOBAL_MEAN, std=GLOBAL_STD)
])

data_path = 'data/data0/lsun/bedroom'
batch_size = 16

image_dataset = ImageFolder(root=data_path, transform=transform)
image_dataset = Subset(image_dataset, torch.randperm(len(image_dataset))[:10000])
train_dataloader = DataLoader(image_dataset, batch_size=batch_size, shuffle=True)

for idx, batch in enumerate(train_dataloader):
    break

In [None]:
plt.imshow(batch[0][np.random.randint(0, batch_size)].permute(1, 2, 0))
plt.title("Random transformed image.")
plt.axis('off')
plt.show()

In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, size):
        super(Block, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.dense_time = nn.Linear(out_channels, out_channels)
        self.norm = nn.LayerNorm([out_channels, size, size])

    def forward(self, x, t):
        x_parameter = F.relu(self.conv1(x))
        time_parameter = F.relu(self.dense_time(t))
        time_parameter = time_parameter.view(-1, out_channels, 1, 1)
        print(x_parameter.shape)
        print(time_parameter.shape)
        x_parameter = x_parameter * time_parameter
        x_out = F.relu(self.conv2(x) + x_parameter)
        x_out = self.norm(x_out)
        return x_out

class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.block_down_1 = Block(in_channels=3, out_channels=128, size=resize_size)
        self.block_down_2 = Block(in_channels=128, out_channels=128, size=8)
        self.block_down_3 = Block(in_channels=128, out_channels=128, size=4)

        self.block_up_1 = Block(in_channels=3, out_channels=128, size=4)
        self.block_up_2 = Block(in_channels=128, out_channels=128, size=8)
        self.block_up_3 = Block(in_channels=128, out_channels=128, size=resize_size)

        self.maxpool = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.mlp_t_initial = nn.Linear(1, 128)
        self.norm_t_initial = nn.LayerNorm([128])

        self.mlp_dense = nn.Linear(128, 128)
        self.conv_out = nn.Conv2d(32, 3, kernel_size=1, padding=0)

    def forward(self, x_img, x_ts):

        x_ts = self.mlp_t_initial(x_ts)
        x_ts = self.norm_t_initial(x_ts)
        x_ts = F.relu(x_ts)

        x1 = self.block_down_1(x_img, x_ts)
        x = self.maxpool(x1)
        x2 = self.block_down_2(x, x_ts)
        x = self.maxpool(x2)
        x3 = self.block_down_3(x, x_ts)
        
        x = self.flatten(x)
        x = torch.cat((x, x_ts), dim=1)
        x = self.mlp_dense(x)
        x = nn.LayerNorm([128])
        x = F.relu(x)
        x = x.view(-1, 32, 4, 4)
        
        x = torch.cat((x, x3), dim=1)
        x = self.block_up_1(x, x_ts)
        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        
        x = torch.cat((x, x2), dim=1)
        x = self.block_up_2(x, x_ts)
        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        
        x = torch.cat((x, x1), dim=1)
        x = self.block_up_3(x, x_ts)

        
        x = self.conv_out(x)
        return x

model = UNet()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0008)
loss_func = nn.MSELoss()

In [None]:
t = torch.full([batch_size, 1], 10, dtype=torch.float)
model(batch[0], t)