# Score-Based Generative Modeling on Flowers102

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import functools
from pathlib import Path

import torch
import torchvision.transforms as transforms
from torchvision.datasets import Flowers102

from ensae_deep_learning.sampling import euler_maruyama_sampler, ode_sampler, pc_sampler
from ensae_deep_learning.sde import ScoreNet, diffusion_coeff, marginal_prob_std
from ensae_deep_learning.utils import (
    plot_dataset,
    plot_loss_history,
    plot_samples,
    print_sde_dim,
    run_training,
    summary_model,
)

device = "cuda"

In [None]:
sigma = 25.0
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

## 1. Data preprocessing

In [None]:
dataset_original = Flowers102("data/", transform=transforms.ToTensor(), download=True)

In [None]:
config = {
    "model_name": "FLOWERS_1",
    "image_size": (25, 25),
    "in_channels": 3,
    "channels": (128, 256, 512, 1024),
    "embed_dim": 256,
    "n_epochs": 100,
    "lr": 3e-4,
    "batch_size": 64,
}

unnormalize = None

transformation = transforms.Compose(
    [
        transforms.Resize(config["image_size"]),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.ToTensor(),
    ]
)
dataset_transformed = Flowers102("data/", transform=transformation, download=True)

In [None]:
config = {
    "model_name": "FLOWERS_2",
    "image_size": (41, 41),
    "in_channels": 3,
    "channels": (128, 256, 512, 1024),
    "embed_dim": 256,
    "n_epochs": 100,
    "lr": 3e-4,
}

mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])

mean = torch.tensor([0.4329, 0.3819, 0.2963])
std = torch.tensor([0.2945, 0.2465, 0.2734])

normalize = transforms.Normalize(mean, std)
unnormalize = transforms.Normalize((-mean / std), (1.0 / std))

transformation = transforms.Compose(
    [
        transforms.Resize(config["image_size"]),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(degrees=[0, 180]),
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.ToTensor(),
        normalize,
    ]
)
dataset_transformed = Flowers102("data/", transform=transformation, download=True)

In [None]:
# config = {
#     "model_name": "FLOWERS_3",
#     "image_size": (65, 65),
#     "in_channels": 3,
#     "channels": (128, 256, 512, 1024),
#     "embed_dim": 512,
#     "n_epochs": 200,
#     "lr": 3e-4,
#     "batch_size": 64,
# }

# mean = torch.tensor([0.485, 0.456, 0.406])
# std = torch.tensor([0.229, 0.224, 0.225])

# normalize = transforms.Normalize(mean, std)
# unnormalize = transforms.Normalize((-mean / std), (1.0 / std))

# transformation = transforms.Compose(
#     [
#     # transforms.Grayscale(),
#     transforms.Resize(config["image_size"]),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomVerticalFlip(),
#     transforms.RandomRotation(degrees=[0, 180]),
#     transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
#     transforms.ToTensor(),
#     normalize,
#     ]
# )
# dataset_transformed = Flowers102("data/", transform=transformation, download=True)

In [None]:
save_dir = Path("models", config["model_name"])
save_dir.mkdir(exist_ok=True)

plot_dataset(dataset_original, save_dir, "dataset_original.pdf")
plot_dataset(dataset_transformed, save_dir, "dataset_transformed.pdf", unnormalize)

## 2. Training

In [None]:
score_model = torch.nn.DataParallel(
    ScoreNet(
        in_channels=config["in_channels"],
        marginal_prob_std=marginal_prob_std_fn,
        channels=config["channels"],
        embed_dim=config["embed_dim"],
    )
)
score_model = score_model.to(device)

summary_model(score_model, config)

In [None]:
print_sde_dim(config)

In [None]:
loss_history = run_training(
    score_model,
    dataset_transformed,
    config,
    marginal_prob_std_fn,
    save_dir,
    device,
)

In [None]:
plot_loss_history(loss_history, save_dir)

## 3. Sampling

In [None]:
# Load the pre-trained checkpoint from disk.
ckpt = torch.load(Path(save_dir, "model.pth"), map_location=device)
score_model.load_state_dict(ckpt)

sample_batch_size = 36
for sampler in [
    euler_maruyama_sampler,
    pc_sampler,
    ode_sampler,
]:
    # Generate samples using the specified sampler.
    samples = sampler(
        score_model,
        marginal_prob_std_fn,
        diffusion_coeff_fn,
        config,
        sample_batch_size,
        device=device,
    )
    samples = unnormalize(samples)

    plot_samples(samples, sampler, save_dir)