# Generate Best Weights for Task 2 & Task 3 Runs

This notebook scans the results folders for Task 2 (interpolation) and Task 3 (extrapolation), finds runs that are **missing `*.best.pt` files**, and evaluates all available checkpoints to determine and save the best weights.

## Kernels Processed:
- **SM Free** (sm_free)
- **SM Wrapped** (sm_wrapped)  
- **Periodic** (periodic)

## Process:
1. Scan each run folder for `gp_weights.best.pt`
2. If missing, load all available checkpoints
3. Evaluate each checkpoint on validation set
4. Save the best checkpoint as `*.best.pt`

## 1. Setup

In [None]:
import os
import sys
import glob
import shutil
import re

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Auto-detect project path
current_dir = os.getcwd()
print(f"üìç Current directory: {current_dir}")

if current_dir == '/content':
    print("\nüîÑ Mounting Google Drive...")
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        drive_path = '/content/drive/MyDrive/gppvae'
        if os.path.exists(drive_path):
            PROJECT_PATH = drive_path
            print(f"‚úÖ Found project in Google Drive: {PROJECT_PATH}")
        else:
            print(f"‚ö†Ô∏è Project not found at: {drive_path}")
            PROJECT_PATH = '/content'
    except Exception as e:
        print(f"Could not mount Drive: {e}")
        PROJECT_PATH = '/content'
else:
    if 'notebooks' in current_dir:
        PROJECT_PATH = os.path.dirname(os.path.dirname(current_dir))
    else:
        PROJECT_PATH = current_dir
    print(f"üíª Using project path: {PROJECT_PATH}")

# Add code paths
coil100_path = os.path.join(PROJECT_PATH, 'GPPVAE/pysrc/coil100')
sys.path.insert(0, coil100_path)

# Change to project root
os.chdir(PROJECT_PATH)
print(f"Working directory: {os.getcwd()}")

In [None]:
# Import models
import numpy as np
import pickle
from torch.autograd import Variable
from torch.utils.data import DataLoader

from vae import FaceVAE
from vmod import Vmodel
from gp import GP
from data_parser import COIL100Dataset, get_n_views, get_num_objects

print("‚úÖ All modules imported")

## 2. Configuration

In [None]:
# Configuration - Auto-detect results folder location
# Check both ./results and ./GPPVAE/results
if os.path.exists('./GPPVAE/results'):
    RESULTS_BASE = './GPPVAE/results'
    print(f"‚úÖ Found results at: {os.path.abspath(RESULTS_BASE)}")
elif os.path.exists('./results'):
    RESULTS_BASE = './results'
    print(f"‚úÖ Found results at: {os.path.abspath(RESULTS_BASE)}")
else:
    RESULTS_BASE = './results'
    print(f"‚ö†Ô∏è Results folder not found, using default: {RESULTS_BASE}")

BATCH_SIZE = 64
XDIM = 64

# Tasks and kernels to process
TASKS_TO_PROCESS = [
    {
        'task_name': 'task2',
        'data_file': './data/coil-100/coil100_task2_interpolation.h5',
        'kernels': ['periodic', 'sm_free', 'sm_wrapped'],
    },
    {
        'task_name': 'task3',
        'data_file': './data/coil-100/coil100_task3_extrapolation.h5',
        'kernels': ['periodic', 'sm_free', 'sm_wrapped'],
    },
]

# Kernel configurations
KERNEL_CONFIGS = {
    'periodic': {
        'view_kernel': 'periodic',
        'kernel_kwargs': {'period': 360.0, 'lengthscale': 1.0, 'variance': 1.0},
    },
    'sm_free': {
        'view_kernel': 'sm_circle',
        'kernel_kwargs': {'num_mixtures': 2, 'use_angle_input': True},
    },
    'sm_wrapped': {
        'view_kernel': 'sm_circle',
        'kernel_kwargs': {'num_mixtures': 2, 'use_angle_input': True},
    },
}

# Default VAE config
DEFAULT_VAE_CFG = {
    'img_size': 128,
    'nf': 32,
    'zdim': 256,
    'steps': 5,
    'colors': 3,
    'act': 'elu',
    'vy': 0.001
}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 3. Helper Functions

In [None]:
def find_runs_missing_best(results_base, task_name, kernel_name):
    """
    Find all runs that are missing gp_weights.best.pt or vae_weights.best.pt
    """
    folder_name = f"{task_name}_{kernel_name}"
    kernel_path = os.path.join(results_base, folder_name)
    
    if not os.path.exists(kernel_path):
        print(f"‚ö†Ô∏è Folder not found: {kernel_path}")
        return []
    
    runs_missing = []
    runs_complete = []
    
    for run_dir in sorted(os.listdir(kernel_path)):
        run_path = os.path.join(kernel_path, run_dir)
        if not os.path.isdir(run_path):
            continue
        
        weights_dir = os.path.join(run_path, 'weights')
        if not os.path.exists(weights_dir):
            continue
        
        gp_best = os.path.join(weights_dir, 'gp_weights.best.pt')
        vae_best = os.path.join(weights_dir, 'vae_weights.best.pt')
        
        if os.path.exists(gp_best) and os.path.exists(vae_best):
            runs_complete.append(run_path)
        else:
            runs_missing.append(run_path)
    
    return runs_missing, runs_complete


def get_available_checkpoints(weights_dir):
    """
    Get list of available checkpoint epochs from weights directory.
    Returns list of tuples: (epoch, gp_path, vae_path)
    """
    gp_files = glob.glob(os.path.join(weights_dir, 'gp_weights.*.pt'))
    
    checkpoints = []
    for gp_path in gp_files:
        # Skip best files
        if 'best' in gp_path:
            continue
        
        # Extract epoch number
        match = re.search(r'gp_weights\.(\d+)\.pt', gp_path)
        if match:
            epoch = int(match.group(1))
            vae_path = gp_path.replace('gp_weights', 'vae_weights')
            if os.path.exists(vae_path):
                checkpoints.append((epoch, gp_path, vae_path))
    
    # Sort by epoch
    checkpoints.sort(key=lambda x: x[0])
    return checkpoints


def encode_dataset(vae, data_queue, device, zdim=256):
    """Encode all images to latent space."""
    vae.eval()
    n = data_queue.dataset.Y.shape[0]
    Zm = torch.zeros(n, zdim).to(device)
    
    with torch.no_grad():
        for data in data_queue:
            y = data[0].to(device)
            idxs = data[-1].to(device)
            zm, _ = vae.encode(y)
            Zm[idxs] = zm.detach()
    
    return Zm


def evaluate_checkpoint(vae, vm, gp, train_queue, val_queue,
                        Dt, Wt, Dv, Wv, device):
    """
    Evaluate a checkpoint on validation set.
    Returns validation MSE (mse_out - the GP prediction MSE).
    """
    vae.eval()
    vm.eval()
    gp.eval()
    
    with torch.no_grad():
        # Encode training data
        Zm = encode_dataset(vae, train_queue, device)
        
        # Compute V matrices
        Vt = vm(Dt, Wt).detach()
        Vv = vm(Dv, Wv).detach()
        
        # GP prediction
        vs = gp.get_vs()
        U, UBi, _ = gp.U_UBi_Shb([Vt], vs)
        Kiz = gp.solve(Zm, U, UBi, vs)
        Zo_val = vs[0] * Vv.mm(Vt.transpose(0, 1).mm(Kiz))
        
        # Compute MSE
        mse_total = 0.0
        n_samples = 0
        
        for data in val_queue:
            idxs = data[-1].to(device)
            Yv = data[0].to(device)
            Yo = vae.decode(Zo_val[idxs])
            mse_batch = ((Yv - Yo) ** 2).view(Yv.shape[0], -1).mean(1)
            mse_total += mse_batch.sum().item()
            n_samples += Yv.shape[0]
        
        mse_out = mse_total / n_samples
    
    return mse_out


def process_run(run_path, kernel_config, data_file, device):
    """
    Process a single run: evaluate all checkpoints and save best.
    """
    weights_dir = os.path.join(run_path, 'weights')
    checkpoints = get_available_checkpoints(weights_dir)
    
    if not checkpoints:
        print(f"    ‚ö†Ô∏è No checkpoints found")
        return None
    
    print(f"    Found {len(checkpoints)} checkpoints")
    
    # Load data
    train_data = COIL100Dataset(data_file, split='train', use_angle_encoding=False)
    val_data = COIL100Dataset(data_file, split='val', use_angle_encoding=False)
    train_queue = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False)
    val_queue = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)
    
    P = get_num_objects(data_file)
    Q = get_n_views()
    
    Dt = Variable(train_data.Did.long(), requires_grad=False).to(device)
    Wt = Variable(train_data.Rid.long(), requires_grad=False).to(device)
    Dv = Variable(val_data.Did.long(), requires_grad=False).to(device)
    Wv = Variable(val_data.Rid.long(), requires_grad=False).to(device)
    
    best_mse = float('inf')
    best_epoch = -1
    best_gp_path = None
    best_vae_path = None
    
    # Evaluate each checkpoint
    for epoch, gp_path, vae_path in checkpoints:
        try:
            # Load VAE
            vae = FaceVAE(**DEFAULT_VAE_CFG).to(device)
            vae.load_state_dict(torch.load(vae_path, map_location=device))
            
            # Load Vmodel and GP
            vm = Vmodel(
                P=P, Q=Q, p=XDIM,
                view_kernel=kernel_config['view_kernel'],
                **kernel_config['kernel_kwargs']
            ).to(device)
            gp = GP(n_rand_effs=1).to(device)
            
            checkpoint = torch.load(gp_path, map_location=device)
            gp.load_state_dict(checkpoint['gp_state'])
            vm.load_state_dict(checkpoint['vm_state'])
            
            # Evaluate
            mse = evaluate_checkpoint(
                vae, vm, gp, train_queue, val_queue,
                Dt, Wt, Dv, Wv, device
            )
            
            if mse < best_mse:
                best_mse = mse
                best_epoch = epoch
                best_gp_path = gp_path
                best_vae_path = vae_path
            
            # Clean up
            del vae, vm, gp
            torch.cuda.empty_cache()
            
        except Exception as e:
            print(f"    ‚ö†Ô∏è Error at epoch {epoch}: {e}")
            continue
    
    if best_epoch >= 0:
        print(f"    ‚úÖ Best epoch: {best_epoch} (MSE: {best_mse:.6f})")
        
        # Copy best checkpoint as *.best.pt
        best_gp_dest = os.path.join(weights_dir, 'gp_weights.best.pt')
        best_vae_dest = os.path.join(weights_dir, 'vae_weights.best.pt')
        
        shutil.copy2(best_gp_path, best_gp_dest)
        shutil.copy2(best_vae_path, best_vae_dest)
        
        print(f"    üìÅ Saved: gp_weights.best.pt, vae_weights.best.pt")
        
        return {
            'run': os.path.basename(run_path),
            'best_epoch': best_epoch,
            'best_mse': best_mse,
        }
    
    return None

print("‚úÖ Helper functions defined")

## 4. Scan for Missing Best Weights

In [None]:
# Scan all tasks and kernels
all_missing_runs = []

print("üìä Scanning for runs missing *.best.pt files...")
print("=" * 60)

for task_info in TASKS_TO_PROCESS:
    task_name = task_info['task_name']
    print(f"\nüîç {task_name.upper()}:")
    
    for kernel_name in task_info['kernels']:
        runs_missing, runs_complete = find_runs_missing_best(
            RESULTS_BASE, task_name, kernel_name
        )
        
        print(f"   {kernel_name}: {len(runs_complete)} complete, {len(runs_missing)} missing")
        
        for run_path in runs_missing:
            all_missing_runs.append({
                'task': task_name,
                'kernel': kernel_name,
                'run_path': run_path,
                'data_file': task_info['data_file'],
            })

print(f"\n{'='*60}")
print(f"üìã Total runs to process: {len(all_missing_runs)}")

if all_missing_runs:
    print("\nRuns to process:")
    for i, run_info in enumerate(all_missing_runs):
        print(f"  {i+1}. {run_info['task']}/{run_info['kernel']}: {os.path.basename(run_info['run_path'])}")

## 5. Process Missing Runs

In [None]:
# Process all missing runs
results = []

print(f"\nüöÄ Processing {len(all_missing_runs)} runs...")
print("=" * 70)

for i, run_info in enumerate(all_missing_runs):
    task = run_info['task']
    kernel = run_info['kernel']
    run_path = run_info['run_path']
    data_file = run_info['data_file']
    
    print(f"\n[{i+1}/{len(all_missing_runs)}] {task}/{kernel}: {os.path.basename(run_path)}")
    
    # Check data file exists
    if not os.path.exists(data_file):
        print(f"    ‚ùå Data file not found: {data_file}")
        continue
    
    kernel_config = KERNEL_CONFIGS[kernel]
    
    try:
        result = process_run(run_path, kernel_config, data_file, device)
        if result:
            result['task'] = task
            result['kernel'] = kernel
            results.append(result)
    except Exception as e:
        print(f"    ‚ùå Error: {e}")
        import traceback
        traceback.print_exc()

print("\n" + "=" * 70)
print(f"‚úÖ Processing complete! Generated best weights for {len(results)} runs.")

## 6. Summary

In [None]:
import pandas as pd

if results:
    df = pd.DataFrame(results)
    print("\nüìä Generated Best Weights Summary:")
    print("=" * 70)
    print(df.to_string(index=False))
    
    # Group by task and kernel
    print("\nüìà Summary by Task/Kernel:")
    summary = df.groupby(['task', 'kernel']).agg({
        'best_mse': ['mean', 'std', 'min', 'max', 'count']
    }).round(6)
    print(summary)
else:
    print("\n‚úÖ All runs already have best weights!")

## 7. Verify Generated Files

In [None]:
# Verify all runs now have best weights
print("\nüîç Verifying all runs now have best weights...")
print("=" * 60)

for task_info in TASKS_TO_PROCESS:
    task_name = task_info['task_name']
    print(f"\n{task_name.upper()}:")
    
    for kernel_name in task_info['kernels']:
        runs_missing, runs_complete = find_runs_missing_best(
            RESULTS_BASE, task_name, kernel_name
        )
        
        status = "‚úÖ" if len(runs_missing) == 0 else "‚ö†Ô∏è"
        print(f"   {status} {kernel_name}: {len(runs_complete)} complete, {len(runs_missing)} missing")
        
        if runs_missing:
            for run in runs_missing:
                print(f"      ‚ùå Still missing: {os.path.basename(run)}")

print("\n" + "=" * 60)
print("‚úÖ Verification complete!")