# Imports & Globals

In [None]:
import os
import json
import tempfile

import torch
from torchvision import transforms
from torchvision.utils import make_grid

from scripts.diffusion_utils import DiffusionManager
from scripts.unet_openai import UNetModel

import ffmpeg


In [None]:
CONFIG_PTH = "configs/afhq.json"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Parse config

In [None]:
with open(CONFIG_PTH) as f:
    config = json.load(f)

In [None]:
hyperparams = config["hyperparams"]
data_dir = config["data_dir"]
model_params = config["model"]

# Noise scheduler

In [None]:
diffusion_manager = DiffusionManager(
    num_steps=hyperparams["num_steps"],
    beta_start=hyperparams["beta_start"],
    beta_end=hyperparams["beta_end"],
    beta_schedule=hyperparams["scheduler_mode"],
    training_method="v_prediction" if hyperparams["v_prediction"] else "noise",
)

# Model

In [None]:
# Following the procedure in https://github.com/VSehwag/minimal-diffusion/blob/main/unets.py
attention_ds = []
attention_resolutions = model_params["attention_resolutions"]
for res in attention_resolutions.split(","):
    attention_ds.append(model_params["img_size"] // int(res))

In [None]:
model = UNetModel(
    image_size=model_params["img_size"],
    in_channels=model_params["in_channels"],
    model_channels=model_params["base_width"],
    out_channels=model_params["in_channels"],
    num_res_blocks=3,
    attention_resolutions=tuple(attention_ds),
    dropout=0.1,
    channel_mult=model_params["channel_mult"],
    num_classes=None,
    use_checkpoint=False,
    use_fp16=False,
    num_heads=4,
    num_head_channels=64,
    num_heads_upsample=-1,
    use_scale_shift_norm=True,
    resblock_updown=True,
    use_new_attention_order=True,
).to(DEVICE)
model.eval()

In [None]:
model.load_state_dict(torch.load("weights/afhq/weights_11-6-2025_7_59_28.pth"))

# Image sampler

In [None]:
def images_to_video(frame_dir, fname, framerate=200):
    _ = (
        ffmpeg.input(
            f"{frame_dir}/*.png", pattern_type="glob", framerate=framerate
        )
        .output(fname)
        .run()
    )
    

In [None]:
@torch.no_grad()
def sample(model, config, out_dir):
    """
    Starting from a normal distribution with mean 0 and variance 1, denoise progressively
    until we get an image that could belong to the pdf representing the training
    data.
    """
    noise = torch.randn(
        (
            config["sampling"]["num_samples"],
            model_params["in_channels"],
            model_params["img_size"],
            model_params["img_size"],
        )
    ).to(DEVICE)
    

    imgs = diffusion_manager.sample(model, noise)
    for counter, item in enumerate(imgs):
        item = torch.clamp(item, -1.0, 1.0).detach().cpu()
        item = (item + 1) / 2
        grid = make_grid(item, nrow=config["sampling"]["num_grid_rows"])
        img = transforms.ToPILImage()(grid)
        img.save(os.path.join(out_dir, f"{counter:04}.png"))
    
    # Show denoised samples
    ims = torch.clamp(imgs[-1], -1.0, 1.0).detach().cpu()
    ims = (ims + 1) / 2
    grid = make_grid(ims, nrow=config["sampling"]["num_grid_rows"])
    img = transforms.ToPILImage()(grid)
    img = img.resize(( int(img.width * 3), int(img.height * 3)))
    display(img)

    return img

In [None]:
temp_dir = tempfile.TemporaryDirectory()
final_image = sample(model, config, temp_dir.name)
final_image.save("demo.jpg")
images_to_video(temp_dir.name, "demo.mp4")
temp_dir.cleanup()