## DDPM Forward process (Noising process) - Image Case 
- 임의의 입력 이미지에 대해 DDPM forward process를 거친 영상을 시각화한다.
- image 차원 x0[C, H, W] 에 정규 분포 noise가 각 각 pixel의 RGB별로 추가됨을 이해하자.
- linear beta로 noise schedule했을 때 영상의 변화를 관찰해 보자.
- 학습 데이터에서 추출하는 어떠한 image에 대해서도 N(0,1) 표준 정규 분포의 noise image로 변환됨을 이해해보자.  


- Reference: [annotated diffusion](https://huggingface.co/blog/annotated-diffusion)
- Requirements: pytorch

In [None]:
from IPython.display import Image as DisplayImage
import matplotlib.pyplot as plt

import numpy as np

import torch
import torch.nn.functional as F

In [None]:
# Azure Cloud
n_drive = '\\\\swschoolavdazfiles002.file.core.windows.net\\aias-vision'
dataset_path = n_drive + '\\' + 'AI-Application-Specialist-Vision-Dataset' 

In [None]:
#DisplayImage(filename='/group-volume/sr_edu/AI-Application-Specialist-Vision-Dataset/hf-assets/ddpm_paper.png', width=600)
DisplayImage(filename=dataset_path+'/hf-assets/ddpm_paper.png', width=600)

## Define the forward diffusion process

### Define Scheduler

In [None]:
def linear_beta_schedule(timesteps): ## DDPM original scheduler
    beta_start = 0.0001
    beta_end = 0.02
    # torch.linspace는 start에서 end를 timesteps 갯수 만큼 등간격으로 나눔(evenly spaced)
    # 단, inclusive start, end 
    # https://pytorch.org/docs/stable/generated/torch.linspace.html
    return torch.linspace(beta_start, beta_end, timesteps) 

def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

In [None]:
timesteps = 1000

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0) # alpha_bars

sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# array a에서 t번째 index B개 추출(gather)
# input: a(array), t(B개)
# return: out[B,1,1,1] (B: batch_size == t.shape[0])
#        image: x_shape = [C, H, W] 와 계산위해 shape 맞춰줌
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu()) 
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

In [None]:
betas.shape, alphas.shape, alphas_cumprod.shape, betas[-1:], alphas[:5], alphas_cumprod[:5]

In [None]:
tt = torch.linspace(1, timesteps, timesteps)
plt.plot(tt, sqrt_alphas_cumprod, label=r'$\sqrt {\bar {\alpha}} $')
plt.plot(tt, sqrt_one_minus_alphas_cumprod, label=r'$\sqrt {1- \bar {\alpha} }$')
plt.legend(loc='best')

### Load an image

In [None]:
from PIL import Image
import requests

In [None]:
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw) # PIL image of shape HWC
image

### Pre-processing 
- using torchvision transforms
- transform a PIL image to a normalized PyTorch tensor in [-1, 1] 
  * input: image (PIL image, HWC format)
  * output: tensor ([-1, 1], CHW format)
- unsqueeze(0): add axis=0 (for batch processing)

In [None]:
from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize

In [None]:
image_size = 128

# PIL image (HWC) to normalized torch image [-1,1]
image2tensor = Compose([
    Resize(image_size),
    CenterCrop(image_size),
    ToTensor(), # turn into torch Tensor of shape CHW, divide by 255, in the range [0,1]
    Lambda(lambda t: (t * 2) - 1),    
])

x_start = image2tensor(image).unsqueeze(0)
image.size, image.mode, x_start.shape, x_start.min(), x_start.max() 


### Post-processing 
- to transform a PyTorch tensor normalized in [-1,1] to a PIL Image

In [None]:
tensor2image = Compose([
     Lambda(lambda t: (t + 1) / 2),
     Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
     Lambda(lambda t: torch.clamp(t * 255., min=0, max=255)),
     Lambda(lambda t: t.numpy().astype(np.uint8)),
     ToPILImage(),
])

In [None]:
tensor2image(x_start.squeeze())

### Apply forward diffusion process 
#### to get a noisy image at time=t

In [None]:
# forward diffusion q (using the nice property)
# input: x_start(x0: [C, H, W]), t: [B]
# output: xt[B(t), C, H, W]
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

In [None]:
# to get a PIL Image from x_t (tensor)
# x_start = x0
def get_noisy_image(x_start, t):
    # add noise
    x_noisy = q_sample(x_start, t=t)

    # turn back into PIL image
    noisy_image = tensor2image(x_noisy.squeeze())

    return noisy_image

In [None]:
# [0,T-1]에서 임의의 time step t 에서  noise image 그려보자
t = torch.tensor([40])

get_noisy_image(x_start, t)

### To visualize this for various time steps

In [None]:
# use seed for reproducability
torch.manual_seed(0)

# source: https://pytorch.org/vision/0.15/auto_examples/plot_transforms.html 
def plot_images(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [image] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

## forward diffusion process
- $x_t(x_0)$ ~ $ q(x_t|x_0) = N(\mu_t , \sigma_t ^2)$, where $\mu_t = \sqrt{\overline \alpha_t} x_0 , \sigma_t = \sqrt{1-\overline \alpha_t}$
- 정규 분포를 사용하여 데이터 샘플링: $x_t =\sqrt{\overline \alpha_t} x_0 + \sqrt{1-\overline \alpha_t}\epsilon$ 
  * where $\epsilon$ ~ $N(0 , 1^2)$
  

In [None]:
plot_images([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])

In [None]:
plot_images([get_noisy_image(x_start, torch.tensor([t])) 
             for t in [0, 100, 200, 300, 400, 500, 600, 700, 800, 999]])