In [7]:
import imageio
import ipywidgets
from IPython.display import display, Image as pyImage, HTML
import numpy as np
import torch
from contextlib import contextmanager
from PIL import Image
from tqdm import tqdm
from diffusers import DDPMScheduler
from diffusers.models import UNet2DModel
from torch.optim import AdamW
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST,CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize, RandomCrop,RandomHorizontalFlip, PILToTensor
from torchvision.utils import make_grid
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
@contextmanager
def create_gif(images, output_filename, duration=0.2):
    writer = None
    if images.shape[-1]==1:
        images = np.repeat(images, 3, axis=-1)
    try:
        writer = imageio.get_writer(output_filename, mode='I', duration=duration)
        for img in images:
            writer.append_data(img)
        yield writer
    finally:
        if writer:
            writer.close()

In [None]:
transforms = Compose([
    RandomCrop(32),
    RandomHorizontalFlip(),
    ToTensor(),
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
tr_data = CIFAR10("data", train=True, download=False, transform=transforms)

In [None]:
train_loader = DataLoader(tr_data, batch_size=64, shuffle=True, num_workers=2)

In [None]:
x,y = next(iter(train_loader))

In [None]:
def show_images(x):
    """Given a batch of images x, make a grid and convert to PIL"""
    x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)
    x = x.cpu().permute(0, 2, 3, 1).clip(0, 1).numpy()*255
    x = x.astype(np.uint8)
    gif_path = "data/sample.gif"
    with create_gif(x, gif_path, 0.5) as gif_writer:
        gif_writer.close()
    return display(pyImage(filename=gif_path, width=1280, height=128))
    

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = img * 0.5 + 0.5
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img,dtype=np.uint8))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

grid = make_grid(x,)
show((grid))

In [None]:
model = UNet2DModel(
    sample_size=(32, 32),  # the target image resolution
    in_channels=1,  # the number of input channels, 3 for RGB images
    out_channels=1,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 128, 256),  # More channels -> more parameters
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",  # a regular ResNet upsampling block
    ),
)
model.to(device)

## Unconditional generation

In [None]:
optim = AdamW(params=model.parameters(), lr=4e-4)
sched = DDPMScheduler(num_train_timesteps=50, beta_schedule="squaredcos_cap_v2")
mse = MSELoss()

In [None]:
for e in range(30):
    running_loss=0.
    with tqdm(train_loader) as mb_train:
        for x, y in mb_train:
            x = x.to(device)
            bs = x.shape[0]
            noise = torch.randn_like(x).to(device)
            timesteps = torch.randint(high=50, size=(bs,)).to(device)
            noisy_x = sched.add_noise(x, noise, timesteps=timesteps)
            eps = model(noisy_x, timesteps)

            optim.zero_grad()
            loss = mse(eps.sample, noise)
            running_loss+=loss.mean().item()
            loss.backward()
            optim.step()

            mb_train.set_postfix({"loss": loss.mean().item()})
    print(f"Epoch {e+1} loss: {running_loss/len(train_loader)}")

    with torch.no_grad():
        sample = torch.randn((1, 1, 32, 32)).to(device)
        hist = []
        for t in sched.timesteps:
            res = model(sample, t).sample
            sample = sched.step(res, t, sample).prev_sample
            
            hist.append(sample)
        hist = torch.concat(hist, dim=0)
        show_images(hist)



## Class guided generation

In [None]:
class UNet2DModelCC(torch.nn.Module):

    def __init__(self, num_classes, in_channels, n_embed):
        super(UNet2DModelCC, self).__init__()

        self.embedding = torch.nn.Embedding(num_classes, n_embed)
        self.unet = UNet2DModel(
                    sample_size=(32, 32),  # the target image resolution
                    in_channels=in_channels + n_embed,  # the number of input channels, 3 for RGB images
                    out_channels=in_channels,  # the number of output channels
                    layers_per_block=2,  # how many ResNet layers to use per UNet block
                    block_out_channels=(64, 128, 128, 256),  # More channels -> more parameters
                    down_block_types=(
                        "DownBlock2D",  # a regular ResNet downsampling block
                        "DownBlock2D",
                        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
                        "AttnDownBlock2D",
                    ),
                    up_block_types=(
                        "AttnUpBlock2D",
                        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
                        "UpBlock2D",
                        "UpBlock2D",  # a regular ResNet upsampling block
                    ),
                    )
    
    def forward(self, x, timestep, y):
        bs = x.shape[0]
        class_embed = self.embedding(y)
        class_embed = class_embed.view(bs, class_embed.shape[-1], 1, 1).expand(bs, class_embed.shape[-1], x.shape[2], x.shape[3])
        x = torch.cat([x, class_embed], dim=1)
        return self.unet(x, timestep)

In [None]:
model = UNet2DModelCC(10, 3, 8).to(device)

In [None]:
with torch.no_grad():
    x = torch.randn((32, 3, 32, 32)).to(device)
    y = torch.randint(10, (32,1)).to(device)
    timesteps = torch.randint(50, (32,)).to(device)
    out = model(x, timesteps, y).sample
print(out.shape)

In [None]:
optim = AdamW(params=model.parameters(), lr=4e-4)
sched = DDPMScheduler(num_train_timesteps=100, beta_schedule="squaredcos_cap_v2")
mse = MSELoss()
lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 1000, 0)

In [None]:
for e in range(30):
    running_loss=0.
    with tqdm(train_loader) as mb_train:
        for mb_idx, batch in enumerate(mb_train):
            x,y = batch
            x = x.to(device)
            y = y.to(device)
            bs = x.shape[0]
            noise = torch.randn_like(x).to(device)
            timesteps = torch.randint(high=100, size=(bs,)).to(device)
            noisy_x = sched.add_noise(x, noise, timesteps=timesteps)
            eps = model(noisy_x, timesteps, y).sample
            loss = mse(eps, noise)
            running_loss+=loss.mean().item()
            loss.backward()
            
            if mb_idx%2==0:
                optim.step()
                optim.zero_grad()

            mb_train.set_postfix({"loss": loss.mean().item()})
            
    print(f"Epoch {e+1} loss: {running_loss/len(train_loader)}")

    with torch.no_grad():
        sample = torch.randn((10, 3, 32, 32)).to(device)
        hist = []
        y_ = torch.arange(10, device=device)
        for t in sched.timesteps:
            res = model(sample, t, y_).sample
            sample = sched.step(res, t, sample).prev_sample
            if t%10==0:
                hist.append(make_grid(sample,nrow=10))
        hist = torch.stack(hist, dim=0)
        show_images(hist)

