# GP-VAE Training on Google Colab

This notebook trains the **GP-VAE (Gaussian Process Variational Autoencoder)** model using Google Colab's free GPU.

## What is GP-VAE?
GP-VAE adds a **Gaussian Process prior** to the VAE latent space to model structured correlations:
- **VAE**: Learns image ‚Üî latent code mapping
- **GP Prior**: Models correlations between latent codes based on:
  - Object identity (same person's face)
  - View angle (front, side, profile)
  - Other factors of variation

## Prerequisites ‚ö†Ô∏è
**You MUST have trained VAE weights first!** This model loads pre-trained VAE and fine-tunes it jointly with the GP.

Required files:
- ‚úÖ `out/vae_colab/YYYYMMDD_HHMMSS/vae.cfg.p` - VAE configuration
- ‚úÖ `out/vae_colab/YYYYMMDD_HHMMSS/weights/weights.00000.pt` - Trained VAE weights

## Output Directory Structure:

Each training run creates a **timestamped directory** to avoid overwriting previous runs:
- Format: `./out/gppvae_colab/YYYYMMDD_HHMMSS/`
- Example: `./out/gppvae_colab/20251224_143530/weights/weights.00100.pt`
- This allows you to compare different training runs and keep a history!

Cell 6 below will automatically find your latest VAE training run.

## Setup Instructions:

1. **Open this notebook in VS Code**
2. **Connect to Colab**: Click kernel picker ‚Üí "Connect to Colab" ‚Üí Choose **GPU runtime (T4)**
3. **Important**: When prompted with "Alias your server", press Enter
4. **Run cell 2** - it will automatically detect your project location


## 1. Check GPU Availability

In [None]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è WARNING: GPU not detected! Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

PyTorch version: 2.9.0+cu126
CUDA available: True
GPU Device: NVIDIA A100-SXM4-40GB
GPU Memory: 42.47 GB


## 2. Auto-Detect Project Path

This automatically finds your project files on the Colab runtime.

In [18]:
import os
import sys

# Get current directory
current_dir = os.getcwd()
print(f"üìç Current directory: {current_dir}")

# Check if on Colab and need to mount Drive
if current_dir == '/content':
    print("\nüîÑ Mounting Google Drive...")

    try:
        from google.colab import drive
        drive.mount('/content/drive')

        # Check for project in Drive
        drive_path = '/content/drive/MyDrive/gppvae'
        if os.path.exists(drive_path):
            PROJECT_PATH = drive_path
            print(f"‚úÖ Found project in Google Drive: {PROJECT_PATH}")
        else:
            print(f"\n‚ö†Ô∏è  Project not found at: {drive_path}")
            print("\nPlease upload your gppvae folder to Google Drive!")
            print("Required structure:")
            print("  MyDrive/gppvae/")
            print("    ‚îú‚îÄ‚îÄ GPPVAE/")
            print("    ‚îú‚îÄ‚îÄ data/faceplace/data_faces.h5")
            print("    ‚îî‚îÄ‚îÄ out/vae_colab/YYYYMMDD_HHMMSS/")
            print("        ‚îú‚îÄ‚îÄ vae.cfg.p")
            print("        ‚îî‚îÄ‚îÄ weights/weights.00000.pt")
            PROJECT_PATH = '/content'
    except Exception as e:
        print(f"Could not mount Drive: {e}")
        PROJECT_PATH = '/content'
else:
    # Running via VS Code sync
    if 'notebooks' in current_dir:
        PROJECT_PATH = os.path.dirname(current_dir)
    else:
        PROJECT_PATH = current_dir
    print(f"üíª Using project path: {PROJECT_PATH}")

# Verify structure
print(f"\nüìÅ Contents of {PROJECT_PATH}:")
if os.path.exists(PROJECT_PATH):
    items = os.listdir(PROJECT_PATH)
    for item in sorted(items)[:15]:
        item_path = os.path.join(PROJECT_PATH, item)
        if os.path.isdir(item_path):
            print(f"   üìÇ {item}/")
        else:
            print(f"   üìÑ {item}")

    # Check required files (with timestamped directory structure)
    print(f"\nüîç Checking required files:")
    required = {
        'GPPVAE code': os.path.exists(os.path.join(PROJECT_PATH, 'GPPVAE')),
        'data/faceplace': os.path.exists(os.path.join(PROJECT_PATH, 'data/faceplace')),
        'data_faces.h5': os.path.exists(os.path.join(PROJECT_PATH, 'data/faceplace/data_faces.h5')),
    }

    # Check for VAE runs (timestamped subdirectories)
    vae_base_dir = os.path.join(PROJECT_PATH, 'out/vae_colab')
    vae_run_found = False
    vae_weights_found = False

    if os.path.exists(vae_base_dir):
        # Look for timestamped subdirectories
        potential_runs = [d for d in os.listdir(vae_base_dir)
                         if os.path.isdir(os.path.join(vae_base_dir, d)) and d[0].isdigit()]

        for run_dir in potential_runs:
            run_path = os.path.join(vae_base_dir, run_dir)
            cfg_path = os.path.join(run_path, 'vae.cfg.p')
            weights_dir = os.path.join(run_path, 'weights')

            if os.path.exists(cfg_path):
                vae_run_found = True

            if os.path.exists(weights_dir):
                weight_files = [f for f in os.listdir(weights_dir) if f.endswith('.pt')]
                if weight_files:
                    vae_weights_found = True
                    break

    required['VAE config'] = vae_run_found
    required['VAE weights'] = vae_weights_found

    for name, exists in required.items():
        status = "‚úÖ" if exists else "‚ùå"
        print(f"   {status} {name}")

    # Show VAE runs if they exist
    if os.path.exists(vae_base_dir):
        potential_runs = sorted([d for d in os.listdir(vae_base_dir)
                                if os.path.isdir(os.path.join(vae_base_dir, d)) and d[0].isdigit()],
                               reverse=True)

        if potential_runs:
            print(f"\nüì¶ Found {len(potential_runs)} VAE training run(s):")
            for i, run_dir in enumerate(potential_runs[:3], 1):  # Show latest 3
                run_path = os.path.join(vae_base_dir, run_dir)
                weights_dir = os.path.join(run_path, 'weights')

                if os.path.exists(weights_dir):
                    weight_files = sorted([f for f in os.listdir(weights_dir) if f.endswith('.pt')])
                    print(f"   {i}. {run_dir}/ ({len(weight_files)} checkpoints)")
                    if weight_files:
                        print(f"      Latest: {weight_files[-1]}")

            if len(potential_runs) > 3:
                print(f"   ... and {len(potential_runs) - 3} more")

            print(f"\nüí° Cell 6 below will help you choose which run to use")

    if not all(required.values()):
        print(f"\n‚ö†Ô∏è  Missing required files!")
        if not required['VAE weights']:
            print("\nüö® CRITICAL: No trained VAE weights found!")
            print("   You must train VAE first before running GP-VAE")
            print("   Use the train_vae_colab.ipynb notebook")
else:
    print(f"‚ùå Path doesn't exist: {PROJECT_PATH}")


üìç Current directory: /content/drive/MyDrive/gppvae
üíª Using project path: /content/drive/MyDrive/gppvae

üìÅ Contents of /content/drive/MyDrive/gppvae:
   üìÇ GPPVAE/
   üìÇ data/
   üìÑ environment.yml
   üìÇ notebooks/
   üìÇ out/

üîç Checking required files:
   ‚úÖ GPPVAE code
   ‚úÖ data/faceplace
   ‚úÖ data_faces.h5
   ‚úÖ VAE config
   ‚úÖ VAE weights

üì¶ Found 3 VAE training run(s):
   1. 20251224_171841/ (11 checkpoints)
      Latest: weights.00099.pt
   2. 20251224_171753/ (0 checkpoints)
   3. 20251224_120136/ (16 checkpoints)
      Latest: weights.00140.pt

üí° Cell 6 below will help you choose which run to use


## 3. Install Dependencies

In [19]:
# Install required packages
!pip install -q wandb==0.12.21 imageio==2.15.0 pyyaml

# Verify installations
import wandb
import imageio
import yaml
import numpy as np
print("‚úÖ All dependencies installed successfully!")

  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m√ó[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m‚îÇ[0m exit code: [1;36m1[0m
  [31m‚ï∞‚îÄ>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Preparing metadata (setup.py) ... [?25l[?25herror
[1;31merror[0m: [1mmetadata-generation-failed[0m

[31m√ó[0m Encountered error while generating package metadata.
[31m‚ï∞‚îÄ>[0m See above for output.

[1;35mnote[0m: This is an issue with the package mentioned above, not pip.
[1;36mhint[0m: See above for details.
‚úÖ All dependencies installed successfully!


## 4. Login to Weights & Biases (Optional)

Track your experiments with W&B for better monitoring.

In [20]:
import wandb
wandb.login()

# Or run offline without W&B:
# import os
# os.environ['WANDB_MODE'] = 'offline'



False

## 5. Navigate to Project Directory

In [21]:
import os
import sys

os.chdir(PROJECT_PATH)
print(f"Current directory: {os.getcwd()}")

# Add to Python path
sys.path.insert(0, os.path.join(PROJECT_PATH, 'GPPVAE/pysrc/faceplace'))

print("\nProject structure:")
!ls -la

Current directory: /content/drive/MyDrive/gppvae

Project structure:
total 17
drwx------ 3 root root 4096 Dec 23 14:09 data
-rw------- 1 root root  258 Dec 23 11:40 environment.yml
drwx------ 3 root root 4096 Dec 23 14:09 GPPVAE
drwx------ 2 root root 4096 Dec 23 14:09 notebooks
drwx------ 5 root root 4096 Dec 23 14:21 out


## 6. Verify VAE Weights

**Critical check:** Make sure you have trained VAE weights!

In [22]:
import os
import pickle
import glob

# Check for VAE runs (may be in timestamped subdirectories)
vae_base_dir = './out/vae_colab_interpolation'
vae_runs = []

if os.path.exists(vae_base_dir):
    # Look for timestamped subdirectories
    potential_runs = [d for d in os.listdir(vae_base_dir) if os.path.isdir(os.path.join(vae_base_dir, d))]
    for run_dir in sorted(potential_runs, reverse=True):  # Most recent first
        run_path = os.path.join(vae_base_dir, run_dir)
        cfg_path = os.path.join(run_path, 'vae.cfg.p')
        weights_dir = os.path.join(run_path, 'weights')

        if os.path.exists(cfg_path) and os.path.exists(weights_dir):
            weight_files = sorted([f for f in os.listdir(weights_dir) if f.endswith('.pt')])
            if weight_files:
                vae_runs.append({
                    'run_dir': run_dir,
                    'cfg_path': cfg_path,
                    'weights_dir': weights_dir,
                    'weight_files': weight_files
                })

if vae_runs:
    print(f"‚úÖ Found {len(vae_runs)} VAE training run(s):\n")

    for i, run in enumerate(vae_runs, 1):
        print(f"Run {i}: {run['run_dir']}")

        # Load and show config
        vae_cfg = pickle.load(open(run['cfg_path'], 'rb'))
        print(f"   Config: zdim={vae_cfg.get('zdim', 'N/A')}, nf={vae_cfg.get('nf', 'N/A')}")

        # Show checkpoints
        print(f"   Checkpoints: {len(run['weight_files'])} files")
        if len(run['weight_files']) <= 3:
            for wf in run['weight_files']:
                print(f"      üì¶ {wf}")
        else:
            print(f"      üì¶ {run['weight_files'][0]} ... {run['weight_files'][-1]}")
        print()

    # Recommendation
    latest_run = vae_runs[0]
    latest_weight = latest_run['weight_files'][-1]
    recommended_path = os.path.join(latest_run['weights_dir'], latest_weight)

    print(f"üí° Recommendation:")
    print(f"   Use latest run: {latest_run['run_dir']}")
    print(f"   Latest checkpoint: {latest_weight}")
    print(f"   \n   Set in next cell:")
    print(f"   CONFIG['vae_cfg'] = '{latest_run['cfg_path']}'")
    print(f"   CONFIG['vae_weights'] = '{recommended_path}'")

else:
    print("‚ùå No trained VAE runs found!")
    print("\n   Please train VAE first using train_vae_colab.ipynb")
    print(f"   Expected location: {vae_base_dir}/YYYYMMDD_HHMMSS/")


‚úÖ Found 4 VAE training run(s):

Run 1: 20260103_192743
   Config: zdim=256, nf=32
   Checkpoints: 11 files
      üì¶ weights.00000.pt ... weights.00499.pt

Run 2: 20260103_191647
   Config: zdim=256, nf=32
   Checkpoints: 2 files
      üì¶ weights.00000.pt
      üì¶ weights.00050.pt

Run 3: 20260103_190918
   Config: zdim=256, nf=32
   Checkpoints: 2 files
      üì¶ weights.00000.pt
      üì¶ weights.00050.pt

Run 4: 20260103_184756
   Config: zdim=256, nf=32
   Checkpoints: 5 files
      üì¶ weights.00000.pt ... weights.00200.pt

üí° Recommendation:
   Use latest run: 20260103_192743
   Latest checkpoint: weights.00499.pt
   
   Set in next cell:
   CONFIG['vae_cfg'] = './out/vae_colab_interpolation/20260103_192743/vae.cfg.p'
   CONFIG['vae_weights'] = './out/vae_colab_interpolation/20260103_192743/weights/weights.00499.pt'


## 8. Configure GP-VAE Training

Adjust these parameters as needed:

## 7. Choose View Kernel üî¨

**NEW: Kernel Selection for View Correlations**

The view kernel models how correlations between face angles (0¬∞, 15¬∞, 30¬∞, ..., 90¬∞) are structured.

### Available Kernels:

1. **`'legacy'`** - Original implementation (normalized embeddings, 81 params)
   - Most flexible but can overfit
   - Good baseline for comparison

2. **`'fullrank'`** - Direct full-rank covariance (45 params)
   - Flexible but still many parameters
   - Better than legacy due to fewer constraints

3. **`'periodic'`** ‚≠ê **RECOMMENDED** - Periodic kernel (1 param: lengthscale)
   - Knows that 0¬∞ = 360¬∞ (periodicity!)
   - Smooth correlations between nearby angles
   - Massive regularization (only 1 parameter)
   - Best for rotation data

4. **`'vonmises'`** ‚≠ê **RECOMMENDED** - Von Mises kernel (1 param: kappa)
   - Designed specifically for circular/angular data
   - Similar to Periodic but different parameterization
   - Also best for rotation data

5. **`'matern'`** - Mat√©rn kernel (1 param: lengthscale)
   - More realistic than RBF, less smooth
   - Good for modeling realistic correlations
   - Can choose smoothness: nu=1.5 or nu=2.5

6. **`'linear'`** - Low-rank linear (rank√ó9 params)
   - Original GP-VAE kernel from Casale et al. (2018)
   - Good middle-ground

7. **`'rbf'`** - RBF/Gaussian (1 param: lengthscale)
   - Smooth but NOT periodic
   - Use only if views don't wrap around

8. **`'spectral_mixture'`** ‚≠ê **NEW** - Spectral Mixture kernel (3√ó3 params)
   - Learns mixture of frequencies in the spectral domain
   - Very flexible - can model periodic AND non-periodic patterns
   - Each component has: weight, mean frequency, lengthscale
   - Good for complex correlation structures
   - Requires continuous angle encoding

### Expected Performance:

| Metric | Legacy | FullRank | Periodic | VonMises | Mat√©rn | Spectral |
|--------|--------|----------|----------|----------|--------|----------|
| Val MSE | Medium | Medium | **Best** | **Best** | Good | **Excellent** |
| Out-of-sample | Worst | Bad | **Best** | **Best** | Good | **Excellent** |
| Overfitting | High | Medium | Low | Low | Low | Medium |
| Parameters | 81 | 45 | 1 | 1 | 1 | 9 (3 comp) |
| Smoothness | - | - | Very smooth | Very smooth | Adjustable | Very flexible |

**Recommendation**:
- **Best for rotations**: `'periodic'` or `'vonmises'`
- **More realistic**: `'matern'` (less smooth than periodic)
- **Most flexible**: `'spectral_mixture'` (can learn complex patterns)

In [23]:
# ============================================================================
# VIEW SPLIT CONFIGURATION - For Interpolation Experiment
# ============================================================================

# Experiment mode
VIEW_SPLIT_MODE = 'interpolation'  # 'random' or 'by_view'

# View angle mapping (after angular ordering fix):
# Index 0: 90L (-90¬∞), 1: 60L (-60¬∞), 2: 45L (-45¬∞), 3: 30L (-30¬∞), 4: 00F (0¬∞),
# Index 5: 30R (+30¬∞), 6: 45R (+45¬∞), 7: 60R (+60¬∞), 8: 90R (+90¬∞)

if VIEW_SPLIT_MODE == 'interpolation':
    # EXPERIMENT 1 (Interpolation): Train on boundaries, test on intermediate views
    TRAIN_VIEW_INDICES = [0, 1, 3, 4, 5, 7, 8]  # 90L, 60L, 30L, 00F, 30R, 60R, 90R (boundaries)
    VAL_VIEW_INDICES = [2, 6]  # 45L, 45R (intermediate angles)

    print("üî¨ EXPERIMENT MODE: Interpolation (Train boundaries, test intermediate)")
    print("=" * 60)
    print("Training views (boundaries):")
    print("  Index 0: 90L (-90¬∞)")
    print("  Index 1: 60L (-60¬∞)")
    print("  Index 3: 30L (-30¬∞)")
    print("  Index 4: 00F (  0¬∞)")
    print("  Index 5: 30R (+30¬∞)")
    print("  Index 7: 60R (+60¬∞)")
    print("  Index 8: 90R (+90¬∞)")
    print("\nValidation views (intermediate):")
    print("  Index 2: 45L (-45¬∞)")
    print("  Index 6: 45R (+45¬∞)")
    print("=" * 60)
    print("\nüí° Research Question:")
    print("   Do structured kernels improve interpolation performance?")
    print("   Expected: Periodic/VonMises/Mat√©rn > FullRank > Legacy")
else:
    TRAIN_VIEW_INDICES = None
    VAL_VIEW_INDICES = None
    print("üìä Standard Mode: Random 90/10 train/val split")

üî¨ EXPERIMENT MODE: Interpolation (Train boundaries, test intermediate)
Training views (boundaries):
  Index 0: 90L (-90¬∞)
  Index 1: 60L (-60¬∞)
  Index 3: 30L (-30¬∞)
  Index 4: 00F (  0¬∞)
  Index 5: 30R (+30¬∞)
  Index 7: 60R (+60¬∞)
  Index 8: 90R (+90¬∞)

Validation views (intermediate):
  Index 2: 45L (-45¬∞)
  Index 6: 45R (+45¬∞)

üí° Research Question:
   Do structured kernels improve interpolation performance?
   Expected: Periodic/VonMises/Mat√©rn > FullRank > Legacy


In [None]:
# ============================================================================
# KERNEL CONFIGURATION - Choose one option below
# ============================================================================

# Option 1: Periodic kernel (RECOMMENDED for face rotations)
# KERNEL_CONFIG = {
#     'view_kernel': 'periodic',
#     'kernel_kwargs': {'lengthscale': 1.0}
# }

# Option 2: Von Mises kernel (RECOMMENDED alternative)
# KERNEL_CONFIG = {
#     'view_kernel': 'vonmises',
#     'kernel_kwargs': {'kappa': 1.0}
# }

# Option 3: Mat√©rn kernel (realistic, less smooth than periodic)
# KERNEL_CONFIG = {
#     'view_kernel': 'matern',
#     'kernel_kwargs': {'lengthscale': 1.0, 'nu': 1.5}  # nu=1.5 or nu=2.5
# }

# Option 4: Legacy (original implementation - baseline)
# KERNEL_CONFIG = {
#     'view_kernel': 'legacy',
#     'kernel_kwargs': {}
# }

# Option 5: Full Rank (flexible, 45 params)
# KERNEL_CONFIG = {
#     'view_kernel': 'fullrank',
#     'kernel_kwargs': {}
# }

# Option 6: Linear low-rank (original GP-VAE paper)
# KERNEL_CONFIG = {
#     'view_kernel': 'linear',
#     'kernel_kwargs': {'rank': 3}
# }

# Option 7: RBF (smooth but not periodic)
# KERNEL_CONFIG = {
#     'view_kernel': 'rbf',
#     'kernel_kwargs': {'lengthscale': 1.0, 'angle_scale': 'normalized'}
# }

# Option 8: Spectral Mixture (flexible frequency-domain kernel)
KERNEL_CONFIG = {
    'view_kernel': 'spectral_mixture',
    'kernel_kwargs': {'n_components': 2, 'angle_scale': 'normalized'}
}

print("Selected Kernel Configuration:")
print("=" * 60)
print(f"Kernel type: {KERNEL_CONFIG['view_kernel']}")
if KERNEL_CONFIG['kernel_kwargs']:
    print(f"Parameters: {KERNEL_CONFIG['kernel_kwargs']}")
else:
    print("Parameters: (default)")
print("=" * 60)

Selected Kernel Configuration:
Kernel type: rbf
Parameters: {'lengthscale': 1.0, 'angle_scale': 'normalized'}


In [None]:
from datetime import datetime

# GP-VAE Training configuration
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
kernel_name = KERNEL_CONFIG['view_kernel']

# Include view split mode in directory name - "interpolation" for this experiment
view_mode_str = 'interpolation' if VIEW_SPLIT_MODE == 'by_view' else 'random'

CONFIG = {
    'data': './data/faceplace/data_faces.h5',
    # Output directory now includes kernel name AND experiment type
    'outdir': f'./out/gppvae_colab/{kernel_name}_{view_mode_str}_{timestamp}',
    'vae_cfg': './out/vae_colab_interpolation/20260103_192743/vae.cfg.p',
    'vae_weights': './out/vae_colab_interpolation/20260103_192743/weights/weights.00499.pt',

    # Training hyperparameters
    'epochs': 500,
    'batch_size': 64,
    'vae_lr': 0.001,
    'gp_lr': 0.001,
    'xdim': 64,

    # Kernel configuration
    'view_kernel': KERNEL_CONFIG['view_kernel'],
    'kernel_kwargs': KERNEL_CONFIG['kernel_kwargs'],
    
    # Angle encoding (will be determined automatically based on kernel type)
    'use_angle_encoding': KERNEL_CONFIG['view_kernel'] in ['rbf', 'matern', 'spectral_mixture'],

    # Experiment configuration (NEW)
    'view_split_mode': VIEW_SPLIT_MODE,
    'train_view_indices': TRAIN_VIEW_INDICES,
    'val_view_indices': VAL_VIEW_INDICES,

    # Logging
    'epoch_cb': 100,
    'use_wandb': True,
    'wandb_project': 'gppvae',
    'wandb_run_name': f'interpolation_{kernel_name}_{timestamp}',
    'seed': 0,
}

print("GP-VAE Training Configuration:")
print("=" * 60)
for key, value in CONFIG.items():
    if key in ['train_view_indices', 'val_view_indices'] and value is not None:
        print(f"  {key:20s}: {value}")
    elif key not in ['train_view_indices', 'val_view_indices']:
        print(f"  {key:20s}: {value}")
print("=" * 60)

# Verify VAE weights path
if not os.path.exists(CONFIG['vae_weights']):
    print(f"\n‚ö†Ô∏è  WARNING: VAE weights not found at:")
    print(f"   {CONFIG['vae_weights']}")

print(f"\n‚úÖ Output will be saved to:")
print(f"   {CONFIG['outdir']}")
print(f"\n   Directory name includes kernel type AND experiment mode!")
print(f"\nüí° Experiment: Interpolation (boundaries ‚Üí intermediate views)")

GP-VAE Training Configuration:
  data                : ./data/faceplace/data_faces.h5
  outdir              : ./out/gppvae_colab/rbf_random_20260116_185111
  vae_cfg             : ./out/vae_colab_interpolation/20260103_192743/vae.cfg.p
  vae_weights         : ./out/vae_colab_interpolation/20260103_192743/weights/weights.00499.pt
  epochs              : 1000
  batch_size          : 64
  vae_lr              : 0.001
  gp_lr               : 0.001
  xdim                : 64
  view_kernel         : rbf
  kernel_kwargs       : {'lengthscale': 1.0, 'angle_scale': 'normalized'}
  view_split_mode     : interpolation
  train_view_indices  : [0, 1, 3, 4, 5, 7, 8]
  val_view_indices    : [2, 6]
  epoch_cb            : 100
  use_wandb           : True
  wandb_project       : gppvae
  wandb_run_name      : interpolation_rbf_20260116_185111
  seed                : 0

‚úÖ Output will be saved to:
   ./out/gppvae_colab/rbf_random_20260116_185111

   Directory name includes kernel type AND experiment mod

## 9. Import Training Modules

In [26]:
os.chdir(os.path.join(PROJECT_PATH, 'GPPVAE/pysrc/faceplace'))

# Import modules
import matplotlib
matplotlib.use('Agg')

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from vae import FaceVAE
from vmod import Vmodel
from gp import GP
import h5py
import numpy as np
import logging
import pylab as pl
from utils import smartSum, smartAppendDict, smartAppend, export_scripts
from callbacks import callback_gppvae
import pickle
import time
import wandb

# IMPORTANT: Use interpolation data parser with angle encoding
from data_parser_interpolation import read_face_data, FaceDataset

print("‚úÖ All modules imported successfully!")
print("‚úÖ Using data_parser_interpolation for interpolation experiment")

‚úÖ All modules imported successfully!
‚úÖ Using data_parser_interpolation for interpolation experiment


## 10. Setup Training Environment

In [27]:
# Go back to project root
os.chdir(PROJECT_PATH)

# Create output directories
outdir = CONFIG['outdir']
wdir = os.path.join(outdir, "weights")
fdir = os.path.join(outdir, "plots")
os.makedirs(wdir, exist_ok=True)
os.makedirs(fdir, exist_ok=True)

# Setup device (GPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Setup logging
log_format = "%(asctime)s %(message)s"
logging.basicConfig(
    level=logging.INFO,
    format=log_format,
    datefmt="%m/%d %I:%M:%S %p",
)
fh = logging.FileHandler(os.path.join(outdir, "log.txt"))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

# Copy code to output
export_scripts(os.path.join(outdir, "scripts"))

print("‚úÖ Training environment setup complete!")
print(f"   Outputs will be saved to: {outdir}")

Using device: cuda:0
‚úÖ Training environment setup complete!
   Outputs will be saved to: ./out/gppvae_colab/rbf_random_20260116_185111


## 10. Initialize Models and Data

This cell:
1. Loads pre-trained VAE
2. Creates GP and Vmodel
3. Loads dataset
4. Sets up optimizers

In [None]:
# Set random seed
torch.manual_seed(CONFIG['seed'])

# Determine if we need angle encoding based on kernel choice
# RBF, Mat√©rn, and Spectral Mixture kernels work with continuous angles
use_angle_encoding = CONFIG['view_kernel'] in ['rbf', 'matern', 'spectral_mixture']

if use_angle_encoding:
    print("\nüéØ ANGLE ENCODING ENABLED")
    print(f"   Kernel '{CONFIG['view_kernel']}' requires continuous angle values")
    print(f"   Views will be encoded as normalized angles (e.g., -1.0 to +1.0)")
else:
    print("\nüìç DISCRETE VIEW INDICES MODE")
    print(f"   Kernel '{CONFIG['view_kernel']}' uses discrete view embeddings")

# Initialize W&B
if CONFIG['use_wandb']:
    wandb.init(
        project=CONFIG['wandb_project'],
        name=CONFIG['wandb_run_name'],
        config=CONFIG  # CONFIG already contains use_angle_encoding
    )

# Load VAE configuration
vae_cfg = pickle.load(open(CONFIG['vae_cfg'], "rb"))
print(f"VAE config: {vae_cfg}")

# Load pre-trained VAE
print("\nLoading pre-trained VAE...")
vae = FaceVAE(**vae_cfg).to(device)
vae_state = torch.load(CONFIG['vae_weights'], map_location=device)
vae.load_state_dict(vae_state)
print(f"‚úÖ VAE loaded from {CONFIG['vae_weights']}")
print(f"   Total VAE parameters: {sum(p.numel() for p in vae.parameters()):,}")

# Load data with interpolation experiment configuration
print(f"\nLoading dataset with angle_encoding={use_angle_encoding}...")
img, obj, view = read_face_data(
    CONFIG['data'],
    use_angle_encoding=use_angle_encoding,
    view_split_mode=CONFIG['view_split_mode'],
    train_view_indices=CONFIG.get('train_view_indices'),
    val_view_indices=CONFIG.get('val_view_indices')
)

train_data = FaceDataset(img["train"], obj["train"], view["train"])
val_data = FaceDataset(img["val"], obj["val"], view["val"])
train_queue = DataLoader(train_data, batch_size=CONFIG['batch_size'], shuffle=True)
val_queue = DataLoader(val_data, batch_size=CONFIG['batch_size'], shuffle=False)

# Enhanced diagnostic logging
print(f"\n‚úÖ Data loaded:")
print(f"   Training samples: {len(train_data)}")
print(f"   Validation samples: {len(val_data)}")
if use_angle_encoding:
    print(f"   Train view angles: {np.unique(view['train'].numpy().round(3))}")
    print(f"   Val view angles: {np.unique(view['val'].numpy().round(3))}")
else:
    print(f"   Train view indices: {np.unique(view['train'].numpy())}")
    print(f"   Val view indices: {np.unique(view['val'].numpy())}")
print(f"   Unique train identities: {len(np.unique(obj['train'].numpy()))}")
print(f"   Unique val identities: {len(np.unique(obj['val'].numpy()))}")

# Validation checks for interpolation experiment
if CONFIG['view_split_mode'] == 'by_view':
    print("\nüîç Interpolation Experiment Validation Checks:")

    if use_angle_encoding:
        # With angle encoding, views are continuous floats
        # Need to check against expected angle values
        from data_parser_interpolation import encode_view_angles
        train_angles_expected = encode_view_angles(
            np.array(CONFIG['train_view_indices']), 
            encoding='normalized'
        ).numpy().round(6)
        val_angles_expected = encode_view_angles(
            np.array(CONFIG['val_view_indices']), 
            encoding='normalized'
        ).numpy().round(6)
        
        train_angles_actual = np.round(np.unique(view['train'].numpy().flatten()), 6)
        val_angles_actual = np.round(np.unique(view['val'].numpy().flatten()), 6)
        
        assert set(train_angles_actual) == set(train_angles_expected), "Train angles mismatch!"
        assert set(val_angles_actual) == set(val_angles_expected), "Val angles mismatch!"
        print("   ‚úÖ View angles verified correctly!")
    else:
        # Check 1: View split correctness with discrete indices
        train_views_set = set(np.unique(view['train'].numpy().flatten()).astype(int))
        val_views_set = set(np.unique(view['val'].numpy().flatten()).astype(int))

        assert train_views_set == set(CONFIG['train_view_indices']), f"Train views mismatch!"
        assert val_views_set == set(CONFIG['val_view_indices']), f"Val views mismatch!"
        assert len(train_views_set & val_views_set) == 0, "Train and val views overlap!"
        print("   ‚úÖ View split verified correctly!")

    # Check 2: Identity coverage
    train_ids = set(np.unique(obj['train'].numpy()))
    val_ids = set(np.unique(obj['val'].numpy()))
    assert train_ids == val_ids, "Identity sets don't match between train/val!"
    print(f"   ‚úÖ All {len(train_ids)} identities present in both train/val!")

    # Check 3: Sample distribution
    train_samples_per_id = len(img['train']) / len(train_ids)
    val_samples_per_id = len(img['val']) / len(val_ids)
    print(f"   ‚úÖ Train samples per identity: {train_samples_per_id:.1f} (expected: {len(CONFIG['train_view_indices'])}.0)")
    print(f"   ‚úÖ Val samples per identity: {val_samples_per_id:.1f} (expected: {len(CONFIG['val_view_indices'])}.0)")

# Create object and view variables for GP
Dt = Variable(obj["train"][:, 0].long(), requires_grad=False).cuda()
Dv = Variable(obj["val"][:, 0].long(), requires_grad=False).cuda()

# Keep view as float if using angle encoding, otherwise convert to long
if use_angle_encoding:
    Wt = Variable(view["train"][:, 0], requires_grad=False).cuda()  # Float angles
    Wv = Variable(view["val"][:, 0], requires_grad=False).cuda()  # Float angles
else:
    Wt = Variable(view["train"][:, 0].long(), requires_grad=False).cuda()  # Integer indices
    Wv = Variable(view["val"][:, 0].long(), requires_grad=False).cuda()  # Integer indices

# Initialize GP and Vmodel
print("\nInitializing GP-VAE components...")

# Count unique identities and views
all_identities = np.unique(np.concatenate([obj["train"].numpy(), obj["val"].numpy()]))

if use_angle_encoding:
    # With angle encoding, Q is still the number of reference angles (9 views)
    Q = 9
    P = len(all_identities)
    print(f"   Objects (people): {P}")
    print(f"   Views (reference angles): {Q}")
    print(f"   Using continuous angle values")
else:
    # With discrete indices, count unique view indices
    all_views = np.unique(np.concatenate([view["train"].numpy(), view["val"].numpy()]))
    Q = len(all_views)
    P = len(all_identities)
    print(f"   Objects (people): {P}")
    print(f"   Views (discrete): {Q}")
    print(f"   Train views: {sorted(np.unique(view['train'].numpy()).astype(int).tolist())}")
    print(f"   Val views: {sorted(np.unique(view['val'].numpy()).astype(int).tolist())}")

# Initialize Vmodel with standard discrete view indices
vm = Vmodel(
    P, Q,
    p=CONFIG['xdim'],
    q=Q,  # For legacy, q=Q
    view_kernel=CONFIG['view_kernel'],
    **CONFIG['kernel_kwargs']
).cuda()

print(f"\nüî¨ Initializing view kernel: '{CONFIG['view_kernel']}'")
if CONFIG['kernel_kwargs']:
    print(f"   Kernel parameters: {CONFIG['kernel_kwargs']}")
else:
    print(f"   Kernel parameters: (default)")

gp = GP(n_rand_effs=1).to(device)

# Combine GP parameters (Vmodel + GP)
gp_params = nn.ParameterList()
gp_params.extend(vm.parameters())
gp_params.extend(gp.parameters())

print(f"‚úÖ GP-VAE components initialized:")
print(f"   Vmodel parameters: {sum(p.numel() for p in vm.parameters()):,}")
print(f"   GP parameters: {sum(p.numel() for p in gp.parameters()):,}")
print(f"   Total trainable: {sum(p.numel() for p in vae.parameters()) + sum(p.numel() for p in gp_params):,}")

# Create optimizers (separate for VAE and GP)
vae_optimizer = optim.Adam(vae.parameters(), lr=CONFIG['vae_lr'])
gp_optimizer = optim.Adam(gp_params, lr=CONFIG['gp_lr'])
print(f"\n‚úÖ Optimizers created:")
print(f"   VAE optimizer: Adam(lr={CONFIG['vae_lr']})")
print(f"   GP optimizer: Adam(lr={CONFIG['gp_lr']})")

0,1
diagnostics/gap_train_val,‚ñÅ‚ñÑ‚ñÑ‚ñà‚ñÜ‚ñÑ‚ñÇ‚ñÉ‚ñÖ‚ñÑ
diagnostics/gap_val_out,‚ñà‚ñá‚ñÜ‚ñÉ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÅ
diagnostics/variance_ratio,‚ñà‚ñá‚ñÜ‚ñÜ‚ñÖ‚ñÑ‚ñÉ‚ñÉ‚ñÇ‚ñÅ
epoch,‚ñÅ‚ñÇ‚ñÉ‚ñÉ‚ñÑ‚ñÖ‚ñÜ‚ñÜ‚ñá‚ñà
gp_nll,‚ñÉ‚ñá‚ñÑ‚ñÅ‚ñÅ‚ñÉ‚ñÜ‚ñà‚ñá‚ñà
loss,‚ñÅ‚ñà‚ñÇ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ
mse_out,‚ñÉ‚ñà‚ñÉ‚ñÉ‚ñÇ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÅ
mse_out_per_view/45L,‚ñÉ‚ñà‚ñÉ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÅ
mse_out_per_view/45R,‚ñÉ‚ñà‚ñÉ‚ñÉ‚ñÇ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÅ
mse_train,‚ñÅ‚ñà‚ñÇ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ

0,1
diagnostics/gap_train_val,-0.00071
diagnostics/gap_val_out,0.03981
diagnostics/variance_ratio,0.49539
epoch,9
gp_nll,0.0032
loss,-1.47101
mse_out,0.04705
mse_out_per_view/45L,0.04662
mse_out_per_view/45R,0.04747
mse_train,0.00653


VAE config: {'nf': 32, 'zdim': 256, 'vy': 0.002}

Loading pre-trained VAE...
‚úÖ VAE loaded from ./out/vae_colab_interpolation/20260103_192743/weights/weights.00499.pt
   Total VAE parameters: 553,304

Loading dataset WITHOUT angle encoding (will use view indices)...

üìÇ Loading data from: ./data/faceplace/data_faces.h5
   Split mode: interpolation
   Angle encoding: ‚úÖ ENABLED (using actual angles)

üîç DEBUG: View encoding from HDF5
   Unique Rid values in train: [np.bytes_(b'00F'), np.bytes_(b'30L'), np.bytes_(b'30R'), np.bytes_(b'45L'), np.bytes_(b'45R'), np.bytes_(b'60L'), np.bytes_(b'60R'), np.bytes_(b'90L'), np.bytes_(b'90R')]
   uRid (ordered): [b'90L' b'60L' b'45L' b'30L' b'00F' b'30R' b'45R' b'60R' b'90R']
   View mapping table_w: {np.bytes_(b'90L'): 0, np.bytes_(b'60L'): 1, np.bytes_(b'45L'): 2, np.bytes_(b'30L'): 3, np.bytes_(b'00F'): 4, np.bytes_(b'30R'): 5, np.bytes_(b'45R'): 6, np.bytes_(b'60R'): 7, np.bytes_(b'90R'): 8}
   W['train'] unique values: [np.int64(0), np.

## 11. Define Training Functions

These functions handle the complex GP-VAE training loop.

In [None]:
def encode_Y(vae, train_queue):
    """Encode all training images to get latent codes"""
    vae.eval()

    with torch.no_grad():
        n = train_queue.dataset.Y.shape[0]
        Zm = Variable(torch.zeros(n, vae_cfg["zdim"]), requires_grad=False).cuda()
        Zs = Variable(torch.zeros(n, vae_cfg["zdim"]), requires_grad=False).cuda()

        for batch_i, data in enumerate(train_queue):
            y = data[0].cuda()
            idxs = data[-1].cuda()
            zm, zs = vae.encode(y)
            Zm[idxs], Zs[idxs] = zm.detach(), zs.detach()

    return Zm, Zs


def eval_step(vae, gp, vm, val_queue, Zm, Vt, Vv, Wv, use_angle_encoding=False):
    """Enhanced evaluation with per-view metrics for Interpolation Experiment"""
    rv = {}

    with torch.no_grad():
        _X = vm.x().data.cpu().numpy()
        _W = vm.v().data.cpu().numpy()
        covs = {"XX": np.dot(_X, _X.T), "WW": np.dot(_W, _W.T)}
        rv["vars"] = gp.get_vs().data.cpu().numpy()

        # Out-of-sample prediction
        vs = gp.get_vs()
        U, UBi, _ = gp.U_UBi_Shb([Vt], vs)
        Kiz = gp.solve(Zm, U, UBi, vs)
        Zo = vs[0] * Vv.mm(Vt.transpose(0, 1).mm(Kiz))

        mse_out = Variable(torch.zeros(Vv.shape[0], 1), requires_grad=False).cuda()
        mse_val = Variable(torch.zeros(Vv.shape[0], 1), requires_grad=False).cuda()

        # Collect ALL validation samples first for diverse sampling
        all_Yv = []
        all_Yr = []
        all_Yo = []

        for batch_i, data in enumerate(val_queue):
            idxs = data[-1].cuda()
            Yv = data[0].cuda()
            Zv = vae.encode(Yv)[0].detach()
            Yr = vae.decode(Zv)
            Yo = vae.decode(Zo[idxs])
            mse_out[idxs] = ((Yv - Yo) ** 2).view(Yv.shape[0], -1).mean(1)[:, None].detach()
            mse_val[idxs] = ((Yv - Yr) ** 2).view(Yv.shape[0], -1).mean(1)[:, None].detach()

            # Collect all samples for diverse visualization
            all_Yv.append(Yv.data.cpu().numpy().transpose(0, 2, 3, 1))
            all_Yr.append(Yr.data.cpu().numpy().transpose(0, 2, 3, 1))
            all_Yo.append(Yo.data.cpu().numpy().transpose(0, 2, 3, 1))

        # Concatenate all validation samples
        all_Yv = np.concatenate(all_Yv, axis=0)
        all_Yr = np.concatenate(all_Yr, axis=0)
        all_Yo = np.concatenate(all_Yo, axis=0)

        # Sample diverse identities across the validation set (evenly spaced)
        n_total = all_Yv.shape[0]
        if n_total >= 24:
            sample_stride = max(1, n_total // 24)
            sample_indices = np.arange(0, n_total, sample_stride)[:24]
        else:
            sample_indices = np.arange(min(24, n_total))

        imgs = {}
        imgs["Yv"] = all_Yv[sample_indices]
        imgs["Yr"] = all_Yr[sample_indices]
        imgs["Yo"] = all_Yo[sample_indices]

        rv["mse_out"] = float(mse_out.data.mean().cpu())
        rv["mse_val"] = float(mse_val.data.mean().cpu())

        # NEW: Per-view metrics for Interpolation Experiment
        # Need to handle both continuous angles and discrete indices
        if use_angle_encoding:
            # With angle encoding, need to map angles back to view indices
            from data_parser_interpolation import encode_view_angles
            # Create mapping from angles to indices for validation views
            val_indices = np.array(CONFIG['val_view_indices'])
            val_angles = encode_view_angles(val_indices, encoding='normalized').numpy()
            
            # Map each unique angle back to its index
            mse_val_per_view = {}
            mse_out_per_view = {}
            
            for idx, angle in zip(val_indices, val_angles):
                # Find samples with this angle (with small tolerance for floating point)
                view_mask = np.abs(Wv.cpu().numpy().flatten() - angle) < 1e-5
                if view_mask.sum() > 0:
                    mse_val_per_view[int(idx)] = float(mse_val.cpu().numpy()[view_mask].mean())
                    mse_out_per_view[int(idx)] = float(mse_out.cpu().numpy()[view_mask].mean())
        else:
            # With discrete indices, use them directly
            unique_views = torch.unique(Wv).cpu().numpy()
            mse_val_per_view = {}
            mse_out_per_view = {}
            
            for view_idx in unique_views:
                view_mask = (Wv.cpu().numpy().flatten() == view_idx)
                if view_mask.sum() > 0:
                    mse_val_per_view[int(view_idx)] = float(mse_val.cpu().numpy()[view_mask].mean())
                    mse_out_per_view[int(view_idx)] = float(mse_out.cpu().numpy()[view_mask].mean())

        rv['mse_val_per_view'] = mse_val_per_view
        rv['mse_out_per_view'] = mse_out_per_view

    return rv, imgs, covs


def backprop_and_update(vae, gp, vm, train_queue, Dt, Wt, Eps, Zb, Vbs, vbs, vae_optimizer, gp_optimizer):
    """Joint optimization of VAE and GP"""
    rv = {}

    vae_optimizer.zero_grad()
    gp_optimizer.zero_grad()
    vae.train()
    gp.train()
    vm.train()

    for batch_i, data in enumerate(train_queue):
        # Get batch data
        y = data[0].cuda()
        eps = Eps[data[-1]]
        _d = Dt[data[-1]]
        _w = Wt[data[-1]]
        _Zb = Zb[data[-1]]
        _Vbs = [Vbs[0][data[-1]]]

        # Forward through VAE
        zm, zs = vae.encode(y)
        z = zm + zs * eps
        yr = vae.decode(z)
        recon_term, mse = vae.nll(y, yr)

        # Forward through GP
        _Vs = [vm(_d, _w)]
        gp_nll_fo = gp.taylor_expansion(z, _Vs, _Zb, _Vbs, vbs) / vae.K

        # Penalization term
        pen_term = -0.5 * zs.sum(1)[:, None] / vae.K

        # Joint loss and backward
        loss = (recon_term + gp_nll_fo + pen_term).sum()
        loss.backward()

        # Accumulate metrics
        _n = train_queue.dataset.Y.shape[0]
        smartSum(rv, "mse", float(mse.data.sum().cpu()) / _n)
        smartSum(rv, "recon_term", float(recon_term.data.sum().cpu()) / _n)
        smartSum(rv, "pen_term", float(pen_term.data.sum().cpu()) / _n)

    # Update both optimizers
    vae_optimizer.step()
    gp_optimizer.step()

    return rv


print("‚úÖ Training functions defined with per-view metrics for interpolation")
print("‚úÖ Diverse identity sampling now collects from entire validation set")

‚úÖ Training functions defined with per-view metrics for interpolation
‚úÖ Diverse identity sampling now collects from entire validation set


## 12. Train GP-VAE Model üöÄ

**This is joint optimization!** Both VAE and GP are updated together each iteration.

Training process per epoch:
1. Encode images to latent codes (VAE)
2. Compute GP prior likelihood on latents
3. Backpropagate through joint loss
4. Update VAE, GP, and Vmodel simultaneously

In [None]:
import time
from IPython.display import clear_output

history = {}
start_time = time.time()

print(f"üöÄ Starting GP-VAE Interpolation Experiment training for {CONFIG['epochs']} epochs...")
print("=" * 80)
print("Training mode: JOINT OPTIMIZATION (VAE + GP updated together)")
print(f"Experiment: Interpolation (boundaries ‚Üí intermediate)")
print(f"  Training views: {CONFIG.get('train_view_indices', 'all')}")
print(f"  Validation views: {CONFIG.get('val_view_indices', 'all')}")
print("=" * 80)

for epoch in range(CONFIG['epochs']):
    epoch_start = time.time()

    # 1. Encode all training images
    Zm, Zs = encode_Y(vae, train_queue)

    # 2. Sample latent codes
    Eps = Variable(torch.randn(*Zs.shape), requires_grad=False).cuda()
    Z = Zm + Eps * Zs

    # 3. Compute variance matrices
    Vt = vm(Dt, Wt).detach()
    Vv = vm(Dv, Wv).detach()

    # 4. Evaluate on validation set (with per-view metrics)
    rv_eval, imgs, covs = eval_step(vae, gp, vm, val_queue, Zm, Vt, Vv, Wv, use_angle_encoding=use_angle_encoding)

    # 5. Compute GP Taylor expansion coefficients
    Zb, Vbs, vbs, gp_nll = gp.taylor_coeff(Z, [Vt])
    rv_eval["gp_nll"] = float(gp_nll.data.mean().cpu()) / vae.K

    # 6. Joint training step (VAE + GP)
    rv_back = backprop_and_update(
        vae, gp, vm, train_queue, Dt, Wt, Eps,
        Zb, Vbs, vbs, vae_optimizer, gp_optimizer
    )
    rv_back["loss"] = rv_back["recon_term"] + rv_eval["gp_nll"] + rv_back["pen_term"]

    # Store history
    smartAppendDict(history, rv_eval)
    smartAppendDict(history, rv_back)
    smartAppend(history, "vs", gp.get_vs().data.cpu().numpy())

    epoch_time = time.time() - epoch_start
    total_time = time.time() - start_time

    # üî¨ Compute diagnostic metrics
    train_val_gap = rv_back["mse"] - rv_eval["mse_val"]
    val_out_gap = rv_eval["mse_out"] - rv_eval["mse_val"]

    vs = gp.get_vs().data.cpu().numpy()
    variance_ratio = vs[0] / (vs[0] + vs[1])

    # Check if kernel has learnable lengthscale
    learned_lengthscale = None
    if hasattr(vm, 'view_kernel') and hasattr(vm.view_kernel, 'log_lengthscale'):
        learned_lengthscale = torch.exp(vm.view_kernel.log_lengthscale).item()

    # Print progress
    if epoch % 5 == 0 or epoch == CONFIG['epochs'] - 1:
        print(f"Epoch {epoch:4d}/{CONFIG['epochs']} | "
              f"MSE train: {rv_back['mse']:.6f} | "
              f"MSE val: {rv_eval['mse_val']:.6f} | "
              f"MSE out: {rv_eval['mse_out']:.6f} | "
              f"GP NLL: {rv_eval['gp_nll']:.4f} | "
              f"Gap(T-V): {train_val_gap:.6f} | "
              f"Gap(V-O): {val_out_gap:.6f} | "
              f"v‚ÇÄ/(v‚ÇÄ+v‚ÇÅ): {variance_ratio:.3f}" +
              (f" | ‚Ñì: {learned_lengthscale:.3f}" if learned_lengthscale else "") +
              f" | Time: {epoch_time:.1f}s")

        # Print per-view breakdown (Interpolation specific)
        if CONFIG['view_split_mode'] == 'by_view' and epoch % 10 == 0:
            if 'mse_out_per_view' in rv_eval and rv_eval['mse_out_per_view']:
                print("   Per-view MSE_out (intermediate views):")
                # Map view indices to names
                view_names = {0: "90L", 1: "60L", 2: "45L", 3: "30L", 4: "00F",
                             5: "30R", 6: "45R", 7: "60R", 8: "90R"}
                for view_idx in sorted(rv_eval['mse_out_per_view'].keys()):
                    mse = rv_eval['mse_out_per_view'][view_idx]
                    view_name = view_names.get(view_idx, f"V{view_idx}")
                    print(f"      {view_name}: {mse:.6f}")

    # Log to W&B
    if CONFIG['use_wandb']:
        log_dict = {
            "epoch": epoch,
            "mse_train": rv_back["mse"],
            "mse_val": rv_eval["mse_val"],
            "mse_out": rv_eval["mse_out"],
            "gp_nll": rv_eval["gp_nll"],
            "recon_term": rv_back["recon_term"],
            "pen_term": rv_back["pen_term"],
            "loss": rv_back["loss"],
            "vars": rv_eval["vars"],
            "time/epoch_seconds": epoch_time,
            # üî¨ Diagnostic metrics
            "diagnostics/gap_train_val": train_val_gap,
            "diagnostics/gap_val_out": val_out_gap,
            "diagnostics/variance_ratio": variance_ratio,
            "vars/v0_object": vs[0],
            "vars/v1_noise": vs[1],
        }

        # Add lengthscale if available
        if learned_lengthscale is not None:
            log_dict["kernel/lengthscale"] = learned_lengthscale

        # Add per-view metrics (Interpolation specific)
        if 'mse_val_per_view' in rv_eval:
            view_names = {0: "90L", 1: "60L", 2: "45L", 3: "30L", 4: "00F",
                         5: "30R", 6: "45R", 7: "60R", 8: "90R"}
            for view_idx, mse in rv_eval['mse_val_per_view'].items():
                view_name = view_names.get(view_idx, f"V{view_idx}")
                log_dict[f"mse_val_per_view/{view_name}"] = mse

        if 'mse_out_per_view' in rv_eval:
            view_names = {0: "90L", 1: "60L", 2: "45L", 3: "30L", 4: "00F",
                         5: "30R", 6: "45R", 7: "60R", 8: "90R"}
            for view_idx, mse in rv_eval['mse_out_per_view'].items():
                view_name = view_names.get(view_idx, f"V{view_idx}")
                log_dict[f"mse_out_per_view/{view_name}"] = mse

        wandb.log(log_dict)

    # Save checkpoint
    if epoch % CONFIG['epoch_cb'] == 0 or epoch == CONFIG['epochs'] - 1:
        logging.info(f"Epoch {epoch} - saving checkpoint")

        # Save VAE weights
        vae_file = os.path.join(wdir, f"vae_weights.{epoch:05d}.pt")
        torch.save(vae.state_dict(), vae_file)

        # Save GP weights
        gp_file = os.path.join(wdir, f"gp_weights.{epoch:05d}.pt")
        torch.save({
            'gp_state': gp.state_dict(),
            'vm_state': vm.state_dict(),
            'gp_params': gp_params.state_dict(),
        }, gp_file)

        # Save visualization
        ffile = os.path.join(fdir, f"plot.{epoch:05d}.png")
        callback_gppvae(epoch, history, covs, imgs, ffile)

        if CONFIG['use_wandb']:
            wandb.log({
                "reconstructions": wandb.Image(ffile),
                "covariances/XX": wandb.Image(ffile),
            })

        print(f"  ‚úì Checkpoint saved at epoch {epoch}")

# At the end, enhanced summary with per-view breakdown
total_time = time.time() - start_time
print("\n" + "=" * 80)
print(f"‚úÖ GP-VAE Interpolation Experiment training complete!")
print(f"   Total time: {total_time/60:.1f} minutes ({total_time/3600:.2f} hours)")
print(f"   Average time per epoch: {total_time/CONFIG['epochs']:.1f} seconds")
print(f"   Final training MSE: {rv_back['mse']:.6f}")
print(f"   Final validation MSE: {rv_eval['mse_val']:.6f}")
print(f"   Final out-of-sample MSE: {rv_eval['mse_out']:.6f}")
print(f"   Final GP NLL: {rv_eval['gp_nll']:.4f}")

print(f"\nüî¨ Final Diagnostics:")
print(f"   Train-Val Gap: {train_val_gap:.6f} (lower = less overfitting)")
print(f"   Val-Out Gap: {val_out_gap:.6f} (CRITICAL for interpolation quality)")
print(f"   Variance Ratio: {variance_ratio:.3f} (higher = more structure learned)")
if learned_lengthscale is not None:
    print(f"   Learned Lengthscale: {learned_lengthscale:.3f}")

# Interpolation specific: Per-view breakdown
if CONFIG['view_split_mode'] == 'by_view' and 'mse_out_per_view' in rv_eval:
    print(f"\nüìä Final Per-View MSE_out (Interpolation Test):")

    view_names = {0: "90L (-90¬∞)", 1: "60L (-60¬∞)", 2: "45L (-45¬∞)", 3: "30L (-30¬∞)",
                 4: "00F (0¬∞)", 5: "30R (+30¬∞)", 6: "45R (+45¬∞)", 7: "60R (+60¬∞)", 8: "90R (+90¬∞)"}

    # Separate training and validation views
    train_view_indices = CONFIG.get('train_view_indices', [])
    val_view_indices = CONFIG.get('val_view_indices', [])

    if rv_eval['mse_out_per_view']:
        print("   INTERMEDIATE VIEWS (held-out, interpolation targets):")
        interp_mses = []
        for view_idx in sorted(rv_eval['mse_out_per_view'].keys()):
            if view_idx in val_view_indices:
                mse = rv_eval['mse_out_per_view'][view_idx]
                interp_mses.append(mse)
                view_name = view_names.get(view_idx, f"V{view_idx}")
                print(f"      {view_name:15s}: {mse:.6f}")

        if interp_mses:
            avg_interp = np.mean(interp_mses)
            print(f"\n   Average MSE on intermediate views: {avg_interp:.6f}")
            print(f"   Overall MSE_out: {rv_eval['mse_out']:.6f}")
            print(f"\nüí° Lower MSE_out on intermediate views = better interpolation!")

if CONFIG['use_wandb']:
    wandb.finish()
    print("\nüîó View detailed results in W&B dashboard")

üöÄ Starting GP-VAE Interpolation Experiment training for 1000 epochs...
Training mode: JOINT OPTIMIZATION (VAE + GP updated together)
Experiment: Interpolation (boundaries ‚Üí intermediate)
  Training views: [0, 1, 3, 4, 5, 7, 8]
  Validation views: [2, 6]


AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


## 13. Download Results

Download the trained model and visualizations to your computer:

In [None]:
# Compress output folder
output_zip = '/content/gppvae_output.zip'
!zip -r {output_zip} {CONFIG['outdir']}

# Download
from google.colab import files
print("Preparing download...")
files.download(output_zip)
print("\n‚úÖ Download started! Extract the zip on your local machine.")
print(f"\nContents include:")
print(f"  - Trained VAE weights (fine-tuned)")
print(f"  - GP + Vmodel weights")
print(f"  - Visualization plots")
print(f"  - Training logs")

## 14. Visualize Results

View the latest reconstruction and covariance plots: