# Task 1 (Standard) Analysis - Kernel Comparison

This notebook analyzes the performance of different GP kernels on **Task 1 (Standard split)**:

## Task Description
- **Task 1 (Standard)**: Random 70/15/15 split across all views
- **Data**: COIL-100 dataset (100 objects √ó 18 views)
- **Goal**: Predict held-out images using GP interpolation

## Kernels Compared
1. **Full Rank**: Free-form learnable covariance (Q√óQ parameters)
2. **Periodic**: Standard periodic kernel with learned lengthscale
3. **SM Wrapped**: Spectral Mixture with wrapped lag distance
4. **SM Free**: Spectral Mixture (unwrapped)

## Analysis
- Load best checkpoints from each kernel (5 seeds each)
- Evaluate on test set
- Compute mean and variance of MSE across seeds
- Visualize reconstructions and kernel matrices

## 1. Setup

In [1]:
import os
import sys
import glob
import pickle
import numpy as np
import pandas as pd
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
import seaborn as sns

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

ModuleNotFoundError: No module named 'pandas'

In [None]:
import os
import sys

current_dir = os.getcwd()
print(f"üìç Current directory: {current_dir}")

if current_dir == '/content':
    print("\nüîÑ Mounting Google Drive...")
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        drive_path = '/content/drive/MyDrive/gppvae'
        if os.path.exists(drive_path):
            PROJECT_PATH = drive_path
            print(f"‚úÖ Found project in Google Drive: {PROJECT_PATH}")
        else:
            print(f"‚ö†Ô∏è Project not found at: {drive_path}")
            PROJECT_PATH = '/content'
    except Exception as e:
        print(f"Could not mount Drive: {e}")
        PROJECT_PATH = '/content'
else:
    if 'notebooks' in current_dir:
        PROJECT_PATH = os.path.dirname(os.path.dirname(current_dir))
    else:
        PROJECT_PATH = current_dir
    print(f"üíª Using project path: {PROJECT_PATH}")

# Add code paths
coil100_path = os.path.join(PROJECT_PATH, 'GPPVAE/pysrc/coil100')
sys.path.insert(0, coil100_path)

# Change to project root
os.chdir(PROJECT_PATH)
print(f"Working directory: {os.getcwd()}")
print(f"Code path added: {coil100_path}")

In [None]:
# Import models
from vae import FaceVAE
from vmod import Vmodel
from gp import GP
from data_parser import COIL100Dataset, get_n_views, get_num_objects

print("‚úÖ All modules imported")

## 2. Configuration

In [None]:
# Task 1 (Standard) configuration
CONFIG = {
    'task': 'task1_standard',
    'data_path': './data/coil100/coil100_task1_standard.h5',
    'batch_size': 64,
    'xdim': 64,  # Object embedding dimension
}

# Auto-detect results folder location
# Check both ./results and ./GPPVAE/results
if os.path.exists('./GPPVAE/results'):
    CONFIG['results_base'] = './GPPVAE/results'
    print(f"‚úÖ Found results at: {os.path.abspath(CONFIG['results_base'])}")
elif os.path.exists('./results'):
    CONFIG['results_base'] = './results'
    print(f"‚úÖ Found results at: {os.path.abspath(CONFIG['results_base'])}")
else:
    CONFIG['results_base'] = './results'
    print(f"‚ö†Ô∏è Results folder not found, using default: {CONFIG['results_base']}")

# Auto-detect data folder location
if os.path.exists('./GPPVAE/data/coil100/coil100_task1_standard.h5'):
    CONFIG['data_path'] = './GPPVAE/data/coil100/coil100_task1_standard.h5'
    print(f"‚úÖ Found data at: {os.path.abspath(CONFIG['data_path'])}")
elif os.path.exists('./data/coil100/coil100_task1_standard.h5'):
    CONFIG['data_path'] = './data/coil100/coil100_task1_standard.h5'
    print(f"‚úÖ Found data at: {os.path.abspath(CONFIG['data_path'])}")
else:
    CONFIG['data_path'] = './data/coil100/coil100_task1_standard.h5'
    print(f"‚ö†Ô∏è Data file not found, using default: {CONFIG['data_path']}")

# Kernel configurations (must match training)
KERNEL_CONFIGS = {
    'fullrank': {
        'folder': 'task1_fullrank',
        'view_kernel': 'full_rank',
        'kernel_kwargs': {},
        'display_name': 'Full Rank',
        'color': '#e74c3c',
    },
    'periodic': {
        'folder': 'task1_periodic',
        'view_kernel': 'periodic',
        'kernel_kwargs': {'period': 360.0, 'lengthscale': 1.0, 'variance': 1.0},
        'display_name': 'Periodic',
        'color': '#3498db',
    },
    'sm_wrapped': {
        'folder': 'task1_sm_wrapped',
        'view_kernel': 'sm_circle',
        'kernel_kwargs': {'freq_init': [1/360.0, 1/40.0], 'weight_init':[0.5, 0.5], 'length_init':[90,30]},
        'display_name': 'SM (Wrapped)',
        'color': '#2ecc71',
    },
    'sm_free': {
        'folder': 'task1_sm_free',
        'view_kernel': 'sm_circle',  # Adjust if different
        'kernel_kwargs': {'freq_init': [1/360.0, 1/40.0], 'weight_init':[0.5, 0.5], 'use_angle_input': True},
        'display_name': 'SM (Free)',
        'color': '#9b59b6',
    },
}

# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 3. Load Data

In [None]:
# Load datasets
train_data = COIL100Dataset(CONFIG['data_path'], split='train', use_angle_encoding=False)
val_data = COIL100Dataset(CONFIG['data_path'], split='val', use_angle_encoding=False)
test_data = COIL100Dataset(CONFIG['data_path'], split='test', use_angle_encoding=False)

train_queue = DataLoader(train_data, batch_size=CONFIG['batch_size'], shuffle=False)
val_queue = DataLoader(val_data, batch_size=CONFIG['batch_size'], shuffle=False)
test_queue = DataLoader(test_data, batch_size=CONFIG['batch_size'], shuffle=False)

# Get dimensions
P = get_num_objects(CONFIG['data_path'])  # Number of objects
Q = get_n_views()  # Number of views (18)

print(f"\nüìä Task 1 (Standard) Dataset:")
print(f"   Objects (P): {P}")
print(f"   Views (Q): {Q}")
print(f"   Train: {len(train_data)} samples")
print(f"   Val: {len(val_data)} samples")
print(f"   Test: {len(test_data)} samples")

# Create tensors for indices
Dt = Variable(train_data.Did.long(), requires_grad=False).to(device)
Wt = Variable(train_data.Rid.long(), requires_grad=False).to(device)
Dtest = Variable(test_data.Did.long(), requires_grad=False).to(device)
Wtest = Variable(test_data.Rid.long(), requires_grad=False).to(device)

## 4. Helper Functions

In [None]:
def find_run_folders(results_base, kernel_folder):
    """Find all run folders (seeds) for a kernel."""
    kernel_path = os.path.join(results_base, kernel_folder)
    if not os.path.exists(kernel_path):
        print(f"‚ö†Ô∏è Kernel folder not found: {kernel_path}")
        return []
    
    runs = sorted([d for d in os.listdir(kernel_path) 
                   if os.path.isdir(os.path.join(kernel_path, d))])
    return [os.path.join(kernel_path, r) for r in runs]


def load_vae_config(run_folder=None):
    """Load VAE config from GPPVAE folder structure or use default."""
    # Try to find VAE config in GPPVAE folder (as shown in your screenshot)
    possible_paths = [
        './GPPVAE/pysrc/coil100/vae.cfg.p',  # If there's a shared config
        './out/vae/vae.cfg.p',  # Common VAE output folder
        './out/vae_colab/*/vae.cfg.p',  # Colab VAE folders
    ]
    
    for path_pattern in possible_paths:
        if '*' in path_pattern:
            import glob
            matches = glob.glob(path_pattern)
            if matches:
                try:
                    import pickle
                    with open(matches[0], 'rb') as f:
                        vae_cfg = pickle.load(f)
                    print(f"‚úÖ Loaded VAE config from: {matches[0]}")
                    return vae_cfg
                except:
                    continue
        elif os.path.exists(path_pattern):
            try:
                import pickle
                with open(path_pattern, 'rb') as f:
                    vae_cfg = pickle.load(f)
                print(f"‚úÖ Loaded VAE config from: {path_pattern}")
                return vae_cfg
            except:
                continue
    
    # Default VAE config for COIL-100
    print("‚ö†Ô∏è Using default VAE config")
    vae_cfg = {
        'img_size': 128,
        'nf': 32,
        'zdim': 256,
        'steps': 5,
        'colors': 3,
        'act': 'elu',
        'vy': 0.001
    }
    return vae_cfg


def load_models(run_folder, kernel_config, P, Q, xdim, device):
    """
    Load VAE, Vmodel, and GP from a run folder's best checkpoint.
    """
    weights_dir = os.path.join(run_folder, 'weights')
    gp_weights_path = os.path.join(weights_dir, 'gp_weights.best.pt')
    vae_weights_path = os.path.join(weights_dir, 'vae_weights.best.pt')
    
    if not os.path.exists(gp_weights_path):
        raise FileNotFoundError(f"GP weights not found: {gp_weights_path}")
    if not os.path.exists(vae_weights_path):
        raise FileNotFoundError(f"VAE weights not found: {vae_weights_path}")
    
    # Load VAE
    vae_cfg = load_vae_config(run_folder)
    vae = FaceVAE(**vae_cfg).to(device)
    vae.load_state_dict(torch.load(vae_weights_path, map_location=device))
    vae.eval()
    
    # Load Vmodel and GP
    vm = Vmodel(
        P=P, Q=Q, p=xdim,
        view_kernel=kernel_config['view_kernel'],
        **kernel_config['kernel_kwargs']
    ).to(device)
    gp = GP(n_rand_effs=1).to(device)
    
    checkpoint = torch.load(gp_weights_path, map_location=device)
    gp.load_state_dict(checkpoint['gp_state'])
    vm.load_state_dict(checkpoint['vm_state'])
    
    vm.eval()
    gp.eval()
    
    return vae, vm, gp


def encode_dataset(vae, data_queue, device):
    """Encode all images in a dataset to latent space."""
    vae.eval()
    n = data_queue.dataset.Y.shape[0]
    zdim = 256  # Default zdim
    
    Zm = torch.zeros(n, zdim).to(device)
    
    with torch.no_grad():
        for data in data_queue:
            y = data[0].to(device)
            idxs = data[-1].to(device)
            zm, _ = vae.encode(y)
            Zm[idxs] = zm.detach()
    
    return Zm


def evaluate_on_test(vae, vm, gp, train_queue, test_queue, 
                     Dt, Wt, Dtest, Wtest, device):
    """
    Evaluate model on test set.
    
    Returns:
        mse_test: Mean MSE on test set
        mse_per_sample: MSE for each test sample
        mse_per_view: Dict mapping view index to list of MSEs
    """
    vae.eval()
    vm.eval()
    gp.eval()
    
    with torch.no_grad():
        # Encode training data
        Zm = encode_dataset(vae, train_queue, device)
        
        # Compute V matrices
        Vt = vm(Dt, Wt).detach()
        Vtest = vm(Dtest, Wtest).detach()
        
        # GP prediction
        vs = gp.get_vs()
        U, UBi, _ = gp.U_UBi_Shb([Vt], vs)
        Kiz = gp.solve(Zm, U, UBi, vs)
        Zo_test = vs[0] * Vtest.mm(Vt.transpose(0, 1).mm(Kiz))
        
        # Compute MSE
        test_Rid = test_queue.dataset.Rid
        mse_per_view = {}
        mse_per_sample = []
        mse_test_total = 0.0
        
        for data in test_queue:
            idxs = data[-1].to(device)
            Ytest = data[0].to(device)
            Yo = vae.decode(Zo_test[idxs])
            mse_batch = ((Ytest - Yo) ** 2).view(Ytest.shape[0], -1).mean(1)
            
            for i, idx in enumerate(data[-1]):
                view = int(test_Rid[idx].item())
                mse_val = mse_batch[i].item()
                
                if view not in mse_per_view:
                    mse_per_view[view] = []
                mse_per_view[view].append(mse_val)
                mse_per_sample.append(mse_val)
            
            mse_test_total += mse_batch.sum().item()
        
        mse_test = mse_test_total / len(test_queue.dataset)
    
    return mse_test, np.array(mse_per_sample), mse_per_view


def get_reconstructions(vae, vm, gp, train_queue, test_queue,
                        Dt, Wt, Dtest, Wtest, device, n_samples=24):
    """
    Get sample reconstructions for visualization.
    
    Returns:
        Y_orig: Original test images [n_samples, H, W, C]
        Y_recon: GP-predicted reconstructions [n_samples, H, W, C]
    """
    vae.eval()
    vm.eval()
    gp.eval()
    
    with torch.no_grad():
        # Encode training data
        Zm = encode_dataset(vae, train_queue, device)
        
        # Compute V matrices
        Vt = vm(Dt, Wt).detach()
        Vtest = vm(Dtest, Wtest).detach()
        
        # GP prediction
        vs = gp.get_vs()
        U, UBi, _ = gp.U_UBi_Shb([Vt], vs)
        Kiz = gp.solve(Zm, U, UBi, vs)
        Zo_test = vs[0] * Vtest.mm(Vt.transpose(0, 1).mm(Kiz))
        
        # Get sample images
        n_total = len(test_queue.dataset)
        sample_stride = max(1, n_total // n_samples)
        sample_indices = list(range(0, n_total, sample_stride))[:n_samples]
        
        Y_orig = test_queue.dataset.Y[sample_indices].numpy().transpose(0, 2, 3, 1)
        
        # Decode predictions
        sample_indices_tensor = torch.tensor(sample_indices, dtype=torch.long).to(device)
        Y_recon = vae.decode(Zo_test[sample_indices_tensor])
        Y_recon = Y_recon.cpu().numpy().transpose(0, 2, 3, 1)
    
    return Y_orig, Y_recon

print("‚úÖ Helper functions defined")

## 5. Evaluate All Kernels

In [None]:
# Results storage
results = {}

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    print(f"\n{'='*60}")
    print(f"Evaluating: {kernel_config['display_name']}")
    print(f"{'='*60}")
    
    # Find all runs (seeds)
    run_folders = find_run_folders(CONFIG['results_base'], kernel_config['folder'])
    
    if not run_folders:
        print(f"‚ö†Ô∏è No runs found for {kernel_name}")
        continue
    
    print(f"Found {len(run_folders)} runs (seeds)")
    
    kernel_results = {
        'mse_per_seed': [],
        'mse_per_sample_all': [],
        'kernel_matrices': [],
        'variance_ratios': [],
    }
    
    for i, run_folder in enumerate(run_folders):
        try:
            print(f"  Seed {i}: {os.path.basename(run_folder)}...", end=" ")
            
            # Load models
            vae, vm, gp = load_models(
                run_folder, kernel_config, P, Q, 
                CONFIG['xdim'], device
            )
            
            # Evaluate
            mse_test, mse_per_sample, mse_per_view = evaluate_on_test(
                vae, vm, gp, train_queue, test_queue,
                Dt, Wt, Dtest, Wtest, device
            )
            
            # Get kernel matrix
            with torch.no_grad():
                K = vm.get_kernel_matrix().cpu().numpy()
                vs = gp.get_vs().cpu().numpy()
                variance_ratio = vs[0] / (vs[0] + vs[1])
            
            kernel_results['mse_per_seed'].append(mse_test)
            kernel_results['mse_per_sample_all'].append(mse_per_sample)
            kernel_results['kernel_matrices'].append(K)
            kernel_results['variance_ratios'].append(variance_ratio)
            
            print(f"MSE = {mse_test:.6f}")
            
        except Exception as e:
            print(f"‚ùå Error: {e}")
            continue
    
    if kernel_results['mse_per_seed']:
        mse_array = np.array(kernel_results['mse_per_seed'])
        kernel_results['mean_mse'] = np.mean(mse_array)
        kernel_results['std_mse'] = np.std(mse_array)
        kernel_results['var_mse'] = np.var(mse_array)
        
        print(f"\n  üìä {kernel_config['display_name']} Summary:")
        print(f"     Mean MSE: {kernel_results['mean_mse']:.6f} ¬± {kernel_results['std_mse']:.6f}")
        print(f"     Variance: {kernel_results['var_mse']:.8f}")
        print(f"     Seeds: {len(kernel_results['mse_per_seed'])}")
    
    results[kernel_name] = kernel_results

print("\n" + "="*60)
print("‚úÖ Evaluation complete!")

## 6. Summary Table

In [None]:
# Create summary DataFrame
summary_data = []

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and results[kernel_name]['mse_per_seed']:
        r = results[kernel_name]
        summary_data.append({
            'Kernel': kernel_config['display_name'],
            'Mean MSE': r['mean_mse'],
            'Std MSE': r['std_mse'],
            'Var MSE': r['var_mse'],
            'Min MSE': np.min(r['mse_per_seed']),
            'Max MSE': np.max(r['mse_per_seed']),
            'N Seeds': len(r['mse_per_seed']),
            'Mean Var Ratio': np.mean(r['variance_ratios']),
        })

summary_df = pd.DataFrame(summary_data)
summary_df = summary_df.sort_values('Mean MSE')

print("\nüìä Task 1 (Standard) - Kernel Comparison Summary")
print("="*80)
print(summary_df.to_string(index=False))
print("="*80)

In [None]:
# Display formatted table with pandas styling
styled_df = summary_df.style.format({
    'Mean MSE': '{:.6f}',
    'Std MSE': '{:.6f}',
    'Var MSE': '{:.8f}',
    'Min MSE': '{:.6f}',
    'Max MSE': '{:.6f}',
    'Mean Var Ratio': '{:.3f}',
}).background_gradient(subset=['Mean MSE'], cmap='RdYlGn_r')

styled_df

## 7. Visualization: MSE Comparison

In [None]:
# Define consistent colors for each kernel
line_colors = {
    'Full Rank': '#1f77b4',      # Blue
    'Periodic': '#ff7f0e',        # Orange
    'SM (Wrapped)': '#9467bd',    # Purple
    'SM (Free)': '#8c564b',       # Brown
}

# Bar plot with error bars - clean style
fig, ax = plt.subplots(1, 1, figsize=(6, 5))

kernel_names = []
means = []
stds = []
colors = []

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and results[kernel_name]['mse_per_seed']:
        display_name = kernel_config['display_name']
        kernel_names.append(display_name)
        means.append(results[kernel_name]['mean_mse'])
        stds.append(results[kernel_name]['std_mse'])
        colors.append(line_colors.get(display_name, kernel_config['color']))

x = np.arange(len(kernel_names))
bars = ax.bar(x, means, yerr=stds, capsize=5, color=colors, alpha=0.8, edgecolor='black', linewidth=0.5)
ax.set_xticks(x)
ax.set_xticklabels(kernel_names, rotation=0, ha='center')
ax.set_ylabel('MSE [test set]')
ax.set_ylim(bottom=0)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(False)

plt.tight_layout()
plt.savefig('./notebooks/analysis/task1_mse_bar.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä Saved: task1_mse_bar.png (300 DPI)")

# Box plot - clean style
fig, ax = plt.subplots(1, 1, figsize=(6, 5))

data_for_box = []
labels_for_box = []
colors_for_box = []

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and results[kernel_name]['mse_per_seed']:
        display_name = kernel_config['display_name']
        data_for_box.append(results[kernel_name]['mse_per_seed'])
        labels_for_box.append(display_name)
        colors_for_box.append(line_colors.get(display_name, kernel_config['color']))

bp = ax.boxplot(data_for_box, labels=labels_for_box, patch_artist=True)
for patch, color in zip(bp['boxes'], colors_for_box):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)
    patch.set_linewidth(0.5)

for element in ['whiskers', 'fliers', 'means', 'medians', 'caps']:
    plt.setp(bp[element], linewidth=0.5)

ax.set_ylabel('MSE [test set]')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(False)
ax.tick_params(axis='x', rotation=0)

plt.tight_layout()
plt.savefig('./notebooks/analysis/task1_mse_boxplot.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä Saved: task1_mse_boxplot.png (300 DPI)")

## 8. Visualization: Kernel Matrices

In [None]:
# Plot kernel matrices for each kernel type (first seed)
# Using same style as callback_gppvae from callbacks.py
import pylab as pl
pl.rcdefaults()  # Reset to matplotlib defaults for consistent styling

n_kernels = len([k for k in results if results[k]['kernel_matrices']])
if n_kernels > 0:
    fig, axes = pl.subplots(2, 2, figsize=(10, 10))
    axes = axes.flatten()
    
    ax_idx = 0
    for kernel_name, kernel_config in KERNEL_CONFIGS.items():
        if kernel_name in results and results[kernel_name]['kernel_matrices']:
            K = results[kernel_name]['kernel_matrices'][0]  # First seed
            
            ax = axes[ax_idx]
            # Match callback_gppvae style: default colormap, vmin=-0.4, vmax=1, aspect='auto'
            im = ax.imshow(K, vmin=-0.4, vmax=1, aspect='auto')
            
            # Add angle labels
            tick_positions = [0, 4, 8, 12, 17]
            tick_labels = [f"{p*20}¬∞" for p in tick_positions]
            ax.set_xticks(tick_positions)
            ax.set_xticklabels(tick_labels, fontsize=8)
            ax.set_yticks(tick_positions)
            ax.set_yticklabels(tick_labels, fontsize=8)
            ax.tick_params(labelsize=8)
            ax.set_title(kernel_config['display_name'], fontsize=10)
            
            pl.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
            ax_idx += 1
    
    # Hide unused subplots
    for i in range(ax_idx, 4):
        axes[i].axis('off')
    
    pl.tight_layout()
    pl.savefig('./notebooks/analysis/task1_kernel_matrices.png', dpi=300, bbox_inches='tight')
    pl.show()
    print("üìä Saved: task1_kernel_matrices.png (300 DPI)")
else:
    print("‚ö†Ô∏è No kernel matrices to display")

## 9. Visualization: Sample Reconstructions

In [None]:
# Get reconstructions from each kernel (best seed by MSE)
n_samples = 8
n_kernels = len([k for k in results if results[k]['mse_per_seed']])

if n_kernels > 0:
    fig, axes = plt.subplots(n_kernels + 1, n_samples, figsize=(n_samples * 1.5, (n_kernels + 1) * 1.5))
    
    # Plot ground truth (first row)
    n_total = len(test_data)
    sample_stride = max(1, n_total // n_samples)
    sample_indices = list(range(0, n_total, sample_stride))[:n_samples]
    Y_gt = test_data.Y[sample_indices].numpy().transpose(0, 2, 3, 1)
    
    for i in range(n_samples):
        axes[0, i].imshow(np.clip(Y_gt[i], 0, 1))
        axes[0, i].axis('off')
    axes[0, 0].set_ylabel('GT', fontsize=10)
    
    # Plot reconstructions for each kernel
    row_idx = 1
    for kernel_name, kernel_config in KERNEL_CONFIGS.items():
        if kernel_name in results and results[kernel_name]['mse_per_seed']:
            # Find best seed
            best_seed_idx = np.argmin(results[kernel_name]['mse_per_seed'])
            best_run = find_run_folders(CONFIG['results_base'], kernel_config['folder'])[best_seed_idx]
            
            # Load models
            vae, vm, gp = load_models(
                best_run, kernel_config, P, Q, 
                CONFIG['xdim'], device
            )
            
            # Get reconstructions
            Y_orig, Y_recon = get_reconstructions(
                vae, vm, gp, train_queue, test_queue,
                Dt, Wt, Dtest, Wtest, device, n_samples=n_samples
            )
            
            for i in range(n_samples):
                axes[row_idx, i].imshow(np.clip(Y_recon[i], 0, 1))
                axes[row_idx, i].axis('off')
            
            mse = results[kernel_name]['mse_per_seed'][best_seed_idx]
            axes[row_idx, 0].set_ylabel(f"{kernel_config['display_name']}\n({mse:.4f})", fontsize=8)
            row_idx += 1
    
    plt.tight_layout()
    plt.savefig('./notebooks/analysis/task1_reconstructions.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("üìä Saved: task1_reconstructions.png (300 DPI)")
else:
    print("‚ö†Ô∏è No kernels to display")

## 10. Visualization: Per-Seed Performance

In [None]:
# Line plot showing MSE across seeds for each kernel
fig, ax = plt.subplots(figsize=(10, 5))

# Define consistent colors for each kernel
line_colors = {
    'Full Rank': '#1f77b4',      # Blue
    'Periodic': '#ff7f0e',        # Orange
    'SM (Wrapped)': '#9467bd',    # Purple
    'SM (Free)': '#8c564b',       # Brown
}

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and results[kernel_name]['mse_per_seed']:
        mse_values = results[kernel_name]['mse_per_seed']
        seeds = range(len(mse_values))
        display_name = kernel_config['display_name']
        color = line_colors.get(display_name, kernel_config['color'])
        ax.plot(seeds, mse_values, 'o-', 
                label=f"{display_name} (Œº={np.mean(mse_values):.5f})",
                color=color, markersize=8, linewidth=2)

ax.set_xlabel('Seed', fontsize=12)
ax.set_ylabel('Test MSE', fontsize=12)
ax.legend(loc='best')
ax.set_xticks(range(5))
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(False)

plt.tight_layout()
plt.savefig('./notebooks/analysis/task1_per_seed_mse.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä Saved: task1_per_seed_mse.png (300 DPI)")

## 11. Statistical Analysis

In [None]:
from scipy import stats

# Pairwise t-tests between kernels
kernel_names_list = [k for k in KERNEL_CONFIGS if k in results and results[k]['mse_per_seed']]

if len(kernel_names_list) >= 2:
    print("\nüìä Pairwise T-Tests (p-values)")
    print("="*60)
    
    pvalue_matrix = np.ones((len(kernel_names_list), len(kernel_names_list)))
    
    for i, k1 in enumerate(kernel_names_list):
        for j, k2 in enumerate(kernel_names_list):
            if i < j:
                mse1 = results[k1]['mse_per_seed']
                mse2 = results[k2]['mse_per_seed']
                
                # Paired t-test
                t_stat, p_value = stats.ttest_ind(mse1, mse2)
                pvalue_matrix[i, j] = p_value
                pvalue_matrix[j, i] = p_value
                
                k1_name = KERNEL_CONFIGS[k1]['display_name']
                k2_name = KERNEL_CONFIGS[k2]['display_name']
                sig = "*" if p_value < 0.05 else ""
                print(f"{k1_name} vs {k2_name}: p = {p_value:.4f} {sig}")
    
    print("\n* indicates p < 0.05 (statistically significant)")
else:
    print("‚ö†Ô∏è Need at least 2 kernels for statistical comparison")

## 12. Export Results

In [None]:
# Save detailed results to CSV
detailed_results = []

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and results[kernel_name]['mse_per_seed']:
        for seed_idx, mse in enumerate(results[kernel_name]['mse_per_seed']):
            detailed_results.append({
                'kernel': kernel_config['display_name'],
                'seed': seed_idx,
                'test_mse': mse,
                'variance_ratio': results[kernel_name]['variance_ratios'][seed_idx],
            })

detailed_df = pd.DataFrame(detailed_results)
detailed_df.to_csv('./notebooks/analysis/task1_detailed_results.csv', index=False)
print("üìä Saved: task1_detailed_results.csv")

# Save summary to CSV
summary_df.to_csv('./notebooks/analysis/task1_summary.csv', index=False)
print("üìä Saved: task1_summary.csv")

# Display detailed results
print("\nüìã Detailed Results:")
print(detailed_df.to_string(index=False))

## 13. Callback-Style Plot (like training)

In [None]:
# Import callback_gppvae directly from callbacks.py
from callbacks import callback_gppvae, _compose_multi
import pylab as pl

def generate_callback_plot_for_kernel(kernel_name, kernel_config, results, 
                                       train_queue, test_queue,
                                       Dt, Wt, Dtest, Wtest, P, Q, xdim, device,
                                       output_file):
    """
    Generate a callback-style plot for a kernel using the actual callback_gppvae function.
    """
    if kernel_name not in results or not results[kernel_name]['mse_per_seed']:
        return None
    
    # Reset matplotlib style to default (same as training)
    pl.rcdefaults()
    
    # Use best seed
    best_seed_idx = np.argmin(results[kernel_name]['mse_per_seed'])
    best_run = find_run_folders(CONFIG['results_base'], kernel_config['folder'])[best_seed_idx]
    
    # Load models
    vae, vm, gp = load_models(
        best_run, kernel_config, P, Q, xdim, device
    )
    
    with torch.no_grad():
        # Get covariances (same as training code)
        X = vm.x().cpu().numpy()
        W = vm.v().cpu().numpy()
        XX = X @ X.T
        WW = W @ W.T
        covs = {"XX": XX[:20, :20], "WW": WW}
        
        # Encode training data
        Zm = encode_dataset(vae, train_queue, device)
        
        # Compute V matrices
        Vt = vm(Dt, Wt).detach()
        Vtest = vm(Dtest, Wtest).detach()
        
        # GP prediction
        vs = gp.get_vs()
        U, UBi, _ = gp.U_UBi_Shb([Vt], vs)
        Kiz = gp.solve(Zm, U, UBi, vs)
        Zo_test = vs[0] * Vtest.mm(Vt.transpose(0, 1).mm(Kiz))
        
        # Sample diverse test samples
        n_total = len(test_queue.dataset)
        if n_total >= 24:
            sample_stride = max(1, n_total // 24)
            sample_indices = np.arange(0, n_total, sample_stride)[:24]
        else:
            sample_indices = np.arange(min(24, n_total))
        
        sample_indices_tensor = torch.tensor(sample_indices, dtype=torch.long).to(device)
        
        # Get images - Yv (ground truth), Yr (VAE recon), Yo (GP pred)
        Yv = test_queue.dataset.Y[sample_indices].numpy().transpose((0, 2, 3, 1))
        
        # VAE direct reconstruction
        Y_input = test_queue.dataset.Y[sample_indices].to(device)
        Zm_test, _ = vae.encode(Y_input)
        Yr = vae.decode(Zm_test).data.cpu().numpy().transpose((0, 2, 3, 1))
        
        # GP prediction reconstruction
        Yo = vae.decode(Zo_test[sample_indices_tensor]).data.cpu().numpy().transpose((0, 2, 3, 1))
        
        imgs = {"Yv": Yv, "Yr": Yr, "Yo": Yo}
        
        # Create history dict with analysis results (instead of training history)
        # We'll create a mock history that shows the seed results
        n_seeds = len(results[kernel_name]['mse_per_seed'])
        history = {
            "loss": results[kernel_name]['mse_per_seed'],  # Use MSE per seed as "loss"
            "vs": [[vs.cpu().numpy()[0], vs.cpu().numpy()[1]]] * n_seeds,
            "recon_term": [0.0] * n_seeds,
            "gp_nll": [0.0] * n_seeds,
            "mse_out": results[kernel_name]['mse_per_seed'],
            "mse": [results[kernel_name]['mean_mse']] * n_seeds,
            "mse_val": results[kernel_name]['mse_per_seed'],
        }
    
    # Call the actual callback_gppvae function
    callback_gppvae(
        epoch=n_seeds - 1,  # Use number of seeds as "epoch"
        history=history,
        covs=covs,
        imgs=imgs,
        ffile=output_file
    )
    
    return output_file

print("‚úÖ Callback plot function defined (using callback_gppvae from callbacks.py)")

In [None]:
# Generate callback plots for each kernel (using callback_gppvae from callbacks.py)
for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and results[kernel_name]['mse_per_seed']:
        print(f"\nüìä Generating callback plot for {kernel_config['display_name']}...")
        
        filename = f'./notebooks/analysis/task1_callback_{kernel_name}.png'
        
        result = generate_callback_plot_for_kernel(
            kernel_name, kernel_config, results,
            train_queue, test_queue,
            Dt, Wt, Dtest, Wtest, P, Q, CONFIG['xdim'], device,
            output_file=filename
        )
        
        if result:
            print(f"   ‚úÖ Saved: {filename}")
            # Display the saved image
            from IPython.display import Image, display
            display(Image(filename=filename))

## 14. Final Summary

In [None]:
print("\n" + "="*70)
print("üìä TASK 1 (STANDARD) - FINAL SUMMARY")
print("="*70)

print("\nüèÜ Kernel Ranking (by Mean Test MSE):")
print("-"*50)

ranked = sorted(
    [(k, results[k]['mean_mse'], results[k]['std_mse']) 
     for k in results if results[k]['mse_per_seed']],
    key=lambda x: x[1]
)

for rank, (kernel_name, mean_mse, std_mse) in enumerate(ranked, 1):
    display_name = KERNEL_CONFIGS[kernel_name]['display_name']
    medal = "ü•á" if rank == 1 else "ü•à" if rank == 2 else "ü•â" if rank == 3 else "  "
    print(f"{medal} {rank}. {display_name:15s}: {mean_mse:.6f} ¬± {std_mse:.6f}")

print("\nüìÅ Generated Files:")
print("-"*50)
print("   - task1_mse_comparison.png")
print("   - task1_kernel_matrices.png")
print("   - task1_reconstructions.png")
print("   - task1_per_seed_mse.png")
print("   - task1_detailed_results.csv")
print("   - task1_summary.csv")
for k in results:
    if results[k]['mse_per_seed']:
        print(f"   - task1_callback_{k}.png")

print("\n" + "="*70)
print("‚úÖ Analysis Complete!")
print("="*70)