In [1]:
%load_ext autoreload 
%autoreload 2

import os
import random
from collections import deque
import numpy as np
import scipy.linalg as sl
from PIL import Image, ImageDraw
import matplotlib as mpl
from matplotlib import pyplot as plt
import seaborn as sns
from IPython import display

import torch
from torch import nn, distributions as dist, autograd
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision.transforms import Compose, Resize, CenterCrop, RandomHorizontalFlip, RandomVerticalFlip, ToTensor, Normalize
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.set_default_dtype(torch.float32)
plt.style.use('seaborn-v0_8')

from imagenet64x64 import Imagenet64x64Dataset, Imagenet64x64Sampler

  from .autonotebook import tqdm as notebook_tqdm


## Imagenet 64x64

In [2]:
img_size = 64
batch_size = 64

In [3]:
train_ds = Imagenet64x64Dataset("train", sizes=batch_size*5000, transforms=Compose([
    ToTensor(),
    CenterCrop(img_size),
    RandomHorizontalFlip(0.1),
    RandomVerticalFlip(0.1),
    Normalize(127., 128.)
]))

val_ds = Imagenet64x64Dataset("valid", transforms=Compose([
    ToTensor(),
    CenterCrop(img_size),
    RandomHorizontalFlip(0.1),
    RandomVerticalFlip(0.1),
    Normalize(127., 128.)
]))
print(len(train_ds), len(val_ds))
train_sampler = Imagenet64x64Sampler(train_ds)

320000 49999


In [4]:
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False, sampler=train_sampler,)

In [5]:
for i, out in enumerate(train_loader):
    if i == 0:
        print(i, out)

0 tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63])


## Diffusion Model

In [6]:
class SinusoidalPosition(nn.Module):
    
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.register_buffer("device_tensor", torch.tensor(1.))
        self.register_buffer("weights", self.init_weights())
    
    def init_weights(self):
        h = self.dim // 2
        idx = torch.arange(0., h, dtype=torch.float32)
        w = idx / (h - 1)
        w = torch.exp(-w * np.log(10000.))
        # w *= idx
        return w      
    
    def forward(self, t):
        # t: (b,)
        w = t[..., None].float() * self.weights[None, ...]
        w = torch.cat([w.sin(), w.cos()], 1)
        return w

# SinusoidalPosition(64)(torch.tensor([0, 1]))

In [7]:
class Block(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.conv = nn.Conv2d(in_channel, out_channel, 3, 1, padding="same")
        self.bn = nn.BatchNorm2d(out_channel)
        self.act = nn.GELU()
    
    def forward(self, x, scale_shift=None):
        x = self.bn(self.conv(x))
        if scale_shift is not None:
            scale, shift = scale_shift
            x = scale * x + shift
        return self.act(x)

In [8]:
class Resblock(nn.Module):
    def __init__(self, in_dim, out_dim, time_dim):
        super().__init__()
        self.block1 = Block(in_dim, out_dim)
        self.block2 = Block(out_dim, out_dim)
        self.time_mlp = nn.Sequential(nn.Linear(time_dim, out_dim * 2), nn.LeakyReLU(0.1))
        self.res_conv = nn.Conv2d(in_dim, out_dim, 3, padding=1) if out_dim != in_dim else nn.Identity()
        self.in_dim = in_dim
        self.out_dim = out_dim
        
    def forward(self, x, t_emb=None):
        _x = x
        scale_shift = None
        if t_emb is not None:
            scale_shift = self.time_mlp(t_emb)
            scale_shift = scale_shift.unsqueeze(-1).unsqueeze(-1)
            scale_shift = scale_shift.chunk(2, dim=1)
            
        x = self.block1(x, scale_shift=scale_shift)
        x = self.block2(x)
        return x + self.res_conv(_x)
        

In [9]:
class DownBlock(nn.Module):
    
    def __init__(self, in_dim, out_dim, time_dim, nblocks=2, reduce=True):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.blocks = nn.ModuleList([Resblock(in_dim, in_dim, time_dim) for _ in range(nblocks)])
        self.nblocks = nblocks
        stride = 2 if reduce else 1
        self.down_fn = nn.Conv2d(in_dim, out_dim, 3, stride, 1)
            
    def forward(self, x, t, h):
        for fn in self.blocks:
            x = fn(x, t)
            h.append(x)
        x = self.down_fn(x)
        return x, h

In [10]:
class MidBlock(nn.Module):
    
    def __init__(self, dim, time_dim, nblocks=2):
        super().__init__()
        self.blocks = nn.ModuleList([Resblock(dim, dim, time_dim) for _ in range(nblocks)])
        self.nblocks = nblocks
            
    def forward(self, x, t, h):
        for fn in self.blocks:
            x = fn(x, t)
            h.append(x)
        return x, h

In [11]:
class UpBlock(nn.Module):
    
    def __init__(self, in_dim, out_dim, time_dim, nblocks=2, up=True):
        super().__init__()
        self.blocks = nn.ModuleList([Resblock(in_dim + out_dim, out_dim, time_dim) for _ in range(nblocks)])
        self.nblocks = nblocks
        self.up_fn = (
            nn.Conv2d(out_dim, in_dim, 3, padding=1) if not up
            else nn.Sequential(nn.Upsample(scale_factor=2,), 
                               nn.Conv2d(out_dim, in_dim, 3, padding=1)
                               )
        ) 
        self.out_dim = out_dim
        self.in_dim = in_dim
            
    def forward(self, x, t, h):
        for fn in self.blocks:
            down_x = h.pop()
            x = torch.cat([x, down_x], dim=1)
            x = fn(x, t)
        x = self.up_fn(x )
        return x, h

In [12]:


class UNet(nn.Module):
    
    def __init__(self, init_dim=64, resolutions=(1, 2, 4, 8), out_dim=3):
        super().__init__()
        self.resolutions = resolutions
        self.conv = nn.Sequential(
            nn.Conv2d(3, init_dim, 3, 1, padding="same"),
            nn.LeakyReLU(0.1),
        )
        resolutions = [init_dim] + [init_dim * r for r in resolutions]
        dims = list(zip(resolutions[:-1], resolutions[1:]))
        
        time_dim = init_dim * 4
        self.time_mlp = nn.Sequential(
            SinusoidalPosition(init_dim),
            nn.Linear(init_dim, time_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(time_dim, time_dim),
            
        )
        
        self.down_block = nn.ModuleList([
            DownBlock(*bdims, time_dim=time_dim, reduce=i<len(dims)-1) for i, bdims in enumerate(dims)
        ])
        
        self.mid_block = MidBlock(dims[-1][-1], time_dim=time_dim, nblocks=2)
        
        self.up_block = nn.ModuleList([
            UpBlock(in_dim, out_dim, time_dim=time_dim, up=i<len(dims) - 1)
            for i, (in_dim, out_dim) in enumerate(reversed(dims))
        ])  
        
        self.out_resblock = Resblock(init_dim * 2, init_dim, time_dim=time_dim)
        self.out_conv = nn.Conv2d(init_dim, out_dim, kernel_size=1) 
            
    def forward(self, x, t):
        x = self.conv(x)
        _x = x
        
        t = self.time_mlp(t)
        h = []
        for db_fn in self.down_block:
            x, h = db_fn(x, t, h)
        
        x, _ = self.mid_block(x, t, [])

        for up_fn in self.up_block:
            x, h = up_fn(x, t, h)
        
        x = self.out_resblock(torch.cat([x, _x], dim=1))
        x = self.out_conv(x)
        return x
        

In [13]:
F.pad(torch.tensor([1., 2.]), (1, 0), value=0.)

tensor([0., 1., 2.])

In [14]:
class DiffusionModel(nn.Module):
    
    def __init__(self, model, timesteps=1000, img_size=64):
        super().__init__()
        self.model = model
        self.timesteps = timesteps
        self.img_size = img_size
        
        
        beta = self.noise_scheduler(timesteps)
        alpha = 1 - beta
        alpha_b = torch.cumprod(alpha, dim=0)
        # TODO check
        alpha_b_prev = F.pad(alpha_b[:-1], (1, 0), value=1.)
        
        self.register_buffer("beta", beta)
        self.register_buffer("alpha", alpha)
        self.register_buffer("one_minus_alpha", 1 - alpha)
        self.register_buffer("recip_sqrt_alpha", alpha.sqrt().reciprocal())
        
        self.register_buffer("alpha_b", alpha_b)
        self.register_buffer("alpha_b_prev", alpha_b_prev)
        self.register_buffer("sqrt_alpha_b", torch.sqrt(alpha_b))
        self.register_buffer("sqrt_one_minus_alpha_b", torch.sqrt( 1 - alpha_b))
        self.register_buffer("recip_sqrt_one_minus_alpha_b", 1. / torch.sqrt( 1 - alpha_b))
        self.register_buffer("beta_h", beta * (1 - alpha_b_prev) / (1 - alpha))
        self.register_buffer("sigma", self.beta_h.sqrt())
        
    def noise_scheduler(self, T):
        s = 1000. / T
        start = s * 1e-4
        end = s * 2e-2
        return torch.linspace(start, end, T).float()        
        
    def forward(self, x0):
        b, c, h, w = x0.size()

        # Sample time 
        t = torch.randint(0, self.timesteps, (b,), device=x0.device, dtype=torch.long)

        # Sample Noise
        eps = torch.randn_like(x0)
        
        # Sample xt
        xt = self.qsample(x0, t, eps)
        
        # Noise output from network
        eps_out = self.model(xt, t)
        
        return eps_out 
    
    def qsample(self, x0, t, eps):
        return (self.extract(self.sqrt_alpha_b, t, x0.size() ) * x0 + 
                self.extract(self.sqrt_one_minus_alpha_b, t, eps.size()) * eps 
                )
        
    def extract(self, v, t, shape):
        b = shape[0]
        v = torch.index_select(v, -1, t)
        return v.reshape((b, ) + ((1,) * (len(shape) - 1)))
    
    @torch.no_grad()
    def sample(self, nsamples):
        device = self.beta.device
        xt = torch.randn((nsamples, 3, self.img_size, self.img_size), device=device)
        for step in range(self.timesteps-1, -1, -1):
            t = torch.full((nsamples, ), fill_value=step, dtype=torch.int64, device=device)
            noise_pred = self.model(xt, t)
            mean =  xt - (
                self.extract(self.one_minus_alpha, t, noise_pred.size()) * 
                noise_pred * 
                self.extract(self.recip_sqrt_one_minus_alpha_b, t, noise_pred.size())
            )
            noise = torch.randn_like(noise_pred)
            xt = (
                self.extract(self.recip_sqrt_alpha, t, mean.size()) * mean +
                self.extract(self.sigma, t, noise.size()) * noise
            )
        # TODO During image saving clamp it to -1 and 1
        return xt
            

In [15]:
class Criterion(nn.Module):
    def __init__(self, loss_tp="l2"):
        super().__init__()
        self.loss_fn = {"l2": F.mse_loss, "l1": F.l1_loss}[loss_tp]
    
    def forward(self, input, target):
        loss = self.loss_fn(input, target, reduction="none")
        loss = loss.mean((1, 2, 3))
        # TODO add loss weights
        return loss.mean()
        
    

In [16]:
model = UNet()
ddpm = DiffusionModel(model, timesteps=100)

In [17]:
# a = ddpm(torch.randn((2, 3, 64, 64)))

In [18]:

# Criterion()(a, torch.randn((2, 3, 64, 64)))

In [19]:
ddpm.sample(16)

tensor([[[[-3.4241e+02,  3.8705e+02, -1.7379e+03,  ..., -1.5311e+03,
           -4.4898e+02, -5.4209e+02],
          [ 2.4153e+02, -4.1740e+01, -1.0047e+03,  ..., -1.2116e+03,
           -7.3685e+00, -2.3450e+02],
          [-1.2019e+03,  2.3336e+02, -7.6618e+02,  ..., -1.7671e+03,
           -2.6930e+03, -2.5952e+02],
          ...,
          [ 7.3020e+02, -2.4198e+02, -4.9901e+02,  ..., -3.1467e+02,
           -2.4908e+01, -5.1794e+02],
          [-9.3919e+01,  3.7231e+02,  5.7960e+02,  ..., -1.4336e+02,
            9.2236e+01, -5.7243e+01],
          [ 7.2627e+02, -9.4281e+02,  5.9360e+01,  ..., -1.6989e+02,
            7.4079e+02, -2.1061e+02]],

         [[ 2.3291e+02,  8.0047e+02,  3.7754e+02,  ..., -1.2854e+03,
            6.3292e+02,  3.3361e+01],
          [ 4.1602e+02, -5.1918e+02,  2.7547e+02,  ...,  3.8816e+02,
           -6.2494e+01, -3.9876e+02],
          [ 1.6522e+02,  1.1154e+03,  2.7612e+02,  ..., -1.0493e+02,
           -2.2295e+02,  5.1885e+02],
          ...,
     