# GP-VAE Training on Google Colab

This notebook trains the **GP-VAE (Gaussian Process Variational Autoencoder)** model using Google Colab's free GPU.

## What is GP-VAE?
GP-VAE adds a **Gaussian Process prior** to the VAE latent space to model structured correlations:
- **VAE**: Learns image ‚Üî latent code mapping
- **GP Prior**: Models correlations between latent codes based on:
  - Object identity (same person's face)
  - View angle (front, side, profile)
  - Other factors of variation

## Prerequisites ‚ö†Ô∏è
**You MUST have trained VAE weights first!** This model loads pre-trained VAE and fine-tunes it jointly with the GP.

Required files:
- ‚úÖ `out/vae_colab/YYYYMMDD_HHMMSS/vae.cfg.p` - VAE configuration
- ‚úÖ `out/vae_colab/YYYYMMDD_HHMMSS/weights/weights.00000.pt` - Trained VAE weights

## Output Directory Structure:

Each training run creates a **timestamped directory** to avoid overwriting previous runs:
- Format: `./out/gppvae_colab/YYYYMMDD_HHMMSS/`
- Example: `./out/gppvae_colab/20251224_143530/weights/weights.00100.pt`
- This allows you to compare different training runs and keep a history!

Cell 6 below will automatically find your latest VAE training run.

## Setup Instructions:

1. **Open this notebook in VS Code**
2. **Connect to Colab**: Click kernel picker ‚Üí "Connect to Colab" ‚Üí Choose **GPU runtime (T4)**
3. **Important**: When prompted with "Alias your server", press Enter
4. **Run cell 2** - it will automatically detect your project location


## 1. Check GPU Availability

In [1]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è WARNING: GPU not detected! Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

PyTorch version: 2.9.0+cu126
CUDA available: True
GPU Device: Tesla T4
GPU Memory: 15.83 GB


## 2. Auto-Detect Project Path

This automatically finds your project files on the Colab runtime.

In [2]:
import os
import sys

# Get current directory
current_dir = os.getcwd()
print(f"üìç Current directory: {current_dir}")

# Check if on Colab and need to mount Drive
if current_dir == '/content':
    print("\nüîÑ Mounting Google Drive...")

    try:
        from google.colab import drive
        drive.mount('/content/drive')

        # Check for project in Drive
        drive_path = '/content/drive/MyDrive/gppvae'
        if os.path.exists(drive_path):
            PROJECT_PATH = drive_path
            print(f"‚úÖ Found project in Google Drive: {PROJECT_PATH}")
        else:
            print(f"\n‚ö†Ô∏è  Project not found at: {drive_path}")
            print("\nPlease upload your gppvae folder to Google Drive!")
            print("Required structure:")
            print("  MyDrive/gppvae/")
            print("    ‚îú‚îÄ‚îÄ GPPVAE/")
            print("    ‚îú‚îÄ‚îÄ data/faceplace/data_faces.h5")
            print("    ‚îî‚îÄ‚îÄ out/vae_colab/YYYYMMDD_HHMMSS/")
            print("        ‚îú‚îÄ‚îÄ vae.cfg.p")
            print("        ‚îî‚îÄ‚îÄ weights/weights.00000.pt")
            PROJECT_PATH = '/content'
    except Exception as e:
        print(f"Could not mount Drive: {e}")
        PROJECT_PATH = '/content'
else:
    # Running via VS Code sync
    if 'notebooks' in current_dir:
        PROJECT_PATH = os.path.dirname(current_dir)
    else:
        PROJECT_PATH = current_dir
    print(f"üíª Using project path: {PROJECT_PATH}")

# Verify structure
print(f"\nüìÅ Contents of {PROJECT_PATH}:")
if os.path.exists(PROJECT_PATH):
    items = os.listdir(PROJECT_PATH)
    for item in sorted(items)[:15]:
        item_path = os.path.join(PROJECT_PATH, item)
        if os.path.isdir(item_path):
            print(f"   üìÇ {item}/")
        else:
            print(f"   üìÑ {item}")

    # Check required files (with timestamped directory structure)
    print(f"\nüîç Checking required files:")
    required = {
        'GPPVAE code': os.path.exists(os.path.join(PROJECT_PATH, 'GPPVAE')),
        'data/faceplace': os.path.exists(os.path.join(PROJECT_PATH, 'data/faceplace')),
        'data_faces.h5': os.path.exists(os.path.join(PROJECT_PATH, 'data/faceplace/data_faces.h5')),
    }

    # Check for VAE runs (timestamped subdirectories)
    vae_base_dir = os.path.join(PROJECT_PATH, 'out/vae_colab')
    vae_run_found = False
    vae_weights_found = False

    if os.path.exists(vae_base_dir):
        # Look for timestamped subdirectories
        potential_runs = [d for d in os.listdir(vae_base_dir)
                         if os.path.isdir(os.path.join(vae_base_dir, d)) and d[0].isdigit()]

        for run_dir in potential_runs:
            run_path = os.path.join(vae_base_dir, run_dir)
            cfg_path = os.path.join(run_path, 'vae.cfg.p')
            weights_dir = os.path.join(run_path, 'weights')

            if os.path.exists(cfg_path):
                vae_run_found = True

            if os.path.exists(weights_dir):
                weight_files = [f for f in os.listdir(weights_dir) if f.endswith('.pt')]
                if weight_files:
                    vae_weights_found = True
                    break

    required['VAE config'] = vae_run_found
    required['VAE weights'] = vae_weights_found

    for name, exists in required.items():
        status = "‚úÖ" if exists else "‚ùå"
        print(f"   {status} {name}")

    # Show VAE runs if they exist
    if os.path.exists(vae_base_dir):
        potential_runs = sorted([d for d in os.listdir(vae_base_dir)
                                if os.path.isdir(os.path.join(vae_base_dir, d)) and d[0].isdigit()],
                               reverse=True)

        if potential_runs:
            print(f"\nüì¶ Found {len(potential_runs)} VAE training run(s):")
            for i, run_dir in enumerate(potential_runs[:3], 1):  # Show latest 3
                run_path = os.path.join(vae_base_dir, run_dir)
                weights_dir = os.path.join(run_path, 'weights')

                if os.path.exists(weights_dir):
                    weight_files = sorted([f for f in os.listdir(weights_dir) if f.endswith('.pt')])
                    print(f"   {i}. {run_dir}/ ({len(weight_files)} checkpoints)")
                    if weight_files:
                        print(f"      Latest: {weight_files[-1]}")

            if len(potential_runs) > 3:
                print(f"   ... and {len(potential_runs) - 3} more")

            print(f"\nüí° Cell 6 below will help you choose which run to use")

    if not all(required.values()):
        print(f"\n‚ö†Ô∏è  Missing required files!")
        if not required['VAE weights']:
            print("\nüö® CRITICAL: No trained VAE weights found!")
            print("   You must train VAE first before running GP-VAE")
            print("   Use the train_vae_colab.ipynb notebook")
else:
    print(f"‚ùå Path doesn't exist: {PROJECT_PATH}")


üìç Current directory: /content

üîÑ Mounting Google Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚úÖ Found project in Google Drive: /content/drive/MyDrive/gppvae

üìÅ Contents of /content/drive/MyDrive/gppvae:
   üìÇ GPPVAE/
   üìÇ data/
   üìÑ environment.yml
   üìÇ notebooks/
   üìÇ out/

üîç Checking required files:
   ‚úÖ GPPVAE code
   ‚úÖ data/faceplace
   ‚úÖ data_faces.h5
   ‚úÖ VAE config
   ‚úÖ VAE weights

üì¶ Found 1 VAE training run(s):
   1. 20251224_120136/ (16 checkpoints)
      Latest: weights.00140.pt

üí° Cell 6 below will help you choose which run to use


## 3. Install Dependencies

In [3]:
# Install required packages
!pip install -q wandb==0.12.21 imageio==2.15.0 pyyaml

# Verify installations
import wandb
import imageio
import yaml
import numpy as np
print("‚úÖ All dependencies installed successfully!")

  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m√ó[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m‚îÇ[0m exit code: [1;36m1[0m
  [31m‚ï∞‚îÄ>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Preparing metadata (setup.py) ... [?25l[?25herror
[1;31merror[0m: [1mmetadata-generation-failed[0m

[31m√ó[0m Encountered error while generating package metadata.
[31m‚ï∞‚îÄ>[0m See above for output.

[1;35mnote[0m: This is an issue with the package mentioned above, not pip.
[1;36mhint[0m: See above for details.
‚úÖ All dependencies installed successfully!


## 4. Login to Weights & Biases (Optional)

Track your experiments with W&B for better monitoring.

In [4]:
import wandb
wandb.login()

# Or run offline without W&B:
# import os
# os.environ['WANDB_MODE'] = 'offline'

[34m[1mwandb[0m: Currently logged in as: [33mminh1008[0m ([33mminh1008-ludwig-maximilianuniversity-of-munich[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## 5. Navigate to Project Directory

In [5]:
import os
import sys

os.chdir(PROJECT_PATH)
print(f"Current directory: {os.getcwd()}")

# Add to Python path
sys.path.insert(0, os.path.join(PROJECT_PATH, 'GPPVAE/pysrc/faceplace'))

print("\nProject structure:")
!ls -la

Current directory: /content/drive/MyDrive/gppvae

Project structure:
total 17
drwx------ 3 root root 4096 Dec 23 14:09 data
-rw------- 1 root root  258 Dec 23 11:40 environment.yml
drwx------ 3 root root 4096 Dec 23 14:09 GPPVAE
drwx------ 2 root root 4096 Dec 23 14:09 notebooks
drwx------ 3 root root 4096 Dec 23 14:21 out


## 6. Verify VAE Weights

**Critical check:** Make sure you have trained VAE weights!

In [6]:
import os
import pickle
import glob

# Check for VAE runs (may be in timestamped subdirectories)
vae_base_dir = './out/vae_colab'
vae_runs = []

if os.path.exists(vae_base_dir):
    # Look for timestamped subdirectories
    potential_runs = [d for d in os.listdir(vae_base_dir) if os.path.isdir(os.path.join(vae_base_dir, d))]
    for run_dir in sorted(potential_runs, reverse=True):  # Most recent first
        run_path = os.path.join(vae_base_dir, run_dir)
        cfg_path = os.path.join(run_path, 'vae.cfg.p')
        weights_dir = os.path.join(run_path, 'weights')

        if os.path.exists(cfg_path) and os.path.exists(weights_dir):
            weight_files = sorted([f for f in os.listdir(weights_dir) if f.endswith('.pt')])
            if weight_files:
                vae_runs.append({
                    'run_dir': run_dir,
                    'cfg_path': cfg_path,
                    'weights_dir': weights_dir,
                    'weight_files': weight_files
                })

if vae_runs:
    print(f"‚úÖ Found {len(vae_runs)} VAE training run(s):\n")

    for i, run in enumerate(vae_runs, 1):
        print(f"Run {i}: {run['run_dir']}")

        # Load and show config
        vae_cfg = pickle.load(open(run['cfg_path'], 'rb'))
        print(f"   Config: zdim={vae_cfg.get('zdim', 'N/A')}, nf={vae_cfg.get('nf', 'N/A')}")

        # Show checkpoints
        print(f"   Checkpoints: {len(run['weight_files'])} files")
        if len(run['weight_files']) <= 3:
            for wf in run['weight_files']:
                print(f"      üì¶ {wf}")
        else:
            print(f"      üì¶ {run['weight_files'][0]} ... {run['weight_files'][-1]}")
        print()

    # Recommendation
    latest_run = vae_runs[0]
    latest_weight = latest_run['weight_files'][-1]
    recommended_path = os.path.join(latest_run['weights_dir'], latest_weight)

    print(f"üí° Recommendation:")
    print(f"   Use latest run: {latest_run['run_dir']}")
    print(f"   Latest checkpoint: {latest_weight}")
    print(f"   \n   Set in next cell:")
    print(f"   CONFIG['vae_cfg'] = '{latest_run['cfg_path']}'")
    print(f"   CONFIG['vae_weights'] = '{recommended_path}'")

else:
    print("‚ùå No trained VAE runs found!")
    print("\n   Please train VAE first using train_vae_colab.ipynb")
    print(f"   Expected location: {vae_base_dir}/YYYYMMDD_HHMMSS/")


‚úÖ Found 1 VAE training run(s):

Run 1: 20251224_120136
   Config: zdim=256, nf=32
   Checkpoints: 16 files
      üì¶ weights.00000.pt ... weights.00140.pt

üí° Recommendation:
   Use latest run: 20251224_120136
   Latest checkpoint: weights.00140.pt
   
   Set in next cell:
   CONFIG['vae_cfg'] = './out/vae_colab/20251224_120136/vae.cfg.p'
   CONFIG['vae_weights'] = './out/vae_colab/20251224_120136/weights/weights.00140.pt'


## 7. Configure GP-VAE Training

Adjust these parameters as needed:

In [13]:
from datetime import datetime

# GP-VAE Training configuration
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
CONFIG = {
    'data': './data/faceplace/data_faces.h5',
    'outdir': f'./out/gppvae_colab/{timestamp}',  # Timestamped directory for each run
    'vae_cfg': './out/vae_colab/20251224_120136/vae.cfg.p',
    'vae_weights': './out/vae_colab/20251224_120136/weights/weights.00140.pt',  # ‚¨ÖÔ∏è Change this if using different checkpoint

    # Training hyperparameters
    'epochs': 50,  # Start with 100, increase to 1000+ for publication quality
    'batch_size': 64,
    'vae_lr': 0.001,  # Learning rate for VAE (fine-tuning)
    'gp_lr': 0.01,    # Learning rate for GP and Vmodel
    'xdim': 64,        # Rank of object linear covariance

    # Logging
    'epoch_cb': 10,    # Save checkpoint every N epochs
    'use_wandb': True,
    'wandb_project': 'gppvae',
    'wandb_run_name': f'colab_gppvae_gpu_{timestamp}',  # Also add timestamp to run name
    'seed': 0,
}

print("GP-VAE Training Configuration:")
print("=" * 60)
for key, value in CONFIG.items():
    print(f"  {key:20s}: {value}")
print("=" * 60)

# Verify VAE weights path
if not os.path.exists(CONFIG['vae_weights']):
    print(f"\n‚ö†Ô∏è  WARNING: VAE weights not found at:")
    print(f"   {CONFIG['vae_weights']}")
    print(f"\n   Please update CONFIG['vae_weights'] to point to a valid checkpoint.")

GP-VAE Training Configuration:
  data                : ./data/faceplace/data_faces.h5
  outdir              : ./out/gppvae_colab/20251224_111332
  vae_cfg             : ./out/vae_colab/20251224_120136/vae.cfg.p
  vae_weights         : ./out/vae_colab/20251224_120136/weights/weights.00140.pt
  epochs              : 50
  batch_size          : 64
  vae_lr              : 0.001
  gp_lr               : 0.01
  xdim                : 64
  epoch_cb            : 10
  use_wandb           : True
  wandb_project       : gppvae
  wandb_run_name      : colab_gppvae_gpu_20251224_111332
  seed                : 0


## 8. Import Training Modules

In [15]:
# Change to training script directory
os.chdir(os.path.join(PROJECT_PATH, 'GPPVAE/pysrc/faceplace'))

# Import modules
import matplotlib
matplotlib.use('Agg')

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from vae import FaceVAE
from vmod import Vmodel
from gp import GP
import h5py
import numpy as np
import logging
import pylab as pl
from utils import smartSum, smartAppendDict, smartAppend, export_scripts
from callbacks import callback_gppvae
from data_parser import read_face_data, FaceDataset
import pickle
import time
import wandb

print("‚úÖ All modules imported successfully!")

‚úÖ All modules imported successfully!


## 9. Setup Training Environment

In [16]:
# Go back to project root
os.chdir(PROJECT_PATH)

# Create output directories
outdir = CONFIG['outdir']
wdir = os.path.join(outdir, "weights")
fdir = os.path.join(outdir, "plots")
os.makedirs(wdir, exist_ok=True)
os.makedirs(fdir, exist_ok=True)

# Setup device (GPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Setup logging
log_format = "%(asctime)s %(message)s"
logging.basicConfig(
    level=logging.INFO,
    format=log_format,
    datefmt="%m/%d %I:%M:%S %p",
)
fh = logging.FileHandler(os.path.join(outdir, "log.txt"))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

# Copy code to output
export_scripts(os.path.join(outdir, "scripts"))

print("‚úÖ Training environment setup complete!")
print(f"   Outputs will be saved to: {outdir}")

Using device: cuda:0
‚úÖ Training environment setup complete!
   Outputs will be saved to: ./out/gppvae_colab/20251224_111332


## 10. Initialize Models and Data

This cell:
1. Loads pre-trained VAE
2. Creates GP and Vmodel
3. Loads dataset
4. Sets up optimizers

In [17]:
# Set random seed
torch.manual_seed(CONFIG['seed'])

# Initialize W&B
if CONFIG['use_wandb']:
    wandb.init(
        project=CONFIG['wandb_project'],
        name=CONFIG['wandb_run_name'],
        config=CONFIG
    )

# Load VAE configuration
vae_cfg = pickle.load(open(CONFIG['vae_cfg'], "rb"))
print(f"VAE config: {vae_cfg}")

# Load pre-trained VAE
print("\nLoading pre-trained VAE...")
vae = FaceVAE(**vae_cfg).to(device)
vae_state = torch.load(CONFIG['vae_weights'], map_location=device)
vae.load_state_dict(vae_state)
print(f"‚úÖ VAE loaded from {CONFIG['vae_weights']}")
print(f"   Total VAE parameters: {sum(p.numel() for p in vae.parameters()):,}")

# Load data
print("\nLoading dataset...")
img, obj, view = read_face_data(CONFIG['data'])
train_data = FaceDataset(img["train"], obj["train"], view["train"])
val_data = FaceDataset(img["val"], obj["val"], view["val"])
train_queue = DataLoader(train_data, batch_size=CONFIG['batch_size'], shuffle=True)
val_queue = DataLoader(val_data, batch_size=CONFIG['batch_size'], shuffle=False)
print(f"‚úÖ Data loaded:")
print(f"   Training samples: {len(train_data)}")
print(f"   Validation samples: {len(val_data)}")

# Create object and view variables for GP
Dt = Variable(obj["train"][:, 0].long(), requires_grad=False).cuda()
Wt = Variable(view["train"][:, 0].long(), requires_grad=False).cuda()
Dv = Variable(obj["val"][:, 0].long(), requires_grad=False).cuda()
Wv = Variable(view["val"][:, 0].long(), requires_grad=False).cuda()

# Initialize GP and Vmodel
print("\nInitializing GP-VAE components...")
P = np.unique(obj["train"]).shape[0]  # Number of unique objects (people)
Q = np.unique(view["train"]).shape[0]  # Number of unique views (angles)
print(f"   Objects (people): {P}")
print(f"   Views (angles): {Q}")

vm = Vmodel(P, Q, CONFIG['xdim'], Q).cuda()
gp = GP(n_rand_effs=1).to(device)

# Combine GP parameters (Vmodel + GP)
gp_params = nn.ParameterList()
gp_params.extend(vm.parameters())
gp_params.extend(gp.parameters())

print(f"‚úÖ GP-VAE components initialized:")
print(f"   Vmodel parameters: {sum(p.numel() for p in vm.parameters()):,}")
print(f"   GP parameters: {sum(p.numel() for p in gp.parameters()):,}")
print(f"   Total trainable: {sum(p.numel() for p in vae.parameters()) + sum(p.numel() for p in gp_params):,}")

# Create optimizers (separate for VAE and GP)
vae_optimizer = optim.Adam(vae.parameters(), lr=CONFIG['vae_lr'])
gp_optimizer = optim.Adam(gp_params, lr=CONFIG['gp_lr'])
print(f"\n‚úÖ Optimizers created:")
print(f"   VAE optimizer: Adam(lr={CONFIG['vae_lr']})")
print(f"   GP optimizer: Adam(lr={CONFIG['gp_lr']})")

VAE config: {'nf': 32, 'zdim': 256, 'vy': 0.002}

Loading pre-trained VAE...
‚úÖ VAE loaded from ./out/vae_colab/20251224_120136/weights/weights.00140.pt
   Total VAE parameters: 553,304

Loading dataset...
‚úÖ Data loaded:
   Training samples: 3868
   Validation samples: 484

Initializing GP-VAE components...
   Objects (people): 542
   Views (angles): 9
‚úÖ GP-VAE components initialized:
   Vmodel parameters: 34,769
   GP parameters: 2
   Total trainable: 588,075

‚úÖ Optimizers created:
   VAE optimizer: Adam(lr=0.001)
   GP optimizer: Adam(lr=0.01)


## 11. Define Training Functions

These functions handle the complex GP-VAE training loop.

In [18]:
def encode_Y(vae, train_queue):
    """Encode all training images to get latent codes"""
    vae.eval()

    with torch.no_grad():
        n = train_queue.dataset.Y.shape[0]
        Zm = Variable(torch.zeros(n, vae_cfg["zdim"]), requires_grad=False).cuda()
        Zs = Variable(torch.zeros(n, vae_cfg["zdim"]), requires_grad=False).cuda()

        for batch_i, data in enumerate(train_queue):
            y = data[0].cuda()
            idxs = data[-1].cuda()
            zm, zs = vae.encode(y)
            Zm[idxs], Zs[idxs] = zm.detach(), zs.detach()

    return Zm, Zs


def eval_step(vae, gp, vm, val_queue, Zm, Vt, Vv):
    """Evaluate model on validation set"""
    rv = {}

    with torch.no_grad():
        _X = vm.x().data.cpu().numpy()
        _W = vm.v().data.cpu().numpy()
        covs = {"XX": np.dot(_X, _X.T), "WW": np.dot(_W, _W.T)}
        rv["vars"] = gp.get_vs().data.cpu().numpy()

        # Out-of-sample prediction
        vs = gp.get_vs()
        U, UBi, _ = gp.U_UBi_Shb([Vt], vs)
        Kiz = gp.solve(Zm, U, UBi, vs)
        Zo = vs[0] * Vv.mm(Vt.transpose(0, 1).mm(Kiz))

        mse_out = Variable(torch.zeros(Vv.shape[0], 1), requires_grad=False).cuda()
        mse_val = Variable(torch.zeros(Vv.shape[0], 1), requires_grad=False).cuda()

        for batch_i, data in enumerate(val_queue):
            idxs = data[-1].cuda()
            Yv = data[0].cuda()
            Zv = vae.encode(Yv)[0].detach()
            Yr = vae.decode(Zv)
            Yo = vae.decode(Zo[idxs])
            mse_out[idxs] = ((Yv - Yo) ** 2).view(Yv.shape[0], -1).mean(1)[:, None].detach()
            mse_val[idxs] = ((Yv - Yr) ** 2).view(Yv.shape[0], -1).mean(1)[:, None].detach()

            # Store examples for visualization
            if batch_i == 0:
                imgs = {}
                imgs["Yv"] = Yv[:24].data.cpu().numpy().transpose(0, 2, 3, 1)
                imgs["Yr"] = Yr[:24].data.cpu().numpy().transpose(0, 2, 3, 1)
                imgs["Yo"] = Yo[:24].data.cpu().numpy().transpose(0, 2, 3, 1)

        rv["mse_out"] = float(mse_out.data.mean().cpu())
        rv["mse_val"] = float(mse_val.data.mean().cpu())

    return rv, imgs, covs


def backprop_and_update(vae, gp, vm, train_queue, Dt, Wt, Eps, Zb, Vbs, vbs, vae_optimizer, gp_optimizer):
    """Joint optimization of VAE and GP"""
    rv = {}

    vae_optimizer.zero_grad()
    gp_optimizer.zero_grad()
    vae.train()
    gp.train()
    vm.train()

    for batch_i, data in enumerate(train_queue):
        # Get batch data
        y = data[0].cuda()
        eps = Eps[data[-1]]
        _d = Dt[data[-1]]
        _w = Wt[data[-1]]
        _Zb = Zb[data[-1]]
        _Vbs = [Vbs[0][data[-1]]]

        # Forward through VAE
        zm, zs = vae.encode(y)
        z = zm + zs * eps
        yr = vae.decode(z)
        recon_term, mse = vae.nll(y, yr)

        # Forward through GP
        _Vs = [vm(_d, _w)]
        gp_nll_fo = gp.taylor_expansion(z, _Vs, _Zb, _Vbs, vbs) / vae.K

        # Penalization term
        pen_term = -0.5 * zs.sum(1)[:, None] / vae.K

        # Joint loss and backward
        loss = (recon_term + gp_nll_fo + pen_term).sum()
        loss.backward()

        # Accumulate metrics
        _n = train_queue.dataset.Y.shape[0]
        smartSum(rv, "mse", float(mse.data.sum().cpu()) / _n)
        smartSum(rv, "recon_term", float(recon_term.data.sum().cpu()) / _n)
        smartSum(rv, "pen_term", float(pen_term.data.sum().cpu()) / _n)

    # Update both optimizers
    vae_optimizer.step()
    gp_optimizer.step()

    return rv

print("‚úÖ Training functions defined")

‚úÖ Training functions defined


## 12. Train GP-VAE Model üöÄ

**This is joint optimization!** Both VAE and GP are updated together each iteration.

Training process per epoch:
1. Encode images to latent codes (VAE)
2. Compute GP prior likelihood on latents
3. Backpropagate through joint loss
4. Update VAE, GP, and Vmodel simultaneously

In [19]:
import time
from IPython.display import clear_output

history = {}
start_time = time.time()

print(f"üöÄ Starting GP-VAE training for {CONFIG['epochs']} epochs...")
print("=" * 80)
print("Training mode: JOINT OPTIMIZATION (VAE + GP updated together)")
print("=" * 80)

for epoch in range(CONFIG['epochs']):
    epoch_start = time.time()

    # 1. Encode all training images
    Zm, Zs = encode_Y(vae, train_queue)

    # 2. Sample latent codes
    Eps = Variable(torch.randn(*Zs.shape), requires_grad=False).cuda()
    Z = Zm + Eps * Zs

    # 3. Compute variance matrices
    Vt = vm(Dt, Wt).detach()
    Vv = vm(Dv, Wv).detach()

    # 4. Evaluate on validation set
    rv_eval, imgs, covs = eval_step(vae, gp, vm, val_queue, Zm, Vt, Vv)

    # 5. Compute GP Taylor expansion coefficients
    Zb, Vbs, vbs, gp_nll = gp.taylor_coeff(Z, [Vt])
    rv_eval["gp_nll"] = float(gp_nll.data.mean().cpu()) / vae.K

    # 6. Joint training step (VAE + GP)
    rv_back = backprop_and_update(
        vae, gp, vm, train_queue, Dt, Wt, Eps,
        Zb, Vbs, vbs, vae_optimizer, gp_optimizer
    )
    rv_back["loss"] = rv_back["recon_term"] + rv_eval["gp_nll"] + rv_back["pen_term"]

    # Store history
    smartAppendDict(history, rv_eval)
    smartAppendDict(history, rv_back)
    smartAppend(history, "vs", gp.get_vs().data.cpu().numpy())

    epoch_time = time.time() - epoch_start
    total_time = time.time() - start_time

    # Print progress
    if epoch % 5 == 0 or epoch == CONFIG['epochs'] - 1:
        print(f"Epoch {epoch:4d}/{CONFIG['epochs']} | "
              f"MSE val: {rv_eval['mse_val']:.6f} | "
              f"MSE out: {rv_eval['mse_out']:.6f} | "
              f"GP NLL: {rv_eval['gp_nll']:.4f} | "
              f"Loss: {rv_back['loss']:.4f} | "
              f"Time: {epoch_time:.1f}s")

    # Log to W&B
    if CONFIG['use_wandb']:
        wandb.log({
            "epoch": epoch,
            "mse_val": rv_eval["mse_val"],
            "mse_out": rv_eval["mse_out"],
            "gp_nll": rv_eval["gp_nll"],
            "recon_term": rv_back["recon_term"],
            "pen_term": rv_back["pen_term"],
            "loss": rv_back["loss"],
            "vars": rv_eval["vars"],
            "time/epoch_seconds": epoch_time,
        })

    # Save checkpoint
    if epoch % CONFIG['epoch_cb'] == 0 or epoch == CONFIG['epochs'] - 1:
        logging.info(f"Epoch {epoch} - saving checkpoint")

        # Save VAE weights
        vae_file = os.path.join(wdir, f"vae_weights.{epoch:05d}.pt")
        torch.save(vae.state_dict(), vae_file)

        # Save GP weights
        gp_file = os.path.join(wdir, f"gp_weights.{epoch:05d}.pt")
        torch.save({
            'gp_state': gp.state_dict(),
            'vm_state': vm.state_dict(),
            'gp_params': gp_params.state_dict(),
        }, gp_file)

        # Save visualization
        ffile = os.path.join(fdir, f"plot.{epoch:05d}.png")
        callback_gppvae(epoch, history, covs, imgs, ffile)

        if CONFIG['use_wandb']:
            wandb.log({
                "reconstructions": wandb.Image(ffile),
                "covariances/XX": wandb.Image(ffile),  # Uses same plot
            })

        print(f"  ‚úì Checkpoint saved at epoch {epoch}")

total_time = time.time() - start_time
print("\n" + "=" * 80)
print(f"‚úÖ GP-VAE training complete!")
print(f"   Total time: {total_time/60:.1f} minutes ({total_time/3600:.2f} hours)")
print(f"   Average time per epoch: {total_time/CONFIG['epochs']:.1f} seconds")
print(f"   Final validation MSE: {rv_eval['mse_val']:.6f}")
print(f"   Final out-of-sample MSE: {rv_eval['mse_out']:.6f}")
print(f"   Final GP NLL: {rv_eval['gp_nll']:.4f}")

if CONFIG['use_wandb']:
    wandb.finish()
    print("\nüîó View detailed results in W&B dashboard")

üöÄ Starting GP-VAE training for 50 epochs...
Training mode: JOINT OPTIMIZATION (VAE + GP updated together)
Epoch    0/50 | MSE val: 0.004238 | MSE out: 0.068923 | GP NLL: 0.0020 | Loss: -2.1612 | Time: 6.1s
  ‚úì Checkpoint saved at epoch 0
Epoch    5/50 | MSE val: 0.031072 | MSE out: 0.050269 | GP NLL: -0.0001 | Loss: 4.6122 | Time: 6.2s
Epoch   10/50 | MSE val: 0.015151 | MSE out: 0.037485 | GP NLL: -0.0001 | Loss: 0.5970 | Time: 6.0s
  ‚úì Checkpoint saved at epoch 10
Epoch   15/50 | MSE val: 0.011732 | MSE out: 0.034518 | GP NLL: -0.0000 | Loss: -0.1987 | Time: 6.1s
Epoch   20/50 | MSE val: 0.009168 | MSE out: 0.031358 | GP NLL: -0.0003 | Loss: -0.8755 | Time: 6.1s
  ‚úì Checkpoint saved at epoch 20
Epoch   25/50 | MSE val: 0.007705 | MSE out: 0.029252 | GP NLL: -0.0003 | Loss: -1.2451 | Time: 6.1s
Epoch   30/50 | MSE val: 0.006775 | MSE out: 0.029119 | GP NLL: -0.0003 | Loss: -1.4836 | Time: 6.0s
  ‚úì Checkpoint saved at epoch 30
Epoch   35/50 | MSE val: 0.006400 | MSE out: 0.0

0,1
epoch,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
gp_nll,‚ñÖ‚ñà‚ñÖ‚ñÑ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
loss,‚ñÅ‚ñà‚ñÇ‚ñÇ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
mse_out,‚ñà‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
mse_val,‚ñÅ‚ñà‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
pen_term,‚ñá‚ñà‚ñá‚ñÖ‚ñÅ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
recon_term,‚ñÅ‚ñà‚ñÇ‚ñÇ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
time/epoch_seconds,‚ñÑ‚ñÉ‚ñÖ‚ñÜ‚ñá‚ñÑ‚ñà‚ñÖ‚ñÇ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÜ‚ñÇ‚ñÉ‚ñÖ‚ñÉ‚ñÑ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÑ‚ñÅ‚ñÅ‚ñÑ‚ñÉ‚ñÉ‚ñÖ‚ñÜ‚ñÇ‚ñÑ‚ñÇ‚ñÖ‚ñÑ‚ñÉ‚ñÇ‚ñÜ‚ñÑ

0,1
epoch,49.0
gp_nll,-0.00065
loss,-1.87432
mse_out,0.02801
mse_val,0.00529
pen_term,-2e-05
recon_term,-1.87364
time/epoch_seconds,6.10906



üîó View detailed results in W&B dashboard


## 13. Download Results

Download the trained model and visualizations to your computer:

In [None]:
# Compress output folder
output_zip = '/content/gppvae_output.zip'
!zip -r {output_zip} {CONFIG['outdir']}

# Download
from google.colab import files
print("Preparing download...")
files.download(output_zip)
print("\n‚úÖ Download started! Extract the zip on your local machine.")
print(f"\nContents include:")
print(f"  - Trained VAE weights (fine-tuned)")
print(f"  - GP + Vmodel weights")
print(f"  - Visualization plots")
print(f"  - Training logs")

## 14. Visualize Results

View the latest reconstruction and covariance plots:

In [None]:
from IPython.display import Image, display
import glob

# Get latest plot
plot_files = sorted(glob.glob(os.path.join(fdir, "*.png")))
if plot_files:
    latest_plot = plot_files[-1]
    print(f"Latest visualization: {latest_plot}")
    display(Image(filename=latest_plot))
else:
    print("No plots generated yet")

## 15. Analyze Learned Structure

Examine what the GP-VAE learned:

In [None]:
# Get learned variance components
vs = gp.get_vs().data.cpu().numpy()
print("Learned variance components:")
print(f"  Object variance (people): {vs[0]:.4f} ({vs[0]*100:.1f}%)")
print(f"  Noise variance: {vs[1]:.4f} ({vs[1]*100:.1f}%)")
print(f"\nInterpretation:")
print(f"  {vs[0]*100:.1f}% of latent variation explained by object identity")
print(f"  {vs[1]*100:.1f}% unexplained (noise + view + other factors)")

# Get object and view embeddings
X_embed = vm.x().data.cpu().numpy()
V_embed = vm.v().data.cpu().numpy()
print(f"\nLearned embeddings:")
print(f"  Object embeddings shape: {X_embed.shape}")
print(f"  View embeddings shape: {V_embed.shape}")

# Compute correlation structures
XX = np.dot(X_embed, X_embed.T)
VV = np.dot(V_embed, V_embed.T)
print(f"\nCovariance matrices:")
print(f"  Object-object correlation range: [{XX.min():.3f}, {XX.max():.3f}]")
print(f"  View-view correlation range: [{VV.min():.3f}, {VV.max():.3f}]")

## 16. Compare with VAE-only Model

Compare GP-VAE's out-of-sample prediction with standard VAE reconstruction:

In [None]:
print("Performance Summary:")
print("=" * 60)
print(f"VAE reconstruction MSE:        {rv_eval['mse_val']:.6f}")
print(f"GP-VAE out-of-sample MSE:      {rv_eval['mse_out']:.6f}")
print(f"Difference:                     {rv_eval['mse_out'] - rv_eval['mse_val']:.6f}")
print("=" * 60)

if rv_eval['mse_out'] < rv_eval['mse_val'] * 1.1:
    print("‚úÖ Excellent! GP-VAE predicts unseen views almost as well as VAE reconstructs")
    print("   This means the GP successfully learned view-independent representations")
elif rv_eval['mse_out'] < rv_eval['mse_val'] * 1.5:
    print("‚úì Good! GP-VAE can predict unseen views reasonably well")
    print("  Consider training longer for better results")
else:
    print("‚ö†Ô∏è GP-VAE out-of-sample prediction is significantly worse")
    print("  Try:")
    print("  - Training for more epochs")
    print("  - Adjusting learning rates")
    print("  - Increasing xdim (covariance rank)")

## 17. Next Steps & Tips

### Understanding the Results:
- **mse_val**: How well VAE reconstructs images (baseline)
- **mse_out**: How well GP-VAE predicts **unseen views** of known objects
- **gp_nll**: GP prior likelihood (lower is better)
- **vars**: Variance decomposition (object vs noise)

### To improve results:
1. **Train longer**: Try 1000+ epochs for publication quality
2. **Adjust xdim**: Increase from 64 to 128 for more expressive covariances
3. **Tune learning rates**: Lower vae_lr if fine-tuning too aggressive
4. **Better VAE**: Train VAE for more epochs before GP-VAE

### What you've learned:
‚úÖ GP-VAE enables **structured latent representations**  
‚úÖ Can predict **new viewpoints** of known objects  
‚úÖ Learns **disentangled** object identity vs view factors  
‚úÖ Joint optimization of VAE + GP works!  

### Performance vs Local:
- **Colab GPU**: ~1-2 hours for 100 epochs
- **M1 Pro CPU**: Would take 20-50 hours!
- **Speedup**: 20-50x faster on Colab üöÄ

Congratulations on training a GP-VAE! üéâ