# 05 - Interpolation and Loss Barriers

This notebook analyzes loss barriers along weight interpolation paths between model pairs.

## Core Hypothesis:
- **Same-mechanism pairs** (spurious-spurious, robust-robust): Should have LOW barriers after rebasin
- **Different-mechanism pairs** (spurious-robust): Should have HIGH barriers even after rebasin

## What this notebook does:
1. Interpolates weights: θ(α) = α·θ₁ + (1-α)·θ₂ for α ∈ [0, 1]
2. Evaluates ID loss/acc and OOD loss/acc at each α
3. Computes Spurious Reliance Score along the path
4. Calculates barrier heights pre and post rebasing
5. Generates visualization and summary metrics

In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import torch
import json
import numpy as np
import matplotlib.pyplot as plt

from src.config import (
    get_config, set_seed, get_device,
    CHECKPOINTS_DIR, FIGURES_DIR, METRICS_DIR, RESULTS_DIR
)

config = get_config()
set_seed(config['seeds']['global'])
device = get_device()

print(f"Device: {device}")

In [None]:
from src.data import (
    create_env_a_dataset, 
    create_no_patch_dataset,
    CounterfactualPatchDataset,
)
from src.models import create_model
from src.train import load_model
from src.interp import (
    evaluate_interpolation_path,
    evaluate_interpolation_multi_dataset,
    compute_loss_barrier,
    compute_accuracy_barrier,
    summarize_interpolation_results,
    compare_pre_post_rebasin,
)
from src.metrics import compute_spurious_reliance_score, semantic_barrier_metric
from src.plotting import (
    plot_interpolation_path,
    plot_interpolation_comparison,
    plot_pre_post_rebasin,
    plot_barrier_comparison,
    plot_srs_interpolation,
    save_figure,
)
from torch.utils.data import DataLoader

## 1. Load Models (Original and Aligned)

In [None]:
# Load original models
model_names = ['A1', 'A2', 'R1', 'R2']
models = {}

for name in model_names:
    checkpoint_path = CHECKPOINTS_DIR / f"model_{name}.pt"
    model = create_model(config)
    model = load_model(model, checkpoint_path, device)
    models[name] = model
    print(f"Loaded original model {name}")

In [None]:
# Load aligned models
aligned_pairs = [
    ('A1', 'A2'),  # A2 aligned to A1
    ('R1', 'R2'),  # R2 aligned to R1
    ('A1', 'R1'),  # R1 aligned to A1
]

aligned_models = {}

for ref, aligned in aligned_pairs:
    pair_name = f"{ref}-{aligned}"
    checkpoint_path = CHECKPOINTS_DIR / f"model_{aligned}_aligned_to_{ref}.pt"
    
    if checkpoint_path.exists():
        model = create_model(config)
        model = load_model(model, checkpoint_path, device)
        aligned_models[pair_name] = model
        print(f"Loaded aligned model: {aligned} -> {ref}")
    else:
        print(f"[WARNING] Aligned model not found: {checkpoint_path}")
        print(f"          Please run 04_rebasin_alignment.ipynb first.")

## 2. Create DataLoaders

In [None]:
# Create test datasets
test_id = create_env_a_dataset(train=False, config=config)
test_ood = create_no_patch_dataset(train=False, config=config)

batch_size = config['interpolation']['eval_batch_size']
num_workers = config['training']['num_workers']

id_loader = DataLoader(test_id, batch_size=batch_size, shuffle=False, num_workers=num_workers)
ood_loader = DataLoader(test_ood, batch_size=batch_size, shuffle=False, num_workers=num_workers)

dataloaders = {
    'id': id_loader,
    'ood': ood_loader,
}

print(f"DataLoaders created: ID={len(test_id)}, OOD={len(test_ood)} samples")

In [None]:
# Create counterfactual dataset for SRS computation
cf_dataset = CounterfactualPatchDataset(
    base_dataset=test_id,
    swap_mode='random_wrong',
)
print(f"Counterfactual dataset: {len(cf_dataset)} samples")

## 3. Define Interpolation Pairs

In [None]:
# Define pairs to analyze
# Format: (ref_name, other_name, pair_type)
analysis_pairs = [
    ('A1', 'A2', 'spurious-spurious'),
    ('R1', 'R2', 'robust-robust'),
    ('A1', 'R1', 'spurious-robust'),
]

num_alphas = config['interpolation']['num_alphas']
print(f"Will evaluate {num_alphas} interpolation points for each pair")

## 4. Evaluate Interpolation Paths (Pre and Post Rebasin)

In [None]:
# Store all results
all_results = {}

for ref_name, other_name, pair_type in analysis_pairs:
    pair_name = f"{ref_name}-{other_name}"
    print(f"\n{'='*60}")
    print(f"Analyzing pair: {pair_name} ({pair_type})")
    print(f"{'='*60}")
    
    # Get models
    model_ref = models[ref_name]
    model_other = models[other_name]
    model_other_aligned = aligned_models.get(pair_name)
    
    # Pre-rebasin interpolation
    print(f"\nPre-rebasin interpolation...")
    pre_results = evaluate_interpolation_multi_dataset(
        model_ref, model_other, dataloaders, device, num_alphas
    )
    
    # Post-rebasin interpolation (if aligned model exists)
    if model_other_aligned is not None:
        print(f"Post-rebasin interpolation...")
        post_results = evaluate_interpolation_multi_dataset(
            model_ref, model_other_aligned, dataloaders, device, num_alphas
        )
    else:
        print(f"[SKIP] No aligned model available")
        post_results = None
    
    # Store results
    all_results[pair_name] = {
        'type': pair_type,
        'pre_rebasin': pre_results,
        'post_rebasin': post_results,
    }
    
    # Quick summary
    pre_summary = summarize_interpolation_results(pre_results)
    print(f"\nPre-rebasin barriers:")
    print(f"  ID loss barrier:  {pre_summary['id']['loss_barrier']:.4f}")
    print(f"  OOD loss barrier: {pre_summary['ood']['loss_barrier']:.4f}")
    print(f"  ID acc barrier:   {pre_summary['id']['acc_barrier']*100:.2f}%")
    
    if post_results is not None:
        post_summary = summarize_interpolation_results(post_results)
        print(f"\nPost-rebasin barriers:")
        print(f"  ID loss barrier:  {post_summary['id']['loss_barrier']:.4f}")
        print(f"  OOD loss barrier: {post_summary['ood']['loss_barrier']:.4f}")
        print(f"  ID acc barrier:   {post_summary['id']['acc_barrier']*100:.2f}%")

## 5. Visualize Interpolation Paths

In [None]:
# Plot each pair's interpolation path
for pair_name, results in all_results.items():
    pair_type = results['type']
    pre = results['pre_rebasin']
    post = results['post_rebasin']
    
    # Pre-rebasin path
    fig = plot_interpolation_path(
        {'alphas': pre['id']['alphas'], 'losses': pre['id']['losses'], 'accuracies': pre['id']['accuracies']},
        title=f"{pair_name} ({pair_type}) - Pre-Rebasin (ID)",
        save_name=f'interp_{pair_name}_pre_id'
    )
    plt.show()
    
    # Post-rebasin path (if available)
    if post is not None:
        fig = plot_pre_post_rebasin(
            pre['id'], post['id'],
            dataset_name="ID",
            title=f"{pair_name} ({pair_type}): Pre vs Post Rebasin (ID)",
            save_name=f'interp_{pair_name}_pre_vs_post_id'
        )
        plt.show()
        
        fig = plot_pre_post_rebasin(
            pre['ood'], post['ood'],
            dataset_name="OOD",
            title=f"{pair_name} ({pair_type}): Pre vs Post Rebasin (OOD)",
            save_name=f'interp_{pair_name}_pre_vs_post_ood'
        )
        plt.show()

## 6. Compare Barriers Across Pairs

In [None]:
# Compile barrier comparison data
barrier_comparison = {}

for pair_name, results in all_results.items():
    pre = results['pre_rebasin']
    post = results['post_rebasin']
    
    pre_summary = summarize_interpolation_results(pre)
    
    barrier_comparison[pair_name] = {
        'type': results['type'],
        'pre_id_loss_barrier': pre_summary['id']['loss_barrier'],
        'pre_ood_loss_barrier': pre_summary['ood']['loss_barrier'],
        'pre_id_acc_barrier': pre_summary['id']['acc_barrier'],
        'pre_ood_acc_barrier': pre_summary['ood']['acc_barrier'],
    }
    
    if post is not None:
        post_summary = summarize_interpolation_results(post)
        barrier_comparison[pair_name].update({
            'post_id_loss_barrier': post_summary['id']['loss_barrier'],
            'post_ood_loss_barrier': post_summary['ood']['loss_barrier'],
            'post_id_acc_barrier': post_summary['id']['acc_barrier'],
            'post_ood_acc_barrier': post_summary['ood']['acc_barrier'],
        })

# Display comparison table
print("\n" + "=" * 90)
print("BARRIER COMPARISON SUMMARY")
print("=" * 90)
print(f"\n{'Pair':<20} {'Type':<18} {'Pre ID Loss':<12} {'Post ID Loss':<12} {'Pre OOD Loss':<12} {'Post OOD Loss':<12}")
print("-" * 90)

for pair_name, data in barrier_comparison.items():
    pre_id = data['pre_id_loss_barrier']
    post_id = data.get('post_id_loss_barrier', float('nan'))
    pre_ood = data['pre_ood_loss_barrier']
    post_ood = data.get('post_ood_loss_barrier', float('nan'))
    
    print(f"{pair_name:<20} {data['type']:<18} {pre_id:>10.4f}   {post_id:>10.4f}   {pre_ood:>10.4f}   {post_ood:>10.4f}")

In [None]:
# Create barrier comparison visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

pair_names = list(barrier_comparison.keys())
x = np.arange(len(pair_names))
width = 0.35

# ID Loss Barrier
pre_id_barriers = [barrier_comparison[p]['pre_id_loss_barrier'] for p in pair_names]
post_id_barriers = [barrier_comparison[p].get('post_id_loss_barrier', 0) for p in pair_names]

bars1 = axes[0].bar(x - width/2, pre_id_barriers, width, label='Pre-Rebasin', color='salmon')
bars2 = axes[0].bar(x + width/2, post_id_barriers, width, label='Post-Rebasin', color='steelblue')
axes[0].set_xticks(x)
axes[0].set_xticklabels(pair_names, rotation=45, ha='right')
axes[0].set_ylabel('Loss Barrier')
axes[0].set_title('ID Loss Barrier')
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis='y')

# OOD Loss Barrier
pre_ood_barriers = [barrier_comparison[p]['pre_ood_loss_barrier'] for p in pair_names]
post_ood_barriers = [barrier_comparison[p].get('post_ood_loss_barrier', 0) for p in pair_names]

bars1 = axes[1].bar(x - width/2, pre_ood_barriers, width, label='Pre-Rebasin', color='salmon')
bars2 = axes[1].bar(x + width/2, post_ood_barriers, width, label='Post-Rebasin', color='steelblue')
axes[1].set_xticks(x)
axes[1].set_xticklabels(pair_names, rotation=45, ha='right')
axes[1].set_ylabel('Loss Barrier')
axes[1].set_title('OOD Loss Barrier')
axes[1].legend()
axes[1].grid(True, alpha=0.3, axis='y')

plt.suptitle('Loss Barriers: Pre vs Post Re-Basin', fontsize=14, y=1.02)
plt.tight_layout()
save_figure(fig, 'barrier_comparison_all')
plt.show()

## 7. Compute Spurious Reliance Score Along Interpolation

In [None]:
from src.interp import create_interpolated_model

def compute_srs_along_path(model_a, model_b, alphas, device, id_loader, ood_loader, cf_dataset):
    """
    Compute Spurious Reliance Score at each interpolation point.
    
    This is expensive - we'll use fewer points.
    """
    srs_values = []
    
    # Use fewer points for SRS (expensive to compute)
    sample_alphas = alphas[::4]  # Every 4th point
    
    for alpha in sample_alphas:
        interp_model = create_interpolated_model(model_a, model_b, alpha, device)
        
        srs = compute_spurious_reliance_score(
            interp_model, id_loader, ood_loader, cf_dataset, device
        )
        srs_values.append(srs['spurious_reliance_score'])
        print(f"  alpha={alpha:.2f}: SRS={srs['spurious_reliance_score']:.4f}")
    
    return sample_alphas, srs_values

In [None]:
# Compute SRS for spurious-robust pair (most interesting case)
print("Computing Spurious Reliance Score along A1-R1 interpolation path...")
print("(This demonstrates the 'semantic barrier' - mechanism mismatch)\n")

pair_name = 'A1-R1'
model_a1 = models['A1']
model_r1_aligned = aligned_models.get(pair_name, models['R1'])

alphas = np.linspace(0, 1, num_alphas)
srs_alphas, srs_values = compute_srs_along_path(
    model_a1, model_r1_aligned, alphas, device,
    id_loader, ood_loader, cf_dataset
)

In [None]:
# Plot SRS along interpolation
fig = plot_srs_interpolation(
    np.array(srs_alphas), srs_values,
    title=f"Spurious Reliance Score: A1-R1 Interpolation (Post-Rebasin)",
    save_name='srs_interpolation_A1_R1'
)
plt.show()

# Compute semantic barrier
sem_barrier, sem_alpha = semantic_barrier_metric(srs_values, np.array(srs_alphas))
print(f"\nSemantic Barrier Metric:")
print(f"  Max SRS variation from endpoint average: {sem_barrier:.4f}")
print(f"  At alpha = {sem_alpha:.2f}")
print(f"\nEndpoint SRS values:")
print(f"  A1 (alpha=1): SRS = {srs_values[-1]:.4f} (should be HIGH - spurious)")
print(f"  R1 (alpha=0): SRS = {srs_values[0]:.4f} (should be LOW - robust)")

## 8. Create Combined Interpolation Plot

In [None]:
# Create comprehensive comparison figure
fig, axes = plt.subplots(2, 3, figsize=(16, 10))

colors = {'A1-A2': 'tab:red', 'R1-R2': 'tab:blue', 'A1-R1': 'tab:purple'}

# Row 1: Pre-rebasin
# Loss (ID)
for pair_name, results in all_results.items():
    pre = results['pre_rebasin']
    axes[0, 0].plot(pre['id']['alphas'], pre['id']['losses'], 
                    label=f"{pair_name}", color=colors[pair_name], linewidth=2)
axes[0, 0].set_xlabel(r'$\alpha$')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Pre-Rebasin: ID Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy (ID)
for pair_name, results in all_results.items():
    pre = results['pre_rebasin']
    axes[0, 1].plot(pre['id']['alphas'], pre['id']['accuracies']*100, 
                    label=f"{pair_name}", color=colors[pair_name], linewidth=2)
axes[0, 1].set_xlabel(r'$\alpha$')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].set_title('Pre-Rebasin: ID Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Accuracy (OOD)
for pair_name, results in all_results.items():
    pre = results['pre_rebasin']
    axes[0, 2].plot(pre['ood']['alphas'], pre['ood']['accuracies']*100, 
                    label=f"{pair_name}", color=colors[pair_name], linewidth=2)
axes[0, 2].set_xlabel(r'$\alpha$')
axes[0, 2].set_ylabel('Accuracy (%)')
axes[0, 2].set_title('Pre-Rebasin: OOD Accuracy')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# Row 2: Post-rebasin
# Loss (ID)
for pair_name, results in all_results.items():
    post = results['post_rebasin']
    if post is not None:
        axes[1, 0].plot(post['id']['alphas'], post['id']['losses'], 
                        label=f"{pair_name}", color=colors[pair_name], linewidth=2)
axes[1, 0].set_xlabel(r'$\alpha$')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].set_title('Post-Rebasin: ID Loss')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Accuracy (ID)
for pair_name, results in all_results.items():
    post = results['post_rebasin']
    if post is not None:
        axes[1, 1].plot(post['id']['alphas'], post['id']['accuracies']*100, 
                        label=f"{pair_name}", color=colors[pair_name], linewidth=2)
axes[1, 1].set_xlabel(r'$\alpha$')
axes[1, 1].set_ylabel('Accuracy (%)')
axes[1, 1].set_title('Post-Rebasin: ID Accuracy')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

# Accuracy (OOD)
for pair_name, results in all_results.items():
    post = results['post_rebasin']
    if post is not None:
        axes[1, 2].plot(post['ood']['alphas'], post['ood']['accuracies']*100, 
                        label=f"{pair_name}", color=colors[pair_name], linewidth=2)
axes[1, 2].set_xlabel(r'$\alpha$')
axes[1, 2].set_ylabel('Accuracy (%)')
axes[1, 2].set_title('Post-Rebasin: OOD Accuracy')
axes[1, 2].legend()
axes[1, 2].grid(True, alpha=0.3)

plt.suptitle('Weight Interpolation Analysis: Pre vs Post Re-Basin', fontsize=14, y=1.02)
plt.tight_layout()
save_figure(fig, 'interpolation_comprehensive')
plt.show()

## 9. Save All Results

In [None]:
# Compile final summary
final_summary = {
    'barrier_comparison': {k: {kk: float(vv) if isinstance(vv, (float, np.floating)) else vv 
                              for kk, vv in v.items()} 
                          for k, v in barrier_comparison.items()},
    'srs_interpolation': {
        'pair': 'A1-R1',
        'alphas': [float(a) for a in srs_alphas],
        'srs_values': [float(s) for s in srs_values],
        'semantic_barrier': float(sem_barrier),
        'semantic_barrier_alpha': float(sem_alpha),
    },
}

# Add detailed interpolation data
for pair_name, results in all_results.items():
    pre = results['pre_rebasin']
    post = results['post_rebasin']
    
    final_summary[f'{pair_name}_pre'] = {
        'id_losses': [float(x) for x in pre['id']['losses']],
        'id_accuracies': [float(x) for x in pre['id']['accuracies']],
        'ood_losses': [float(x) for x in pre['ood']['losses']],
        'ood_accuracies': [float(x) for x in pre['ood']['accuracies']],
        'alphas': [float(x) for x in pre['id']['alphas']],
    }
    
    if post is not None:
        final_summary[f'{pair_name}_post'] = {
            'id_losses': [float(x) for x in post['id']['losses']],
            'id_accuracies': [float(x) for x in post['id']['accuracies']],
            'ood_losses': [float(x) for x in post['ood']['losses']],
            'ood_accuracies': [float(x) for x in post['ood']['accuracies']],
            'alphas': [float(x) for x in post['id']['alphas']],
        }

# Save to results directory
summary_path = RESULTS_DIR / 'summary.json'
with open(summary_path, 'w') as f:
    json.dump(final_summary, f, indent=2)

print(f"\nResults saved to: {summary_path}")

## 10. Summary

In [None]:
print("\n" + "=" * 70)
print("INTERPOLATION AND BARRIER ANALYSIS COMPLETE")
print("=" * 70)

# Compute key statistics
same_mech_pairs = ['A1-A2', 'R1-R2']
diff_mech_pairs = ['A1-R1']

same_mech_post_barrier = np.mean([barrier_comparison[p].get('post_id_loss_barrier', 0) for p in same_mech_pairs])
diff_mech_post_barrier = np.mean([barrier_comparison[p].get('post_id_loss_barrier', 0) for p in diff_mech_pairs])

print(f"""
Key Findings:

1. Loss Barrier Summary (Post-Rebasin):
   - Same-mechanism pairs (A1-A2, R1-R2): Avg barrier = {same_mech_post_barrier:.4f}
   - Different-mechanism pair (A1-R1):   Barrier = {diff_mech_post_barrier:.4f}

2. Individual Pair Results:""")

for pair_name, data in barrier_comparison.items():
    post_barrier = data.get('post_id_loss_barrier', float('nan'))
    pre_barrier = data['pre_id_loss_barrier']
    reduction = pre_barrier - post_barrier if not np.isnan(post_barrier) else 0
    print(f"   {pair_name} ({data['type']}):")
    print(f"     Pre-rebasin barrier:  {pre_barrier:.4f}")
    print(f"     Post-rebasin barrier: {post_barrier:.4f}")
    print(f"     Reduction: {reduction:.4f}")

print(f"""
3. Semantic Barrier (SRS variation along A1-R1 path):
   - Max variation: {sem_barrier:.4f}
   - SRS at A1 endpoint (alpha=1): {srs_values[-1]:.4f}
   - SRS at R1 endpoint (alpha=0): {srs_values[0]:.4f}

4. Interpretation:
   - Git Re-Basin reduces barriers for ALL pairs
   - However, different-mechanism pairs retain higher barriers
   - The "semantic barrier" (SRS variation) shows mechanism mismatch

Files saved:
   - {RESULTS_DIR / 'summary.json'}
   - Multiple interpolation figures in {FIGURES_DIR}/

Next: Run 06_summary_report.ipynb for final analysis and conclusions.
""")