# Stable Diffusion VAE Reconstruction Demo

This notebook shows reconstructions obtained from a stable diffusion model.

## Setup

### Load Stable Diffusion Model

In [None]:
# Load Stable Diffusion VAE model
from diffusers import AutoencoderKL

# sd_vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3.5-medium", subfolder="vae")
sd_vae = AutoencoderKL.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="vae")
sd_vae.eval()

## Encoding Decoding Pipeline

### 1 Sample Pipeline

In [None]:
# Load one example image
import torch
IMAGE_TENSOR_PATH = "../data/ffhq/pt_images/65432.pt"

img_tensor = torch.load(IMAGE_TENSOR_PATH)

# Add batch dimension: Change from (C, H, W) to (B, C, H, W)
img_tensor = img_tensor.unsqueeze(0)

# Show the image
import matplotlib.pyplot as plt
import numpy as np
img = img_tensor.squeeze(0).permute(1, 2, 0).numpy()
plt.imshow((img + 1) / 2)
plt.axis('off')
plt.show()
print("Image shape:", img_tensor.shape)

In [None]:
# Encode the image using the Stable Diffusion VAE
sd_latent = sd_vae.encode(img_tensor).latent_dist.sample()

print("SD Latent shape:", sd_latent.shape)

In [None]:
# Decode the latent using the Stable Diffusion VAE
sd_recon = sd_vae.decode(sd_latent).sample

print("SD Recon shape:", sd_recon.shape)

In [None]:
# Show the recon image
recon_img = sd_recon.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
plt.imshow((recon_img + 1) / 2)
plt.axis('off')
plt.show()

In [None]:
# plot input and recon next to each other
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow((img + 1) / 2)
ax[0].set_title("Input Image")
ax[0].axis('off')
ax[1].imshow((recon_img + 1) / 2)
ax[1].set_title("Reconstructed Image")
ax[1].axis('off')
plt.show()

### 5 Sample Comparison

In [None]:
# plot 5 random inputs and recons next to each other
import os
import glob
import random
# Load the images from the folder
IMAGE_FOLDER = "../data/ffhq/pt_images/"
IMAGE_TENSOR_PATHS = glob.glob(os.path.join(IMAGE_FOLDER, "*.pt"))
# Select 5 random images
random.seed(42)
random.shuffle(IMAGE_TENSOR_PATHS)
selected_paths = IMAGE_TENSOR_PATHS[:5]
# Initialize lists to store images and reconstructions
input_images = []
reconstructed_images = []
# Loop through the selected images
for path in selected_paths:
    # Load the image
    img_tensor = torch.load(path)
    # Add batch dimension: Change from (C, H, W) to (B, C, H, W)
    img_tensor = img_tensor.unsqueeze(0)
    # Encode the image using the Stable Diffusion VAE
    sd_latent = sd_vae.encode(img_tensor).latent_dist.sample()
    # Decode the latent using the Stable Diffusion VAE
    sd_recon = sd_vae.decode(sd_latent).sample
    # Store the images
    input_images.append(img_tensor.squeeze(0).permute(1, 2, 0).numpy())
    reconstructed_images.append(sd_recon.squeeze(0).permute(1, 2, 0).detach().cpu().numpy())
# Plot the input and reconstructed images
fig, ax = plt.subplots(5, 2, figsize=(10, 25))
for i in range(5):
    ax[i, 0].imshow((input_images[i] + 1) / 2)
    ax[i, 0].set_title("Input Image")
    ax[i, 0].axis('off')
    ax[i, 1].imshow((reconstructed_images[i] + 1) / 2)
    ax[i, 1].set_title("Reconstructed Image")
    ax[i, 1].axis('off')
plt.tight_layout()
plt.show()