# Latent Model Reconstructions

This notebook compares different latent models and hyperparameter settings, and the corresponding training results.

## Setup

In [None]:
from src.models.latent_models import *
import yaml

### Model Configuration

In [None]:
# Manual Configuration
MODEL_LIST = [
	# MODEL_TYPE, VERSION
	("LatentVAE", 16),
	("LatentAE", 1),
	("LatentLinearAE", 4),
	("LatentVQVAE", 7),
	("LatentVQVAE", "8_2"),
	("LatentVQVAE2", 0),
	("LatentVQVAE2", "1_2"),
]

In [None]:
# Automatic Configuration
model_type_mapping = {
    "LatentVAE": "latent_vae",
    "LatentVQVAE": "latent_vqvae",
    "LatentVQVAE2": "latent_vqvae2",
    "LatentAE": "latent_ae",
    "LatentLinearAE": "latent_linear_ae"
}
model_type_sub = [model_type_mapping[MODEL_LIST[i][0]] for i in range(len(MODEL_LIST))]
config_path = [f"../models/{model_type_sub[i]}/version_{MODEL_LIST[i][1]}/hparams.yaml" for i in range(len(MODEL_LIST))]
ckpt_path = [f"../models/{model_type_sub[i]}/version_{MODEL_LIST[i][1]}/checkpoints/last.ckpt" for i in range(len(MODEL_LIST))]

### Load Latent Model

In [None]:
# Get model config
config = []
for i in range(len(MODEL_LIST)):
    with open(config_path[i], "r") as f:
        config.append(yaml.safe_load(f))

# Extract model class
model_cls_mapping = {
    "LatentVAE": LatentVAE,
    "LatentVQVAE": LatentVQVAE,
    "LatentVQVAE2": LatentVQVAE2,
    "LatentAE": LatentAE,
    "LatentLinearAE": LatentLinearAE
}
model_cls = [model_cls_mapping[MODEL_LIST[i][0]] for i in range(len(MODEL_LIST))]

In [None]:
# Initialize models
model = []
for i in range(len(MODEL_LIST)):
    curr_model = model_cls[i](
        ddconfig=config[i]["ddconfig"],
        lossconfig=config[i]["lossconfig"],
        ckpt_path=ckpt_path[i],
        ignore_keys=['loss'],
    )
    curr_model.eval()
    model.append(curr_model)

### Load Stable Diffusion Model

In [None]:
# Load Stable Diffusion VAE model
from diffusers import AutoencoderKL

sd_vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3.5-medium", subfolder="vae")
sd_vae.eval()

## Encoding Decoding Pipeline

### Run Forward Pass

In [None]:
import torch

NUM_SAMPLES = 5

# Load (subset of) eval batch
eval_batch = torch.load("../data/ffhq/eval/batch_256.pt")
eval_batch = eval_batch[:NUM_SAMPLES]

# Initialize lists to store images and reconstructions
input_img = []
recon_img = [
    [] for _ in range(len(MODEL_LIST))
]

# Loop through the selected images
for i in range(NUM_SAMPLES):
    # Add batch dimension: Change from (C, H, W) to (B, C, H, W)
    img_tensor = eval_batch[i].unsqueeze(0)
    # Encode the image using the Stable Diffusion VAE
    sd_latent = sd_vae.encode(img_tensor).latent_dist.sample()
    for i, model_i in enumerate(model):
        # Encode and decode the latent using the model
        recon = model_i(sd_latent, return_only_recon=True)
        # Decode the latents using the Stable Diffusion VAE
        sd_recon = sd_vae.decode(recon).sample
        # Clip the reconstructions to the range [-1, 1]
        sd_recon = torch.clamp(sd_recon, -1, 1)
        # Store the images
        recon_img[i].append(sd_recon.squeeze(0).permute(1, 2, 0).detach().cpu().numpy())
    input_img.append(img_tensor.squeeze(0).permute(1, 2, 0).numpy())

In [None]:
import matplotlib.pyplot as plt

num_cols = len(MODEL_LIST) + 1
num_rows = NUM_SAMPLES

# Plot the input and reconstructed images
fig, ax = plt.subplots(num_rows, num_cols, figsize=(num_cols*5, num_rows*5))
for i in range(num_rows):
    # Plot the input image
    ax[0, 0].set_title("Input Image", fontsize=25)
    ax[i, 0].imshow((input_img[i] + 1) / 2)
    ax[i, 0].axis('off')

    # Plot the reconstructed images
    for j in range(len(MODEL_LIST)):
        ax[0, j + 1].set_title(f"{MODEL_LIST[j][0]} (v{MODEL_LIST[j][1]})", fontsize=25)
        ax[i, j + 1].imshow((recon_img[j][i] + 1) / 2)
        ax[i, j + 1].axis('off')
    
plt.tight_layout()
plt.show()

## Custom Plots

### LatentVAEs

```python
MODEL_LIST = [
    # MODEL_TYPE, VERSION
    ("LatentVAE", 5),
    ("LatentVAE", 8),
    ("LatentVAE", 16),
    ("LatentVAE", 17),
    ("LatentVAE", 19),
    ("LatentVAE", 18),
    ("LatentVAE", 20),
]
```

In [None]:
import matplotlib.pyplot as plt

num_cols = len(MODEL_LIST) + 1
num_rows = NUM_SAMPLES

# Plot the input and reconstructed images
fig, ax = plt.subplots(num_rows, num_cols, figsize=(num_cols*5, num_rows*5))
for i in range(num_rows):
    # Plot the input image
    ax[i, 0].imshow((input_img[i] + 1) / 2)
    ax[i, 0].axis('off')

    # Plot the reconstructed images
    for j in range(len(MODEL_LIST)):
        ax[i, j + 1].imshow((recon_img[j][i] + 1) / 2)
        ax[i, j + 1].axis('off')
        
	# Set custom titles
    ax[0, 0].set_title("Original", fontsize=25)
    ax[0, 1].set_title(f"16k → 128 (VAE Loss)", fontsize=25) # v5
    ax[0, 2].set_title(f"16k → 128 (VAE+Disc. Loss)", fontsize=25) # v8
    ax[0, 3].set_title(f"16k → 512", fontsize=25) # v16
    ax[0, 4].set_title(f"16k → 512 (Alt. Design)", fontsize=25) # v17
    ax[0, 5].set_title(f"16k → 4k", fontsize=25) # v19
    ax[0, 6].set_title(f"16k → 8k", fontsize=25) # v18
    ax[0, 7].set_title(f"16k → 8k ($\\ell_1$ Loss)", fontsize=25) # v20

plt.tight_layout()
plt.savefig("vis/latent_vae_recon_comparison.pdf", bbox_inches='tight')
plt.show()

### LatentVAE, LatentAE, LatentLinearAE, LatentVQVAE, LatentVQVAE2, VQVAE2

```python
MODEL_LIST = [
	# MODEL_TYPE, VERSION
	("LatentVAE", 16),
	("LatentAE", 1),
	("LatentLinearAE", 4),
	("LatentVQVAE", 7),
	("LatentVQVAE", "8_2"),
	("LatentVQVAE2", 0),
	("LatentVQVAE2", "1_2"),
]
```

The ```VQVAE2``` model is added directly since it does not work in the Stable Diffusion latent space.

In [None]:
# Load VQVAE2 model
from src.models.vqvae2 import VQVAE2

ckpt_vqvae   = "../models/vqvae2/version_0_2/checkpoints/last.ckpt"
config_yaml  = "../models/vqvae2/version_0_2/hparams.yaml"

vqvae2 = VQVAE2.load_from_checkpoint(
    ckpt_vqvae,
    hparams_file=config_yaml,
    map_location="cpu",
)
vqvae2.eval().requires_grad_(False)

# Get the same reconstruction for the VQVAE2 model
recon_vqvae2 = []
for i in range(NUM_SAMPLES):
	# Add batch dimension: Change from (C, H, W) to (B, C, H, W)
	img_tensor = eval_batch[i].unsqueeze(0)
	# Encode and decode the latent using the VQVAE2 model
	recon = vqvae2(img_tensor, return_only_recon=True)
	# Clip the reconstructions to the range [-1, 1]
	recon = torch.clamp(recon, -1, 1)
	# Store the images
	recon_vqvae2.append(recon.squeeze(0).permute(1, 2, 0).detach().cpu().numpy())

# Append the VQVAE2 reconstructions to the recon_img list
recon_img.append(recon_vqvae2)

In [None]:
import matplotlib.pyplot as plt

num_cols = len(MODEL_LIST) + 2
num_rows = NUM_SAMPLES

# Plot the input and reconstructed images
fig, ax = plt.subplots(num_rows, num_cols, figsize=(num_cols*5, num_rows*5))
for i in range(num_rows):
    # Plot the input image
    ax[i, 0].imshow((input_img[i] + 1) / 2)
    ax[i, 0].axis('off')

    # Plot the reconstructed images
    for j in range(len(MODEL_LIST)+1):
        ax[i, j + 1].imshow((recon_img[j][i] + 1) / 2)
        ax[i, j + 1].axis('off')

	# Set custom titles
    ax[0, 0].set_title("Original", fontsize=25)
    ax[0, 1].set_title(f"LatentVAE", fontsize=25)
    ax[0, 2].set_title(f"LatentAE", fontsize=25)
    ax[0, 3].set_title(f"LatentLinearAE", fontsize=25)
    ax[0, 4].set_title(f"LatentVQVAE", fontsize=25) # v7
    ax[0, 5].set_title(f"+ larger codebook", fontsize=25) # v8
    ax[0, 6].set_title(f"LatentVQVAE2", fontsize=25) # v0
    ax[0, 7].set_title(f"+ larger latent space", fontsize=25) # v1
    ax[0, 8].set_title(f"VQVAE2", fontsize=25)

plt.tight_layout()
plt.savefig("vis/latent_model_recon_comparison.pdf", bbox_inches='tight')
plt.show()