# GPPVAE Training on Google Colab

This notebook trains the VAE model using Google Colab's free GPU (50-100x faster than M1 CPU).

## Setup Instructions for VS Code + Colab:

1. **Open this notebook in VS Code**
2. **Connect to Colab**: Click kernel picker (top-right) ‚Üí "Connect to Colab" ‚Üí Choose GPU runtime (T4)
3. **Important**: When prompted with "Alias your server", just press Enter to confirm
4. **Run cell 2 below** - it will automatically detect your project location
5. **Note**: Your files are synced to the Colab runtime, but in a different path than your local Mac

The notebook will automatically figure out where your files are!

## Output Directory Structure:

Each training run creates a **timestamped directory** to avoid overwriting previous runs:
- Format: `./out/vae_colab/YYYYMMDD_HHMMSS/`
- Example: `./out/vae_colab/20251224_143022/weights/weights.00100.pt`
- This allows you to compare different training runs and keep a history!

## Expected Performance:
- **10 epochs**: ~5-10 minutes (vs 2 hours on M1 CPU!)
- **100 epochs**: ~30-60 minutes  
- **1000 epochs**: ~5-10 hours


## 1. Check GPU Availability

In [1]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è WARNING: GPU not detected! Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

PyTorch version: 2.9.0+cu126
CUDA available: True
GPU Device: Tesla T4
GPU Memory: 15.83 GB


## 2. Auto-Detect Project Path

This cell automatically finds your project files wherever they are on the Colab runtime.
VS Code syncs your workspace, but the path on Colab is different from your local Mac path.

In [4]:
import os
import sys

# Get the current working directory
current_dir = os.getcwd()
print(f"üìç Current directory: {current_dir}")

# Check if we're in /content (Colab) and need to upload files
if current_dir == '/content':
    print("\nüîÑ You're on Colab but files aren't synced yet.")
    print("\nOption 1: Upload via Google Drive (Recommended)")
    print("=" * 60)
    
    # Try to mount Google Drive
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        
        # Check if project exists in Drive
        drive_path = '/content/drive/MyDrive/gppvae'
        if os.path.exists(drive_path):
            PROJECT_PATH = drive_path
            print(f"‚úÖ Found project in Google Drive: {PROJECT_PATH}")
        else:
            print(f"\n‚ö†Ô∏è  Project not found at: {drive_path}")
            print("\nTo upload your project to Google Drive:")
            print("1. Go to https://drive.google.com")
            print("2. Create a folder called 'gppvae' in My Drive")
            print("3. Upload these folders into it:")
            print("   - GPPVAE/ (code)")
            print("   - data/ (your data_faces.h5 file)")
            print("   - notebooks/ (this notebook)")
            PROJECT_PATH = '/content'
    except Exception as e:
        print(f"Could not mount Drive: {e}")
        PROJECT_PATH = '/content'
        
    print(f"\nüìÇ Using project path: {PROJECT_PATH}")
else:
    # Running locally or files are synced
    if 'notebooks' in current_dir:
        PROJECT_PATH = os.path.dirname(current_dir)
    else:
        PROJECT_PATH = current_dir
    print(f"üíª Using project path: {PROJECT_PATH}")

# Verify what we have
print(f"\nüìÅ Contents of {PROJECT_PATH}:")
if os.path.exists(PROJECT_PATH):
    items = os.listdir(PROJECT_PATH)
    for item in sorted(items)[:15]:
        item_path = os.path.join(PROJECT_PATH, item)
        if os.path.isdir(item_path):
            print(f"   üìÇ {item}/")
        else:
            print(f"   üìÑ {item}")
    
    # Check for required folders
    print(f"\nüîç Checking required files:")
    required = {
        'GPPVAE': os.path.exists(os.path.join(PROJECT_PATH, 'GPPVAE')),
        'data': os.path.exists(os.path.join(PROJECT_PATH, 'data')),
        'data_faces.h5': os.path.exists(os.path.join(PROJECT_PATH, 'data/faceplace/data_faces.h5'))
    }
    
    for name, exists in required.items():
        status = "‚úÖ" if exists else "‚ùå"
        print(f"   {status} {name}")
    
    if not all(required.values()):
        print(f"\n‚ö†Ô∏è  Missing files! Please upload your gppvae folder to Google Drive")
else:
    print(f"‚ùå Path doesn't exist: {PROJECT_PATH}")

üìç Current directory: /content

üîÑ You're on Colab but files aren't synced yet.

Option 1: Upload via Google Drive (Recommended)
Could not mount Drive: mount failed

üìÇ Using project path: /content

üìÅ Contents of /content:
   üìÇ .config/
   üìÇ sample_data/

üîç Checking required files:
   ‚ùå GPPVAE
   ‚ùå data
   ‚ùå data_faces.h5

‚ö†Ô∏è  Missing files! Please upload your gppvae folder to Google Drive


## 3. Install Dependencies

In [None]:
# Install required packages
!pip install -q wandb==0.12.21 imageio==2.15.0 pyyaml

# Check installations
import wandb
import imageio
import yaml
print("‚úÖ All dependencies installed successfully!")

## 4. Login to Weights & Biases (Optional)

If you want to track your experiments, login to W&B. Otherwise, skip this cell.

In [None]:
import wandb
wandb.login()

# Or set to offline mode if you don't want to use W&B
# os.environ['WANDB_MODE'] = 'offline'

## 5. Navigate to Project Directory

In [None]:
import os
os.chdir(PROJECT_PATH)
print(f"Current directory: {os.getcwd()}")

# Add project to Python path
import sys
sys.path.insert(0, os.path.join(PROJECT_PATH, 'GPPVAE/pysrc/faceplace'))

# List files to verify
print("\nProject structure:")
!ls -la

## 6. Check Data File

In [None]:
data_path = './data/faceplace/data_faces.h5'

if os.path.exists(data_path):
    import h5py
    with h5py.File(data_path, 'r') as f:
        print("‚úÖ Data file found!")
        print("\nDatasets in file:")
        for key in f.keys():
            print(f"  - {key}: {f[key].shape}")
else:
    print(f"‚ùå Data file not found at: {data_path}")
    print("\nPlease ensure you've uploaded the data folder to Google Drive")

## 7. Configure Training Parameters

Adjust these parameters as needed:

In [None]:
from datetime import datetime

# Training configuration
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
CONFIG = {
    'data': './data/faceplace/data_faces.h5',
    'outdir': f'./out/vae_colab/{timestamp}',  # Timestamped directory for each run
    'epochs': 100,  # Start with 100 epochs (~30-60 min)
    'batch_size': 64,
    'lr': 0.0002,
    'zdim': 256,
    'filts': 32,
    'epoch_cb': 10,  # Save every 10 epochs
    'use_wandb': True,
    'wandb_project': 'gppvae',
    'wandb_run_name': f'colab_vae_gpu_{timestamp}',  # Also add timestamp to run name
}

print("Training configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

## 8. Import Training Modules

In [None]:
# Change to the training script directory
os.chdir(os.path.join(PROJECT_PATH, 'GPPVAE/pysrc/faceplace'))

# Import required modules
import matplotlib
matplotlib.use('Agg')

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from vae import FaceVAE
import h5py
import scipy as sp
import logging
import pylab as pl
from utils import smartSum, smartAppendDict, smartAppend, export_scripts
from callbacks import callback
from data_parser import read_face_data, FaceDataset
import pickle
import time
import wandb

print("‚úÖ All modules imported successfully!")

## 9. Setup Training Environment

In [None]:
# Go back to project root
os.chdir(PROJECT_PATH)

# Create output directories
outdir = CONFIG['outdir']
wdir = os.path.join(outdir, "weights")
fdir = os.path.join(outdir, "plots")
os.makedirs(wdir, exist_ok=True)
os.makedirs(fdir, exist_ok=True)

# Setup device (GPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Setup logging
log_format = "%(asctime)s %(message)s"
logging.basicConfig(
    level=logging.INFO,
    format=log_format,
    datefmt="%m/%d %I:%M:%S %p",
)
fh = logging.FileHandler(os.path.join(outdir, "log.txt"))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

print("‚úÖ Training environment setup complete!")

## 10. Initialize Model and Data

In [None]:
# Set random seed
torch.manual_seed(0)

# Initialize wandb
if CONFIG['use_wandb']:
    wandb.init(
        project=CONFIG['wandb_project'],
        name=CONFIG['wandb_run_name'],
        config=CONFIG
    )

# Define VAE config
vae_cfg = {
    "nf": CONFIG['filts'],
    "zdim": CONFIG['zdim'],
    "vy": 0.002
}

# Save VAE config
pickle.dump(vae_cfg, open(os.path.join(outdir, "vae.cfg.p"), "wb"))

# Create VAE model
vae = FaceVAE(**vae_cfg).to(device)
print(f"‚úÖ VAE model created with {sum(p.numel() for p in vae.parameters()):,} parameters")

# Create optimizer
optimizer = optim.Adam(vae.parameters(), lr=CONFIG['lr'])

# Load data
print("Loading data...")
img, obj, view = read_face_data(CONFIG['data'])
train_data = FaceDataset(img["train"], obj["train"], view["train"])
val_data = FaceDataset(img["val"], obj["val"], view["val"])
train_queue = DataLoader(train_data, batch_size=CONFIG['batch_size'], shuffle=True)
val_queue = DataLoader(val_data, batch_size=CONFIG['batch_size'], shuffle=False)

print(f"‚úÖ Data loaded: {len(train_data)} training samples, {len(val_data)} validation samples")

## 11. Define Training Functions

In [None]:
def train_ep(vae, train_queue, optimizer, device):
    """Train for one epoch"""
    rv = {}
    vae.train()

    for batch_i, data in enumerate(train_queue):
        # Forward pass
        y = data[0]
        eps = Variable(torch.randn(y.shape[0], CONFIG['zdim']), requires_grad=False)
        y, eps = y.to(device), eps.to(device)
        elbo, mse, nll, kld = vae.forward(y, eps)
        loss = elbo.sum()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate metrics
        _n = train_queue.dataset.Y.shape[0]
        smartSum(rv, "mse", float(mse.data.sum().cpu()) / float(_n))
        smartSum(rv, "nll", float(nll.data.sum().cpu()) / float(_n))
        smartSum(rv, "kld", float(kld.data.sum().cpu()) / float(_n))
        smartSum(rv, "loss", float(elbo.data.sum().cpu()) / float(_n))

    return rv


def eval_ep(vae, val_queue, device):
    """Evaluate for one epoch"""
    rv = {}
    vae.eval()

    with torch.no_grad():
        for batch_i, data in enumerate(val_queue):
            # Forward pass
            y = data[0]
            eps = Variable(torch.randn(y.shape[0], CONFIG['zdim']), requires_grad=False)
            y, eps = y.to(device), eps.to(device)
            elbo, mse, nll, kld = vae.forward(y, eps)

            # Accumulate metrics
            _n = val_queue.dataset.Y.shape[0]
            smartSum(rv, "mse_val", float(mse.data.sum().cpu()) / float(_n))
            smartSum(rv, "nll_val", float(nll.data.sum().cpu()) / float(_n))
            smartSum(rv, "kld_val", float(kld.data.sum().cpu()) / float(_n))
            smartSum(rv, "loss_val", float(elbo.data.sum().cpu()) / float(_n))

    return rv

print("‚úÖ Training functions defined")

## 12. Train the Model üöÄ

This is where the magic happens! Monitor the output to see training progress.

In [None]:
import time
from IPython.display import clear_output

history = {}
start_time = time.time()

print(f"üöÄ Starting training for {CONFIG['epochs']} epochs...\n")
print("=" * 80)

for epoch in range(CONFIG['epochs']):
    epoch_start = time.time()
    
    # Train and evaluate
    ht = train_ep(vae, train_queue, optimizer, device)
    hv = eval_ep(vae, val_queue, device)
    smartAppendDict(history, ht)
    smartAppendDict(history, hv)
    
    epoch_time = time.time() - epoch_start
    total_time = time.time() - start_time
    
    # Print progress
    if epoch % 5 == 0 or epoch == CONFIG['epochs'] - 1:
        print(f"Epoch {epoch:4d}/{CONFIG['epochs']} | "
              f"Train MSE: {ht['mse']:.6f} | "
              f"Val MSE: {hv['mse_val']:.6f} | "
              f"Time: {epoch_time:.1f}s | "
              f"Total: {total_time/60:.1f}min")
    
    # Log to wandb
    if CONFIG['use_wandb']:
        wandb.log({
            "epoch": epoch,
            "train/mse": ht["mse"],
            "train/nll": ht["nll"],
            "train/kld": ht["kld"],
            "train/loss": ht["loss"],
            "val/mse": hv["mse_val"],
            "val/nll": hv["nll_val"],
            "val/kld": hv["kld_val"],
            "val/loss": hv["loss_val"],
            "time/epoch_seconds": epoch_time,
        })
    
    # Save checkpoint
    if epoch % CONFIG['epoch_cb'] == 0 or epoch == CONFIG['epochs'] - 1:
        logging.info(f"Epoch {epoch} - saving checkpoint")
        wfile = os.path.join(wdir, f"weights.{epoch:05d}.pt")
        ffile = os.path.join(fdir, f"plot.{epoch:05d}.png")
        torch.save(vae.state_dict(), wfile)
        callback(epoch, val_queue, vae, history, ffile, device)
        
        if CONFIG['use_wandb']:
            wandb.log({"reconstructions": wandb.Image(ffile)})
        
        print(f"  ‚úì Checkpoint saved at epoch {epoch}")

total_time = time.time() - start_time
print("\n" + "=" * 80)
print(f"‚úÖ Training complete! Total time: {total_time/60:.1f} minutes ({total_time/3600:.2f} hours)")
print(f"   Average time per epoch: {total_time/CONFIG['epochs']:.1f} seconds")
print(f"   Final validation MSE: {hv['mse_val']:.6f}")

if CONFIG['use_wandb']:
    wandb.finish()
    print("\nüîó View results in W&B dashboard")

## 13. Download Results to Local Machine

Download the trained model and plots to your computer:

In [None]:
# Compress output folder
!zip -r /content/vae_output.zip {CONFIG['outdir']}

# Download via Colab
from google.colab import files
print("Preparing download...")
files.download('/content/vae_output.zip')
print("\n‚úÖ Download started! Extract the zip file on your local machine.")

## 14. View Sample Reconstructions

In [None]:
from IPython.display import Image, display
import glob

# Get latest plot
plot_files = sorted(glob.glob(os.path.join(fdir, "*.png")))
if plot_files:
    latest_plot = plot_files[-1]
    print(f"Latest reconstruction plot: {latest_plot}")
    display(Image(filename=latest_plot))
else:
    print("No plots generated yet")

## 15. Next Steps

### To continue training:
1. Keep this notebook running
2. Increase `CONFIG['epochs']` in cell 7
3. Re-run cells 10-12

### To train GPPVAE next:
1. Use the trained VAE weights from this run
2. Create a similar notebook for `train_gppvae.py`
3. Or download weights and run locally

### Performance comparison:
- **Colab GPU (T4)**: ~5-10 min for 100 epochs
- **M1 Pro CPU**: ~2 hours for 10 epochs
- **Speedup**: 50-100x faster! üöÄ