# GP-VAE Training on COIL-100 (Standard Task) - Periodic Kernel

This notebook trains **GP-VAE** on COIL-100 dataset using the **Periodic kernel**.

## Kernel: Periodic
- **Standard periodic kernel** for exact periodic patterns
- k(Œ∏, Œ∏') = variance √ó exp(-2 √ó sin¬≤(œÄ|Œ∏-Œ∏'|/period) / lengthscale¬≤)
- **period=360¬∞** fixed for full rotation
- **Parameters**: 2 learnable (lengthscale, variance)
- **Best for**: Data with exact periodicity (360¬∞ = 0¬∞)

## Dataset Info:
- **COIL-100**: 100 objects √ó 18 views (every 20¬∞: 0¬∞, 20¬∞, ..., 340¬∞)
- **Image size**: 128√ó128√ó3 RGB
- **Task**: Standard split (random train/val/test)

## Prerequisites:
- ‚úÖ Trained VAE weights from `train_vae_colab_standard.ipynb`
- ‚úÖ COIL-100 data file: `data/coil100/coil100_task1_standard.h5`

## 1. Check GPU Availability

In [1]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è WARNING: GPU not detected!")

PyTorch version: 2.9.0+cu126
CUDA available: True
GPU Device: Tesla T4
GPU Memory: 15.83 GB


## 2. Auto-Detect Project Path

In [2]:
import os
import sys

current_dir = os.getcwd()
print(f"üìç Current directory: {current_dir}")

# Task configuration
DATA_TASK = "task1_standard"
KERNEL_TYPE = "periodic"

if current_dir == '/content':
    print("\nüîÑ Mounting Google Drive...")
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        drive_path = '/content/drive/MyDrive/gppvae'
        if os.path.exists(drive_path):
            PROJECT_PATH = drive_path
            print(f"‚úÖ Found project in Google Drive: {PROJECT_PATH}")
        else:
            print(f"‚ö†Ô∏è Project not found at: {drive_path}")
            PROJECT_PATH = '/content'
    except Exception as e:
        print(f"Could not mount Drive: {e}")
        PROJECT_PATH = '/content'
else:
    if 'notebooks' in current_dir:
        PROJECT_PATH = os.path.dirname(current_dir)
    else:
        PROJECT_PATH = current_dir
    print(f"üíª Using project path: {PROJECT_PATH}")

# Check required files
print(f"\nüîç Checking required files:")
data_path = os.path.join(PROJECT_PATH, f'data/coil100/coil100_{DATA_TASK}.h5')
required = {
    'GPPVAE code': os.path.exists(os.path.join(PROJECT_PATH, 'GPPVAE')),
    'COIL-100 data': os.path.exists(data_path),
}

vae_base_dir = os.path.join(PROJECT_PATH, f'out/vae_colab_{DATA_TASK}_standard')
vae_run_found = False
if os.path.exists(vae_base_dir):
    runs = [d for d in os.listdir(vae_base_dir) if os.path.isdir(os.path.join(vae_base_dir, d))]
    for run in runs:
        weights_dir = os.path.join(vae_base_dir, run, 'weights')
        if os.path.exists(weights_dir) and any(f.endswith('.pt') for f in os.listdir(weights_dir)):
            vae_run_found = True
            break
required['VAE weights'] = vae_run_found

for name, exists in required.items():
    status = "‚úÖ" if exists else "‚ùå"
    print(f"   {status} {name}")

üìç Current directory: /content

üîÑ Mounting Google Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚úÖ Found project in Google Drive: /content/drive/MyDrive/gppvae

üîç Checking required files:
   ‚úÖ GPPVAE code
   ‚úÖ COIL-100 data
   ‚úÖ VAE weights


## 3. Install Dependencies

In [3]:
!pip install -q wandb==0.12.21 imageio==2.15.0 pyyaml

import wandb
import numpy as np
print("‚úÖ All dependencies installed!")

  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m√ó[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m‚îÇ[0m exit code: [1;36m1[0m
  [31m‚ï∞‚îÄ>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Preparing metadata (setup.py) ... [?25l[?25herror
[1;31merror[0m: [1mmetadata-generation-failed[0m

[31m√ó[0m Encountered error while generating package metadata.
[31m‚ï∞‚îÄ>[0m See above for output.

[1;35mnote[0m: This is an issue with the package mentioned above, not pip.
[1;36mhint[0m: See above for details.
‚úÖ All dependencies installed!


## 4. Login to W&B (Optional)

In [4]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mminh1008[0m ([33mminh1008-ludwig-maximilianuniversity-of-munich[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## 5. Navigate to Project

In [5]:
import os
import sys

os.chdir(PROJECT_PATH)
print(f"Current directory: {os.getcwd()}")

sys.path.insert(0, os.path.join(PROJECT_PATH, 'GPPVAE/pysrc/coil100'))
sys.path.insert(0, os.path.join(PROJECT_PATH, 'GPPVAE/pysrc/faceplace'))

Current directory: /content/drive/MyDrive/gppvae


## 6. Find VAE Weights

In [6]:
import os
import pickle

vae_base_dir = f'./out/vae_colab_{DATA_TASK}_standard'
vae_runs = []

if os.path.exists(vae_base_dir):
    for run_dir in sorted(os.listdir(vae_base_dir), reverse=True):
        run_path = os.path.join(vae_base_dir, run_dir)
        cfg_path = os.path.join(run_path, 'vae.cfg.p')
        weights_dir = os.path.join(run_path, 'weights')

        if os.path.exists(cfg_path) and os.path.exists(weights_dir):
            weight_files = sorted([f for f in os.listdir(weights_dir) if f.endswith('.pt')])
            if weight_files:
                vae_runs.append({
                    'run_dir': run_dir,
                    'cfg_path': cfg_path,
                    'weights_dir': weights_dir,
                    'weight_files': weight_files
                })

if vae_runs:
    print(f"‚úÖ Found {len(vae_runs)} VAE run(s)")
    latest = vae_runs[0]
    print(f"\nüí° Latest: {latest['run_dir']}")
    print(f"   VAE_CFG = '{latest['cfg_path']}'")
    print(f"   VAE_WEIGHTS = '{os.path.join(latest['weights_dir'], latest['weight_files'][-1])}'")
else:
    print("‚ùå No VAE runs found!")

‚úÖ Found 1 VAE run(s)

üí° Latest: 20260119_134924
   VAE_CFG = './out/vae_colab_task1_standard_standard/20260119_134924/vae.cfg.p'
   VAE_WEIGHTS = './out/vae_colab_task1_standard_standard/20260119_134924/weights/weights.00499.pt'


## 7. Configure Training

**Periodic Kernel Parameters:**
- `period`: Fixed at 360¬∞ for full rotation
- `lengthscale`: Controls smoothness within one period. Default: 30¬∞
- `variance`: Signal variance. Default: 1.0

In [20]:
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# ============================================================================
# UPDATE THESE PATHS!
# ============================================================================
VAE_CFG = './out/vae_colab_task1_standard_standard/20260119_134924/vae.cfg.p'
VAE_WEIGHTS = './out/vae_colab_task1_standard_standard/20260119_134924/weights/weights.00499.pt'

CONFIG = {
    # Data
    'data': f'./data/coil100/coil100_{DATA_TASK}.h5',
    'outdir': f'./out/gppvae_coil100_{KERNEL_TYPE}_{DATA_TASK}/{timestamp}',

    # VAE
    'vae_cfg': VAE_CFG,
    'vae_weights': VAE_WEIGHTS,

    # Training
    'epochs': 1200,
    'batch_size': 64,
    'vae_lr': 0.001,
    'gp_lr': 0.001,
    'xdim': 64,

    # Kernel - Periodic
    'view_kernel': KERNEL_TYPE,
    'kernel_kwargs': {
        'period': 360.0,  # Fixed: full rotation
        'lengthscale': 1.0,
        'variance': 1.0,
    },

    # Logging
    'epoch_cb': 100,
    'use_wandb': True,
    'wandb_project': 'gppvae-coil100',
    'wandb_run_name': f'gppvae_{KERNEL_TYPE}_{DATA_TASK}_{timestamp}',
    'seed': 0,
}

print("GP-VAE Training Configuration:")
print("=" * 60)
for key, value in CONFIG.items():
    print(f"  {key:20s}: {value}")
print("=" * 60)

if not os.path.exists(CONFIG['vae_weights']):
    print(f"\n‚ö†Ô∏è Update VAE_CFG and VAE_WEIGHTS paths!")

GP-VAE Training Configuration:
  data                : ./data/coil100/coil100_task1_standard.h5
  outdir              : ./out/gppvae_coil100_periodic_task1_standard/20260120_090712
  vae_cfg             : ./out/vae_colab_task1_standard_standard/20260119_134924/vae.cfg.p
  vae_weights         : ./out/vae_colab_task1_standard_standard/20260119_134924/weights/weights.00499.pt
  epochs              : 1200
  batch_size          : 64
  vae_lr              : 0.001
  gp_lr               : 0.001
  xdim                : 64
  view_kernel         : periodic
  kernel_kwargs       : {'period': 360.0, 'lengthscale': 1.0, 'variance': 1.0}
  epoch_cb            : 100
  use_wandb           : True
  wandb_project       : gppvae-coil100
  wandb_run_name      : gppvae_periodic_task1_standard_20260120_090712
  seed                : 0


## 8. Import Modules

In [21]:
# IMPORTANT: Add coil100 to path FIRST before importing anything
# This ensures coil100's data_parser is used, not faceplace's
import sys
import os

coil100_path = os.path.join(PROJECT_PATH, 'GPPVAE/pysrc/coil100')
sys.path.insert(0, coil100_path)    # Add coil100 first (so it has priority)

os.chdir(coil100_path)
print(f"Working directory: {os.getcwd()}")
print(f"sys.path priority: coil100 > faceplace")

import matplotlib
matplotlib.use('Agg')

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from vae import FaceVAE
from vmod import Vmodel
from gp import GP
import numpy as np
import logging
import pylab as pl
from utils import smartSum, smartAppendDict, smartAppend, export_scripts
from callbacks import callback_gppvae
import pickle
import time
import wandb

# COIL-100 data parser (explicitly import from coil100, not faceplace)
from data_parser import COIL100Dataset, get_n_views, get_num_objects

# Verify we're using the right data_parser
import data_parser
print(f"‚úÖ data_parser loaded from: {data_parser.__file__}")
if 'coil100' in data_parser.__file__:
    print("‚úÖ Using COIL-100 data_parser (correct!)")
else:
    print("‚ùå WARNING: Using faceplace data_parser (wrong!)")
print("‚úÖ All modules imported successfully!")

Working directory: /content/drive/MyDrive/gppvae/GPPVAE/pysrc/coil100
sys.path priority: coil100 > faceplace
‚úÖ data_parser loaded from: /content/drive/MyDrive/gppvae/GPPVAE/pysrc/coil100/data_parser.py
‚úÖ Using COIL-100 data_parser (correct!)
‚úÖ All modules imported successfully!


## 9. Setup Training Environment

In [22]:
os.chdir(PROJECT_PATH)

outdir = CONFIG['outdir']
wdir = os.path.join(outdir, "weights")
fdir = os.path.join(outdir, "plots")
os.makedirs(wdir, exist_ok=True)
os.makedirs(fdir, exist_ok=True)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

log_format = "%(asctime)s %(message)s"
logging.basicConfig(level=logging.INFO, format=log_format, datefmt="%m/%d %I:%M:%S %p")
fh = logging.FileHandler(os.path.join(outdir, "log.txt"))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

export_scripts(os.path.join(outdir, "scripts"))
print(f"‚úÖ Output: {outdir}")

Using device: cuda:0
‚úÖ Output: ./out/gppvae_coil100_periodic_task1_standard/20260120_090712


## 10. Initialize Models

In [23]:
torch.manual_seed(CONFIG['seed'])

if CONFIG['use_wandb']:
    wandb.init(project=CONFIG['wandb_project'], name=CONFIG['wandb_run_name'], config=CONFIG)

# Load VAE
vae_cfg = pickle.load(open(CONFIG['vae_cfg'], "rb"))
vae = FaceVAE(**vae_cfg).to(device)
vae.load_state_dict(torch.load(CONFIG['vae_weights'], map_location=device))
print(f"‚úÖ VAE loaded")

# Load data
train_data = COIL100Dataset(CONFIG['data'], split='train', use_angle_encoding=False)
val_data = COIL100Dataset(CONFIG['data'], split='val', use_angle_encoding=False)
train_queue = DataLoader(train_data, batch_size=CONFIG['batch_size'], shuffle=True)
val_queue = DataLoader(val_data, batch_size=CONFIG['batch_size'], shuffle=False)

# IMPORTANT: Use get_num_objects for correct P (includes all objects from all splits)
P = get_num_objects(CONFIG['data'])  # 100 for COIL-100
Q = get_n_views()
print(f"P={P}, Q={Q}")

# Create object and view tensors (Did and Rid are 1D tensors)
Dt = Variable(train_data.Did.long(), requires_grad=False).to(device)
Dv = Variable(val_data.Did.long(), requires_grad=False).to(device)
Wt = Variable(train_data.Rid.long(), requires_grad=False).to(device)
Wv = Variable(val_data.Rid.long(), requires_grad=False).to(device)

# Initialize Vmodel with Periodic kernel
print(f"\nüî¨ Initializing '{KERNEL_TYPE}' kernel (period={CONFIG['kernel_kwargs']['period']}¬∞)...")
vm = Vmodel(P=P, Q=Q, p=CONFIG['xdim'], view_kernel=CONFIG['view_kernel'], **CONFIG['kernel_kwargs']).to(device)
gp = GP(n_rand_effs=1).to(device)

# Show kernel matrix
K = vm.get_kernel_matrix()
print(f"   K[0,0]={K[0,0].item():.4f} (self)")
print(f"   K[0,1]={K[0,1].item():.4f} (20¬∞ apart)")
print(f"   K[0,9]={K[0,9].item():.4f} (180¬∞ apart)")
print(f"   K[0,17]={K[0,17].item():.4f} (340¬∞=-20¬∞ periodic)")

gp_params = nn.ParameterList()
gp_params.extend(vm.parameters())
gp_params.extend(gp.parameters())

vae_optimizer = optim.Adam(vae.parameters(), lr=CONFIG['vae_lr'])
gp_optimizer = optim.Adam(gp_params, lr=CONFIG['gp_lr'])
print(f"\n‚úÖ Models initialized")

‚úÖ VAE loaded
Loaded COIL-100 train: 1260 samples
  Y shape: torch.Size([1260, 3, 128, 128])
  Unique objects: 70
  Did range: [0, 99] (remapped to contiguous)
  Angle encoding: indices [0, 17]
Loaded COIL-100 val: 270 samples
  Y shape: torch.Size([270, 3, 128, 128])
  Unique objects: 15
  Did range: [1, 98] (remapped to contiguous)
  Angle encoding: indices [0, 17]
P=100, Q=18

üî¨ Initializing 'periodic' kernel (period=360.0¬∞)...
   K[0,0]=1.0000 (self)
   K[0,1]=0.9415 (20¬∞ apart)
   K[0,9]=0.1353 (180¬∞ apart)
   K[0,17]=0.9415 (340¬∞=-20¬∞ periodic)

‚úÖ Models initialized


## 11. Training Functions

In [24]:
def encode_Y(vae, train_queue):
    vae.eval()
    with torch.no_grad():
        n = train_queue.dataset.Y.shape[0]
        Zm = Variable(torch.zeros(n, vae_cfg["zdim"]), requires_grad=False).to(device)
        Zs = Variable(torch.zeros(n, vae_cfg["zdim"]), requires_grad=False).to(device)
        for data in train_queue:
            y = data[0].to(device)
            idxs = data[-1].to(device)
            zm, zs = vae.encode(y)
            Zm[idxs], Zs[idxs] = zm.detach(), zs.detach()
    return Zm, Zs

def eval_step(vae, gp, vm, val_queue, Zm, Vt, Vv, Wv):
    rv = {}
    with torch.no_grad():
        _X = vm.x().data.cpu().numpy()
        _W = vm.v().data.cpu().numpy()
        covs = {"XX": np.dot(_X, _X.T), "WW": np.dot(_W, _W.T)}
        rv["vars"] = gp.get_vs().data.cpu().numpy()

        vs = gp.get_vs()
        U, UBi, _ = gp.U_UBi_Shb([Vt], vs)
        Kiz = gp.solve(Zm, U, UBi, vs)
        Zo = vs[0] * Vv.mm(Vt.transpose(0, 1).mm(Kiz))

        mse_out = Variable(torch.zeros(Vv.shape[0], 1), requires_grad=False).to(device)
        mse_val = Variable(torch.zeros(Vv.shape[0], 1), requires_grad=False).to(device)
        all_Yv, all_Yr, all_Yo = [], [], []

        for data in val_queue:
            idxs = data[-1].to(device)
            Yv = data[0].to(device)
            Zv = vae.encode(Yv)[0].detach()
            Yr = vae.decode(Zv)
            Yo = vae.decode(Zo[idxs])
            mse_out[idxs] = ((Yv - Yo) ** 2).view(Yv.shape[0], -1).mean(1)[:, None].detach()
            mse_val[idxs] = ((Yv - Yr) ** 2).view(Yv.shape[0], -1).mean(1)[:, None].detach()
            all_Yv.append(Yv.data.cpu().numpy().transpose(0, 2, 3, 1))
            all_Yr.append(Yr.data.cpu().numpy().transpose(0, 2, 3, 1))
            all_Yo.append(Yo.data.cpu().numpy().transpose(0, 2, 3, 1))

        all_Yv = np.concatenate(all_Yv, axis=0)
        all_Yr = np.concatenate(all_Yr, axis=0)
        all_Yo = np.concatenate(all_Yo, axis=0)
        n_total = all_Yv.shape[0]
        sample_indices = np.arange(0, n_total, max(1, n_total // 24))[:24]
        imgs = {"Yv": all_Yv[sample_indices], "Yr": all_Yr[sample_indices], "Yo": all_Yo[sample_indices]}
        rv["mse_out"] = float(mse_out.data.mean().cpu())
        rv["mse_val"] = float(mse_val.data.mean().cpu())
    return rv, imgs, covs

def backprop_and_update(vae, gp, vm, train_queue, Dt, Wt, Eps, Zb, Vbs, vbs, vae_optimizer, gp_optimizer):
    rv = {}
    vae_optimizer.zero_grad()
    gp_optimizer.zero_grad()
    vae.train(); gp.train(); vm.train()

    for data in train_queue:
        y = data[0].to(device)
        eps = Eps[data[-1]]
        _d, _w = Dt[data[-1]], Wt[data[-1]]
        _Zb = Zb[data[-1]]
        _Vbs = [Vbs[0][data[-1]]]

        zm, zs = vae.encode(y)
        z = zm + zs * eps
        yr = vae.decode(z)
        recon_term, mse = vae.nll(y, yr)

        _Vs = [vm(_d, _w)]
        gp_nll_fo = gp.taylor_expansion(z, _Vs, _Zb, _Vbs, vbs) / vae.K
        pen_term = -0.5 * zs.sum(1)[:, None] / vae.K

        loss = (recon_term + gp_nll_fo + pen_term).sum()
        loss.backward()

        _n = train_queue.dataset.Y.shape[0]
        smartSum(rv, "mse", float(mse.data.sum().cpu()) / _n)
        smartSum(rv, "recon_term", float(recon_term.data.sum().cpu()) / _n)
        smartSum(rv, "pen_term", float(pen_term.data.sum().cpu()) / _n)

    vae_optimizer.step()
    gp_optimizer.step()
    return rv

print("‚úÖ Training functions defined")

‚úÖ Training functions defined


## 12. Train GP-VAE üöÄ

In [None]:
## 12. Train GP-VAE üöÄ (with Early Stopping)

history = {}
start_time = time.time()

# -----------------------------
# Early stopping configuration
# -----------------------------
early_stop_patience = 150        # epochs without improvement
early_stop_min_delta = 1e-4      # minimum improvement threshold

best_mse_out = float("inf")
best_epoch = -1
no_improve_epochs = 0

print(f"üöÄ Training GP-VAE with {KERNEL_TYPE} kernel for up to {CONFIG['epochs']} epochs...")
print(f"üõë Early stopping patience = {early_stop_patience}")
print("=" * 70)

for epoch in range(CONFIG['epochs']):
    epoch_start = time.time()

    # -------- Encode training data --------
    Zm, Zs = encode_Y(vae, train_queue)
    Eps = Variable(torch.randn(*Zs.shape), requires_grad=False).to(device)
    Z = Zm + Eps * Zs

    # -------- Precompute V --------
    Vt = vm(Dt, Wt).detach()
    Vv = vm(Dv, Wv).detach()

    # -------- Validation step --------
    rv_eval, imgs, covs = eval_step(
        vae, gp, vm, val_queue, Zm, Vt, Vv, Wv
    )

    # -------- GP Taylor expansion --------
    Zb, Vbs, vbs, gp_nll = gp.taylor_coeff(Z, [Vt])
    rv_eval["gp_nll"] = float(gp_nll.data.mean().cpu()) / vae.K

    # -------- Backprop --------
    rv_back = backprop_and_update(
        vae, gp, vm, train_queue, Dt, Wt,
        Eps, Zb, Vbs, vbs,
        vae_optimizer, gp_optimizer
    )

    rv_back["loss"] = (
        rv_back["recon_term"] +
        rv_eval["gp_nll"] +
        rv_back["pen_term"]
    )

    # -------- Logging --------
    smartAppendDict(history, rv_eval)
    smartAppendDict(history, rv_back)
    smartAppend(history, "vs", gp.get_vs().data.cpu().numpy())

    vs = gp.get_vs().data.cpu().numpy()
    variance_ratio = vs[0] / (vs[0] + vs[1])

    current_mse_out = rv_eval["mse_out"]

    # -------- Early stopping check --------
    if current_mse_out < best_mse_out - early_stop_min_delta:
        best_mse_out = current_mse_out
        best_epoch = epoch
        no_improve_epochs = 0

        # Save BEST checkpoint
        torch.save(
            vae.state_dict(),
            os.path.join(wdir, "vae_weights.best.pt")
        )
        torch.save(
            {'gp_state': gp.state_dict(), 'vm_state': vm.state_dict()},
            os.path.join(wdir, "gp_weights.best.pt")
        )

    else:
        no_improve_epochs += 1

    learned_ls = vm.kernel.lengthscale.item()

    # -------- Console output --------
    if epoch % 5 == 0 or epoch == CONFIG['epochs'] - 1:
        print(
            f"Epoch {epoch:4d} | "
            f"MSE train: {rv_back['mse']:.6f} | "
            f"MSE out: {current_mse_out:.6f} | "
            f"GP NLL: {rv_eval['gp_nll']:.4f} | "
            f"v‚ÇÄ/(v‚ÇÄ+v‚ÇÅ): {variance_ratio:.3f}"
        )

    # -------- wandb --------
    if CONFIG['use_wandb']:
        wandb.log({
            "epoch": epoch,
            "mse_train": rv_back["mse"],
            "mse_out": current_mse_out,
            "gp_nll": rv_eval["gp_nll"],
            "variance_ratio": variance_ratio,
            "best_mse_out": best_mse_out,
            "no_improve_epochs": no_improve_epochs,
            "lengthscale": learned_ls,
        })

    # -------- Periodic checkpoint + plots --------
    if epoch % CONFIG['epoch_cb'] == 0 or epoch == CONFIG['epochs'] - 1:
        torch.save(
            vae.state_dict(),
            os.path.join(wdir, f"vae_weights.{epoch:05d}.pt")
        )
        torch.save(
            {'gp_state': gp.state_dict(), 'vm_state': vm.state_dict()},
            os.path.join(wdir, f"gp_weights.{epoch:05d}.pt")
        )
        ffile = os.path.join(fdir, f"plot.{epoch:05d}.png")
        callback_gppvae(epoch, history, covs, imgs, ffile)
        if CONFIG['use_wandb']:
            wandb.log({"reconstructions": wandb.Image(ffile)})
        print("  ‚úì Checkpoint saved")

    # -------- Stop condition --------
    if no_improve_epochs >= early_stop_patience:
        print(
            f"\n‚èπ Early stopping triggered at epoch {epoch}\n"
            f"   Best epoch: {best_epoch}\n"
            f"   Best mse_out: {best_mse_out:.6f}"
        )
        break

print(
    f"\n‚úÖ Training complete in {(time.time()-start_time)/60:.1f} min\n"
    f"   Best epoch: {best_epoch}\n"
    f"   Best mse_out: {best_mse_out:.6f}"
)

if CONFIG['use_wandb']:
    wandb.finish()

üöÄ Training GP-VAE with periodic kernel for up to 1200 epochs...
üõë Early stopping patience = 150
Epoch    0 | MSE train: 0.001870 | MSE out: 0.045718 | GP NLL: 0.0087 | v‚ÇÄ/(v‚ÇÄ+v‚ÇÅ): 0.500
  ‚úì Checkpoint saved
Epoch    5 | MSE train: 0.008683 | MSE out: 0.049707 | GP NLL: 0.0083 | v‚ÇÄ/(v‚ÇÄ+v‚ÇÅ): 0.497
Epoch   10 | MSE train: 0.005075 | MSE out: 0.044941 | GP NLL: 0.0095 | v‚ÇÄ/(v‚ÇÄ+v‚ÇÅ): 0.495
Epoch   15 | MSE train: 0.003834 | MSE out: 0.044722 | GP NLL: 0.0088 | v‚ÇÄ/(v‚ÇÄ+v‚ÇÅ): 0.492
Epoch   20 | MSE train: 0.003091 | MSE out: 0.044669 | GP NLL: 0.0068 | v‚ÇÄ/(v‚ÇÄ+v‚ÇÅ): 0.490
Epoch   25 | MSE train: 0.002597 | MSE out: 0.044850 | GP NLL: 0.0069 | v‚ÇÄ/(v‚ÇÄ+v‚ÇÅ): 0.488
Epoch   30 | MSE train: 0.002242 | MSE out: 0.043837 | GP NLL: 0.0064 | v‚ÇÄ/(v‚ÇÄ+v‚ÇÅ): 0.486
Epoch   35 | MSE train: 0.002060 | MSE out: 0.044384 | GP NLL: 0.0060 | v‚ÇÄ/(v‚ÇÄ+v‚ÇÅ): 0.484


## 13. View Results

In [None]:
from IPython.display import Image, display
import glob

plot_files = sorted(glob.glob(os.path.join(fdir, "*.png")))
if plot_files:
    display(Image(filename=plot_files[-1]))

## 14. Download Results

In [None]:
!zip -r /content/gppvae_periodic_output.zip {CONFIG['outdir']}
from google.colab import files
files.download('/content/gppvae_periodic_output.zip')