# GP-VAE Training on COIL-100 (Extrapolation Task) - Spectral Mixture Kernel

This notebook trains **GP-VAE** on COIL-100 dataset using the **Spectral Mixture (SM) kernel** for the **extrapolation task**.

## Task: Extrapolation
- **Train**: 10 views (0¬∞-180¬∞) - first half of rotation
- **Val**: 2 views (200¬∞, 220¬∞) - immediately after train
- **Test**: 6 views (240¬∞-340¬∞) - far extrapolation
- **Goal**: Predict views BEYOND the training range using GP extrapolation

## Kernel: Spectral Mixture (SM)
- **Mixture of Gaussians in spectral domain** for flexible patterns
- k(Œ∏, Œ∏') = Œ£·µ¢ w·µ¢ √ó exp(-2œÄ¬≤œÉ·µ¢¬≤d¬≤) √ó cos(2œÄŒº·µ¢d)
- where d = wrapped lag distance
- **Parameters**: 3 per mixture component (weight, frequency, lengthscale)
- **Best for**: Complex periodic patterns, multiple frequency components

## Dataset Info:
- **COIL-100**: 100 objects √ó 18 views (every 20¬∞: 0¬∞, 20¬∞, ..., 340¬∞)
- **Image size**: 128√ó128√ó3 RGB

## Prerequisites:
- ‚úÖ Trained VAE weights from `train_vae_colab_extrapolation.ipynb`
- ‚úÖ COIL-100 data file: `data/coil-100/coil100_task3_extrapolation.h5`

## 1. Check GPU Availability

In [None]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è WARNING: GPU not detected!")

## 2. Auto-Detect Project Path

In [None]:
import os
import sys

current_dir = os.getcwd()
print(f"üìç Current directory: {current_dir}")

# Task configuration
DATA_TASK = "task3_extrapolation"
KERNEL_TYPE = "sm_circle"

if current_dir == '/content':
    print("\nüîÑ Mounting Google Drive...")
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        drive_path = '/content/drive/MyDrive/gppvae'
        if os.path.exists(drive_path):
            PROJECT_PATH = drive_path
            print(f"‚úÖ Found project in Google Drive: {PROJECT_PATH}")
        else:
            print(f"‚ö†Ô∏è Project not found at: {drive_path}")
            PROJECT_PATH = '/content'
    except Exception as e:
        print(f"Could not mount Drive: {e}")
        PROJECT_PATH = '/content'
else:
    if 'notebooks' in current_dir:
        PROJECT_PATH = os.path.dirname(current_dir)
    else:
        PROJECT_PATH = current_dir
    print(f"üíª Using project path: {PROJECT_PATH}")

# Check required files
print(f"\nüîç Checking required files:")
data_path = os.path.join(PROJECT_PATH, f'data/coil-100/coil100_{DATA_TASK}.h5')
required = {
    'GPPVAE code': os.path.exists(os.path.join(PROJECT_PATH, 'GPPVAE')),
    'COIL-100 data': os.path.exists(data_path),
}

# Look for VAE weights trained on extrapolation task
vae_base_dir = os.path.join(PROJECT_PATH, f'out/vae_colab_{DATA_TASK}')
vae_run_found = False
if os.path.exists(vae_base_dir):
    runs = [d for d in os.listdir(vae_base_dir) if os.path.isdir(os.path.join(vae_base_dir, d))]
    for run in runs:
        weights_dir = os.path.join(vae_base_dir, run, 'weights')
        if os.path.exists(weights_dir) and any(f.endswith('.pt') for f in os.listdir(weights_dir)):
            vae_run_found = True
            break
required['VAE weights'] = vae_run_found

for name, exists in required.items():
    status = "‚úÖ" if exists else "‚ùå"
    print(f"   {status} {name}")

print(f"\nüìä Extrapolation Task Info:")
print(f"   Train: 10 views (0¬∞-180¬∞) √ó 100 objects = 1000 samples")
print(f"   Val: 2 views (200¬∞, 220¬∞) √ó 100 objects = 200 samples")
print(f"   Test: 6 views (240¬∞-340¬∞) √ó 100 objects = 600 samples")
print(f"   ‚ö†Ô∏è This is a HARD task: predicting beyond training range!")

## 3. Install Dependencies

In [None]:
!pip install -q wandb==0.12.21 imageio==2.15.0 pyyaml

import wandb
import numpy as np
print("‚úÖ All dependencies installed!")

## 4. Login to W&B (Optional)

In [None]:
import wandb
wandb.login()

## 5. Navigate to Project

In [None]:
import os
import sys

os.chdir(PROJECT_PATH)
print(f"Current directory: {os.getcwd()}")

sys.path.insert(0, os.path.join(PROJECT_PATH, 'GPPVAE/pysrc/coil100'))

## 6. Find VAE Weights

In [None]:
import os
import pickle

# Look for VAE trained on extrapolation task
vae_base_dir = f'./out/vae_colab_{DATA_TASK}'
vae_runs = []

if os.path.exists(vae_base_dir):
    for run_dir in sorted(os.listdir(vae_base_dir), reverse=True):
        run_path = os.path.join(vae_base_dir, run_dir)
        cfg_path = os.path.join(run_path, 'vae.cfg.p')
        weights_dir = os.path.join(run_path, 'weights')

        if os.path.exists(cfg_path) and os.path.exists(weights_dir):
            weight_files = sorted([f for f in os.listdir(weights_dir) if f.endswith('.pt')])
            if weight_files:
                vae_runs.append({
                    'run_dir': run_dir,
                    'cfg_path': cfg_path,
                    'weights_dir': weights_dir,
                    'weight_files': weight_files
                })

if vae_runs:
    print(f"‚úÖ Found {len(vae_runs)} VAE run(s) for extrapolation task")
    latest = vae_runs[0]
    print(f"\nüí° Latest: {latest['run_dir']}")
    print(f"   VAE_CFG = '{latest['cfg_path']}'")
    print(f"   VAE_WEIGHTS = '{os.path.join(latest['weights_dir'], latest['weight_files'][-1])}'")
else:
    print("‚ùå No VAE runs found for extrapolation task!")
    print("   Run train_vae_colab_extrapolation.ipynb first.")

## 7. Configure Training

**Spectral Mixture Kernel Parameters:**
- `num_mixtures`: Number of mixture components (default: 2)
- SM kernel learns: weights, frequencies (means), variances

**‚ö†Ô∏è Important for Extrapolation:**
- SM kernel with wrapped distance helps capture circular structure
- Views at 340¬∞ are only 20¬∞ away from 0¬∞ (which IS in training) via wrapped distance

In [None]:
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# ============================================================================
# UPDATE THESE PATHS!
# ============================================================================
VAE_CFG = './out/vae_colab_task3_extrapolation/YYYYMMDD_HHMMSS/vae.cfg.p'  # UPDATE
VAE_WEIGHTS = './out/vae_colab_task3_extrapolation/YYYYMMDD_HHMMSS/weights/weights.00499.pt'  # UPDATE

CONFIG = {
    # Data
    'data': f'./data/coil-100/coil100_{DATA_TASK}.h5',
    'outdir': f'./out/gppvae_coil100_{KERNEL_TYPE}_{DATA_TASK}/{timestamp}',

    # VAE
    'vae_cfg': VAE_CFG,
    'vae_weights': VAE_WEIGHTS,

    # Training
    'epochs': 1500,
    'batch_size': 64,
    'vae_lr': 0.001,
    'gp_lr': 0.001,
    'xdim': 64,

    # Kernel - Spectral Mixture
    'view_kernel': KERNEL_TYPE,
    'kernel_kwargs': {
        'num_mixtures': 2,  # 2 mixtures to avoid overfitting
        # Frequencies, lengthscales, weights are learned automatically
    },

    # Logging
    'epoch_cb': 100,
    'use_wandb': True,
    'wandb_project': 'gppvae-coil100',
    'wandb_run_name': f'gppvae_{KERNEL_TYPE}_{DATA_TASK}_{timestamp}',
    'seed': 0,
}

print("GP-VAE Training Configuration (Extrapolation Task):")
print("=" * 60)
for key, value in CONFIG.items():
    print(f"  {key:20s}: {value}")
print("=" * 60)

if not os.path.exists(CONFIG['vae_weights']):
    print(f"\n‚ö†Ô∏è Update VAE_CFG and VAE_WEIGHTS paths!")

## 8. Import Modules

In [None]:
# IMPORTANT: Add coil100 to path FIRST before importing anything
# This ensures coil100's data_parser is used, not faceplace's
import sys
import os

coil100_path = os.path.join(PROJECT_PATH, 'GPPVAE/pysrc/coil100')
sys.path.insert(0, coil100_path)    # Add coil100 first (so it has priority)

os.chdir(coil100_path)
print(f"Working directory: {os.getcwd()}")
print(f"sys.path priority: coil100 > faceplace")

import matplotlib
matplotlib.use('Agg')

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from vae import FaceVAE
from vmod import Vmodel
from gp import GP
import numpy as np
import logging
import pylab as pl
from utils import smartSum, smartAppendDict, smartAppend, export_scripts
from callbacks import callback_gppvae
import pickle
import time
import wandb

# COIL-100 data parser (explicitly import from coil100, not faceplace)
from data_parser import COIL100Dataset, get_n_views, get_num_objects

# Verify we're using the right data_parser
import data_parser
print(f"‚úÖ data_parser loaded from: {data_parser.__file__}")
if 'coil100' in data_parser.__file__:
    print("‚úÖ Using COIL-100 data_parser (correct!)")
else:
    print("‚ùå WARNING: Using faceplace data_parser (wrong!)")
print("‚úÖ All modules imported successfully!")

## 9. Setup Training Environment

In [None]:
os.chdir(PROJECT_PATH)

outdir = CONFIG['outdir']
wdir = os.path.join(outdir, "weights")
fdir = os.path.join(outdir, "plots")
os.makedirs(wdir, exist_ok=True)
os.makedirs(fdir, exist_ok=True)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

log_format = "%(asctime)s %(message)s"
logging.basicConfig(level=logging.INFO, format=log_format, datefmt="%m/%d %I:%M:%S %p")
fh = logging.FileHandler(os.path.join(outdir, "log.txt"))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

export_scripts(os.path.join(outdir, "scripts"))
print(f"‚úÖ Output: {outdir}")

## 10. Initialize Models

**Extrapolation Task Specifics:**
- Training views: indices [0-9] ‚Üí angles [0¬∞, 20¬∞, ..., 180¬∞]
- Validation views: indices [10, 11] ‚Üí angles [200¬∞, 220¬∞]
- Test views: indices [12-17] ‚Üí angles [240¬∞, 260¬∞, ..., 340¬∞]
- The GP must learn to predict views BEYOND the training range

In [None]:
torch.manual_seed(CONFIG['seed'])

if CONFIG['use_wandb']:
    wandb.init(project=CONFIG['wandb_project'], name=CONFIG['wandb_run_name'], config=CONFIG)

# Load VAE
vae_cfg = pickle.load(open(CONFIG['vae_cfg'], "rb"))
vae = FaceVAE(**vae_cfg).to(device)
vae.load_state_dict(torch.load(CONFIG['vae_weights'], map_location=device))
print(f"‚úÖ VAE loaded")

# Load data
train_data = COIL100Dataset(CONFIG['data'], split='train', use_angle_encoding=False)
val_data = COIL100Dataset(CONFIG['data'], split='val', use_angle_encoding=False)
train_queue = DataLoader(train_data, batch_size=CONFIG['batch_size'], shuffle=True)
val_queue = DataLoader(val_data, batch_size=CONFIG['batch_size'], shuffle=False)

# IMPORTANT: Use get_num_objects for correct P (includes all objects from all splits)
P = get_num_objects(CONFIG['data'])  # 100 for COIL-100
Q = get_n_views()  # 18 total views
print(f"P={P}, Q={Q}")

# Show extrapolation task structure
train_views = sorted(train_data.Rid.unique().tolist())
val_views = sorted(val_data.Rid.unique().tolist())
print(f"\nüìä Extrapolation Task:")
print(f"   Train views (Rid): {train_views} ‚Üí angles {[v*20 for v in train_views]}¬∞")
print(f"   Val views (Rid):   {val_views} ‚Üí angles {[v*20 for v in val_views]}¬∞")
print(f"   Train samples: {len(train_data)}, Val samples: {len(val_data)}")
print(f"   ‚ö†Ô∏è Val views are BEYOND training range (extrapolation)!")

# Create object and view tensors (Did and Rid are 1D tensors)
Dt = Variable(train_data.Did.long(), requires_grad=False).to(device)
Dv = Variable(val_data.Did.long(), requires_grad=False).to(device)
Wt = Variable(train_data.Rid.long(), requires_grad=False).to(device)
Wv = Variable(val_data.Rid.long(), requires_grad=False).to(device)

# Initialize Vmodel with SM kernel
print(f"\nüî¨ Initializing '{KERNEL_TYPE}' kernel (num_mixtures={CONFIG['kernel_kwargs']['num_mixtures']})...")
vm = Vmodel(P=P, Q=Q, p=CONFIG['xdim'], view_kernel=CONFIG['view_kernel'], **CONFIG['kernel_kwargs']).to(device)
gp = GP(n_rand_effs=1).to(device)

# Show kernel matrix - important for extrapolation!
K = vm.get_kernel_matrix()
print(f"\nüìà Kernel correlations (critical for extrapolation):")
print(f"   K[0,0]={K[0,0].item():.4f} (self, 0¬∞)")
print(f"   K[0,9]={K[0,9].item():.4f} (180¬∞ apart - edge of training)")
print(f"   K[0,10]={K[0,10].item():.4f} (200¬∞ apart - val view, extrapolation!)")
print(f"   K[0,17]={K[0,17].item():.4f} (340¬∞ = 20¬∞ wrapped distance!)")

gp_params = nn.ParameterList()
gp_params.extend(vm.parameters())
gp_params.extend(gp.parameters())

vae_optimizer = optim.Adam(vae.parameters(), lr=CONFIG['vae_lr'])
gp_optimizer = optim.Adam(gp_params, lr=CONFIG['gp_lr'])
print(f"\n‚úÖ Models initialized")

## 11. Training Functions

In [None]:
def encode_Y(vae, train_queue):
    vae.eval()
    with torch.no_grad():
        n = train_queue.dataset.Y.shape[0]
        Zm = Variable(torch.zeros(n, vae_cfg["zdim"]), requires_grad=False).to(device)
        Zs = Variable(torch.zeros(n, vae_cfg["zdim"]), requires_grad=False).to(device)
        for data in train_queue:
            y = data[0].to(device)
            idxs = data[-1].to(device)
            zm, zs = vae.encode(y)
            Zm[idxs], Zs[idxs] = zm.detach(), zs.detach()
    return Zm, Zs

def eval_step(vae, gp, vm, val_queue, Zm, Vt, Vv, Wv, val_Rid):
    """Evaluate GP-VAE on validation set (unseen views for extrapolation task).
    
    Args:
        val_Rid: View indices for validation set (to compute per-view MSE)
    
    Returns:
        rv: Dict with mse_out, mse_val, and per-view MSE (mse_view_XXX)
    """
    rv = {}
    with torch.no_grad():
        _X = vm.x().data.cpu().numpy()
        _W = vm.v().data.cpu().numpy()
        covs = {"XX": np.dot(_X, _X.T), "WW": np.dot(_W, _W.T)}
        rv["vars"] = gp.get_vs().data.cpu().numpy()

        # GP prediction: predict validation latents from training latents
        vs = gp.get_vs()
        U, UBi, _ = gp.U_UBi_Shb([Vt], vs)
        Kiz = gp.solve(Zm, U, UBi, vs)
        # Zo: predicted latents for validation views using GP extrapolation
        Zo = vs[0] * Vv.mm(Vt.transpose(0, 1).mm(Kiz))

        mse_out = Variable(torch.zeros(Vv.shape[0], 1), requires_grad=False).to(device)
        mse_val = Variable(torch.zeros(Vv.shape[0], 1), requires_grad=False).to(device)
        all_Yv, all_Yr, all_Yo = [], [], []

        for data in val_queue:
            idxs = data[-1].to(device)
            Yv = data[0].to(device)  # Ground truth validation images
            Zv = vae.encode(Yv)[0].detach()  # Encoded validation latents
            Yr = vae.decode(Zv)  # Reconstruction (encode-decode)
            Yo = vae.decode(Zo[idxs])  # GP-predicted images (extrapolated!)

            # mse_out: How well can we predict views BEYOND training range?
            mse_out[idxs] = ((Yv - Yo) ** 2).view(Yv.shape[0], -1).mean(1)[:, None].detach()
            mse_val[idxs] = ((Yv - Yr) ** 2).view(Yv.shape[0], -1).mean(1)[:, None].detach()
            all_Yv.append(Yv.data.cpu().numpy().transpose(0, 2, 3, 1))
            all_Yr.append(Yr.data.cpu().numpy().transpose(0, 2, 3, 1))
            all_Yo.append(Yo.data.cpu().numpy().transpose(0, 2, 3, 1))

        all_Yv = np.concatenate(all_Yv, axis=0)
        all_Yr = np.concatenate(all_Yr, axis=0)
        all_Yo = np.concatenate(all_Yo, axis=0)
        n_total = all_Yv.shape[0]
        sample_indices = np.arange(0, n_total, max(1, n_total // 24))[:24]
        imgs = {"Yv": all_Yv[sample_indices], "Yr": all_Yr[sample_indices], "Yo": all_Yo[sample_indices]}
        rv["mse_out"] = float(mse_out.data.mean().cpu())
        rv["mse_val"] = float(mse_val.data.mean().cpu())
        
        # Compute per-view MSE for extrapolation analysis
        mse_out_cpu = mse_out.data.cpu().squeeze()
        val_Rid_cpu = val_Rid.cpu()
        unique_views = torch.unique(val_Rid_cpu).tolist()
        for view_idx in unique_views:
            mask = (val_Rid_cpu == view_idx)
            view_mse = mse_out_cpu[mask].mean().item()
            angle = int(view_idx * 20)  # Convert index to angle
            rv[f"mse_view_{angle:03d}"] = view_mse
        
    return rv, imgs, covs

def backprop_and_update(vae, gp, vm, train_queue, Dt, Wt, Eps, Zb, Vbs, vbs, vae_optimizer, gp_optimizer):
    rv = {}
    vae_optimizer.zero_grad()
    gp_optimizer.zero_grad()
    vae.train(); gp.train(); vm.train()

    for data in train_queue:
        y = data[0].to(device)
        eps = Eps[data[-1]]
        _d, _w = Dt[data[-1]], Wt[data[-1]]
        _Zb = Zb[data[-1]]
        _Vbs = [Vbs[0][data[-1]]]

        zm, zs = vae.encode(y)
        z = zm + zs * eps
        yr = vae.decode(z)
        recon_term, mse = vae.nll(y, yr)

        _Vs = [vm(_d, _w)]
        gp_nll_fo = gp.taylor_expansion(z, _Vs, _Zb, _Vbs, vbs) / vae.K
        pen_term = -0.5 * zs.sum(1)[:, None] / vae.K

        loss = (recon_term + gp_nll_fo + pen_term).sum()
        loss.backward()

        _n = train_queue.dataset.Y.shape[0]
        smartSum(rv, "mse", float(mse.data.sum().cpu()) / _n)
        smartSum(rv, "recon_term", float(recon_term.data.sum().cpu()) / _n)
        smartSum(rv, "pen_term", float(pen_term.data.sum().cpu()) / _n)

    vae_optimizer.step()
    gp_optimizer.step()
    return rv

print("‚úÖ Training functions defined")

## 12. Train GP-VAE üöÄ

**Key Metrics for Extrapolation:**
- `mse_out`: MSE on **unseen views** (extrapolated) - This is the key metric!
- `mse_view_XXX`: Per-view MSE for each validation angle (200¬∞, 220¬∞)
- SM kernel parameters are automatically learned and logged

In [None]:
history = {}
start_time = time.time()

# Get validation view indices for per-view MSE tracking
val_Rid = val_data.Rid.to(device)
val_view_angles = sorted([int(v * 20) for v in val_data.Rid.unique().tolist()])
print(f"üìä Tracking per-view MSE for validation angles: {val_view_angles}¬∞")

# Helper: extract SM kernel parameters for logging/plotting
def get_sm_param_log_dict(vm):
    k = vm.kernel
    if not hasattr(k, "num_mixtures"):
        return {}

    with torch.no_grad():
        weights = k.weights.detach().cpu().numpy()
        means = k.means.detach().cpu().numpy()
        variances = k.variances.detach().cpu().numpy()
        eff_lens = 1.0 / (np.sqrt(variances) + 1e-12)

    d = {}
    # Aggregate summaries
    d["sm/weights_entropy"] = float(-(weights * np.log(weights + 1e-12)).sum())
    d["sm/mean_freq"] = float(means.mean())
    d["sm/mean_eff_lengthscale"] = float(eff_lens.mean())

    # Per-mixture
    for i in range(int(k.num_mixtures)):
        d[f"sm/mix{i}/weight"] = float(weights[i])
        d[f"sm/mix{i}/freq"] = float(means[i])
        d[f"sm/mix{i}/eff_lengthscale"] = float(eff_lens[i])

    return d

print(f"\nüöÄ Training GP-VAE with {KERNEL_TYPE} kernel for {CONFIG['epochs']} epochs...")
print(f"   Task: EXTRAPOLATION (predicting views BEYOND training range)")
print("=" * 70)

for epoch in range(CONFIG['epochs']):
    epoch_start = time.time()

    Zm, Zs = encode_Y(vae, train_queue)
    Eps = Variable(torch.randn(*Zs.shape), requires_grad=False).to(device)
    Z = Zm + Eps * Zs

    Vt = vm(Dt, Wt).detach()
    Vv = vm(Dv, Wv).detach()

    # Pass val_Rid for per-view MSE computation
    rv_eval, imgs, covs = eval_step(vae, gp, vm, val_queue, Zm, Vt, Vv, Wv, val_Rid)
    Zb, Vbs, vbs, gp_nll = gp.taylor_coeff(Z, [Vt])
    rv_eval["gp_nll"] = float(gp_nll.data.mean().cpu()) / vae.K

    rv_back = backprop_and_update(
        vae, gp, vm, train_queue, Dt, Wt, Eps, Zb, Vbs, vbs, vae_optimizer, gp_optimizer
    )
    rv_back["loss"] = rv_back["recon_term"] + rv_eval["gp_nll"] + rv_back["pen_term"]

    smartAppendDict(history, rv_eval)
    smartAppendDict(history, rv_back)
    smartAppend(history, "vs", gp.get_vs().data.cpu().numpy())

    epoch_time = time.time() - epoch_start
    vs = gp.get_vs().data.cpu().numpy()
    variance_ratio = vs[0] / (vs[0] + vs[1])

    # Extract SM kernel params (learned) for printing/logging
    sm_log = get_sm_param_log_dict(vm)

    if epoch % 5 == 0 or epoch == CONFIG['epochs'] - 1:
        # Print per-view MSE summary
        view_mse_str = " | ".join([f"{a}¬∞:{rv_eval[f'mse_view_{a:03d}']:.4f}" for a in val_view_angles])
        print(
            f"Epoch {epoch:4d} | MSE: {rv_back['mse']:.6f} | Extrap: {rv_eval['mse_out']:.6f} | "
            f"GP NLL: {rv_eval['gp_nll']:.4f} | v‚ÇÄ/(v‚ÇÄ+v‚ÇÅ): {variance_ratio:.3f}"
        )
        print(f"         Per-view: {view_mse_str}")

        # Lightweight print of learned SM params
        if sm_log:
            mixes = int(vm.kernel.num_mixtures)
            mix_str = []
            for i in range(mixes):
                mix_str.append(
                    f"mix{i}: w={sm_log[f'sm/mix{i}/weight']:.3f}, "
                    f"f={sm_log[f'sm/mix{i}/freq']:.5f}, "
                    f"‚Ñì‚âà{sm_log[f'sm/mix{i}/eff_lengthscale']:.2f}"
                )
            print("         SM params: " + " | ".join(mix_str))

    if CONFIG['use_wandb']:
        # Log basic metrics
        log_dict = {
            "epoch": epoch,
            "mse_train": rv_back["mse"],
            "mse_extrap": rv_eval["mse_out"],
            "gp_nll": rv_eval["gp_nll"],
            "variance_ratio": variance_ratio,
        }

        # Log per-view MSE
        for angle in val_view_angles:
            log_dict[f"mse_view_{angle:03d}"] = rv_eval[f"mse_view_{angle:03d}"]

        # Log SM learned parameters
        log_dict.update(sm_log)

        wandb.log(log_dict)

    if epoch % CONFIG['epoch_cb'] == 0 or epoch == CONFIG['epochs'] - 1:
        torch.save(vae.state_dict(), os.path.join(wdir, f"vae_weights.{epoch:05d}.pt"))
        torch.save({"gp_state": gp.state_dict(), "vm_state": vm.state_dict()}, os.path.join(wdir, f"gp_weights.{epoch:05d}.pt"))
        ffile = os.path.join(fdir, f"plot.{epoch:05d}.png")
        callback_gppvae(epoch, history, covs, imgs, ffile)
        if CONFIG['use_wandb']:
            wandb.log({"reconstructions": wandb.Image(ffile)})
        print(f"  ‚úì Checkpoint saved")

print(f"\n‚úÖ Complete! Time: {(time.time()-start_time)/60:.1f}min")
print(f"   Final extrapolation MSE: {rv_eval['mse_out']:.6f}")
print(f"\nüìä Final per-view MSE:")
for angle in val_view_angles:
    print(f"   {angle:3d}¬∞: {rv_eval[f'mse_view_{angle:03d}']:.6f}")
if CONFIG['use_wandb']:
    wandb.finish()

## 13. Evaluate on Test Set (Far Extrapolation)

Test set contains views [12-17] ‚Üí angles [240¬∞, 260¬∞, 280¬∞, 300¬∞, 320¬∞, 340¬∞]

**Note:** This is FAR extrapolation - views 60¬∞-160¬∞ beyond training range!
But 340¬∞ is only 20¬∞ from 0¬∞ via wrapped distance, so SM kernel should help.

In [None]:
# Load test data
test_data = COIL100Dataset(CONFIG['data'], split='test', use_angle_encoding=False)
test_queue = DataLoader(test_data, batch_size=CONFIG['batch_size'], shuffle=False)

test_views = sorted(test_data.Rid.unique().tolist())
print(f"Test views (Rid): {test_views} ‚Üí angles {[v*20 for v in test_views]}¬∞")
print(f"Test samples: {len(test_data)}")
print(f"‚ö†Ô∏è These are FAR extrapolation views (60¬∞-160¬∞ beyond training)!")
print(f"üí° But via wrapped distance: 340¬∞‚Üí20¬∞, 320¬∞‚Üí40¬∞ from training views!")

# Create test tensors
Dtest = Variable(test_data.Did.long(), requires_grad=False).to(device)
Wtest = Variable(test_data.Rid.long(), requires_grad=False).to(device)

# Evaluate
vae.eval()
vm.eval()
gp.eval()

with torch.no_grad():
    # Re-encode training data
    Zm, _ = encode_Y(vae, train_queue)
    Vt = vm(Dt, Wt).detach()
    Vtest = vm(Dtest, Wtest).detach()

    # GP prediction for test set
    vs = gp.get_vs()
    U, UBi, _ = gp.U_UBi_Shb([Vt], vs)
    Kiz = gp.solve(Zm, U, UBi, vs)
    Zo_test = vs[0] * Vtest.mm(Vt.transpose(0, 1).mm(Kiz))

    # Per-view test MSE
    test_Rid = test_data.Rid
    mse_per_view = {}
    mse_test_total = 0.0
    
    for data in test_queue:
        idxs = data[-1].to(device)
        Ytest = data[0].to(device)
        Yo = vae.decode(Zo_test[idxs])
        mse_batch = ((Ytest - Yo) ** 2).view(Ytest.shape[0], -1).mean(1)
        
        # Accumulate per-view
        for i, idx in enumerate(data[-1]):
            view = int(test_Rid[idx].item())
            if view not in mse_per_view:
                mse_per_view[view] = []
            mse_per_view[view].append(mse_batch[i].item())
        
        mse_test_total += mse_batch.sum().item()

    mse_test = mse_test_total / len(test_data)
    print(f"\nüéØ Test MSE (far extrapolation): {mse_test:.6f}")
    print(f"\nüìä Test per-view MSE:")
    for view in sorted(mse_per_view.keys()):
        angle = int(view * 20)
        view_mse = np.mean(mse_per_view[view])
        wrapped_dist = min(angle, 360 - angle)  # Wrapped distance from 0¬∞
        if wrapped_dist <= 40:
            print(f"   {angle:3d}¬∞: {view_mse:.6f}  (wrapped: {wrapped_dist}¬∞ from 0¬∞)")
        else:
            print(f"   {angle:3d}¬∞: {view_mse:.6f}")

## 14. View Results

In [None]:
from IPython.display import Image, display
import glob

plot_files = sorted(glob.glob(os.path.join(fdir, "*.png")))
if plot_files:
    display(Image(filename=plot_files[-1]))

## 15. Download Results

In [None]:
!zip -r /content/gppvae_sm_extrapolation_output.zip {CONFIG['outdir']}
from google.colab import files
files.download('/content/gppvae_sm_extrapolation_output.zip')