# Task 3 (Extrapolation) Analysis - Kernel Comparison

This notebook analyzes the performance of different GP kernels on **Task 3 (Extrapolation)**:

## Task Description
- **Task 3 (Extrapolation)**: Predict views **BEYOND** the training range
- **Train Views**: Views 0-9 (0¬∞ to 180¬∞) - covering front half of rotation
- **Val Views**: Views 10-11 (200¬∞, 220¬∞) - near extrapolation
- **Test Views**: Views 12-17 (240¬∞-340¬∞) - far extrapolation
- **Goal**: Test out-of-distribution generalization via GP extrapolation

## Research Question
> "How do different kernel-induced inductive biases influence **extrapolation** to unseen extreme views?"

## Kernels Compared
1. **Full Rank**: Free-form learnable covariance (Q√óQ parameters) - **Expected to fail** (no structure)
2. **Periodic**: Standard periodic kernel - **Expected to succeed** (wraps around 360¬∞)
3. **SM Wrapped**: Spectral Mixture with wrapped distance - Should benefit from periodicity
4. **SM Free**: Spectral Mixture (unwrapped) - Intermediate performance expected

## Analysis (Extended for Extrapolation)
- **Per-View MSE**: How MSE degrades as we extrapolate further from training
- **Extrapolation Distance Analysis**: MSE vs angular distance from training range
- **Near vs Far Extrapolation**: Compare val (near) vs test (far) performance
- **Periodicity Benefit**: Which kernels leverage circular structure

## 1. Setup

In [None]:
import os
import sys
import glob
import pickle
import numpy as np
import pandas as pd
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
import seaborn as sns

plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
import os
import sys

current_dir = os.getcwd()
print(f"üìç Current directory: {current_dir}")

if current_dir == '/content':
    print("\nüîÑ Mounting Google Drive...")
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        drive_path = '/content/drive/MyDrive/gppvae'
        if os.path.exists(drive_path):
            PROJECT_PATH = drive_path
            print(f"‚úÖ Found project in Google Drive: {PROJECT_PATH}")
        else:
            print(f"‚ö†Ô∏è Project not found at: {drive_path}")
            PROJECT_PATH = '/content'
    except Exception as e:
        print(f"Could not mount Drive: {e}")
        PROJECT_PATH = '/content'
else:
    if 'notebooks' in current_dir:
        PROJECT_PATH = os.path.dirname(os.path.dirname(current_dir))
    else:
        PROJECT_PATH = current_dir
    print(f"üíª Using project path: {PROJECT_PATH}")

coil100_path = os.path.join(PROJECT_PATH, 'GPPVAE/pysrc/coil100')
sys.path.insert(0, coil100_path)

os.chdir(PROJECT_PATH)
print(f"Working directory: {os.getcwd()}")
print(f"Code path added: {coil100_path}")

In [None]:
from vae import FaceVAE
from vmod import Vmodel
from gp import GP
from data_parser import COIL100Dataset, get_n_views, get_num_objects

print("‚úÖ All modules imported")

## 2. Configuration

In [None]:
# Task 3 (Extrapolation) configuration
CONFIG = {
    'task': 'task3_extrapolation',
    'data_path': './data/coil100/coil100_task3_extrapolation.h5',
    'batch_size': 64,
    'xdim': 64,
}

# Auto-detect paths
if os.path.exists('./GPPVAE/results'):
    CONFIG['results_base'] = './GPPVAE/results'
    print(f"‚úÖ Found results at: {os.path.abspath(CONFIG['results_base'])}")
elif os.path.exists('./results'):
    CONFIG['results_base'] = './results'
    print(f"‚úÖ Found results at: {os.path.abspath(CONFIG['results_base'])}")
else:
    CONFIG['results_base'] = './results'
    print(f"‚ö†Ô∏è Results folder not found")

if os.path.exists('./GPPVAE/data/coil100/coil100_task3_extrapolation.h5'):
    CONFIG['data_path'] = './GPPVAE/data/coil100/coil100_task3_extrapolation.h5'
    print(f"‚úÖ Found data at: {os.path.abspath(CONFIG['data_path'])}")
elif os.path.exists('./data/coil100/coil100_task3_extrapolation.h5'):
    CONFIG['data_path'] = './data/coil100/coil100_task3_extrapolation.h5'
    print(f"‚úÖ Found data at: {os.path.abspath(CONFIG['data_path'])}")

# View configuration for Task 3 (Extrapolation)
# Train: views 0-9 (0¬∞-180¬∞) 
# Val: views 10-11 (200¬∞, 220¬∞) - near extrapolation
# Test: views 12-17 (240¬∞, 260¬∞, 280¬∞, 300¬∞, 320¬∞, 340¬∞) - far extrapolation
TRAIN_VIEW_INDICES = list(range(0, 10))  # 0-9 (0¬∞ to 180¬∞)
VAL_VIEW_INDICES = [10, 11]  # 200¬∞, 220¬∞ (near extrapolation)
TEST_VIEW_INDICES = [12, 13, 14, 15, 16, 17]  # 240¬∞-340¬∞ (far extrapolation)
VIEW_ANGLES = {i: i * 20 for i in range(18)}

# Compute extrapolation distances
TRAIN_MAX_ANGLE = max([VIEW_ANGLES[v] for v in TRAIN_VIEW_INDICES])  # 180¬∞
EXTRAPOLATION_DISTANCES = {v: VIEW_ANGLES[v] - TRAIN_MAX_ANGLE for v in VAL_VIEW_INDICES + TEST_VIEW_INDICES}

print(f"\nüìä Extrapolation Task Setup:")
print(f"   Train views: {[VIEW_ANGLES[v] for v in TRAIN_VIEW_INDICES]}¬∞ (max: {TRAIN_MAX_ANGLE}¬∞)")
print(f"   Val views (near): {[VIEW_ANGLES[v] for v in VAL_VIEW_INDICES]}¬∞")
print(f"   Test views (far): {[VIEW_ANGLES[v] for v in TEST_VIEW_INDICES]}¬∞")
print(f"\n   Extrapolation distances:")
for v, dist in EXTRAPOLATION_DISTANCES.items():
    label = "(near)" if v in VAL_VIEW_INDICES else "(far)"
    print(f"      {VIEW_ANGLES[v]}¬∞ ‚Üí +{dist}¬∞ beyond training {label}")

# Kernel configurations
KERNEL_CONFIGS = {
    'fullrank': {
        'folder': 'task3_fullrank',
        'view_kernel': 'full_rank',
        'kernel_kwargs': {},
        'display_name': 'Full Rank',
        'color': '#e74c3c',
        'expected': 'FAIL (no structure)',
    },
    'periodic': {
        'folder': 'task3_periodic',
        'view_kernel': 'periodic',
        'kernel_kwargs': {'period': 360.0, 'lengthscale': 1.0, 'variance': 1.0},
        'display_name': 'Periodic',
        'color': '#3498db',
        'expected': 'SUCCEED (wraps at 360¬∞)',
    },
    'sm_wrapped': {
        'folder': 'task3_sm_wrapped',
        'view_kernel': 'sm_circle',
        'kernel_kwargs': {'freq_init': [1/360.0, 1/40.0], 'weight_init':[0.5, 0.5], 'length_init':[90,30]},
        'display_name': 'SM (Wrapped)',
        'color': '#2ecc71',
        'expected': 'Should leverage periodicity',
    },
    'sm_free': {
        'folder': 'task3_sm_free',
        'view_kernel': 'sm_circle',
        'kernel_kwargs': {'freq_init': [1/360.0, 1/40.0], 'weight_init':[0.5, 0.5], 'use_angle_input': True},
        'display_name': 'SM (Free)',
        'color': '#9b59b6',
        'expected': 'Intermediate',
    },
}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")

## 3. Load Data

In [None]:
train_data = COIL100Dataset(CONFIG['data_path'], split='train', use_angle_encoding=False)
val_data = COIL100Dataset(CONFIG['data_path'], split='val', use_angle_encoding=False)
test_data = COIL100Dataset(CONFIG['data_path'], split='test', use_angle_encoding=False)

train_queue = DataLoader(train_data, batch_size=CONFIG['batch_size'], shuffle=False)
val_queue = DataLoader(val_data, batch_size=CONFIG['batch_size'], shuffle=False)
test_queue = DataLoader(test_data, batch_size=CONFIG['batch_size'], shuffle=False)

P = get_num_objects(CONFIG['data_path'])
Q = get_n_views()

print(f"\nüìä Task 3 (Extrapolation) Dataset:")
print(f"   Objects (P): {P}")
print(f"   Views (Q): {Q}")
print(f"   Train: {len(train_data)} samples")
print(f"   Val (near extrap): {len(val_data)} samples")
print(f"   Test (far extrap): {len(test_data)} samples")

# Analyze view distribution
train_views = set(train_data.Rid.numpy())
val_views = set(val_data.Rid.numpy())
test_views = set(test_data.Rid.numpy())

print(f"\n   Train view indices: {sorted(train_views)}")
print(f"   Val view indices: {sorted(val_views)}")
print(f"   Test view indices: {sorted(test_views)}")
print(f"   Train angles: {[VIEW_ANGLES.get(int(v), v*20) for v in sorted(train_views)]}¬∞")
print(f"   Val angles: {[VIEW_ANGLES.get(int(v), v*20) for v in sorted(val_views)]}¬∞")
print(f"   Test angles: {[VIEW_ANGLES.get(int(v), v*20) for v in sorted(test_views)]}¬∞")

Dt = Variable(train_data.Did.long(), requires_grad=False).to(device)
Wt = Variable(train_data.Rid.long(), requires_grad=False).to(device)
Dval = Variable(val_data.Did.long(), requires_grad=False).to(device)
Wval = Variable(val_data.Rid.long(), requires_grad=False).to(device)
Dtest = Variable(test_data.Did.long(), requires_grad=False).to(device)
Wtest = Variable(test_data.Rid.long(), requires_grad=False).to(device)

## 4. Helper Functions

In [None]:
def find_run_folders(results_base, kernel_folder):
    kernel_path = os.path.join(results_base, kernel_folder)
    if not os.path.exists(kernel_path):
        print(f"‚ö†Ô∏è Kernel folder not found: {kernel_path}")
        return []
    runs = sorted([d for d in os.listdir(kernel_path) 
                   if os.path.isdir(os.path.join(kernel_path, d))])
    return [os.path.join(kernel_path, r) for r in runs]


def load_vae_config(run_folder=None):
    possible_paths = [
        './GPPVAE/pysrc/coil100/vae.cfg.p',
        './out/vae/vae.cfg.p',
    ]
    for path in possible_paths:
        if os.path.exists(path):
            try:
                with open(path, 'rb') as f:
                    return pickle.load(f)
            except:
                continue
    return {
        'img_size': 128, 'nf': 32, 'zdim': 256,
        'steps': 5, 'colors': 3, 'act': 'elu', 'vy': 0.001
    }


def load_models(run_folder, kernel_config, P, Q, xdim, device):
    weights_dir = os.path.join(run_folder, 'weights')
    gp_weights_path = os.path.join(weights_dir, 'gp_weights.best.pt')
    vae_weights_path = os.path.join(weights_dir, 'vae_weights.best.pt')
    
    if not os.path.exists(gp_weights_path):
        raise FileNotFoundError(f"GP weights not found: {gp_weights_path}")
    if not os.path.exists(vae_weights_path):
        raise FileNotFoundError(f"VAE weights not found: {vae_weights_path}")
    
    vae_cfg = load_vae_config(run_folder)
    vae = FaceVAE(**vae_cfg).to(device)
    vae.load_state_dict(torch.load(vae_weights_path, map_location=device))
    vae.eval()
    
    vm = Vmodel(
        P=P, Q=Q, p=xdim,
        view_kernel=kernel_config['view_kernel'],
        **kernel_config['kernel_kwargs']
    ).to(device)
    gp = GP(n_rand_effs=1).to(device)
    
    checkpoint = torch.load(gp_weights_path, map_location=device)
    gp.load_state_dict(checkpoint['gp_state'])
    vm.load_state_dict(checkpoint['vm_state'])
    
    vm.eval()
    gp.eval()
    
    return vae, vm, gp


def encode_dataset(vae, data_queue, device):
    vae.eval()
    n = data_queue.dataset.Y.shape[0]
    zdim = 256
    Zm = torch.zeros(n, zdim).to(device)
    with torch.no_grad():
        for data in data_queue:
            y = data[0].to(device)
            idxs = data[-1].to(device)
            zm, _ = vae.encode(y)
            Zm[idxs] = zm.detach()
    return Zm


def evaluate_with_per_view(vae, vm, gp, train_queue, eval_queue, 
                           Dt, Wt, D_eval, W_eval, device):
    """
    Evaluate on any split (val or test) with per-view MSE breakdown.
    """
    vae.eval()
    vm.eval()
    gp.eval()
    
    with torch.no_grad():
        Zm = encode_dataset(vae, train_queue, device)
        Vt = vm(Dt, Wt).detach()
        V_eval = vm(D_eval, W_eval).detach()
        
        vs = gp.get_vs()
        U, UBi, _ = gp.U_UBi_Shb([Vt], vs)
        Kiz = gp.solve(Zm, U, UBi, vs)
        Zo_eval = vs[0] * V_eval.mm(Vt.transpose(0, 1).mm(Kiz))
        
        eval_Rid = eval_queue.dataset.Rid
        mse_per_view = {}
        mse_per_sample = []
        mse_total = 0.0
        
        for data in eval_queue:
            idxs = data[-1].to(device)
            Y_eval = data[0].to(device)
            Yo = vae.decode(Zo_eval[idxs])
            mse_batch = ((Y_eval - Yo) ** 2).view(Y_eval.shape[0], -1).mean(1)
            
            for i, idx in enumerate(data[-1]):
                view = int(eval_Rid[idx].item())
                mse_val = mse_batch[i].item()
                
                if view not in mse_per_view:
                    mse_per_view[view] = []
                mse_per_view[view].append(mse_val)
                mse_per_sample.append(mse_val)
            
            mse_total += mse_batch.sum().item()
        
        mse_mean = mse_total / len(eval_queue.dataset)
    
    return mse_mean, np.array(mse_per_sample), mse_per_view


def get_reconstructions(vae, vm, gp, train_queue, eval_queue,
                        Dt, Wt, D_eval, W_eval, device, n_samples=24):
    vae.eval()
    vm.eval()
    gp.eval()
    
    with torch.no_grad():
        Zm = encode_dataset(vae, train_queue, device)
        Vt = vm(Dt, Wt).detach()
        V_eval = vm(D_eval, W_eval).detach()
        
        vs = gp.get_vs()
        U, UBi, _ = gp.U_UBi_Shb([Vt], vs)
        Kiz = gp.solve(Zm, U, UBi, vs)
        Zo_eval = vs[0] * V_eval.mm(Vt.transpose(0, 1).mm(Kiz))
        
        n_total = len(eval_queue.dataset)
        sample_stride = max(1, n_total // n_samples)
        sample_indices = list(range(0, n_total, sample_stride))[:n_samples]
        
        Y_orig = eval_queue.dataset.Y[sample_indices].numpy().transpose(0, 2, 3, 1)
        sample_indices_tensor = torch.tensor(sample_indices, dtype=torch.long).to(device)
        Y_recon = vae.decode(Zo_eval[sample_indices_tensor]).cpu().numpy().transpose(0, 2, 3, 1)
    
    return Y_orig, Y_recon, sample_indices


def compute_extrapolation_distance(view_idx, train_max_angle=180):
    # \"\"\"Compute how far a view is beyond the training range.\"\"\"
    angle = view_idx * 20
    if angle <= train_max_angle:
        return 0
    # Check both directions (could wrap around)
    forward_dist = angle - train_max_angle
    backward_dist = (360 - angle) + 0  # Distance to 0¬∞ (start of training)
    return min(forward_dist, backward_dist)

print("‚úÖ Helper functions defined")

## 5. Evaluate All Kernels

In [None]:
results = {}

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    print(f"\n{'='*60}")
    print(f"Evaluating: {kernel_config['display_name']}")
    print(f"Expected: {kernel_config['expected']}")
    print(f"{'='*60}")
    
    run_folders = find_run_folders(CONFIG['results_base'], kernel_config['folder'])
    
    if not run_folders:
        print(f"‚ö†Ô∏è No runs found for {kernel_name}")
        continue
    
    print(f"Found {len(run_folders)} runs (seeds)")
    
    kernel_results = {
        'mse_val_per_seed': [],
        'mse_test_per_seed': [],
        'mse_per_view_val': [],
        'mse_per_view_test': [],
        'kernel_matrices': [],
        'variance_ratios': [],
    }
    
    for i, run_folder in enumerate(run_folders):
        try:
            print(f"  Seed {i}: {os.path.basename(run_folder)}...", end=" ")
            
            vae, vm, gp = load_models(
                run_folder, kernel_config, P, Q, CONFIG['xdim'], device
            )
            
            # Evaluate on VAL (near extrapolation)
            mse_val, _, mse_per_view_val = evaluate_with_per_view(
                vae, vm, gp, train_queue, val_queue,
                Dt, Wt, Dval, Wval, device
            )
            
            # Evaluate on TEST (far extrapolation)
            mse_test, _, mse_per_view_test = evaluate_with_per_view(
                vae, vm, gp, train_queue, test_queue,
                Dt, Wt, Dtest, Wtest, device
            )
            
            with torch.no_grad():
                K = vm.get_kernel_matrix().cpu().numpy()
                vs = gp.get_vs().cpu().numpy()
                variance_ratio = vs[0] / (vs[0] + vs[1])
            
            kernel_results['mse_val_per_seed'].append(mse_val)
            kernel_results['mse_test_per_seed'].append(mse_test)
            kernel_results['mse_per_view_val'].append(mse_per_view_val)
            kernel_results['mse_per_view_test'].append(mse_per_view_test)
            kernel_results['kernel_matrices'].append(K)
            kernel_results['variance_ratios'].append(variance_ratio)
            
            print(f"Val={mse_val:.6f}, Test={mse_test:.6f}")
            
        except Exception as e:
            print(f"‚ùå Error: {e}")
            continue
    
    if kernel_results['mse_test_per_seed']:
        kernel_results['mean_mse_val'] = np.mean(kernel_results['mse_val_per_seed'])
        kernel_results['std_mse_val'] = np.std(kernel_results['mse_val_per_seed'])
        kernel_results['mean_mse_test'] = np.mean(kernel_results['mse_test_per_seed'])
        kernel_results['std_mse_test'] = np.std(kernel_results['mse_test_per_seed'])
        
        # Extrapolation degradation: how much worse is far vs near
        kernel_results['extrap_degradation'] = kernel_results['mean_mse_test'] / kernel_results['mean_mse_val']
        
        # Aggregate per-view MSE
        all_views = set()
        for pv in kernel_results['mse_per_view_val'] + kernel_results['mse_per_view_test']:
            all_views.update(pv.keys())
        
        kernel_results['mean_mse_per_view'] = {}
        for view in all_views:
            view_mses = []
            for pv in kernel_results['mse_per_view_val'] + kernel_results['mse_per_view_test']:
                if view in pv:
                    view_mses.extend(pv[view])
            kernel_results['mean_mse_per_view'][view] = np.mean(view_mses)
        
        print(f"\n  üìä {kernel_config['display_name']} Summary:")
        print(f"     Val MSE (near): {kernel_results['mean_mse_val']:.6f} ¬± {kernel_results['std_mse_val']:.6f}")
        print(f"     Test MSE (far): {kernel_results['mean_mse_test']:.6f} ¬± {kernel_results['std_mse_test']:.6f}")
        print(f"     Degradation (far/near): {kernel_results['extrap_degradation']:.2f}x")
    
    results[kernel_name] = kernel_results

print("\n" + "="*60)
print("‚úÖ Evaluation complete!")

## 6. Summary Table

In [None]:
summary_data = []

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and results[kernel_name]['mse_test_per_seed']:
        r = results[kernel_name]
        summary_data.append({
            'Kernel': kernel_config['display_name'],
            'Val MSE (near)': r['mean_mse_val'],
            'Val Std': r['std_mse_val'],
            'Test MSE (far)': r['mean_mse_test'],
            'Test Std': r['std_mse_test'],
            'Degradation': r['extrap_degradation'],
            'N Seeds': len(r['mse_test_per_seed']),
            'Var Ratio': np.mean(r['variance_ratios']),
        })

summary_df = pd.DataFrame(summary_data)
summary_df = summary_df.sort_values('Test MSE (far)')

print("\nüìä Task 3 (Extrapolation) - Kernel Comparison Summary")
print("="*90)
print(summary_df.to_string(index=False))
print("="*90)
print("\n‚ö†Ô∏è Degradation = Test MSE / Val MSE (lower = better extrapolation)")

In [None]:
# Styled table
styled_df = summary_df.style.format({
    'Val MSE (near)': '{:.6f}',
    'Val Std': '{:.6f}',
    'Test MSE (far)': '{:.6f}',
    'Test Std': '{:.6f}',
    'Degradation': '{:.2f}x',
    'Var Ratio': '{:.3f}',
}).background_gradient(subset=['Test MSE (far)'], cmap='RdYlGn_r')

styled_df

## 6.1. MSE Comparison Plots

In [None]:
# Bar plot with error bars - clean style (Test MSE - far extrapolation)
fig, ax = plt.subplots(1, 1, figsize=(6, 5))

kernel_names = []
means = []
stds = []

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and results[kernel_name]['mse_test_per_seed']:
        kernel_names.append(kernel_config['display_name'])
        means.append(results[kernel_name]['mean_mse_test'])
        stds.append(results[kernel_name]['std_mse_test'])

x = np.arange(len(kernel_names))
bars = ax.bar(x, means, yerr=stds, capsize=5, color='#808080', alpha=0.8, edgecolor='black', linewidth=0.5)
ax.set_xticks(x)
ax.set_xticklabels(kernel_names, rotation=0, ha='center', fontsize=11)
ax.set_ylabel('MSE [test set - far extrapolation]', fontsize=12)
ax.set_ylim(bottom=0)
ax.tick_params(labelsize=11)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(False)

plt.tight_layout()
plt.savefig('./notebooks/analysis/task3_mse_bar.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä Saved: task3_mse_bar.png (300 DPI)")

# Box plot - clean style
fig, ax = plt.subplots(1, 1, figsize=(6, 5))

data_for_box = []
labels_for_box = []

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and results[kernel_name]['mse_test_per_seed']:
        data_for_box.append(results[kernel_name]['mse_test_per_seed'])
        labels_for_box.append(kernel_config['display_name'])

bp = ax.boxplot(data_for_box, labels=labels_for_box, patch_artist=True)
for patch in bp['boxes']:
    patch.set_facecolor('#808080')
    patch.set_alpha(0.7)
    patch.set_linewidth(0.5)

for element in ['whiskers', 'fliers', 'means', 'medians', 'caps']:
    plt.setp(bp[element], linewidth=0.5)

ax.set_ylabel('MSE [test set - far extrapolation]', fontsize=12)
ax.tick_params(axis='x', rotation=0, labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(False)

plt.tight_layout()
plt.savefig('./notebooks/analysis/task3_mse_boxplot.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä Saved: task3_mse_boxplot.png (300 DPI)")


## 7. Near vs Far Extrapolation Comparison

In [None]:
# Prepare data
kernel_names = []
val_means = []
val_stds = []
test_means = []
test_stds = []
colors = []

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and results[kernel_name]['mse_test_per_seed']:
        kernel_names.append(kernel_config['display_name'])
        val_means.append(results[kernel_name]['mean_mse_val'])
        val_stds.append(results[kernel_name]['std_mse_val'])
        test_means.append(results[kernel_name]['mean_mse_test'])
        test_stds.append(results[kernel_name]['std_mse_test'])
        colors.append(kernel_config['color'])

x = np.arange(len(kernel_names))

# Plot 1: Grouped bar chart (Val vs Test) - Clean style
fig, ax = plt.subplots(1, 1, figsize=(8, 5))
width = 0.35

bars1 = ax.bar(x - width/2, val_means, width, yerr=val_stds, capsize=3,
               label='Val (near)', color='lightblue', edgecolor='navy', linewidth=0.5)
bars2 = ax.bar(x + width/2, test_means, width, yerr=test_stds, capsize=3,
               label='Test (far)', color='salmon', edgecolor='darkred', linewidth=0.5)

ax.set_xticks(x)
ax.set_xticklabels(kernel_names, rotation=0, ha='center', fontsize=11)
ax.set_ylabel('MSE', fontsize=12)
ax.tick_params(labelsize=11)
ax.legend(frameon=False, fontsize=11)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(False)

plt.tight_layout()
plt.savefig('./notebooks/analysis/task3_near_vs_far_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä Saved: task3_near_vs_far_comparison.png (300 DPI)")

# Plot 2: Degradation factor - Clean style
fig, ax = plt.subplots(1, 1, figsize=(8, 5))

degradations = [t/v for t, v in zip(test_means, val_means)]
bars = ax.bar(x, degradations, color=colors, alpha=0.8, edgecolor='black', linewidth=0.5)
ax.axhline(y=1.0, color='green', linestyle='--', linewidth=1, alpha=0.7)

ax.set_xticks(x)
ax.set_xticklabels(kernel_names, rotation=0, ha='center', fontsize=11)
ax.set_ylabel('Degradation (Test/Val)', fontsize=12)
ax.set_ylim(bottom=0)
ax.tick_params(labelsize=11)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(False)

for bar, deg in zip(bars, degradations):
    ax.annotate(f'{deg:.2f}x',
                xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()),
                xytext=(0, 3), textcoords='offset points',
                ha='center', va='bottom', fontsize=11)

plt.tight_layout()
plt.savefig('./notebooks/analysis/task3_degradation.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä Saved: task3_degradation.png (300 DPI)")

# Plot 3: Test MSE with kernel colors - Clean style
fig, ax = plt.subplots(1, 1, figsize=(8, 5))

bars = ax.bar(x, test_means, yerr=test_stds, capsize=5, color=colors, 
              alpha=0.8, edgecolor='black', linewidth=0.5)
ax.set_xticks(x)
ax.set_xticklabels(kernel_names, rotation=0, ha='center', fontsize=11)
ax.set_ylabel('MSE [test set - far extrapolation]', fontsize=12)
ax.set_ylim(bottom=0)
ax.tick_params(labelsize=11)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(False)

plt.tight_layout()
plt.savefig('./notebooks/analysis/task3_test_mse_colored.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä Saved: task3_test_mse_colored.png (300 DPI)")


## 8. Per-View Extrapolation Analysis

Key question: How does MSE degrade as we extrapolate further from training?

In [None]:
import pylab as pl
pl.rcdefaults()

# Get all extrapolation views
extrap_views = sorted(list(val_views) + list(test_views))

# Define distinct colors for each kernel (avoiding green/red)
line_colors = {
    'Full Rank': '#1f77b4',      # Blue
    'Periodic': '#ff7f0e',        # Orange
    'SM (Wrapped)': '#9467bd',    # Purple
    'SM (Free)': '#8c564b',       # Brown
}

# Plot 1: MSE vs View Angle
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and 'mean_mse_per_view' in results[kernel_name]:
        mean_per_view = results[kernel_name]['mean_mse_per_view']
        angles = [v * 20 for v in sorted(mean_per_view.keys())]
        mses = [mean_per_view[v] for v in sorted(mean_per_view.keys())]
        display_name = kernel_config['display_name']
        color = line_colors.get(display_name, kernel_config['color'])
        ax.plot(angles, mses, 'o-', label=display_name, 
                color=color, markersize=8, linewidth=2, alpha=0.8)

# Mark training range (training ends at 180¬∞, extrapolation starts at 200¬∞)
ax.axvspan(0, 200, alpha=0.15, color='green', zorder=0)
ax.axvspan(200, 360, alpha=0.15, color='red', zorder=0)

ax.set_xlabel('View Angle (¬∞)', fontsize=16)
ax.set_ylabel('Mean MSE', fontsize=16)
ax.set_xlim(160, 360)
ax.set_xticks([180, 200, 220, 240, 260, 280, 300, 320, 340])
ax.tick_params(labelsize=16)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Create custom patches for region legend
from matplotlib.patches import Patch
region_legend_elements = [
    Patch(facecolor='green', alpha=0.15, label='Training range'),
    Patch(facecolor='red', alpha=0.15, label='Extrapolation')
]

# Add region legend (more space for larger text)
region_legend = ax.legend(handles=region_legend_elements, loc='upper center', 
                         bbox_to_anchor=(0.5, -0.12), ncol=2, frameon=False, fontsize=16)
ax.add_artist(region_legend)  # Keep this legend when adding the next one

# Add kernel legend below (more space for larger text)
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.22), ncol=4, frameon=False, fontsize=16)

plt.tight_layout()
plt.savefig('./notebooks/analysis/task3_mse_vs_angle.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä Saved: task3_mse_vs_angle.png (300 DPI)")

# Plot 2: MSE vs Extrapolation Distance (no region shading, so only one legend)
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and 'mean_mse_per_view' in results[kernel_name]:
        mean_per_view = results[kernel_name]['mean_mse_per_view']
        
        distances = []
        mses = []
        for view, mse in sorted(mean_per_view.items()):
            dist = compute_extrapolation_distance(view, TRAIN_MAX_ANGLE)
            if dist > 0:  # Only extrapolation views
                distances.append(dist)
                mses.append(mse)
        
        display_name = kernel_config['display_name']
        color = line_colors.get(display_name, kernel_config['color'])
        ax.plot(distances, mses, 'o-', label=display_name, 
                color=color, markersize=8, linewidth=2, alpha=0.8)

ax.set_xlabel('Distance from Training Range (¬∞)', fontsize=16)
ax.set_ylabel('Mean MSE', fontsize=16)
ax.tick_params(labelsize=16)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Legend below plot
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.12), ncol=4, frameon=False, fontsize=16)

plt.tight_layout()
plt.savefig('./notebooks/analysis/task3_mse_vs_distance.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä Saved: task3_mse_vs_distance.png (300 DPI)")


In [None]:
# Per-view MSE table
per_view_data = []
all_extrap_views = sorted([int(v) for v in list(val_views) + list(test_views)])

for view in all_extrap_views:
    angle = view * 20
    dist = compute_extrapolation_distance(view, TRAIN_MAX_ANGLE)
    region = 'Near' if view in VAL_VIEW_INDICES else 'Far'
    
    row = {
        'View': f"{angle}¬∞", 
        'Distance': f"+{dist}¬∞",
        'Region': region
    }
    for kernel_name, kernel_config in KERNEL_CONFIGS.items():
        if kernel_name in results and 'mean_mse_per_view' in results[kernel_name]:
            row[kernel_config['display_name']] = results[kernel_name]['mean_mse_per_view'].get(view, np.nan)
    per_view_data.append(row)

per_view_df = pd.DataFrame(per_view_data)
print("\nüìä Per-View MSE (Extrapolation Views):")
print("="*80)
print(per_view_df.to_string(index=False))

# Best kernel per view
print("\nüèÜ Best Kernel per Extrapolation View:")
kernel_cols = [c for c in per_view_df.columns if c not in ['View', 'Distance', 'Region']]
for _, row in per_view_df.iterrows():
    best_kernel = min(kernel_cols, key=lambda k: row[k] if pd.notna(row[k]) else float('inf'))
    print(f"   {row['View']} ({row['Region']}): {best_kernel} ({row[best_kernel]:.6f})")

## 9. Periodicity Analysis

Check if periodic kernels leverage the wrap-around structure (340¬∞ is close to 0¬∞).

In [None]:
# Compare MSE at extreme views (340¬∞) which should be close to 0¬∞ in periodic space
print("üìä Periodicity Benefit Analysis:")
print("="*60)
print("\nIf periodic kernels work correctly:")
print("- 340¬∞ should have LOWER MSE than 280¬∞ (340¬∞ is 20¬∞ from 0¬∞, 280¬∞ is 60¬∞ from training)")
print("- Full Rank should show INCREASING MSE as we go further from training\n")

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and 'mean_mse_per_view' in results[kernel_name]:
        mean_per_view = results[kernel_name]['mean_mse_per_view']
        
        # Get MSE at key views
        mse_280 = mean_per_view.get(14, np.nan)  # 280¬∞
        mse_340 = mean_per_view.get(17, np.nan)  # 340¬∞
        
        if not np.isnan(mse_280) and not np.isnan(mse_340):
            ratio = mse_340 / mse_280
            pattern = "‚úÖ PERIODIC (340¬∞ < 280¬∞)" if ratio < 1 else "‚ùå NON-PERIODIC (340¬∞ > 280¬∞)"
            print(f"{kernel_config['display_name']:15s}: 280¬∞={mse_280:.6f}, 340¬∞={mse_340:.6f}, ratio={ratio:.2f} {pattern}")

## 10. Kernel Matrices

In [None]:
import pylab as pl
pl.rcdefaults()

n_kernels = len([k for k in results if results[k]['kernel_matrices']])
if n_kernels > 0:
    fig, axes = pl.subplots(2, 2, figsize=(8, 8))
    axes = axes.flatten()

    ax_idx = 0
    for kernel_name, kernel_config in KERNEL_CONFIGS.items():
        if kernel_name in results and results[kernel_name]['kernel_matrices']:
            # Use best seed (lowest test MSE) for consistency
            best_seed_idx = np.argmin(results[kernel_name]['mse_test_per_seed'])
            K = results[kernel_name]['kernel_matrices'][best_seed_idx]

            ax = axes[ax_idx]
            im = ax.imshow(K, vmin=-0.4, vmax=1, aspect='auto')
            ax.set_title(f"{kernel_config['display_name']}\nWW (view cov)", fontsize=17)

            # Mark training vs extrapolation regions (boundary at 9.5 = after view 9 at 180¬∞)
            ax.axvline(x=9.5, color='red', linestyle='--', linewidth=2)
            ax.axhline(y=9.5, color='red', linestyle='--', linewidth=2)

            tick_positions = [0, 5, 10, 17]
            tick_labels = [f"{p*20}¬∞" for p in tick_positions]
            ax.set_xticks(tick_positions)
            ax.set_xticklabels(tick_labels, fontsize=17)
            ax.set_yticks(tick_positions)
            ax.set_yticklabels(tick_labels, fontsize=17)
            ax.tick_params(labelsize=16)

            cbar = pl.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
            cbar.ax.tick_params(labelsize=16)
            ax_idx += 1

    # Hide unused subplots
    for i in range(ax_idx, 4):
        axes[i].axis('off')

    # Create legend with red dashed line
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], color='red', linestyle='--', linewidth=2, 
               label='Training/extrapolation boundary')
    ]
    
    # Add legend below the figure
    fig.legend(handles=legend_elements, loc='lower center', 
               bbox_to_anchor=(0.5, -0.05), ncol=1, frameon=False, fontsize=17)
    
    pl.tight_layout()
    pl.savefig('./notebooks/analysis/task3_kernel_matrices.png', dpi=300, bbox_inches='tight')
    pl.show()
    print("üìä Saved: task3_kernel_matrices.png (300 DPI)")
else:
    print("‚ö†Ô∏è No kernel matrices to display")


## 11. Sample Reconstructions (Far Extrapolation)

In [None]:
n_samples = 8
n_kernels = len([k for k in results if results[k]['mse_test_per_seed']])

if n_kernels > 0:
    fig, axes = plt.subplots(n_kernels + 1, n_samples, figsize=(n_samples * 1.5, (n_kernels + 1) * 1.5))
    
    # Ground truth
    n_total = len(test_data)
    sample_stride = max(1, n_total // n_samples)
    sample_indices = list(range(0, n_total, sample_stride))[:n_samples]
    Y_gt = test_data.Y[sample_indices].numpy().transpose(0, 2, 3, 1)
    
    for i in range(n_samples):
        axes[0, i].imshow(np.clip(Y_gt[i], 0, 1))
        axes[0, i].axis('off')
        view_idx = int(test_data.Rid[sample_indices[i]].item())
        axes[0, i].set_title(f"{view_idx*20}¬∞", fontsize=10)
    axes[0, 0].set_ylabel('GT', fontsize=12)
    
    row_idx = 1
    for kernel_name, kernel_config in KERNEL_CONFIGS.items():
        if kernel_name in results and results[kernel_name]['mse_test_per_seed']:
            best_seed_idx = np.argmin(results[kernel_name]['mse_test_per_seed'])
            best_run = find_run_folders(CONFIG['results_base'], kernel_config['folder'])[best_seed_idx]
            
            vae, vm, gp = load_models(best_run, kernel_config, P, Q, CONFIG['xdim'], device)
            Y_orig, Y_recon, _ = get_reconstructions(
                vae, vm, gp, train_queue, test_queue,
                Dt, Wt, Dtest, Wtest, device, n_samples=n_samples
            )
            
            for i in range(n_samples):
                axes[row_idx, i].imshow(np.clip(Y_recon[i], 0, 1))
                axes[row_idx, i].axis('off')
            
            mse = results[kernel_name]['mse_test_per_seed'][best_seed_idx]
            axes[row_idx, 0].set_ylabel(f"{kernel_config['display_name']}\n({mse:.4f})", fontsize=10)
            row_idx += 1
    
    plt.suptitle('Task 3 (Extrapolation): Reconstructions at Far Views (280¬∞-340¬∞)', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig('./notebooks/analysis/task3_reconstructions.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("üìä Saved: task3_reconstructions.png")


## 12. Statistical Analysis

In [None]:
from scipy import stats

kernel_names_list = [k for k in KERNEL_CONFIGS if k in results and results[k]['mse_test_per_seed']]

if len(kernel_names_list) >= 2:
    print("\nüìä Pairwise T-Tests on Test MSE (far extrapolation)")
    print("="*60)
    
    for i, k1 in enumerate(kernel_names_list):
        for j, k2 in enumerate(kernel_names_list):
            if i < j:
                mse1 = results[k1]['mse_test_per_seed']
                mse2 = results[k2]['mse_test_per_seed']
                
                t_stat, p_value = stats.ttest_ind(mse1, mse2)
                
                k1_name = KERNEL_CONFIGS[k1]['display_name']
                k2_name = KERNEL_CONFIGS[k2]['display_name']
                sig = "*" if p_value < 0.05 else ""
                sig = "**" if p_value < 0.01 else sig
                print(f"{k1_name} vs {k2_name}: p = {p_value:.4f} {sig}")
    
    print("\n* p < 0.05, ** p < 0.01")

## 13. Callback-Style Plots

In [None]:
from callbacks import callback_gppvae, _compose_multi
import pylab as pl

def generate_callback_plot_for_kernel(kernel_name, kernel_config, results, 
                                       train_queue, test_queue,
                                       Dt, Wt, Dtest, Wtest, P, Q, xdim, device,
                                       output_file):
    if kernel_name not in results or not results[kernel_name]['mse_test_per_seed']:
        return None
    
    pl.rcdefaults()
    
    best_seed_idx = np.argmin(results[kernel_name]['mse_test_per_seed'])
    best_run = find_run_folders(CONFIG['results_base'], kernel_config['folder'])[best_seed_idx]
    
    vae, vm, gp = load_models(best_run, kernel_config, P, Q, xdim, device)
    
    with torch.no_grad():
        X = vm.x().cpu().numpy()
        W = vm.v().cpu().numpy()
        XX = X @ X.T
        WW = W @ W.T
        covs = {"XX": XX[:20, :20], "WW": WW}
        
        Zm = encode_dataset(vae, train_queue, device)
        Vt = vm(Dt, Wt).detach()
        Vtest = vm(Dtest, Wtest).detach()
        
        vs = gp.get_vs()
        U, UBi, _ = gp.U_UBi_Shb([Vt], vs)
        Kiz = gp.solve(Zm, U, UBi, vs)
        Zo_test = vs[0] * Vtest.mm(Vt.transpose(0, 1).mm(Kiz))
        
        n_total = len(test_queue.dataset)
        sample_stride = max(1, n_total // 24)
        sample_indices = np.arange(0, n_total, sample_stride)[:24]
        sample_indices_tensor = torch.tensor(sample_indices, dtype=torch.long).to(device)
        
        Yv = test_queue.dataset.Y[sample_indices].numpy().transpose((0, 2, 3, 1))
        Y_input = test_queue.dataset.Y[sample_indices].to(device)
        Zm_test, _ = vae.encode(Y_input)
        Yr = vae.decode(Zm_test).data.cpu().numpy().transpose((0, 2, 3, 1))
        Yo = vae.decode(Zo_test[sample_indices_tensor]).data.cpu().numpy().transpose((0, 2, 3, 1))
        
        imgs = {"Yv": Yv, "Yr": Yr, "Yo": Yo}
        
        n_seeds = len(results[kernel_name]['mse_test_per_seed'])
        history = {
            "loss": results[kernel_name]['mse_test_per_seed'],
            "vs": [[vs.cpu().numpy()[0], vs.cpu().numpy()[1]]] * n_seeds,
            "recon_term": [0.0] * n_seeds,
            "gp_nll": [0.0] * n_seeds,
            "mse_out": results[kernel_name]['mse_test_per_seed'],
            "mse": [results[kernel_name]['mean_mse_test']] * n_seeds,
            "mse_val": results[kernel_name]['mse_val_per_seed'],
        }
    
    callback_gppvae(epoch=n_seeds - 1, history=history, covs=covs, imgs=imgs, ffile=output_file)
    return output_file

print("‚úÖ Callback plot function defined")

In [None]:
# Generate callback plots
for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and results[kernel_name]['mse_test_per_seed']:
        print(f"\nüìä Generating callback plot for {kernel_config['display_name']}...")
        filename = f'./notebooks/analysis/task3_callback_{kernel_name}.png'
        result = generate_callback_plot_for_kernel(
            kernel_name, kernel_config, results,
            train_queue, test_queue,
            Dt, Wt, Dtest, Wtest, P, Q, CONFIG['xdim'], device,
            output_file=filename
        )
        if result:
            print(f"   ‚úÖ Saved: {filename}")
            from IPython.display import Image, display
            display(Image(filename=filename))

## 14. Export Results

In [None]:
# Detailed results
detailed_results = []

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and results[kernel_name]['mse_test_per_seed']:
        for seed_idx in range(len(results[kernel_name]['mse_test_per_seed'])):
            detailed_results.append({
                'kernel': kernel_config['display_name'],
                'seed': seed_idx,
                'val_mse': results[kernel_name]['mse_val_per_seed'][seed_idx],
                'test_mse': results[kernel_name]['mse_test_per_seed'][seed_idx],
                'variance_ratio': results[kernel_name]['variance_ratios'][seed_idx],
            })

detailed_df = pd.DataFrame(detailed_results)
detailed_df.to_csv('./notebooks/analysis/task3_detailed_results.csv', index=False)
print("üìä Saved: task3_detailed_results.csv")

summary_df.to_csv('./notebooks/analysis/task3_summary.csv', index=False)
print("üìä Saved: task3_summary.csv")

per_view_df.to_csv('./notebooks/analysis/task3_per_view_mse.csv', index=False)
print("üìä Saved: task3_per_view_mse.csv")

## 15. Final Summary

In [None]:
print("\n" + "="*70)
print("üìä TASK 3 (EXTRAPOLATION) - FINAL SUMMARY")
print("="*70)

print("\nüéØ Research Question:")
print("   How do different kernel-induced inductive biases influence")
print("   EXTRAPOLATION to unseen extreme views?")

print("\nüìä Hypothesis Check:")
for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name in results and results[kernel_name]['mse_test_per_seed']:
        print(f"   {kernel_config['display_name']:15s}: {kernel_config['expected']}")

print("\nüèÜ Kernel Ranking (by Far Extrapolation MSE):")
print("-"*50)

ranked = sorted(
    [(k, results[k]['mean_mse_test'], results[k]['std_mse_test'], results[k]['extrap_degradation']) 
     for k in results if results[k]['mse_test_per_seed']],
    key=lambda x: x[1]
)

for rank, (kernel_name, mean_mse, std_mse, degradation) in enumerate(ranked, 1):
    display_name = KERNEL_CONFIGS[kernel_name]['display_name']
    medal = "ü•á" if rank == 1 else "ü•à" if rank == 2 else "ü•â" if rank == 3 else "  "
    print(f"{medal} {rank}. {display_name:15s}: {mean_mse:.6f} ¬± {std_mse:.6f} (degradation: {degradation:.2f}x)")

print("\nüìä Key Findings:")
print("-"*50)

# Check if periodic kernels outperform full rank
if 'fullrank' in results and 'periodic' in results:
    fr_mse = results['fullrank']['mean_mse_test']
    periodic_mse = results['periodic']['mean_mse_test']
    if periodic_mse < fr_mse:
        improvement = (fr_mse - periodic_mse) / fr_mse * 100
        print(f"   ‚úÖ Periodic kernel outperforms Full Rank by {improvement:.1f}%")
    else:
        print(f"   ‚ùå Full Rank performed better than Periodic (unexpected)")

print("\nüìÅ Generated Files:")
print("-"*50)
print("   - task3_near_vs_far.png")
print("   - task3_per_view_extrapolation.png")
print("   - task3_kernel_matrices.png")
print("   - task3_reconstructions.png")
print("   - task3_detailed_results.csv")
print("   - task3_summary.csv")
print("   - task3_per_view_mse.csv")

print("\n" + "="*70)
print("‚úÖ Analysis Complete!")
print("="*70)

## 16. Extract Learned Kernel Parameters (Best Seeds)

Extract the learned hyperparameters from the best-performing seed for each kernel.


In [None]:
import json

print("="*70)
print("üìä LEARNED KERNEL PARAMETERS (BEST SEEDS)")
print("="*70)

for kernel_name, kernel_config in KERNEL_CONFIGS.items():
    if kernel_name not in results or not results[kernel_name]['mse_test_per_seed']:
        continue
    
    print(f"\n{'='*70}")
    print(f"üî∑ {kernel_config['display_name']}")
    print(f"{'='*70}")
    
    # Get best seed
    best_seed_idx = np.argmin(results[kernel_name]['mse_test_per_seed'])
    best_run = find_run_folders(CONFIG['results_base'], kernel_config['folder'])[best_seed_idx]
    
    print(f"Best seed: {best_seed_idx} (Test MSE: {results[kernel_name]['mse_test_per_seed'][best_seed_idx]:.6f})")
    print(f"Run folder: {os.path.basename(best_run)}\n")
    
    # Load checkpoint
    weights_dir = os.path.join(best_run, 'weights')
    gp_weights_path = os.path.join(weights_dir, 'gp_weights.best.pt')
    
    if os.path.exists(gp_weights_path):
        checkpoint = torch.load(gp_weights_path, map_location='cpu')
        
        # Load the model to extract parameters
        vm_temp = Vmodel(
            P=P, Q=Q, p=CONFIG['xdim'],
            view_kernel=kernel_config['view_kernel'],
            **kernel_config['kernel_kwargs']
        )
        vm_temp.load_state_dict(checkpoint['vm_state'])
        
        gp_temp = GP(n_rand_effs=1)
        gp_temp.load_state_dict(checkpoint['gp_state'])
        
        # Extract GP variance components
        vs = gp_temp.get_vs().detach().cpu().numpy()
        variance_ratio = vs[0] / (vs[0] + vs[1])
        
        print(f"üìà GP Variance Components:")
        print(f"   v_view (between-view):  {vs[0]:.6f}")
        print(f"   v_noise (within-view):  {vs[1]:.6f}")
        print(f"   Variance ratio (v_view / total): {variance_ratio:.4f}")
        
        # Extract kernel-specific parameters
        print(f"\nüîß Kernel Hyperparameters:")
        
        if kernel_config['view_kernel'] == 'periodic':
            # Periodic kernel parameters
            state_dict = vm_temp.kernel.state_dict()
            period = state_dict['period'].item() if 'period' in state_dict else 360.0
            lengthscale = torch.exp(state_dict['log_lengthscale']).item() if 'log_lengthscale' in state_dict else None
            variance = torch.exp(state_dict['log_variance']).item() if 'log_variance' in state_dict else None
            
            print(f"   Period: {period:.2f}¬∞")
            if lengthscale is not None:
                print(f"   Lengthscale: {lengthscale:.4f}")
            if variance is not None:
                print(f"   Variance: {variance:.4f}")
        
        elif kernel_config['view_kernel'] == 'sm_circle':
            # Spectral Mixture parameters
            state_dict = vm_temp.kernel.state_dict()
            
            if 'log_weight' in state_dict:
                weights = torch.exp(state_dict['log_weight']).detach().cpu().numpy()
                weights_normalized = weights / weights.sum()
                print(f"   Mixture weights: {weights_normalized}")
            
            if 'log_freq' in state_dict:
                freqs = torch.exp(state_dict['log_freq']).detach().cpu().numpy()
                periods = 1.0 / freqs
                print(f"   Frequencies: {freqs}")
                print(f"   Periods: {periods}¬∞")
            
            if 'log_length' in state_dict:
                lengths = torch.exp(state_dict['log_length']).detach().cpu().numpy()
                print(f"   Lengthscales: {lengths}")
            
            if 'log_var' in state_dict:
                variances = torch.exp(state_dict['log_var']).detach().cpu().numpy()
                print(f"   Component variances: {variances}")
        
        elif kernel_config['view_kernel'] == 'full_rank':
            # Full rank - just show the learned covariance structure
            with torch.no_grad():
                K = vm_temp.get_kernel_matrix().cpu().numpy()
            print(f"   Learned full Q√óQ covariance matrix")
            print(f"   Matrix shape: {K.shape}")
            print(f"   Mean correlation: {K[np.triu_indices_from(K, k=1)].mean():.4f}")
            print(f"   Std correlation: {K[np.triu_indices_from(K, k=1)].std():.4f}")
        
        # Show raw state dict keys for reference
        print(f"\nüìù Available state dict keys:")
        for key in vm_temp.kernel.state_dict().keys():
            value = vm_temp.kernel.state_dict()[key]
            if value.numel() <= 10:  # Only show small tensors
                print(f"   {key}: {value.detach().cpu().numpy()}")
            else:
                print(f"   {key}: shape {value.shape}")
    else:
        print(f"‚ö†Ô∏è Checkpoint not found: {gp_weights_path}")

print("\n" + "="*70)
print("‚úÖ Parameter extraction complete!")
print("="*70)


## 17. Visualize Learned SM Kernel Functions

Plot the kernel functions k(Œ∏, Œ∏') for the best-performing SM Wrapped and SM Free models, along with GP prior samples.


In [None]:
import pylab as pl
pl.rcdefaults()

# Define which SM kernels to visualize
sm_kernels_to_plot = ['sm_wrapped', 'sm_free']
sm_kernel_names = {'sm_wrapped': 'SM (Wrapped)', 'sm_free': 'SM (Free)'}
sm_kernel_colors = {'sm_wrapped': '#9467bd', 'sm_free': '#8c564b'}

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

for row_idx, kernel_name in enumerate(sm_kernels_to_plot):
    if kernel_name not in results or not results[kernel_name]['mse_test_per_seed']:
        print(f"‚ö†Ô∏è {kernel_name} not found in results")
        continue
    
    kernel_config = KERNEL_CONFIGS[kernel_name]
    display_name = kernel_config['display_name']
    color = sm_kernel_colors[kernel_name]
    
    # Get best seed and load model
    best_seed_idx = np.argmin(results[kernel_name]['mse_test_per_seed'])
    best_run = find_run_folders(CONFIG['results_base'], kernel_config['folder'])[best_seed_idx]
    
    # Load the model
    weights_dir = os.path.join(best_run, 'weights')
    gp_weights_path = os.path.join(weights_dir, 'gp_weights.best.pt')
    checkpoint = torch.load(gp_weights_path, map_location='cpu')
    
    vm_temp = Vmodel(
        P=P, Q=Q, p=CONFIG['xdim'],
        view_kernel=kernel_config['view_kernel'],
        **kernel_config['kernel_kwargs']
    )
    vm_temp.load_state_dict(checkpoint['vm_state'])
    vm_temp.eval()
    
    # Get the kernel
    kernel = vm_temp.kernel
    
    # =====================
    # Plot 1: Kernel function k(0¬∞, Œ∏) - correlation with reference angle 0¬∞
    # =====================
    ax = axes[row_idx, 0]
    
    # Create dense angle grid for smooth visualization
    dense_angles = torch.linspace(0, 360, 361)
    
    with torch.no_grad():
        K_dense = kernel(dense_angles).cpu().numpy()
    
    # k(0¬∞, Œ∏) is the first row of K
    k_from_0 = K_dense[0, :]
    
    ax.plot(dense_angles.numpy(), k_from_0, color=color, linewidth=2)
    ax.axvline(x=200, color='red', linestyle='--', linewidth=1.5, alpha=0.7, label='Training boundary')
    ax.axvspan(0, 200, alpha=0.1, color='green')
    ax.axvspan(200, 360, alpha=0.1, color='red')
    
    ax.set_xlabel('Angle Œ∏ (¬∞)', fontsize=16)
    ax.set_ylabel('k(0¬∞, Œ∏)', fontsize=16)
    ax.set_title(f'{display_name}\nKernel Function', fontsize=16)
    ax.set_xlim(0, 360)
    ax.set_xticks([0, 90, 180, 270, 360])
    ax.tick_params(labelsize=14)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.legend(fontsize=12)
    
    # =====================
    # Plot 2: Full kernel matrix (18x18) from learned parameters
    # =====================
    ax = axes[row_idx, 1]
    
    # Get the 18-view kernel matrix
    view_angles = torch.tensor([i * 20.0 for i in range(18)])
    with torch.no_grad():
        K_18 = kernel(view_angles).cpu().numpy()
    
    im = ax.imshow(K_18, vmin=-0.5, vmax=1, aspect='auto', cmap='RdBu_r')
    ax.axvline(x=9.5, color='black', linestyle='--', linewidth=2)
    ax.axhline(y=9.5, color='black', linestyle='--', linewidth=2)
    
    tick_positions = [0, 4, 9, 13, 17]
    tick_labels = [f"{p*20}¬∞" for p in tick_positions]
    ax.set_xticks(tick_positions)
    ax.set_xticklabels(tick_labels, fontsize=12)
    ax.set_yticks(tick_positions)
    ax.set_yticklabels(tick_labels, fontsize=12)
    ax.set_title(f'{display_name}\nLearned View Covariance', fontsize=16)
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    # =====================
    # Plot 3: GP prior samples
    # =====================
    ax = axes[row_idx, 2]
    
    # Sample from GP prior with learned kernel
    n_samples = 5
    dense_angles_for_samples = torch.linspace(0, 340, 18)  # Match 18 views
    
    with torch.no_grad():
        K_for_samples = kernel(dense_angles_for_samples).cpu().numpy()
    
    # Add small jitter for numerical stability
    K_for_samples = K_for_samples + 1e-4 * np.eye(K_for_samples.shape[0])
    
    # Cholesky decomposition for sampling
    try:
        L = np.linalg.cholesky(K_for_samples)
        
        np.random.seed(42)  # For reproducibility
        for i in range(n_samples):
            z = np.random.randn(K_for_samples.shape[0])
            sample = L @ z
            alpha = 0.5 if i > 0 else 1.0
            lw = 1.5 if i > 0 else 2.5
            ax.plot(dense_angles_for_samples.numpy(), sample, 'o-', 
                    color=color, alpha=alpha, linewidth=lw,
                    label='GP prior samples' if i == 0 else None)
        
        ax.axvline(x=200, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
        ax.axvspan(0, 200, alpha=0.1, color='green')
        ax.axvspan(200, 360, alpha=0.1, color='red')
        
    except np.linalg.LinAlgError:
        ax.text(0.5, 0.5, 'Cholesky failed\n(K not PSD)', 
                transform=ax.transAxes, ha='center', fontsize=14)
    
    ax.set_xlabel('View Angle (¬∞)', fontsize=16)
    ax.set_ylabel('f(Œ∏)', fontsize=16)
    ax.set_title(f'{display_name}\nGP Prior Samples', fontsize=16)
    ax.set_xlim(0, 360)
    ax.set_xticks([0, 90, 180, 270, 360])
    ax.tick_params(labelsize=14)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.legend(fontsize=12)

plt.tight_layout()
plt.savefig('./notebooks/analysis/task3_sm_kernel_visualization.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä Saved: task3_sm_kernel_visualization.png (300 DPI)")


In [None]:
# Detailed comparison of SM kernel components
print("="*70)
print("üìä SM KERNEL COMPONENT ANALYSIS")
print("="*70)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

for idx, kernel_name in enumerate(['sm_wrapped', 'sm_free']):
    if kernel_name not in results or not results[kernel_name]['mse_test_per_seed']:
        continue
    
    kernel_config = KERNEL_CONFIGS[kernel_name]
    display_name = kernel_config['display_name']
    color = sm_kernel_colors[kernel_name]
    
    # Get best seed and load model
    best_seed_idx = np.argmin(results[kernel_name]['mse_test_per_seed'])
    best_run = find_run_folders(CONFIG['results_base'], kernel_config['folder'])[best_seed_idx]
    
    weights_dir = os.path.join(best_run, 'weights')
    gp_weights_path = os.path.join(weights_dir, 'gp_weights.best.pt')
    checkpoint = torch.load(gp_weights_path, map_location='cpu')
    
    vm_temp = Vmodel(
        P=P, Q=Q, p=CONFIG['xdim'],
        view_kernel=kernel_config['view_kernel'],
        **kernel_config['kernel_kwargs']
    )
    vm_temp.load_state_dict(checkpoint['vm_state'])
    kernel = vm_temp.kernel
    
    # Extract learned parameters
    weights = kernel.weights.detach().cpu().numpy()
    means = kernel.means.detach().cpu().numpy()  # frequencies
    variances = kernel.variances.detach().cpu().numpy()
    
    print(f"\nüî∑ {display_name}:")
    print(f"   Mixture weights: {weights}")
    print(f"   Frequencies (Œº): {means}")
    print(f"   Periods (1/Œº): {1.0/means}¬∞")
    print(f"   Variances (œÉ¬≤): {variances}")
    
    # =====================
    # Left plot: Individual SM components
    # =====================
    ax = axes[idx, 0]
    
    # Dense angle grid
    tau = np.linspace(0, 180, 361)  # lag distance for wrapped kernel
    
    colors_comp = ['#e74c3c', '#3498db', '#2ecc71', '#9b59b6'][:len(weights)]
    
    for q in range(len(weights)):
        w = weights[q]
        mu = means[q]
        var = variances[q]
        
        # SM component: w * exp(-2*pi^2*var*tau^2) * cos(2*pi*mu*tau)
        if kernel.use_angle_input:
            # Normalized distance
            D = tau / 360.0
        else:
            # Raw degree distance
            D = tau
        
        exp_term = np.exp(-2 * (np.pi ** 2) * var * (D ** 2))
        cos_term = np.cos(2 * np.pi * mu * D)
        component = w * exp_term * cos_term
        
        period = 1.0 / mu if mu > 0 else np.inf
        ax.plot(tau, component, color=colors_comp[q], linewidth=2.5, 
                label=f'Component {q+1}: w={w:.2f}, T={period:.0f}¬∞')
    
    # Total kernel
    with torch.no_grad():
        tau_torch = torch.tensor(tau, dtype=torch.float32)
        # For k(0, tau), we compute full matrix and take first row
        # But for efficiency, compute manually
        if kernel.use_angle_input:
            D_torch = tau_torch / 360.0
        else:
            D_torch = tau_torch
        
        K_total = torch.zeros_like(D_torch)
        for q in range(len(weights)):
            w = kernel.weights[q]
            mu = kernel.means[q]
            var = kernel.variances[q]
            exp_term = torch.exp(-2 * (np.pi ** 2) * var * (D_torch ** 2))
            cos_term = torch.cos(2 * np.pi * mu * D_torch)
            K_total = K_total + w * exp_term * cos_term
        
        ax.plot(tau, K_total.numpy(), 'k--', linewidth=2, label='Total k(œÑ)')
    
    ax.axhline(y=0, color='gray', linestyle='-', linewidth=0.5, alpha=0.5)
    ax.set_xlabel('Lag Distance œÑ (¬∞)', fontsize=16)
    ax.set_ylabel('k(œÑ)', fontsize=16)
    ax.set_title(f'{display_name}\nSM Kernel Components', fontsize=16)
    ax.set_xlim(0, 180)
    ax.tick_params(labelsize=14)
    ax.legend(fontsize=11, loc='upper right')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    # =====================
    # Right plot: Kernel correlation at key angles
    # =====================
    ax = axes[idx, 1]
    
    # Compute k(0¬∞, Œ∏) for all angles
    angles_full = torch.linspace(0, 360, 361)
    with torch.no_grad():
        K_full = kernel(angles_full).cpu().numpy()
    k_from_0 = K_full[0, :]
    
    ax.plot(angles_full.numpy(), k_from_0, color=color, linewidth=2.5, label='k(0¬∞, Œ∏)')
    
    # Mark specific angles
    key_angles = [0, 20, 180, 200, 280, 340, 360]
    for angle in key_angles:
        idx_angle = int(angle)
        if idx_angle < len(k_from_0):
            ax.scatter(angle, k_from_0[idx_angle], s=80, color='black', zorder=5)
            ax.annotate(f'{k_from_0[idx_angle]:.2f}', 
                       (angle, k_from_0[idx_angle]), 
                       textcoords='offset points', xytext=(5, 10), fontsize=10)
    
    ax.axvline(x=200, color='red', linestyle='--', linewidth=2, alpha=0.7, label='Training boundary')
    ax.axvspan(0, 200, alpha=0.1, color='green')
    ax.axvspan(200, 360, alpha=0.1, color='red')
    
    ax.set_xlabel('Angle Œ∏ (¬∞)', fontsize=16)
    ax.set_ylabel('k(0¬∞, Œ∏)', fontsize=16)
    ax.set_title(f'{display_name}\nKernel Correlation from 0¬∞', fontsize=16)
    ax.set_xlim(0, 360)
    ax.set_xticks([0, 90, 180, 270, 360])
    ax.tick_params(labelsize=14)
    ax.legend(fontsize=12)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig('./notebooks/analysis/task3_sm_kernel_components.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nüìä Saved: task3_sm_kernel_components.png (300 DPI)")
