In [None]:
import numpy as np
import torch
import torchvision
import torch.nn as nn
import numpy as np
from torchvision.transforms import ToPILImage
from IPython.display import display
import math
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F

# Dataset

In [None]:
def show_images(dataset, num_samples = 20, cols = 4):
    
    plt.figure(figsize = (15, 15))
    
    for i, img in enumerate(data):
        if i == num_samples:
            break
        plt.subplot(int(num_samples / cols + 1), cols, i + 1)
        plt.axis('off')
        plt.imshow(img[0])
    
data = torchvision.datasets.CIFAR10(root = './cifar-10', download=True)
show_images(data)

# Forward Process

In [None]:
# Beta scheduling

def linear_beta_schedule(timesteps, start = 0.0001, end = 0.02):
    return torch.linspace(start, end, timesteps)


def cosine_beta_schedule(timesteps, start = 0.0001, end = 0.02):
    betas = torch.linspace(0, 1, timesteps)
    betas = 0.5 * (1 + torch.cos(betas * math.pi))
    betas = start + (end - start) * betas
    return betas

In [None]:
# Timesteps T & beta 

T = 100
# betas = linear_beta_schedule(timesteps=T)
betas = cosine_beta_schedule(timesteps=T)

In [None]:
alphas = 1 - betas # alpha & beta
alphas_cumprod = torch.cumprod(alphas, axis=0) # \Pi alpha = alpha x alpha ...
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1,0), value = 1) # timestep t cumprod alpha = timestep {t-1} cumprod alpha x t step alpha
sqrt_recip_alphas = torch.sqrt(1.0 / alphas) # 1 / alpha = noise의 variance
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) # alphas_cumprod: noise 크기의 누적 변화 (\Pi alpha)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod) # root 1-cumprod alpha
posterior_variance = betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod) # noise variance에 따른 posterior variance. reconstruction에 사용

In [None]:
'''
특정 timestep t에서의 값을 가져오기 위한 장치
vals: timestep에 다라 변하는 값을 가진 리스트 (2차원)
t: 특정 timestep의 시간을 나타내는 tensor (2차원)
'''
def get_index_from_list(vals, t, x_shape):
    batch_size = t.shape[0] # t는 embedding된 timestep의 tensor
    out = vals.gather(-1, t.cpu())  # t에 해당하는 값을 가져옴.
    
    return out.reshape(batch_size, *((1, ) * (len(x_shape) - 1))).to(t.device) # out tensor의 형태를 조정. (batch_size, 1,1,1,...,1) 로 바꿔서 x_0와 동일한 형태


# x_0와 timestep T를 input으로 받고 noise로 만들기
def forward_diffusion_sample(x_0, t, device="cpu"):
    
    noise = torch.randn_like(x_0)  
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)


In [None]:
# Data loader
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

IMG_SIZE = 64
BATCH_SIZE = 128

def load_transformed_dataset():
    datat_transforms = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Lambda(lambda t: (t * 2) - 1) # scale between [-1, 1]
    ]
    
    data_transforms = transforms.Compose(datat_transforms)
    
    # train & test for cifar-10
    train = torchvision.datasets.CIFAR10(root='./cifar-10', download=True, train=True, transform=data_transforms)
    
    test = torchvision.datasets.CIFAR10(root='./cifar-10', download=True, train=False, transform=data_transforms)
    
    
    return torch.utils.data.ConcatDataset([train, test])

def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)),
        transforms.Lambda(lambda t: t * 255.0),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage()
    ])
    
    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0]
    plt.imshow(reverse_transforms(image))

data = load_transformed_dataset()
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)


# Reverse Process
$$ p_\theta(x_T) = \mathcal{N}(x_t; 0, I)$$
$$ p_\theta(x_{0:T}) = p(x_T) \Pi^T_{t=1} p_\theta(x_{t-1}|x_t)$$
$$ x_{t-1} \approx x_t - noise $$

## Unet
- input: noisy image
- timestep encodding: Sinusodial embeddings
- variance is fixed

## Timestep Encoding
- 각 timestep의 위치 벡터 생성

In [None]:
# SinusodialPositionalEncoding

class SinusodialPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        
    def forward(self, timesteps):
        device = timesteps.device
        half_dim = self.dim // 2 # 절반은 sin, 절반은 cos
        embeddings = math.log(10000) / (half_dim - 1) # 
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) # 1/10000^(2i/d)
        embeddings = timesteps[:, None] * embeddings[None, :] # t/10000^(2i/d)
        embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) # sin(t/10000^(2i/d)), cos(t/10000^(2i/d))
        
        return embeddings        

In [None]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up = False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        
        else: 
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()
    
    def forward(self, x, t, ):
        h = self.bnorm1(self.relu(self.conv1(x)))
        # time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimension
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)

In [None]:
class SimpleUnet(nn.Module):
    
    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = [64, 128, 256, 512, 1024]
        up_channels = [1024, 512, 256, 128, 64]
        out_dim = 3
        time_emb_dim = 32
        
        # Time Embeddings
        self.time_mlp = nn.Sequential(
            SinusodialPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        
        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, 1)
        
        # Downsample
        self.donws = nn.ModuleList([Block(down_channels[i], down_channels[i+1], time_emb_dim) for i in range(len(down_channels) - 1)])
        
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], time_emb_dim, up=True) for i in range(len(up_channels) - 1)])
        
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
        
    def forward (self, x, timesteps):
        # Embedding time
        t = self.time_mlp(timesteps)
        # initial projection
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.donws:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)
    
model = SimpleUnet()
print("Num parameters:", sum([p.numel() for p in model.parameters()]))
model
            

# Loss

In [None]:
def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t)
    
    return F.l2_loss(noise_pred, noise)

# Sample

In [None]:
@torch.no_grad()
def sample_timestep(x, t):
    """
    Calls the model to predict the noise in the image and returns 
    the denoised image. 
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
    if t == 0:
        # As pointed out by Luis Pereira (see YouTube comment)
        # The t's are offset from the t's in the paper
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

@torch.no_grad()
def sample_plot_image():
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)
    plt.figure(figsize=(15,15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t)
        # Edit: This is to maintain the natural range of the distribution
        img = torch.clamp(img, -1.0, 1.0)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize)+1)
            show_tensor_image(img.detach().cpu())
    plt.show()            

# Train

In [None]:
from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = Adam(model.parameters(), lr = 0.0001)
epochs = 100

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()
        
        t = torch.randint(0, T, (BATCH_SIZE,), device = device).long()
        loss = get_loss(model, batch[0], t)
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            print("Epoch: {}, Step: {}, Loss: {}".format(epoch, step, loss.item()))