# Latent Prior Sampling

This notebook demonstrates how to sample from a VQVAE2 model that uses a hierarchical Transformer prior in the latent spaces.

## Setup

In [None]:
import torch
from src.models.latent_models import LatentVQVAE2
from src.models.vqvae2 import VQVAE2
from src.models.transformer_prior import HierarchicalTransformerPrior

### Load Stable Diffusion VAE

In [None]:
# Load Stable Diffusion VAE model
from diffusers import AutoencoderKL

sd_vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3.5-medium", subfolder="vae")
sd_vae.eval().cuda()

### Load VQVAE2 models

In [None]:
ckpt_vqvae   = "../models/latent_vqvae2/version_1_2/checkpoints/last.ckpt"
config_yaml  = "../models/latent_vqvae2/version_1_2/hparams.yaml"

latent_vqvae2 = LatentVQVAE2.load_from_checkpoint(
    ckpt_vqvae,
    hparams_file=config_yaml,
    map_location="cpu",
)
latent_vqvae2.eval().requires_grad_(False)

In [None]:
ckpt_vqvae   = "../models/vqvae2/version_0_2/checkpoints/last.ckpt"
config_yaml  = "../models/vqvae2/version_0_2/hparams.yaml"

vqvae2 = VQVAE2.load_from_checkpoint(
    ckpt_vqvae,
    hparams_file=config_yaml,
    map_location="cpu",
)
vqvae2.eval().requires_grad_(False)

### Load Transformer prior

#### LatentVQVAE2 prior

In [None]:
ckpt_prior = "../models/latent_prior/version_16/checkpoints/epoch_020.ckpt"

latent_vqvae2_prior = HierarchicalTransformerPrior.load_from_checkpoint(
    ckpt_prior,
    vqvae=latent_vqvae2,
    map_location="cuda",
).eval().cuda()

#### VQVAE2 prior

In [None]:
ckpt_prior = "../models/latent_prior/version_11/checkpoints/epoch_018.ckpt"

vqvae2_prior = HierarchicalTransformerPrior.load_from_checkpoint(
    ckpt_prior,
    vqvae=vqvae2,
    map_location="cuda",
).eval().cuda()

## Random Samples

### LatentVQVAE2 Prior Sampling

In [None]:
with torch.no_grad():
    samples = latent_vqvae2_prior.sample(
        n=4,              			# how many
        temperature=1.0,  			# lower = sharper, higher = more varied
        top_k=None,         		# restrict to top-k logits (optional, None = full softmax)
        seed=42            			# random seed for reproducibility
    ).cpu()

In [None]:
import matplotlib.pyplot as plt

# Plot the images
fig, axes = plt.subplots(2, 2, figsize=(2*2, 2*2))
axes = axes.flatten()
for i in range(4):
	axes[i].imshow(samples[i].permute(1, 2, 0).cpu().numpy().clip(-1,1) * 0.5 + 0.5)
	axes[i].axis('off')

plt.tight_layout()
plt.show()

### VQVAE2 Prior Sampling

In [None]:
with torch.no_grad():
    samples = vqvae2_prior.sample(
        n=4,              			# how many
        temperature=1.0,  			# lower = sharper, higher = more varied
        top_k=None,         		# restrict to top-k logits (optional, None = full softmax)
        seed=42            			# random seed for reproducibility
    ).cpu()

In [None]:
# Decode the latents
samples = sd_vae.decode(samples, return_dict=False)[0].detach()

In [None]:
import matplotlib.pyplot as plt

# Plot the images
fig, axes = plt.subplots(2, 2, figsize=(2*2, 2*2))
axes = axes.flatten()
for i in range(4):
	axes[i].imshow(samples[i].permute(1, 2, 0).cpu().numpy().clip(-1,1) * 0.5 + 0.5)
	axes[i].axis('off')

plt.tight_layout()
plt.show()

## Grid Comparison

In [None]:
samples = []

for top_k in [64, 128, 256, 512, None]:
	# LatentVQVAE2 prior
	for temp in [0.6, 0.7, 0.8, 0.9, 1.0]:
		with torch.no_grad():
			sample = latent_vqvae2_prior.sample(
				n=1,
				temperature=temp,
				top_k=top_k,
				seed=42
			)
			sample = sd_vae.decode(sample, return_dict=False)[0].detach().squeeze()
			samples.append(sample.cpu())

	# VQVAE2 prior
	for temp in [0.6, 0.7, 0.8, 0.9, 1.0]:
		with torch.no_grad():
			sample = vqvae2_prior.sample(
				n=1,
				temperature=temp,
				top_k=top_k,
				seed=42
			)
			samples.append(sample.squeeze().cpu())

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

# Create a GridSpec with an extra column for spacing
fig = plt.figure(constrained_layout=True, figsize=(2*2*5*1.01, 2*5))
gs = gridspec.GridSpec(nrows=5, ncols=11, width_ratios=[1,1,1,1,1,0.1,1,1,1,1,1])

# Create a list to hold axes (ignoring the spacer column)
axes = []
for i in range(5):
    for j in range(11):
        if j == 5:  # skip spacer column
            continue
        ax = fig.add_subplot(gs[i, j])
        axes.append(ax)

# Plot the images on the 50 axes
for i in range(50):
    ax = axes[i]
    ax.imshow(samples[i].permute(1, 2, 0).cpu().numpy().clip(-1,1) * 0.5 + 0.5)
    # Remove ticks and spines
    ax.set_xticks([])
    ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_visible(False)

# Add labels for each column (assuming the top row of each block)
for i in range(5):
    axes[i].set_title(f"Temp: {0.6 + 0.1*i:.1f}")
    axes[i+5].set_title(f"Temp: {0.6 + 0.1*i:.1f}")

# Add labels for each row
for i in range(5):
    axes[i*10].set_ylabel(f"Top-k: {64 * (2**i) if i < 4 else 'Full'}", rotation=90,
                            ha='center', va='center', labelpad=12, fontsize=12)

# Add the supertitles
fig.text(0.258, 1.02, "LatentVQVAE2 Prior", ha="center", va="center", fontsize=14, fontweight='bold')
fig.text(0.755, 1.02, "VQVAE2 Prior", ha="center", va="center", fontsize=14, fontweight='bold')

# Draw a vertical line between the two blocks in the spacer region.
line = plt.Line2D([0.5055, 0.5055], [0, 1], transform=fig.transFigure,
                  color='black', linewidth=2, linestyle='--')
fig.add_artist(line)

plt.tight_layout()
plt.savefig("vis/latent_prior_samples_comparison.pdf", bbox_inches='tight')
plt.show()