In [1]:
# !pip install datasets
# !pip install torch
# !pip install torchvision
# !pip install scipy
# !pip install matplotlib

In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
from datasets import load_dataset
import torchvision.transforms as transforms
from utils.dataset import resize
import torch
import functools
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import tqdm
import utils.constant as c
from torchvision.utils import make_grid
import utils.network as n
import numpy as np

In [3]:
batch_size = c.batch_size
n_epochs = c.n_epochs
lr = c.lr

device = c.device
sample_batch_size = c.sample_batch_size

sigma = c.sigma

In [4]:
dataset = load_dataset("cats_vs_dogs", split="train")
dataset.set_format(type="torch")
dataset.set_transform(resize)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

Found cached dataset cats_vs_dogs (/home/onyxia/.cache/huggingface/datasets/cats_vs_dogs/default/1.0.0/d4fe9cf31b294ed8639aa58f7d8ee13fe189011837038ed9a774fde19a911fcb)


In [5]:
def marginal_prob_std(t, sigma):
    """
    Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.

    Args:    
        t: A vector of time steps.
        sigma: The $\sigma$ in our SDE.  

    Returns:
        The standard deviation.
    """    
    t = torch.tensor(t, device=device)
    return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
    """
    Compute the diffusion coefficient of our SDE.

    Args:
    t: A vector of time steps.
    sigma: The $\sigma$ in our SDE.
  
    Returns:
    The vector of diffusion coefficients.
    """
    return torch.tensor(sigma**t, device=device)
  
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

In [34]:
# ENTRAINEMENT 
torch.cuda.empty_cache()
list_sigma = [5, 10, 50, 100, 500]
# list_sigma = [25]
n_epochs = 40
for i in range(len(list_sigma)):
    loss_list = []
    sigma = list_sigma[i]

    marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
    diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)
    score_model = torch.nn.DataParallel(
        n.ScoreNet(marginal_prob_std=marginal_prob_std_fn, channels=[16, 32, 64, 128], embed_dim=128, group_norm=16)
    )
    score_model = score_model.to(device)

    optimizer = Adam(score_model.parameters(), lr=lr)
    tqdm_epoch = tqdm.notebook.trange(n_epochs)
    for epoch in tqdm_epoch:
        avg_loss = 0.0
        num_items = 0
        for data in data_loader:
            x, y = data["image"], data["labels"]
            x = x.to(device)
            loss = n.loss_fn(score_model, x, marginal_prob_std_fn)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            avg_loss += loss.item() * x.shape[0]
            num_items += x.shape[0]
        loss_list.append(avg_loss/num_items)
        # Print the averaged training loss so far.
        tqdm_epoch.set_description("Average Loss: {:5f}".format(avg_loss / num_items))
        # Update the checkpoint after each epoch of training.
        torch.save(score_model.state_dict(), f"params/ckpt_{sigma}_16_32_64_128.pth")
        print(avg_loss/num_items)
    torch.save(loss_list, f"loss_{sigma}_16_32_64_128.pth")

In [2]:
# GENERATION

torch.cuda.empty_cache()
snr = 0.1

sigma = 5
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

ckpt = torch.load(f"params/ckpt_{sigma}_16_32_64_128.pth", map_location=device)
score_model = torch.nn.DataParallel(
        n.ScoreNet(marginal_prob_std=marginal_prob_std_fn, channels=[16, 32, 64, 128], embed_dim=128, group_norm=16)
    )
score_model = score_model.to(device)
score_model.load_state_dict(ckpt)

from utils.sampler import ode_sampler, pc_sampler, Euler_Maruyama_sampler

sampler = pc_sampler  # ['Euler_Maruyama_sampler', 'pc_sampler', 'ode_sampler']

## Generate samples using the specified sampler.
samples = sampler(
    score_model,
    marginal_prob_std_fn,
    diffusion_coeff_fn,
    sample_batch_size,
    snr=snr,
    device=device,
)

In [None]:
## Sample visualization.
%matplotlib inline
import matplotlib.pyplot as plt

sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))

plt.figure(figsize=(10, 10))
plt.axis("off")
plt.imshow(sample_grid.permute(1, 2, 0).cpu())
plt.show()