# VQVAE2 Sampling

This notebook demonstrates how to sample from a VQVAE2 model that uses an autoregressive prior in the latent spaces. It visualizes some samples.

## Setup

In [None]:
import torch
from src.models.latent_models import LatentVQVAE2
from src.models.transformer_prior import HierarchicalTransformerPrior
from src.models.pixelsnail_prior import HierarchicalPixelSnailPrior

### Load VQVAE model

In [None]:
ckpt_vqvae   = "../models/latent_vqvae2/version_1/checkpoints/last.ckpt"
config_yaml  = "../models/latent_vqvae2/version_1/hparams.yaml"

latent_model = LatentVQVAE2.load_from_checkpoint(
    ckpt_vqvae,
    hparams_file=config_yaml,
    map_location="cpu",
)
latent_model.eval().requires_grad_(False)

### Load autoregressive prior

#### Tranformer prior

In [None]:
prior = HierarchicalTransformerPrior(
    vqvae=latent_model,
    d_model=768,
    n_layers=12,
    n_heads=12,
    lr=3e-4,
    weight_decay=0.0,
).eval().cuda()

In [None]:
ckpt_prior = "../models/latent_prior/version_7/checkpoints/last.ckpt"

prior = HierarchicalTransformerPrior.load_from_checkpoint(
    ckpt_prior,
    vqvae=latent_model,
    map_location="cuda",
).eval().cuda()

#### PixelSNAIL prior

In [None]:
prior = HierarchicalPixelSnailPrior(
    vqvae=latent_model,
    n_chan=128,
    n_blocks=8,
    n_heads=4,
    lr=3e-4,
    weight_decay=0.0,
    dropout=0.1,
).eval().cuda()

In [None]:
ckpt_prior = "../models/latent_prior/version_5/checkpoints/last.ckpt"

prior = HierarchicalPixelSnailPrior.load_from_checkpoint(
    ckpt_prior,
    vqvae=latent_model,
    map_location="cuda",
).eval().cuda()

## Sampling

In [None]:
with torch.no_grad():
    sd_latents = prior.sample(
        n=4,              # how many
        temperature=.5,  # lower = sharper, higher = more varied
        top_k=64,         # restrict to top-k logits (optional, None = full softmax)
    ).cpu()               # imgs are in [-1,1]

## Decode samples

In [None]:
# Load Stable Diffusion VAE model
from diffusers import AutoencoderKL

sd_vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3.5-medium", subfolder="vae")
sd_vae.eval()

In [None]:
# Decode the latents
images = sd_vae.decode(sd_latents, return_dict=False)[0].detach()

In [None]:
# Plot the images
import matplotlib.pyplot as plt
def plot_images(imgs, nrow=4):
    """Plot a batch of images."""
    n = len(imgs)
    ncols = min(n, nrow)
    nrows = (n + ncols - 1) // ncols
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 2))
    for i in range(n):
        ax = axes[i // ncols, i % ncols]
        ax.imshow((imgs[i].permute(1, 2, 0).cpu().numpy().clip(-1,1) + 1) / 2)  # Convert to [0, 1]
        ax.axis('off')
    for j in range(i + 1, nrows * ncols):
        axes[j // ncols, j % ncols].axis('off')
    plt.tight_layout()
    plt.show()
plot_images(images, nrow=2)  # Display the generated images

## Analyze batch for PixelSNAIL prior

In [None]:
# Create data module
from src.dataloader.ffhq import FFHQWeightedDataset
from src.dataloader.weighting import DataWeighter

# Datamodule
img_dir="../data/ffhq/images1024x1024"
img_tensor_dir="../data/ffhq/pt_images"
attr_path="../data/ffhq/ffhq_smile_scores.json"
max_property_value=5
min_property_value=0
mode="all"
batch_size=16
num_workers=2 # 4
val_split=0.1
data_device="cuda" # "cpu" or "cuda"

# Weighter
weight_type="uniform"
rank_weight_k=1e-3
weight_quantile=None
dbas_noise=None
rwr_alpha=None

from argparse import Namespace

args = Namespace(
    img_dir=img_dir,
    img_tensor_dir=img_tensor_dir,
    attr_path=attr_path,
    max_property_value=max_property_value,
    min_property_value=min_property_value,
    mode=mode,
    batch_size=2,
    num_workers=num_workers,
    val_split=val_split,
    weight_type=weight_type,
    rank_weight_k=rank_weight_k,
    weight_quantile=weight_quantile,
    dbas_noise=dbas_noise,
    rwr_alpha=rwr_alpha,
    aug=True,
    data_device=data_device,
)

datamodule = FFHQWeightedDataset(args, DataWeighter(args))

batch = next(iter(datamodule.train_dataloader()))

latent_batch = sd_vae.encode(batch).latent_dist.sample().cpu()

In [None]:
import math
fr_nll = prior._free_run_nll(latent_batch.to(prior.device))
print("Autoregressive NLL (top):", fr_nll.item(), "nats   ",
      "perplexity ≈", math.exp(fr_nll.item()))