# 02 â€” Train VAE (Vision Model V)

Train a Convolutional VAE to compress 64x64x3 CarRacing frames into 32-dim latent vectors.

**Prerequisites:** Run `scripts/collect_data.py` to generate rollout data.

In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import torch

sys.path.insert(0, str(Path.cwd().parent))
from src.config import Config
from src.train_vae import train_vae

config = Config()
# Use MPS on Mac if available
if torch.backends.mps.is_available():
    config.device = "mps"
print(f"Device: {config.device}")

In [None]:
# Train VAE
data_dir = Path.cwd().parent / config.data.data_dir
model = train_vae(config.vae, data_dir, device=config.device)

In [None]:
# Visualize reconstructions
import numpy as np
from src.vae_dataset import VAEDataset

dataset = VAEDataset(data_dir)
model.eval()

fig, axes = plt.subplots(2, 8, figsize=(20, 5))
indices = np.random.choice(len(dataset), 8, replace=False)

with torch.no_grad():
    for col, idx in enumerate(indices):
        x = dataset[idx].unsqueeze(0).to(config.device)
        x_recon, _, _, _ = model(x)

        orig = x.squeeze().cpu().permute(1, 2, 0).numpy()
        recon = x_recon.squeeze().cpu().permute(1, 2, 0).numpy()

        axes[0, col].imshow(orig)
        axes[0, col].axis("off")
        axes[1, col].imshow(recon)
        axes[1, col].axis("off")

axes[0, 0].set_ylabel("Original", fontsize=14)
axes[1, 0].set_ylabel("Reconstructed", fontsize=14)
plt.suptitle("VAE Reconstruction Quality", fontsize=16)
plt.tight_layout()
plt.show()

In [None]:
# Latent space visualization (t-SNE of random frames)
from sklearn.manifold import TSNE

n_samples = 2000
indices = np.random.choice(len(dataset), n_samples, replace=False)
z_list = []

model.eval()
with torch.no_grad():
    for idx in indices:
        x = dataset[idx].unsqueeze(0).to(config.device)
        z = model.encode(x)
        z_list.append(z.cpu().numpy())

z_all = np.concatenate(z_list, axis=0)
tsne = TSNE(n_components=2, perplexity=30, random_state=42)
z_2d = tsne.fit_transform(z_all)

plt.figure(figsize=(8, 8))
plt.scatter(z_2d[:, 0], z_2d[:, 1], alpha=0.3, s=5)
plt.title("VAE Latent Space (t-SNE)")
plt.xlabel("t-SNE dim 1")
plt.ylabel("t-SNE dim 2")
plt.show()