# 00 - Setup and Configuration

This notebook sets up the environment for the Git Re-Basin spurious features experiment.

## What this notebook does:
1. Validates all dependencies are installed
2. Defines the global CONFIG dictionary
3. Sets deterministic seeds for reproducibility
4. Creates necessary directories
5. Verifies GPU availability

## 1. Add src to path and validate imports

In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

In [None]:
# Validate core dependencies
import importlib

dependencies = [
    ('torch', 'PyTorch'),
    ('torchvision', 'TorchVision'),
    ('numpy', 'NumPy'),
    ('scipy', 'SciPy'),
    ('sklearn', 'Scikit-learn'),
    ('matplotlib', 'Matplotlib'),
    ('seaborn', 'Seaborn'),
    ('tqdm', 'tqdm'),
    ('PIL', 'Pillow'),
]

print("Checking dependencies...\n")
all_ok = True

for module_name, display_name in dependencies:
    try:
        module = importlib.import_module(module_name)
        version = getattr(module, '__version__', 'unknown')
        print(f"  [OK] {display_name}: {version}")
    except ImportError:
        print(f"  [MISSING] {display_name}")
        all_ok = False

if all_ok:
    print("\nAll dependencies are installed!")
else:
    print("\nSome dependencies are missing. Run: pip install -r requirements.txt")

In [None]:
# Validate src module imports
print("Checking src modules...\n")

src_modules = ['config', 'data', 'models', 'train', 'rebasin', 'interp', 'metrics', 'plotting']

for module_name in src_modules:
    try:
        module = importlib.import_module(f'src.{module_name}')
        print(f"  [OK] src.{module_name}")
    except ImportError as e:
        print(f"  [ERROR] src.{module_name}: {e}")

print("\nAll src modules loaded successfully!")

## 2. Load Configuration

In [None]:
from src.config import get_config, CONFIG, set_seed, get_device, setup_directories

# Load configuration
config = get_config()

print("Global Configuration:")
print("=" * 50)
for section, values in config.items():
    print(f"\n[{section}]")
    if isinstance(values, dict):
        for key, val in values.items():
            print(f"  {key}: {val}")
    else:
        print(f"  {values}")

## 3. Set Deterministic Seeds

In [None]:
import torch
import numpy as np
import random

# Set global seed
GLOBAL_SEED = config['seeds']['global']
set_seed(GLOBAL_SEED)

print(f"Global seed set to: {GLOBAL_SEED}")
print(f"\nModel seeds:")
print(f"  Model A1 (spurious): {config['seeds']['model_A1']}")
print(f"  Model A2 (spurious): {config['seeds']['model_A2']}")
print(f"  Model R1 (robust):   {config['seeds']['model_R1']}")
print(f"  Model R2 (robust):   {config['seeds']['model_R2']}")

# Verify determinism
print(f"\nDeterminism settings:")
print(f"  torch.backends.cudnn.deterministic: {torch.backends.cudnn.deterministic}")
print(f"  torch.backends.cudnn.benchmark: {torch.backends.cudnn.benchmark}")

## 4. Create Directory Structure

In [None]:
# Create all necessary directories
dirs = setup_directories()

print("Directory structure:")
for name, path in dirs.items():
    exists = "[EXISTS]" if path.exists() else "[CREATED]"
    print(f"  {exists} {name}: {path}")

## 5. Check Device (GPU/CPU)

In [None]:
device = get_device()

print(f"Using device: {device}")

if device.type == 'cuda':
    print(f"\nCUDA Details:")
    print(f"  Device name: {torch.cuda.get_device_name(0)}")
    print(f"  CUDA version: {torch.version.cuda}")
    print(f"  Memory allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
    print(f"  Memory reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")
elif device.type == 'mps':
    print(f"\nUsing Apple Metal Performance Shaders (MPS)")
else:
    print(f"\nNo GPU available, using CPU. Training will be slower.")

## 6. Quick Sanity Check

In [None]:
# Test that we can create a model and do a forward pass
from src.models import create_model, count_parameters

model = create_model(config)
model = model.to(device)

# Create dummy input
dummy_input = torch.randn(2, 3, 32, 32).to(device)
output = model(dummy_input)

print("Model sanity check:")
print(f"  Model architecture: {config['model']['architecture']}")
print(f"  Parameters: {count_parameters(model):,}")
print(f"  Input shape: {dummy_input.shape}")
print(f"  Output shape: {output.shape}")
print(f"  Forward pass: OK")

In [None]:
# Test data loading
from src.data import create_env_a_dataset, create_no_patch_dataset

print("Testing data loading...")

# This will download CIFAR-10 if not present
env_a_train = create_env_a_dataset(train=True, config=config)
env_a_test = create_env_a_dataset(train=False, config=config)
no_patch_test = create_no_patch_dataset(train=False, config=config)

print(f"\nDataset sizes:")
print(f"  Env A train: {len(env_a_train)}")
print(f"  Env A test (ID): {len(env_a_test)}")
print(f"  No patch test (OOD): {len(no_patch_test)}")

# Verify alignment rate
alignment_rate = env_a_train.get_alignment_rate()
expected_rate = config['patch']['p_align_env_a']
print(f"\nEnv A alignment rate: {alignment_rate:.3f} (expected: {expected_rate})")

## 7. Summary

In [None]:
print("=" * 60)
print("SETUP COMPLETE")
print("=" * 60)
print(f"""
Environment:
  - Device: {device}
  - Global seed: {GLOBAL_SEED}
  - All dependencies: OK
  - All src modules: OK
  - Directory structure: OK

Configuration:
  - Dataset: {config['data']['dataset']}
  - Patch size: {config['patch']['size']}x{config['patch']['size']}
  - Env A alignment: {config['patch']['p_align_env_a']}
  - Env B alignment: {config['patch']['p_align_env_b']}
  - Model: {config['model']['architecture']}
  - Training epochs: {config['training']['num_epochs']}
  - Batch size: {config['training']['batch_size']}

Next steps:
  1. Run 01_data_spurious_envs.ipynb to visualize datasets
  2. Run 02_train_models.ipynb to train all 4 models
  3. Continue with remaining notebooks in order
""")

In [None]:
# Save config to results for reference
import json
from src.config import RESULTS_DIR

config_path = RESULTS_DIR / 'config.json'

# Convert config to JSON-serializable format
config_json = {}
for key, value in config.items():
    if isinstance(value, dict):
        config_json[key] = {}
        for k, v in value.items():
            if isinstance(v, (list, tuple)):
                config_json[key][k] = list(v)
            else:
                config_json[key][k] = v
    else:
        config_json[key] = value

with open(config_path, 'w') as f:
    json.dump(config_json, f, indent=2)

print(f"Configuration saved to: {config_path}")

# 01 - Data: Spurious Environments

This notebook implements and visualizes the CIFAR-10 environments with spurious colored patches.

## Environments:
- **Env A (Spurious Aligned)**: Patch color matches label with probability 0.95
- **Env B (Spurious Flipped)**: Patch color matches label with probability 0.05
- **No Patch (OOD)**: Clean CIFAR-10 without any patches

## What this notebook does:
1. Creates all environment datasets
2. Visualizes sample images from each environment
3. Verifies alignment rates
4. Saves visualization figures to results/figures/

In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

# Load configuration and set seeds
from src.config import get_config, set_seed, get_device, FIGURES_DIR, CIFAR10_CLASSES

config = get_config()
set_seed(config['seeds']['global'])
device = get_device()

print(f"Device: {device}")
print(f"Figures will be saved to: {FIGURES_DIR}")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from src.data import (
    create_env_a_dataset,
    create_env_b_dataset,
    create_no_patch_dataset,
    create_mixed_env_dataset,
    denormalize,
    get_sample_batch,
)
from src.plotting import plot_sample_grid, save_figure

## 1. Create Datasets

In [None]:
# Create all environment datasets
print("Creating datasets (this may download CIFAR-10)...\n")

# Training sets
env_a_train = create_env_a_dataset(train=True, config=config)
env_b_train = create_env_b_dataset(train=True, config=config)
mixed_train = create_mixed_env_dataset(env_a_fraction=0.5, train=True, config=config)

# Test sets
env_a_test = create_env_a_dataset(train=False, config=config)
env_b_test = create_env_b_dataset(train=False, config=config)
no_patch_test = create_no_patch_dataset(train=False, config=config)

print("Dataset sizes:")
print(f"  Env A train: {len(env_a_train)}")
print(f"  Env B train: {len(env_b_train)}")
print(f"  Mixed train: {len(mixed_train)}")
print(f"  Env A test (ID): {len(env_a_test)}")
print(f"  Env B test: {len(env_b_test)}")
print(f"  No patch test (OOD): {len(no_patch_test)}")

## 2. Verify Alignment Rates

In [None]:
# Check alignment rates
print("Alignment rate verification:")
print("=" * 50)

env_a_rate = env_a_train.get_alignment_rate()
env_b_rate = env_b_train.get_alignment_rate()

print(f"\nEnv A (spurious aligned):")
print(f"  Expected alignment: {config['patch']['p_align_env_a']}")
print(f"  Actual alignment:   {env_a_rate:.4f}")
print(f"  Difference:         {abs(env_a_rate - config['patch']['p_align_env_a']):.4f}")

print(f"\nEnv B (spurious flipped):")
print(f"  Expected alignment: {config['patch']['p_align_env_b']}")
print(f"  Actual alignment:   {env_b_rate:.4f}")
print(f"  Difference:         {abs(env_b_rate - config['patch']['p_align_env_b']):.4f}")

# Sanity check
assert abs(env_a_rate - config['patch']['p_align_env_a']) < 0.02, "Env A alignment rate too far from expected!"
assert abs(env_b_rate - config['patch']['p_align_env_b']) < 0.02, "Env B alignment rate too far from expected!"
print("\nSanity check PASSED!")

## 3. Visualize Patch Colors

In [None]:
# Show the color palette for each class
fig, ax = plt.subplots(figsize=(12, 3))

class_colors = config['patch']['class_colors']
n_classes = len(class_colors)

for i, (color, class_name) in enumerate(zip(class_colors, CIFAR10_CLASSES)):
    # Normalize color to [0, 1]
    norm_color = tuple(c / 255.0 for c in color)
    rect = plt.Rectangle((i, 0), 1, 1, facecolor=norm_color, edgecolor='black')
    ax.add_patch(rect)
    ax.text(i + 0.5, -0.2, class_name, ha='center', va='top', fontsize=10, rotation=45)
    ax.text(i + 0.5, 0.5, f"{color}", ha='center', va='center', fontsize=8,
            color='white' if sum(color) < 400 else 'black')

ax.set_xlim(0, n_classes)
ax.set_ylim(-0.5, 1)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title('Patch Color Palette by Class', fontsize=14, pad=20)

plt.tight_layout()
save_figure(fig, 'color_palette')
plt.show()

## 4. Visualize Environment A (Spurious Aligned)

In [None]:
# Get sample batch from Env A
images_a, labels_a = get_sample_batch(env_a_train, n_samples=16, seed=42)

# Denormalize images for visualization
images_a_vis = np.array([denormalize(img, config) for img in images_a])

fig = plot_sample_grid(
    images_a_vis,
    labels_a.numpy(),
    title="Environment A (Spurious Aligned, p=0.95)",
    nrow=4,
    ncol=4,
    save_name='env_a_samples'
)
plt.show()

In [None]:
# Verify: Check that patch colors mostly match labels
print("Env A: Verifying patch-label alignment...\n")

n_check = 100
aligned = 0
for i in range(n_check):
    _, label = env_a_train.cifar[i]
    patch_color = env_a_train.get_patch_color(i)
    expected_color = class_colors[label]
    if patch_color == expected_color:
        aligned += 1

print(f"First {n_check} samples: {aligned}/{n_check} aligned ({aligned/n_check*100:.1f}%)")

## 5. Visualize Environment B (Spurious Flipped)

In [None]:
# Get sample batch from Env B
images_b, labels_b = get_sample_batch(env_b_train, n_samples=16, seed=42)

# Denormalize images for visualization
images_b_vis = np.array([denormalize(img, config) for img in images_b])

fig = plot_sample_grid(
    images_b_vis,
    labels_b.numpy(),
    title="Environment B (Spurious Flipped, p=0.05)",
    nrow=4,
    ncol=4,
    save_name='env_b_samples'
)
plt.show()

In [None]:
# Verify: Check that patch colors mostly DON'T match labels
print("Env B: Verifying patch-label misalignment...\n")

n_check = 100
aligned = 0
for i in range(n_check):
    _, label = env_b_train.cifar[i]
    patch_color = env_b_train.get_patch_color(i)
    expected_color = class_colors[label]
    if patch_color == expected_color:
        aligned += 1

print(f"First {n_check} samples: {aligned}/{n_check} aligned ({aligned/n_check*100:.1f}%)")
print(f"Expected ~{config['patch']['p_align_env_b']*100}% alignment")

## 6. Visualize No-Patch (OOD) Test Set

In [None]:
# Get sample batch from No Patch dataset
images_np, labels_np = get_sample_batch(no_patch_test, n_samples=16, seed=42)

# Denormalize images for visualization
images_np_vis = np.array([denormalize(img, config) for img in images_np])

fig = plot_sample_grid(
    images_np_vis,
    labels_np.numpy(),
    title="No Patch (OOD Test Set)",
    nrow=4,
    ncol=4,
    save_name='no_patch_samples'
)
plt.show()

## 7. Side-by-Side Comparison

In [None]:
# Create side-by-side comparison figure
fig, axes = plt.subplots(3, 8, figsize=(16, 6))

# Get samples with same seed for comparison
images_a, labels_a = get_sample_batch(env_a_train, n_samples=8, seed=123)
images_b, labels_b = get_sample_batch(env_b_train, n_samples=8, seed=123)
images_np, labels_np = get_sample_batch(no_patch_test, n_samples=8, seed=123)

# Note: these are different images due to different datasets
# But we use the same seed for consistent sampling within each dataset

row_titles = ['Env A (aligned)', 'Env B (flipped)', 'No Patch (OOD)']
all_images = [images_a, images_b, images_np]
all_labels = [labels_a, labels_b, labels_np]

for row, (images, labels, title) in enumerate(zip(all_images, all_labels, row_titles)):
    images_vis = [denormalize(img, config) for img in images]
    for col, (img, label) in enumerate(zip(images_vis, labels)):
        axes[row, col].imshow(img)
        axes[row, col].axis('off')
        if row == 0:
            axes[row, col].set_title(CIFAR10_CLASSES[label], fontsize=9)
    axes[row, 0].set_ylabel(title, fontsize=10, rotation=90, labelpad=40)

plt.suptitle('Environment Comparison', fontsize=14, y=1.02)
plt.tight_layout()
save_figure(fig, 'environment_comparison')
plt.show()

## 8. Class Distribution Check

In [None]:
# Verify class distribution is balanced
from collections import Counter

def get_class_distribution(dataset):
    labels = [dataset.cifar[i][1] for i in range(len(dataset))]
    return Counter(labels)

env_a_dist = get_class_distribution(env_a_train)

fig, ax = plt.subplots(figsize=(10, 5))

classes = list(range(10))
counts = [env_a_dist[c] for c in classes]

bars = ax.bar(classes, counts, color='steelblue', edgecolor='black')
ax.set_xticks(classes)
ax.set_xticklabels(CIFAR10_CLASSES, rotation=45, ha='right')
ax.set_xlabel('Class')
ax.set_ylabel('Count')
ax.set_title('CIFAR-10 Class Distribution (Training Set)')

# Add count labels on bars
for bar, count in zip(bars, counts):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 50,
            str(count), ha='center', va='bottom', fontsize=9)

plt.tight_layout()
save_figure(fig, 'class_distribution')
plt.show()

print(f"\nTotal samples: {sum(counts)}")
print(f"Expected per class: {sum(counts)//10}")

## 9. Summary

In [None]:
print("=" * 60)
print("DATA PREPARATION COMPLETE")
print("=" * 60)
print(f"""
Environments created:

1. Environment A (Spurious Aligned)
   - Patch color matches true label with p={config['patch']['p_align_env_a']}
   - Training: {len(env_a_train)} samples
   - Test (ID): {len(env_a_test)} samples
   - Actual alignment rate: {env_a_rate:.4f}

2. Environment B (Spurious Flipped)
   - Patch color matches true label with p={config['patch']['p_align_env_b']}
   - Training: {len(env_b_train)} samples
   - Actual alignment rate: {env_b_rate:.4f}

3. No Patch (OOD Test)
   - Clean CIFAR-10 without patches
   - Test: {len(no_patch_test)} samples

4. Mixed Environment (for robust training)
   - 50% Env A + 50% Env B
   - Training: {len(mixed_train)} samples

Patch configuration:
   - Size: {config['patch']['size']}x{config['patch']['size']} pixels
   - Position: {config['patch']['position']}
   - Colors: 10 distinct colors (one per class)

Figures saved:
   - {FIGURES_DIR / 'color_palette.png'}
   - {FIGURES_DIR / 'env_a_samples.png'}
   - {FIGURES_DIR / 'env_b_samples.png'}
   - {FIGURES_DIR / 'no_patch_samples.png'}
   - {FIGURES_DIR / 'environment_comparison.png'}
   - {FIGURES_DIR / 'class_distribution.png'}

Next: Run 02_train_models.ipynb to train the 4 models.
""")

# 02 - Train Models

This notebook trains the 4 models needed for the experiment:

1. **Model A1** (Spurious, seed 1): ERM on Env A only
2. **Model A2** (Spurious, seed 2): ERM on Env A only
3. **Model R1** (Robust, seed 1): ERM on mixed Env A + Env B
4. **Model R2** (Robust, seed 2): ERM on mixed Env A + Env B

## Expected behavior:
- Spurious models (A1, A2) will achieve high ID accuracy but low OOD accuracy
- Robust models (R1, R2) will have similar ID and OOD accuracy

In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import torch
import json

from src.config import (
    get_config, set_seed, get_device,
    CHECKPOINTS_DIR, FIGURES_DIR, METRICS_DIR
)

config = get_config()
device = get_device()

print(f"Device: {device}")
print(f"Checkpoints will be saved to: {CHECKPOINTS_DIR}")

In [None]:
from src.data import (
    create_env_a_dataset,
    create_no_patch_dataset,
    create_mixed_env_dataset,
    get_dataloaders,
)
from src.models import create_model, count_parameters
from src.train import Trainer, save_training_history
from src.plotting import plot_training_curves, plot_multiple_training_curves, save_figure

## 1. Create Datasets and DataLoaders

In [None]:
# Create test datasets (same for all models)
test_id = create_env_a_dataset(train=False, config=config)  # ID test
test_ood = create_no_patch_dataset(train=False, config=config)  # OOD test

# Create training datasets
train_spurious = create_env_a_dataset(train=True, config=config)  # For A1, A2
train_robust = create_mixed_env_dataset(env_a_fraction=0.5, train=True, config=config)  # For R1, R2

print("Datasets created:")
print(f"  Spurious training (Env A): {len(train_spurious)} samples")
print(f"  Robust training (Mixed): {len(train_robust)} samples")
print(f"  ID test: {len(test_id)} samples")
print(f"  OOD test: {len(test_ood)} samples")

In [None]:
# Create dataloaders
batch_size = config['training']['batch_size']
num_workers = config['training']['num_workers']

loaders_spurious = get_dataloaders(train_spurious, test_id, test_ood, config)
loaders_robust = get_dataloaders(train_robust, test_id, test_ood, config)

print(f"\nDataLoaders created with batch_size={batch_size}")

## 2. Training Function

In [None]:
def train_model(name, seed, dataloaders, config, device):
    """
    Train a single model and save checkpoint + history.
    
    Args:
        name: Model name (e.g., 'A1', 'R1')
        seed: Random seed for this model
        dataloaders: Dictionary with 'train', 'test_id', 'test_ood' loaders
        config: Configuration dictionary
        device: Torch device
    
    Returns:
        model: Trained model
        history: Training history dictionary
    """
    print(f"\n{'='*60}")
    print(f"Training Model {name} (seed={seed})")
    print(f"{'='*60}")
    
    # Set seed for reproducibility
    set_seed(seed)
    
    # Create model
    model = create_model(config)
    model = model.to(device)
    
    print(f"Model parameters: {count_parameters(model):,}")
    
    # Create trainer
    trainer = Trainer(model, device, config)
    
    # Train
    checkpoint_path = CHECKPOINTS_DIR / f"model_{name}.pt"
    history = trainer.train(
        train_loader=dataloaders['train'],
        test_id_loader=dataloaders['test_id'],
        test_ood_loader=dataloaders['test_ood'],
        num_epochs=config['training']['num_epochs'],
        verbose=True,
        checkpoint_path=checkpoint_path,
    )
    
    # Save history
    history_path = METRICS_DIR / f"history_{name}.json"
    save_training_history(history, history_path)
    
    print(f"\nModel {name} training complete!")
    print(f"  Final ID accuracy: {history['id_acc'][-1]*100:.2f}%")
    print(f"  Final OOD accuracy: {history['ood_acc'][-1]*100:.2f}%")
    print(f"  Checkpoint saved to: {checkpoint_path}")
    print(f"  History saved to: {history_path}")
    
    return model, history

## 3. Train Spurious Model A1

In [None]:
model_A1, history_A1 = train_model(
    name='A1',
    seed=config['seeds']['model_A1'],
    dataloaders=loaders_spurious,
    config=config,
    device=device,
)

## 4. Train Spurious Model A2

In [None]:
model_A2, history_A2 = train_model(
    name='A2',
    seed=config['seeds']['model_A2'],
    dataloaders=loaders_spurious,
    config=config,
    device=device,
)

## 5. Train Robust Model R1

In [None]:
model_R1, history_R1 = train_model(
    name='R1',
    seed=config['seeds']['model_R1'],
    dataloaders=loaders_robust,
    config=config,
    device=device,
)

## 6. Train Robust Model R2

In [None]:
model_R2, history_R2 = train_model(
    name='R2',
    seed=config['seeds']['model_R2'],
    dataloaders=loaders_robust,
    config=config,
    device=device,
)

## 7. Visualize Training Curves

In [None]:
# Individual training curves for each model
import matplotlib.pyplot as plt

fig = plot_training_curves(history_A1, title="Model A1 (Spurious, seed 1)", save_name='training_A1')
plt.show()

fig = plot_training_curves(history_A2, title="Model A2 (Spurious, seed 2)", save_name='training_A2')
plt.show()

fig = plot_training_curves(history_R1, title="Model R1 (Robust, seed 1)", save_name='training_R1')
plt.show()

fig = plot_training_curves(history_R2, title="Model R2 (Robust, seed 2)", save_name='training_R2')
plt.show()

In [None]:
# Comparison plots
histories = {
    'A1 (Spurious)': history_A1,
    'A2 (Spurious)': history_A2,
    'R1 (Robust)': history_R1,
    'R2 (Robust)': history_R2,
}

# ID accuracy comparison
fig = plot_multiple_training_curves(
    histories, 
    metric='id_acc', 
    title='ID Test Accuracy Comparison',
    save_name='comparison_id_acc'
)
plt.show()

# OOD accuracy comparison
fig = plot_multiple_training_curves(
    histories, 
    metric='ood_acc', 
    title='OOD Test Accuracy Comparison',
    save_name='comparison_ood_acc'
)
plt.show()

In [None]:
# Combined comparison figure
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

colors = {'A1': 'tab:red', 'A2': 'tab:orange', 'R1': 'tab:blue', 'R2': 'tab:green'}
linestyles = {'A1': '-', 'A2': '--', 'R1': '-', 'R2': '--'}

for name, history in [('A1', history_A1), ('A2', history_A2), ('R1', history_R1), ('R2', history_R2)]:
    epochs = range(1, len(history['id_acc']) + 1)
    
    axes[0].plot(epochs, [x*100 for x in history['id_acc']], 
                 label=name, color=colors[name], linestyle=linestyles[name], linewidth=2)
    axes[1].plot(epochs, [x*100 for x in history['ood_acc']], 
                 label=name, color=colors[name], linestyle=linestyles[name], linewidth=2)

axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy (%)')
axes[0].set_title('ID Test Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('OOD Test Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.suptitle('Training Progress: Spurious vs Robust Models', fontsize=14, y=1.02)
plt.tight_layout()
save_figure(fig, 'training_comparison_combined')
plt.show()

## 8. Summary Statistics

In [None]:
# Compute and display summary statistics
summary = {}

for name, history in [('A1', history_A1), ('A2', history_A2), ('R1', history_R1), ('R2', history_R2)]:
    final_id_acc = history['id_acc'][-1]
    final_ood_acc = history['ood_acc'][-1]
    ood_drop = final_id_acc - final_ood_acc
    
    summary[name] = {
        'id_acc': final_id_acc,
        'ood_acc': final_ood_acc,
        'ood_drop': ood_drop,
        'final_train_loss': history['train_loss'][-1],
    }

print("\nFinal Model Performance:")
print("=" * 70)
print(f"{'Model':<10} {'ID Acc':<12} {'OOD Acc':<12} {'OOD Drop':<12} {'Train Loss':<12}")
print("-" * 70)

for name, stats in summary.items():
    model_type = "Spurious" if name.startswith('A') else "Robust"
    print(f"{name} ({model_type[0]})   {stats['id_acc']*100:>6.2f}%      {stats['ood_acc']*100:>6.2f}%      "
          f"{stats['ood_drop']*100:>+6.2f}%      {stats['final_train_loss']:.4f}")

print("=" * 70)

In [None]:
# Verify expected behavior
print("\nVerification:")
print("-" * 50)

# Spurious models should have large OOD drop
spurious_avg_drop = (summary['A1']['ood_drop'] + summary['A2']['ood_drop']) / 2
robust_avg_drop = (summary['R1']['ood_drop'] + summary['R2']['ood_drop']) / 2

print(f"Average OOD drop (Spurious A1, A2): {spurious_avg_drop*100:.2f}%")
print(f"Average OOD drop (Robust R1, R2):   {robust_avg_drop*100:.2f}%")

if spurious_avg_drop > robust_avg_drop + 0.1:
    print("\n[PASS] Spurious models show significantly larger OOD drop!")
else:
    print("\n[WARNING] OOD drop difference is smaller than expected.")
    print("          This may affect the semantic barrier analysis.")

In [None]:
# Save summary to JSON
summary_path = METRICS_DIR / 'training_summary.json'
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)
print(f"\nSummary saved to: {summary_path}")

## 9. Final Summary

In [None]:
print("\n" + "=" * 60)
print("MODEL TRAINING COMPLETE")
print("=" * 60)
print(f"""
Models trained:

Spurious Models (trained on Env A only):
  - A1: ID={summary['A1']['id_acc']*100:.1f}%, OOD={summary['A1']['ood_acc']*100:.1f}%, Drop={summary['A1']['ood_drop']*100:+.1f}%
  - A2: ID={summary['A2']['id_acc']*100:.1f}%, OOD={summary['A2']['ood_acc']*100:.1f}%, Drop={summary['A2']['ood_drop']*100:+.1f}%

Robust Models (trained on mixed Env A + B):
  - R1: ID={summary['R1']['id_acc']*100:.1f}%, OOD={summary['R1']['ood_acc']*100:.1f}%, Drop={summary['R1']['ood_drop']*100:+.1f}%
  - R2: ID={summary['R2']['id_acc']*100:.1f}%, OOD={summary['R2']['ood_acc']*100:.1f}%, Drop={summary['R2']['ood_drop']*100:+.1f}%

Checkpoints saved:
  - {CHECKPOINTS_DIR / 'model_A1.pt'}
  - {CHECKPOINTS_DIR / 'model_A2.pt'}
  - {CHECKPOINTS_DIR / 'model_R1.pt'}
  - {CHECKPOINTS_DIR / 'model_R2.pt'}

Training histories saved:
  - {METRICS_DIR / 'history_A1.json'}
  - {METRICS_DIR / 'history_A2.json'}
  - {METRICS_DIR / 'history_R1.json'}
  - {METRICS_DIR / 'history_R2.json'}

Figures saved:
  - Training curves for each model
  - Comparison plots

Next: Run 03_mechanism_verification.ipynb to quantify spurious reliance.
""")

# 03 - Mechanism Verification

This notebook quantifies the "spurious vs robust" reliance for each trained model.

## Metrics computed:
1. **OOD Drop**: Acc(ID) - Acc(OOD)
2. **Counterfactual Patch Sensitivity**:
   - Accuracy change when patch color is swapped
   - Mean change in true-class logit
3. **Spurious Reliance Score (SRS)**: Combined metric

## Spurious Reliance Score (SRS) Formula:
```
SRS = 0.4 * OOD_drop + 0.3 * CF_accuracy_drop + 0.3 * CF_flip_rate
```
Where:
- OOD_drop = ID_accuracy - OOD_accuracy
- CF_accuracy_drop = Original_accuracy - Counterfactual_accuracy
- CF_flip_rate = Fraction of correct predictions that flip when patch is swapped

In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import torch
import json
import numpy as np
import matplotlib.pyplot as plt

from src.config import (
    get_config, set_seed, get_device,
    CHECKPOINTS_DIR, FIGURES_DIR, METRICS_DIR
)

config = get_config()
set_seed(config['seeds']['global'])
device = get_device()

print(f"Device: {device}")

In [None]:
from src.data import (
    create_env_a_dataset,
    create_no_patch_dataset,
    SpuriousPatchDataset,
    CounterfactualPatchDataset,
    get_transforms,
    DATA_DIR,
)
from src.models import create_model
from src.train import load_model, evaluate_model
from src.metrics import (
    compute_ood_drop,
    compute_patch_sensitivity,
    compute_spurious_reliance_score,
    compute_class_wise_accuracy,
)
from src.plotting import plot_spurious_reliance_comparison, save_figure
from torch.utils.data import DataLoader

## 1. Load Trained Models

In [None]:
# Load all 4 models
model_names = ['A1', 'A2', 'R1', 'R2']
models = {}

for name in model_names:
    checkpoint_path = CHECKPOINTS_DIR / f"model_{name}.pt"
    if not checkpoint_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}\n"
                               f"Please run 02_train_models.ipynb first.")
    
    model = create_model(config)
    model = load_model(model, checkpoint_path, device)
    models[name] = model
    print(f"Loaded model {name} from {checkpoint_path}")

print(f"\nAll {len(models)} models loaded successfully!")

## 2. Create Test Datasets

In [None]:
# Create test datasets
test_id = create_env_a_dataset(train=False, config=config)  # ID test (with aligned patches)
test_ood = create_no_patch_dataset(train=False, config=config)  # OOD test (no patches)

# Create DataLoaders
batch_size = config['training']['batch_size']
num_workers = config['training']['num_workers']

id_loader = DataLoader(test_id, batch_size=batch_size, shuffle=False, num_workers=num_workers)
ood_loader = DataLoader(test_ood, batch_size=batch_size, shuffle=False, num_workers=num_workers)

print(f"Test datasets:")
print(f"  ID test (Env A): {len(test_id)} samples")
print(f"  OOD test (No patch): {len(test_ood)} samples")

In [None]:
# Create counterfactual dataset for patch sensitivity analysis
# We use the ID test set as the base
cf_dataset = CounterfactualPatchDataset(
    base_dataset=test_id,
    swap_mode='random_wrong',
)

print(f"Counterfactual dataset: {len(cf_dataset)} samples")
print("  Each sample provides: (original_img, label, counterfactual_img)")

## 3. Compute Basic Accuracy Metrics

In [None]:
# Compute ID and OOD accuracy for all models
print("Computing ID and OOD accuracy...\n")

accuracy_results = {}

for name, model in models.items():
    _, id_acc = evaluate_model(model, id_loader, device)
    _, ood_acc = evaluate_model(model, ood_loader, device)
    ood_drop = compute_ood_drop(id_acc, ood_acc)
    
    accuracy_results[name] = {
        'id_acc': id_acc,
        'ood_acc': ood_acc,
        'ood_drop': ood_drop,
    }
    
    print(f"Model {name}:")
    print(f"  ID Accuracy:  {id_acc*100:.2f}%")
    print(f"  OOD Accuracy: {ood_acc*100:.2f}%")
    print(f"  OOD Drop:     {ood_drop*100:+.2f}%\n")

## 4. Compute Counterfactual Patch Sensitivity

In [None]:
# Compute patch sensitivity for all models
print("Computing counterfactual patch sensitivity...\n")

sensitivity_results = {}

for name, model in models.items():
    print(f"\nAnalyzing Model {name}...")
    sensitivity = compute_patch_sensitivity(
        model, cf_dataset, device,
        batch_size=batch_size,
        num_workers=num_workers,
    )
    sensitivity_results[name] = sensitivity
    
    print(f"  Original Accuracy:      {sensitivity['original_accuracy']*100:.2f}%")
    print(f"  Counterfactual Accuracy: {sensitivity['counterfactual_accuracy']*100:.2f}%")
    print(f"  Accuracy Drop:          {sensitivity['accuracy_drop']*100:+.2f}%")
    print(f"  Prediction Flip Rate:   {sensitivity['prediction_flip_rate']*100:.2f}%")
    print(f"  Mean Logit Change:      {sensitivity['mean_logit_change']:.3f}")

## 5. Compute Spurious Reliance Score (SRS)

In [None]:
# Compute full SRS for all models
print("Computing Spurious Reliance Score (SRS)...\n")

srs_results = {}

for name, model in models.items():
    print(f"\nModel {name}:")
    srs = compute_spurious_reliance_score(
        model, id_loader, ood_loader, cf_dataset, device
    )
    srs_results[name] = srs
    
    print(f"  Spurious Reliance Score: {srs['spurious_reliance_score']:.4f}")
    print(f"  Components:")
    print(f"    - OOD Drop:        {srs['ood_drop']*100:.2f}% (weight: 0.4)")
    print(f"    - CF Acc Drop:     {srs['cf_accuracy_drop']*100:.2f}% (weight: 0.3)")
    print(f"    - CF Flip Rate:    {srs['cf_flip_rate']*100:.2f}% (weight: 0.3)")

## 6. Visualize Results

In [None]:
# Plot SRS comparison
fig = plot_spurious_reliance_comparison(
    srs_results,
    title='Spurious Reliance Metrics by Model',
    save_name='spurious_reliance_comparison'
)
plt.show()

In [None]:
# Create detailed comparison figure
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

model_names = list(srs_results.keys())
x = np.arange(len(model_names))
colors = ['#e74c3c', '#e67e22', '#3498db', '#2ecc71']

# 1. SRS Score
srs_values = [srs_results[m]['spurious_reliance_score'] for m in model_names]
bars = axes[0, 0].bar(x, srs_values, color=colors)
axes[0, 0].set_xticks(x)
axes[0, 0].set_xticklabels(model_names)
axes[0, 0].set_ylabel('SRS')
axes[0, 0].set_title('Spurious Reliance Score (SRS)')
axes[0, 0].grid(True, alpha=0.3, axis='y')
for bar, val in zip(bars, srs_values):
    axes[0, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{val:.3f}', ha='center', va='bottom', fontsize=10)

# 2. ID vs OOD Accuracy
id_accs = [srs_results[m]['id_accuracy']*100 for m in model_names]
ood_accs = [srs_results[m]['ood_accuracy']*100 for m in model_names]
width = 0.35
axes[0, 1].bar(x - width/2, id_accs, width, label='ID Acc', color='steelblue')
axes[0, 1].bar(x + width/2, ood_accs, width, label='OOD Acc', color='coral')
axes[0, 1].set_xticks(x)
axes[0, 1].set_xticklabels(model_names)
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].set_title('ID vs OOD Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3, axis='y')

# 3. Counterfactual Accuracy Drop
cf_drops = [srs_results[m]['cf_accuracy_drop']*100 for m in model_names]
bars = axes[1, 0].bar(x, cf_drops, color=colors)
axes[1, 0].set_xticks(x)
axes[1, 0].set_xticklabels(model_names)
axes[1, 0].set_ylabel('Accuracy Drop (%)')
axes[1, 0].set_title('Counterfactual Patch Accuracy Drop')
axes[1, 0].grid(True, alpha=0.3, axis='y')

# 4. Prediction Flip Rate
flip_rates = [srs_results[m]['cf_flip_rate']*100 for m in model_names]
bars = axes[1, 1].bar(x, flip_rates, color=colors)
axes[1, 1].set_xticks(x)
axes[1, 1].set_xticklabels(model_names)
axes[1, 1].set_ylabel('Flip Rate (%)')
axes[1, 1].set_title('Prediction Flip Rate (when patch swapped)')
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.suptitle('Mechanism Verification: Spurious vs Robust Models', fontsize=14, y=1.02)
plt.tight_layout()
save_figure(fig, 'mechanism_verification_detailed')
plt.show()

## 7. Statistical Summary

In [None]:
# Create summary table
print("\n" + "=" * 90)
print("MECHANISM VERIFICATION SUMMARY")
print("=" * 90)

print(f"\n{'Model':<8} {'Type':<10} {'ID Acc':<10} {'OOD Acc':<10} {'OOD Drop':<10} "
      f"{'CF Drop':<10} {'Flip Rate':<10} {'SRS':<8}")
print("-" * 90)

for name in model_names:
    model_type = "Spurious" if name.startswith('A') else "Robust"
    srs = srs_results[name]
    print(f"{name:<8} {model_type:<10} {srs['id_accuracy']*100:>6.2f}%   "
          f"{srs['ood_accuracy']*100:>6.2f}%   {srs['ood_drop']*100:>+6.2f}%   "
          f"{srs['cf_accuracy_drop']*100:>6.2f}%   {srs['cf_flip_rate']*100:>6.2f}%   "
          f"{srs['spurious_reliance_score']:.4f}")

print("=" * 90)

In [None]:
# Compute group statistics
spurious_models = ['A1', 'A2']
robust_models = ['R1', 'R2']

spurious_avg_srs = np.mean([srs_results[m]['spurious_reliance_score'] for m in spurious_models])
robust_avg_srs = np.mean([srs_results[m]['spurious_reliance_score'] for m in robust_models])

spurious_avg_drop = np.mean([srs_results[m]['ood_drop'] for m in spurious_models])
robust_avg_drop = np.mean([srs_results[m]['ood_drop'] for m in robust_models])

print("\nGroup Statistics:")
print("-" * 50)
print(f"Spurious Models (A1, A2):")
print(f"  Average SRS:      {spurious_avg_srs:.4f}")
print(f"  Average OOD Drop: {spurious_avg_drop*100:.2f}%")
print(f"\nRobust Models (R1, R2):")
print(f"  Average SRS:      {robust_avg_srs:.4f}")
print(f"  Average OOD Drop: {robust_avg_drop*100:.2f}%")
print(f"\nSRS Ratio (Spurious/Robust): {spurious_avg_srs/robust_avg_srs:.2f}x")

In [None]:
# Verification checks
print("\nVerification Checks:")
print("-" * 50)

# Check 1: Spurious models should have higher SRS
if spurious_avg_srs > robust_avg_srs:
    print("[PASS] Spurious models have higher SRS than robust models")
else:
    print("[FAIL] Expected spurious models to have higher SRS")

# Check 2: Spurious models should have larger OOD drop
if spurious_avg_drop > robust_avg_drop:
    print("[PASS] Spurious models have larger OOD accuracy drop")
else:
    print("[FAIL] Expected spurious models to have larger OOD drop")

# Check 3: All spurious models should have SRS > 0.1 (reasonable threshold)
all_spurious_high_srs = all(srs_results[m]['spurious_reliance_score'] > 0.05 for m in spurious_models)
if all_spurious_high_srs:
    print("[PASS] All spurious models have SRS > 0.05")
else:
    print("[WARN] Some spurious models have low SRS")

## 8. Save Results

In [None]:
# Save all metrics to JSON
mechanism_results = {
    'srs_results': {k: {kk: float(vv) for kk, vv in v.items()} 
                   for k, v in srs_results.items()},
    'group_statistics': {
        'spurious_avg_srs': float(spurious_avg_srs),
        'robust_avg_srs': float(robust_avg_srs),
        'spurious_avg_ood_drop': float(spurious_avg_drop),
        'robust_avg_ood_drop': float(robust_avg_drop),
        'srs_ratio': float(spurious_avg_srs / robust_avg_srs) if robust_avg_srs > 0 else float('inf'),
    }
}

results_path = METRICS_DIR / 'mechanism_verification.json'
with open(results_path, 'w') as f:
    json.dump(mechanism_results, f, indent=2)

print(f"Results saved to: {results_path}")

## 9. Summary

In [None]:
print("\n" + "=" * 60)
print("MECHANISM VERIFICATION COMPLETE")
print("=" * 60)
print(f"""
Key Findings:

1. Spurious Reliance Score (SRS):
   - Spurious models (A1, A2): Average SRS = {spurious_avg_srs:.4f}
   - Robust models (R1, R2):   Average SRS = {robust_avg_srs:.4f}
   - Ratio: {spurious_avg_srs/robust_avg_srs:.2f}x higher for spurious models

2. OOD Accuracy Drop:
   - Spurious models: {spurious_avg_drop*100:.1f}% average drop
   - Robust models:   {robust_avg_drop*100:.1f}% average drop

3. Interpretation:
   - Spurious models (A1, A2) rely heavily on the colored patch
   - When the patch is removed (OOD) or swapped (CF), accuracy drops significantly
   - Robust models (R1, R2) learned more content-based features
   - They are less sensitive to patch manipulation

SRS Formula:
   SRS = 0.4 * OOD_drop + 0.3 * CF_accuracy_drop + 0.3 * flip_rate

Files saved:
   - {METRICS_DIR / 'mechanism_verification.json'}
   - {FIGURES_DIR / 'spurious_reliance_comparison.png'}
   - {FIGURES_DIR / 'mechanism_verification_detailed.png'}

Next: Run 04_rebasin_alignment.ipynb to perform weight matching.
""")

# 04 - Git Re-Basin Alignment

This notebook implements Git Re-Basin (weight matching) to align independently trained models.

## What is Git Re-Basin?
Neural networks have permutation symmetries - you can reorder neurons in a layer and get the same function. Git Re-Basin finds these permutations to align two models, enabling meaningful weight interpolation.

## Model Pairs to Align:
1. **A1 vs A2** (spurious-spurious): Should align well (same mechanism)
2. **R1 vs R2** (robust-robust): Should align well (same mechanism)
3. **A1 vs R1** (spurious-robust): May not align well (different mechanisms)

## What this notebook does:
1. Implements weight matching permutation finding
2. Aligns model pairs
3. Verifies alignment improves functional similarity
4. Saves aligned models for interpolation analysis

In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import torch
import json
import numpy as np
import matplotlib.pyplot as plt

from src.config import (
    get_config, set_seed, get_device,
    CHECKPOINTS_DIR, FIGURES_DIR, METRICS_DIR
)

config = get_config()
set_seed(config['seeds']['global'])
device = get_device()

print(f"Device: {device}")

In [None]:
from src.data import create_env_a_dataset, create_no_patch_dataset
from src.models import create_model, model_agreement, clone_model
from src.train import load_model, evaluate_model
from src.rebasin import (
    rebasin,
    weight_matching,
    apply_permutations,
    compute_weight_distance,
    compute_cosine_similarity,
)
from src.plotting import save_figure
from torch.utils.data import DataLoader

## 1. Load Trained Models

In [None]:
# Load all 4 models
model_names = ['A1', 'A2', 'R1', 'R2']
models = {}

for name in model_names:
    checkpoint_path = CHECKPOINTS_DIR / f"model_{name}.pt"
    if not checkpoint_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}\n"
                               f"Please run 02_train_models.ipynb first.")
    
    model = create_model(config)
    model = load_model(model, checkpoint_path, device)
    models[name] = model
    print(f"Loaded model {name}")

print(f"\nAll {len(models)} models loaded!")

## 2. Create Test DataLoader

In [None]:
# Create test dataset for agreement computation
test_id = create_env_a_dataset(train=False, config=config)
test_ood = create_no_patch_dataset(train=False, config=config)

batch_size = config['training']['batch_size']
num_workers = config['training']['num_workers']

id_loader = DataLoader(test_id, batch_size=batch_size, shuffle=False, num_workers=num_workers)
ood_loader = DataLoader(test_ood, batch_size=batch_size, shuffle=False, num_workers=num_workers)

print(f"Test loaders created with {len(test_id)} ID and {len(test_ood)} OOD samples")

## 3. Compute Pre-Rebasin Metrics

In [None]:
# Define model pairs to analyze
model_pairs = [
    ('A1', 'A2', 'spurious-spurious'),
    ('R1', 'R2', 'robust-robust'),
    ('A1', 'R1', 'spurious-robust'),
]

# Compute pre-rebasin metrics
pre_rebasin_metrics = {}

print("Computing pre-rebasin metrics...\n")
print(f"{'Pair':<25} {'Weight Dist':<15} {'Cosine Sim':<15} {'Agreement':<15}")
print("-" * 70)

for name1, name2, pair_type in model_pairs:
    pair_name = f"{name1}-{name2}"
    
    # Weight space metrics
    weight_dist = compute_weight_distance(models[name1], models[name2])
    cosine_sim = compute_cosine_similarity(models[name1], models[name2])
    
    # Functional agreement
    agreement = model_agreement(models[name1], models[name2], id_loader, device)
    
    pre_rebasin_metrics[pair_name] = {
        'type': pair_type,
        'weight_distance': weight_dist,
        'cosine_similarity': cosine_sim,
        'agreement': agreement,
    }
    
    print(f"{pair_name} ({pair_type[:7]}...)  {weight_dist:>10.2f}     {cosine_sim:>10.4f}     {agreement*100:>10.2f}%")

## 4. Perform Git Re-Basin (Weight Matching)

In [None]:
def perform_rebasin(model_ref, model_to_align, name_ref, name_align, device):
    """
    Perform Git Re-Basin alignment.
    
    Args:
        model_ref: Reference model (we align TO this)
        model_to_align: Model to be aligned
        name_ref: Name of reference model
        name_align: Name of model to align
        device: Torch device
    
    Returns:
        Aligned model
    """
    print(f"\nAligning {name_align} to {name_ref}...")
    
    # Clone to avoid modifying original
    model_to_align_copy = clone_model(model_to_align).to(device)
    
    # Perform rebasing
    aligned_model = rebasin(model_ref, model_to_align_copy, device)
    
    print(f"  Alignment complete!")
    
    return aligned_model

In [None]:
# Perform rebasing for all pairs
aligned_models = {}

for name1, name2, pair_type in model_pairs:
    pair_name = f"{name1}-{name2}"
    
    # Align model2 to model1
    aligned = perform_rebasin(
        models[name1], models[name2],
        name1, name2, device
    )
    aligned_models[pair_name] = aligned
    
    # Verify alignment preserves accuracy
    _, orig_acc = evaluate_model(models[name2], id_loader, device)
    _, aligned_acc = evaluate_model(aligned, id_loader, device)
    
    print(f"  Original {name2} accuracy: {orig_acc*100:.2f}%")
    print(f"  Aligned {name2} accuracy:  {aligned_acc*100:.2f}%")
    
    if abs(orig_acc - aligned_acc) > 0.01:
        print(f"  [WARNING] Accuracy changed significantly after alignment!")
    else:
        print(f"  [OK] Accuracy preserved")

## 5. Compute Post-Rebasin Metrics

In [None]:
# Compute post-rebasin metrics
post_rebasin_metrics = {}

print("\nComputing post-rebasin metrics...\n")
print(f"{'Pair':<25} {'Weight Dist':<15} {'Cosine Sim':<15} {'Agreement':<15}")
print("-" * 70)

for name1, name2, pair_type in model_pairs:
    pair_name = f"{name1}-{name2}"
    aligned = aligned_models[pair_name]
    
    # Weight space metrics (reference model vs aligned model)
    weight_dist = compute_weight_distance(models[name1], aligned)
    cosine_sim = compute_cosine_similarity(models[name1], aligned)
    
    # Functional agreement
    agreement = model_agreement(models[name1], aligned, id_loader, device)
    
    post_rebasin_metrics[pair_name] = {
        'type': pair_type,
        'weight_distance': weight_dist,
        'cosine_similarity': cosine_sim,
        'agreement': agreement,
    }
    
    print(f"{pair_name} ({pair_type[:7]}...)  {weight_dist:>10.2f}     {cosine_sim:>10.4f}     {agreement*100:>10.2f}%")

## 6. Compare Pre vs Post Rebasin

In [None]:
# Compare metrics
print("\n" + "=" * 80)
print("COMPARISON: Pre vs Post Re-Basin")
print("=" * 80)

comparison_data = {}

for pair_name in pre_rebasin_metrics.keys():
    pre = pre_rebasin_metrics[pair_name]
    post = post_rebasin_metrics[pair_name]
    
    print(f"\n{pair_name} ({pre['type']}):")
    print(f"  Weight Distance: {pre['weight_distance']:.2f} -> {post['weight_distance']:.2f} "
          f"({post['weight_distance'] - pre['weight_distance']:+.2f})")
    print(f"  Cosine Similarity: {pre['cosine_similarity']:.4f} -> {post['cosine_similarity']:.4f} "
          f"({post['cosine_similarity'] - pre['cosine_similarity']:+.4f})")
    print(f"  Agreement: {pre['agreement']*100:.2f}% -> {post['agreement']*100:.2f}% "
          f"({(post['agreement'] - pre['agreement'])*100:+.2f}%)")
    
    comparison_data[pair_name] = {
        'type': pre['type'],
        'pre_weight_dist': pre['weight_distance'],
        'post_weight_dist': post['weight_distance'],
        'weight_dist_change': post['weight_distance'] - pre['weight_distance'],
        'pre_cosine_sim': pre['cosine_similarity'],
        'post_cosine_sim': post['cosine_similarity'],
        'cosine_sim_change': post['cosine_similarity'] - pre['cosine_similarity'],
        'pre_agreement': pre['agreement'],
        'post_agreement': post['agreement'],
        'agreement_change': post['agreement'] - pre['agreement'],
    }

In [None]:
# Visualize comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

pair_names = list(comparison_data.keys())
x = np.arange(len(pair_names))
width = 0.35

# Weight Distance
pre_dists = [comparison_data[p]['pre_weight_dist'] for p in pair_names]
post_dists = [comparison_data[p]['post_weight_dist'] for p in pair_names]
axes[0].bar(x - width/2, pre_dists, width, label='Pre-Rebasin', color='salmon')
axes[0].bar(x + width/2, post_dists, width, label='Post-Rebasin', color='steelblue')
axes[0].set_xticks(x)
axes[0].set_xticklabels(pair_names, rotation=45, ha='right')
axes[0].set_ylabel('Weight Distance')
axes[0].set_title('Weight Distance')
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis='y')

# Cosine Similarity
pre_sims = [comparison_data[p]['pre_cosine_sim'] for p in pair_names]
post_sims = [comparison_data[p]['post_cosine_sim'] for p in pair_names]
axes[1].bar(x - width/2, pre_sims, width, label='Pre-Rebasin', color='salmon')
axes[1].bar(x + width/2, post_sims, width, label='Post-Rebasin', color='steelblue')
axes[1].set_xticks(x)
axes[1].set_xticklabels(pair_names, rotation=45, ha='right')
axes[1].set_ylabel('Cosine Similarity')
axes[1].set_title('Cosine Similarity')
axes[1].legend()
axes[1].grid(True, alpha=0.3, axis='y')

# Agreement
pre_agrs = [comparison_data[p]['pre_agreement']*100 for p in pair_names]
post_agrs = [comparison_data[p]['post_agreement']*100 for p in pair_names]
axes[2].bar(x - width/2, pre_agrs, width, label='Pre-Rebasin', color='salmon')
axes[2].bar(x + width/2, post_agrs, width, label='Post-Rebasin', color='steelblue')
axes[2].set_xticks(x)
axes[2].set_xticklabels(pair_names, rotation=45, ha='right')
axes[2].set_ylabel('Agreement (%)')
axes[2].set_title('Prediction Agreement')
axes[2].legend()
axes[2].grid(True, alpha=0.3, axis='y')

plt.suptitle('Git Re-Basin Effect: Pre vs Post Alignment', fontsize=14, y=1.02)
plt.tight_layout()
save_figure(fig, 'rebasin_comparison')
plt.show()

## 7. Sanity Check: Same-Mechanism Pairs Should Improve More

In [None]:
print("\nSanity Check: Rebasin Effectiveness by Pair Type")
print("=" * 60)

# Group by type
same_mech_pairs = [p for p in pair_names if 'spurious-spurious' in comparison_data[p]['type'] or 
                  'robust-robust' in comparison_data[p]['type']]
diff_mech_pairs = [p for p in pair_names if 'spurious-robust' in comparison_data[p]['type']]

# Compute average improvements
same_mech_sim_change = np.mean([comparison_data[p]['cosine_sim_change'] for p in same_mech_pairs])
diff_mech_sim_change = np.mean([comparison_data[p]['cosine_sim_change'] for p in diff_mech_pairs])

same_mech_agr_change = np.mean([comparison_data[p]['agreement_change']*100 for p in same_mech_pairs])
diff_mech_agr_change = np.mean([comparison_data[p]['agreement_change']*100 for p in diff_mech_pairs])

print(f"\nSame-mechanism pairs ({', '.join(same_mech_pairs)}):")
print(f"  Average cosine similarity change: {same_mech_sim_change:+.4f}")
print(f"  Average agreement change: {same_mech_agr_change:+.2f}%")

print(f"\nDifferent-mechanism pairs ({', '.join(diff_mech_pairs)}):")
print(f"  Average cosine similarity change: {diff_mech_sim_change:+.4f}")
print(f"  Average agreement change: {diff_mech_agr_change:+.2f}%")

# Verification
print("\nVerification:")
if same_mech_sim_change >= diff_mech_sim_change:
    print("[PASS] Same-mechanism pairs show better alignment (cosine sim)")
else:
    print("[INFO] Different-mechanism pairs showed more improvement")
    print("       This is not necessarily expected - rebasing is agnostic to mechanism.")

## 8. Save Aligned Models

In [None]:
# Save aligned models for interpolation analysis
print("\nSaving aligned models...")

for pair_name, aligned_model in aligned_models.items():
    # Name: model_A2_aligned_to_A1.pt
    name1, name2 = pair_name.split('-')
    save_path = CHECKPOINTS_DIR / f"model_{name2}_aligned_to_{name1}.pt"
    
    torch.save({
        'model_state_dict': aligned_model.state_dict(),
        'reference_model': name1,
        'aligned_model': name2,
        'pair_type': comparison_data[pair_name]['type'],
    }, save_path)
    
    print(f"  Saved: {save_path}")

In [None]:
# Save rebasin metrics
rebasin_results = {
    'pre_rebasin': {k: {kk: float(vv) if isinstance(vv, (float, np.floating)) else vv 
                       for kk, vv in v.items()} 
                   for k, v in pre_rebasin_metrics.items()},
    'post_rebasin': {k: {kk: float(vv) if isinstance(vv, (float, np.floating)) else vv 
                        for kk, vv in v.items()} 
                    for k, v in post_rebasin_metrics.items()},
    'comparison': {k: {kk: float(vv) if isinstance(vv, (float, np.floating)) else vv 
                      for kk, vv in v.items()} 
                  for k, v in comparison_data.items()},
}

results_path = METRICS_DIR / 'rebasin_results.json'
with open(results_path, 'w') as f:
    json.dump(rebasin_results, f, indent=2)

print(f"\nResults saved to: {results_path}")

## 9. Summary

In [None]:
print("\n" + "=" * 60)
print("GIT RE-BASIN ALIGNMENT COMPLETE")
print("=" * 60)
print(f"""
Model pairs aligned:

1. A1-A2 (spurious-spurious):
   - Cosine sim: {comparison_data['A1-A2']['pre_cosine_sim']:.4f} -> {comparison_data['A1-A2']['post_cosine_sim']:.4f}
   - Agreement: {comparison_data['A1-A2']['pre_agreement']*100:.1f}% -> {comparison_data['A1-A2']['post_agreement']*100:.1f}%

2. R1-R2 (robust-robust):
   - Cosine sim: {comparison_data['R1-R2']['pre_cosine_sim']:.4f} -> {comparison_data['R1-R2']['post_cosine_sim']:.4f}
   - Agreement: {comparison_data['R1-R2']['pre_agreement']*100:.1f}% -> {comparison_data['R1-R2']['post_agreement']*100:.1f}%

3. A1-R1 (spurious-robust):
   - Cosine sim: {comparison_data['A1-R1']['pre_cosine_sim']:.4f} -> {comparison_data['A1-R1']['post_cosine_sim']:.4f}
   - Agreement: {comparison_data['A1-R1']['pre_agreement']*100:.1f}% -> {comparison_data['A1-R1']['post_agreement']*100:.1f}%

Key observations:
- Git Re-Basin increases weight space similarity (cosine sim up)
- Functional agreement may or may not increase significantly
- Same-mechanism pairs tend to align better than different-mechanism pairs

Aligned models saved:
- {CHECKPOINTS_DIR / 'model_A2_aligned_to_A1.pt'}
- {CHECKPOINTS_DIR / 'model_R2_aligned_to_R1.pt'}
- {CHECKPOINTS_DIR / 'model_R1_aligned_to_A1.pt'}

Metrics saved:
- {METRICS_DIR / 'rebasin_results.json'}

Figures saved:
- {FIGURES_DIR / 'rebasin_comparison.png'}

Next: Run 05_interpolation_and_barriers.ipynb to analyze loss barriers.
""")

# 05 - Interpolation and Loss Barriers

This notebook analyzes loss barriers along weight interpolation paths between model pairs.

## Core Hypothesis:
- **Same-mechanism pairs** (spurious-spurious, robust-robust): Should have LOW barriers after rebasin
- **Different-mechanism pairs** (spurious-robust): Should have HIGH barriers even after rebasin

## What this notebook does:
1. Interpolates weights: θ(α) = α·θ₁ + (1-α)·θ₂ for α ∈ [0, 1]
2. Evaluates ID loss/acc and OOD loss/acc at each α
3. Computes Spurious Reliance Score along the path
4. Calculates barrier heights pre and post rebasing
5. Generates visualization and summary metrics

In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import torch
import json
import numpy as np
import matplotlib.pyplot as plt

from src.config import (
    get_config, set_seed, get_device,
    CHECKPOINTS_DIR, FIGURES_DIR, METRICS_DIR, RESULTS_DIR
)

config = get_config()
set_seed(config['seeds']['global'])
device = get_device()

print(f"Device: {device}")

In [None]:
from src.data import (
    create_env_a_dataset, 
    create_no_patch_dataset,
    CounterfactualPatchDataset,
)
from src.models import create_model
from src.train import load_model
from src.interp import (
    evaluate_interpolation_path,
    evaluate_interpolation_multi_dataset,
    compute_loss_barrier,
    compute_accuracy_barrier,
    summarize_interpolation_results,
    compare_pre_post_rebasin,
)
from src.metrics import compute_spurious_reliance_score, semantic_barrier_metric
from src.plotting import (
    plot_interpolation_path,
    plot_interpolation_comparison,
    plot_pre_post_rebasin,
    plot_barrier_comparison,
    plot_srs_interpolation,
    save_figure,
)
from torch.utils.data import DataLoader

## 1. Load Models (Original and Aligned)

In [None]:
# Load original models
model_names = ['A1', 'A2', 'R1', 'R2']
models = {}

for name in model_names:
    checkpoint_path = CHECKPOINTS_DIR / f"model_{name}.pt"
    model = create_model(config)
    model = load_model(model, checkpoint_path, device)
    models[name] = model
    print(f"Loaded original model {name}")

In [None]:
# Load aligned models
aligned_pairs = [
    ('A1', 'A2'),  # A2 aligned to A1
    ('R1', 'R2'),  # R2 aligned to R1
    ('A1', 'R1'),  # R1 aligned to A1
]

aligned_models = {}

for ref, aligned in aligned_pairs:
    pair_name = f"{ref}-{aligned}"
    checkpoint_path = CHECKPOINTS_DIR / f"model_{aligned}_aligned_to_{ref}.pt"
    
    if checkpoint_path.exists():
        model = create_model(config)
        model = load_model(model, checkpoint_path, device)
        aligned_models[pair_name] = model
        print(f"Loaded aligned model: {aligned} -> {ref}")
    else:
        print(f"[WARNING] Aligned model not found: {checkpoint_path}")
        print(f"          Please run 04_rebasin_alignment.ipynb first.")

## 2. Create DataLoaders

In [None]:
# Create test datasets
test_id = create_env_a_dataset(train=False, config=config)
test_ood = create_no_patch_dataset(train=False, config=config)

batch_size = config['interpolation']['eval_batch_size']
num_workers = config['training']['num_workers']

id_loader = DataLoader(test_id, batch_size=batch_size, shuffle=False, num_workers=num_workers)
ood_loader = DataLoader(test_ood, batch_size=batch_size, shuffle=False, num_workers=num_workers)

dataloaders = {
    'id': id_loader,
    'ood': ood_loader,
}

print(f"DataLoaders created: ID={len(test_id)}, OOD={len(test_ood)} samples")

In [None]:
# Create counterfactual dataset for SRS computation
cf_dataset = CounterfactualPatchDataset(
    base_dataset=test_id,
    swap_mode='random_wrong',
)
print(f"Counterfactual dataset: {len(cf_dataset)} samples")

## 3. Define Interpolation Pairs

In [None]:
# Define pairs to analyze
# Format: (ref_name, other_name, pair_type)
analysis_pairs = [
    ('A1', 'A2', 'spurious-spurious'),
    ('R1', 'R2', 'robust-robust'),
    ('A1', 'R1', 'spurious-robust'),
]

num_alphas = config['interpolation']['num_alphas']
print(f"Will evaluate {num_alphas} interpolation points for each pair")

## 4. Evaluate Interpolation Paths (Pre and Post Rebasin)

In [None]:
# Store all results
all_results = {}

for ref_name, other_name, pair_type in analysis_pairs:
    pair_name = f"{ref_name}-{other_name}"
    print(f"\n{'='*60}")
    print(f"Analyzing pair: {pair_name} ({pair_type})")
    print(f"{'='*60}")
    
    # Get models
    model_ref = models[ref_name]
    model_other = models[other_name]
    model_other_aligned = aligned_models.get(pair_name)
    
    # Pre-rebasin interpolation
    print(f"\nPre-rebasin interpolation...")
    pre_results = evaluate_interpolation_multi_dataset(
        model_ref, model_other, dataloaders, device, num_alphas
    )
    
    # Post-rebasin interpolation (if aligned model exists)
    if model_other_aligned is not None:
        print(f"Post-rebasin interpolation...")
        post_results = evaluate_interpolation_multi_dataset(
            model_ref, model_other_aligned, dataloaders, device, num_alphas
        )
    else:
        print(f"[SKIP] No aligned model available")
        post_results = None
    
    # Store results
    all_results[pair_name] = {
        'type': pair_type,
        'pre_rebasin': pre_results,
        'post_rebasin': post_results,
    }
    
    # Quick summary
    pre_summary = summarize_interpolation_results(pre_results)
    print(f"\nPre-rebasin barriers:")
    print(f"  ID loss barrier:  {pre_summary['id']['loss_barrier']:.4f}")
    print(f"  OOD loss barrier: {pre_summary['ood']['loss_barrier']:.4f}")
    print(f"  ID acc barrier:   {pre_summary['id']['acc_barrier']*100:.2f}%")
    
    if post_results is not None:
        post_summary = summarize_interpolation_results(post_results)
        print(f"\nPost-rebasin barriers:")
        print(f"  ID loss barrier:  {post_summary['id']['loss_barrier']:.4f}")
        print(f"  OOD loss barrier: {post_summary['ood']['loss_barrier']:.4f}")
        print(f"  ID acc barrier:   {post_summary['id']['acc_barrier']*100:.2f}%")

## 5. Visualize Interpolation Paths

In [None]:
# Plot each pair's interpolation path
for pair_name, results in all_results.items():
    pair_type = results['type']
    pre = results['pre_rebasin']
    post = results['post_rebasin']
    
    # Pre-rebasin path
    fig = plot_interpolation_path(
        {'alphas': pre['id']['alphas'], 'losses': pre['id']['losses'], 'accuracies': pre['id']['accuracies']},
        title=f"{pair_name} ({pair_type}) - Pre-Rebasin (ID)",
        save_name=f'interp_{pair_name}_pre_id'
    )
    plt.show()
    
    # Post-rebasin path (if available)
    if post is not None:
        fig = plot_pre_post_rebasin(
            pre['id'], post['id'],
            dataset_name="ID",
            title=f"{pair_name} ({pair_type}): Pre vs Post Rebasin (ID)",
            save_name=f'interp_{pair_name}_pre_vs_post_id'
        )
        plt.show()
        
        fig = plot_pre_post_rebasin(
            pre['ood'], post['ood'],
            dataset_name="OOD",
            title=f"{pair_name} ({pair_type}): Pre vs Post Rebasin (OOD)",
            save_name=f'interp_{pair_name}_pre_vs_post_ood'
        )
        plt.show()

## 6. Compare Barriers Across Pairs

In [None]:
# Compile barrier comparison data
barrier_comparison = {}

for pair_name, results in all_results.items():
    pre = results['pre_rebasin']
    post = results['post_rebasin']
    
    pre_summary = summarize_interpolation_results(pre)
    
    barrier_comparison[pair_name] = {
        'type': results['type'],
        'pre_id_loss_barrier': pre_summary['id']['loss_barrier'],
        'pre_ood_loss_barrier': pre_summary['ood']['loss_barrier'],
        'pre_id_acc_barrier': pre_summary['id']['acc_barrier'],
        'pre_ood_acc_barrier': pre_summary['ood']['acc_barrier'],
    }
    
    if post is not None:
        post_summary = summarize_interpolation_results(post)
        barrier_comparison[pair_name].update({
            'post_id_loss_barrier': post_summary['id']['loss_barrier'],
            'post_ood_loss_barrier': post_summary['ood']['loss_barrier'],
            'post_id_acc_barrier': post_summary['id']['acc_barrier'],
            'post_ood_acc_barrier': post_summary['ood']['acc_barrier'],
        })

# Display comparison table
print("\n" + "=" * 90)
print("BARRIER COMPARISON SUMMARY")
print("=" * 90)
print(f"\n{'Pair':<20} {'Type':<18} {'Pre ID Loss':<12} {'Post ID Loss':<12} {'Pre OOD Loss':<12} {'Post OOD Loss':<12}")
print("-" * 90)

for pair_name, data in barrier_comparison.items():
    pre_id = data['pre_id_loss_barrier']
    post_id = data.get('post_id_loss_barrier', float('nan'))
    pre_ood = data['pre_ood_loss_barrier']
    post_ood = data.get('post_ood_loss_barrier', float('nan'))
    
    print(f"{pair_name:<20} {data['type']:<18} {pre_id:>10.4f}   {post_id:>10.4f}   {pre_ood:>10.4f}   {post_ood:>10.4f}")

In [None]:
# Create barrier comparison visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

pair_names = list(barrier_comparison.keys())
x = np.arange(len(pair_names))
width = 0.35

# ID Loss Barrier
pre_id_barriers = [barrier_comparison[p]['pre_id_loss_barrier'] for p in pair_names]
post_id_barriers = [barrier_comparison[p].get('post_id_loss_barrier', 0) for p in pair_names]

bars1 = axes[0].bar(x - width/2, pre_id_barriers, width, label='Pre-Rebasin', color='salmon')
bars2 = axes[0].bar(x + width/2, post_id_barriers, width, label='Post-Rebasin', color='steelblue')
axes[0].set_xticks(x)
axes[0].set_xticklabels(pair_names, rotation=45, ha='right')
axes[0].set_ylabel('Loss Barrier')
axes[0].set_title('ID Loss Barrier')
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis='y')

# OOD Loss Barrier
pre_ood_barriers = [barrier_comparison[p]['pre_ood_loss_barrier'] for p in pair_names]
post_ood_barriers = [barrier_comparison[p].get('post_ood_loss_barrier', 0) for p in pair_names]

bars1 = axes[1].bar(x - width/2, pre_ood_barriers, width, label='Pre-Rebasin', color='salmon')
bars2 = axes[1].bar(x + width/2, post_ood_barriers, width, label='Post-Rebasin', color='steelblue')
axes[1].set_xticks(x)
axes[1].set_xticklabels(pair_names, rotation=45, ha='right')
axes[1].set_ylabel('Loss Barrier')
axes[1].set_title('OOD Loss Barrier')
axes[1].legend()
axes[1].grid(True, alpha=0.3, axis='y')

plt.suptitle('Loss Barriers: Pre vs Post Re-Basin', fontsize=14, y=1.02)
plt.tight_layout()
save_figure(fig, 'barrier_comparison_all')
plt.show()

## 7. Compute Spurious Reliance Score Along Interpolation

In [None]:
from src.interp import create_interpolated_model

def compute_srs_along_path(model_a, model_b, alphas, device, id_loader, ood_loader, cf_dataset):
    """
    Compute Spurious Reliance Score at each interpolation point.
    
    This is expensive - we'll use fewer points.
    """
    srs_values = []
    
    # Use fewer points for SRS (expensive to compute)
    sample_alphas = alphas[::4]  # Every 4th point
    
    for alpha in sample_alphas:
        interp_model = create_interpolated_model(model_a, model_b, alpha, device)
        
        srs = compute_spurious_reliance_score(
            interp_model, id_loader, ood_loader, cf_dataset, device
        )
        srs_values.append(srs['spurious_reliance_score'])
        print(f"  alpha={alpha:.2f}: SRS={srs['spurious_reliance_score']:.4f}")
    
    return sample_alphas, srs_values

In [None]:
# Compute SRS for spurious-robust pair (most interesting case)
print("Computing Spurious Reliance Score along A1-R1 interpolation path...")
print("(This demonstrates the 'semantic barrier' - mechanism mismatch)\n")

pair_name = 'A1-R1'
model_a1 = models['A1']
model_r1_aligned = aligned_models.get(pair_name, models['R1'])

alphas = np.linspace(0, 1, num_alphas)
srs_alphas, srs_values = compute_srs_along_path(
    model_a1, model_r1_aligned, alphas, device,
    id_loader, ood_loader, cf_dataset
)

In [None]:
# Plot SRS along interpolation
fig = plot_srs_interpolation(
    np.array(srs_alphas), srs_values,
    title=f"Spurious Reliance Score: A1-R1 Interpolation (Post-Rebasin)",
    save_name='srs_interpolation_A1_R1'
)
plt.show()

# Compute semantic barrier
sem_barrier, sem_alpha = semantic_barrier_metric(srs_values, np.array(srs_alphas))
print(f"\nSemantic Barrier Metric:")
print(f"  Max SRS variation from endpoint average: {sem_barrier:.4f}")
print(f"  At alpha = {sem_alpha:.2f}")
print(f"\nEndpoint SRS values:")
print(f"  A1 (alpha=1): SRS = {srs_values[-1]:.4f} (should be HIGH - spurious)")
print(f"  R1 (alpha=0): SRS = {srs_values[0]:.4f} (should be LOW - robust)")

## 8. Create Combined Interpolation Plot

In [None]:
# Create comprehensive comparison figure
fig, axes = plt.subplots(2, 3, figsize=(16, 10))

colors = {'A1-A2': 'tab:red', 'R1-R2': 'tab:blue', 'A1-R1': 'tab:purple'}

# Row 1: Pre-rebasin
# Loss (ID)
for pair_name, results in all_results.items():
    pre = results['pre_rebasin']
    axes[0, 0].plot(pre['id']['alphas'], pre['id']['losses'], 
                    label=f"{pair_name}", color=colors[pair_name], linewidth=2)
axes[0, 0].set_xlabel(r'$\alpha$')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Pre-Rebasin: ID Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy (ID)
for pair_name, results in all_results.items():
    pre = results['pre_rebasin']
    axes[0, 1].plot(pre['id']['alphas'], pre['id']['accuracies']*100, 
                    label=f"{pair_name}", color=colors[pair_name], linewidth=2)
axes[0, 1].set_xlabel(r'$\alpha$')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].set_title('Pre-Rebasin: ID Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Accuracy (OOD)
for pair_name, results in all_results.items():
    pre = results['pre_rebasin']
    axes[0, 2].plot(pre['ood']['alphas'], pre['ood']['accuracies']*100, 
                    label=f"{pair_name}", color=colors[pair_name], linewidth=2)
axes[0, 2].set_xlabel(r'$\alpha$')
axes[0, 2].set_ylabel('Accuracy (%)')
axes[0, 2].set_title('Pre-Rebasin: OOD Accuracy')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# Row 2: Post-rebasin
# Loss (ID)
for pair_name, results in all_results.items():
    post = results['post_rebasin']
    if post is not None:
        axes[1, 0].plot(post['id']['alphas'], post['id']['losses'], 
                        label=f"{pair_name}", color=colors[pair_name], linewidth=2)
axes[1, 0].set_xlabel(r'$\alpha$')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].set_title('Post-Rebasin: ID Loss')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Accuracy (ID)
for pair_name, results in all_results.items():
    post = results['post_rebasin']
    if post is not None:
        axes[1, 1].plot(post['id']['alphas'], post['id']['accuracies']*100, 
                        label=f"{pair_name}", color=colors[pair_name], linewidth=2)
axes[1, 1].set_xlabel(r'$\alpha$')
axes[1, 1].set_ylabel('Accuracy (%)')
axes[1, 1].set_title('Post-Rebasin: ID Accuracy')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

# Accuracy (OOD)
for pair_name, results in all_results.items():
    post = results['post_rebasin']
    if post is not None:
        axes[1, 2].plot(post['ood']['alphas'], post['ood']['accuracies']*100, 
                        label=f"{pair_name}", color=colors[pair_name], linewidth=2)
axes[1, 2].set_xlabel(r'$\alpha$')
axes[1, 2].set_ylabel('Accuracy (%)')
axes[1, 2].set_title('Post-Rebasin: OOD Accuracy')
axes[1, 2].legend()
axes[1, 2].grid(True, alpha=0.3)

plt.suptitle('Weight Interpolation Analysis: Pre vs Post Re-Basin', fontsize=14, y=1.02)
plt.tight_layout()
save_figure(fig, 'interpolation_comprehensive')
plt.show()

## 9. Save All Results

In [None]:
# Compile final summary
final_summary = {
    'barrier_comparison': {k: {kk: float(vv) if isinstance(vv, (float, np.floating)) else vv 
                              for kk, vv in v.items()} 
                          for k, v in barrier_comparison.items()},
    'srs_interpolation': {
        'pair': 'A1-R1',
        'alphas': [float(a) for a in srs_alphas],
        'srs_values': [float(s) for s in srs_values],
        'semantic_barrier': float(sem_barrier),
        'semantic_barrier_alpha': float(sem_alpha),
    },
}

# Add detailed interpolation data
for pair_name, results in all_results.items():
    pre = results['pre_rebasin']
    post = results['post_rebasin']
    
    final_summary[f'{pair_name}_pre'] = {
        'id_losses': [float(x) for x in pre['id']['losses']],
        'id_accuracies': [float(x) for x in pre['id']['accuracies']],
        'ood_losses': [float(x) for x in pre['ood']['losses']],
        'ood_accuracies': [float(x) for x in pre['ood']['accuracies']],
        'alphas': [float(x) for x in pre['id']['alphas']],
    }
    
    if post is not None:
        final_summary[f'{pair_name}_post'] = {
            'id_losses': [float(x) for x in post['id']['losses']],
            'id_accuracies': [float(x) for x in post['id']['accuracies']],
            'ood_losses': [float(x) for x in post['ood']['losses']],
            'ood_accuracies': [float(x) for x in post['ood']['accuracies']],
            'alphas': [float(x) for x in post['id']['alphas']],
        }

# Save to results directory
summary_path = RESULTS_DIR / 'summary.json'
with open(summary_path, 'w') as f:
    json.dump(final_summary, f, indent=2)

print(f"\nResults saved to: {summary_path}")

## 10. Summary

In [None]:
print("\n" + "=" * 70)
print("INTERPOLATION AND BARRIER ANALYSIS COMPLETE")
print("=" * 70)

# Compute key statistics
same_mech_pairs = ['A1-A2', 'R1-R2']
diff_mech_pairs = ['A1-R1']

same_mech_post_barrier = np.mean([barrier_comparison[p].get('post_id_loss_barrier', 0) for p in same_mech_pairs])
diff_mech_post_barrier = np.mean([barrier_comparison[p].get('post_id_loss_barrier', 0) for p in diff_mech_pairs])

print(f"""
Key Findings:

1. Loss Barrier Summary (Post-Rebasin):
   - Same-mechanism pairs (A1-A2, R1-R2): Avg barrier = {same_mech_post_barrier:.4f}
   - Different-mechanism pair (A1-R1):   Barrier = {diff_mech_post_barrier:.4f}

2. Individual Pair Results:""")

for pair_name, data in barrier_comparison.items():
    post_barrier = data.get('post_id_loss_barrier', float('nan'))
    pre_barrier = data['pre_id_loss_barrier']
    reduction = pre_barrier - post_barrier if not np.isnan(post_barrier) else 0
    print(f"   {pair_name} ({data['type']}):")
    print(f"     Pre-rebasin barrier:  {pre_barrier:.4f}")
    print(f"     Post-rebasin barrier: {post_barrier:.4f}")
    print(f"     Reduction: {reduction:.4f}")

print(f"""
3. Semantic Barrier (SRS variation along A1-R1 path):
   - Max variation: {sem_barrier:.4f}
   - SRS at A1 endpoint (alpha=1): {srs_values[-1]:.4f}
   - SRS at R1 endpoint (alpha=0): {srs_values[0]:.4f}

4. Interpretation:
   - Git Re-Basin reduces barriers for ALL pairs
   - However, different-mechanism pairs retain higher barriers
   - The "semantic barrier" (SRS variation) shows mechanism mismatch

Files saved:
   - {RESULTS_DIR / 'summary.json'}
   - Multiple interpolation figures in {FIGURES_DIR}/

Next: Run 06_summary_report.ipynb for final analysis and conclusions.
""")

# 06 - Summary Report

This notebook compiles all results and generates the final summary for the project:

**"When Geometry Fails: Stress-Testing Git Re-Basin on Spurious vs Robust Features"**

## Core Hypothesis (Recap):
Permutation alignment (Git Re-Basin) can successfully connect models relying on the same feature type, but fails to connect models with mismatched mechanisms (spurious vs robust), producing measurable loss and semantic barriers.

In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from src.config import (
    get_config, RESULTS_DIR, FIGURES_DIR, METRICS_DIR
)
from src.plotting import save_figure

config = get_config()
print(f"Loading results from: {RESULTS_DIR}")

## 1. Load All Results

In [None]:
# Load all saved metrics
results = {}

# Training summary
training_path = METRICS_DIR / 'training_summary.json'
if training_path.exists():
    with open(training_path, 'r') as f:
        results['training'] = json.load(f)
    print(f"Loaded training summary")

# Mechanism verification
mechanism_path = METRICS_DIR / 'mechanism_verification.json'
if mechanism_path.exists():
    with open(mechanism_path, 'r') as f:
        results['mechanism'] = json.load(f)
    print(f"Loaded mechanism verification results")

# Rebasin results
rebasin_path = METRICS_DIR / 'rebasin_results.json'
if rebasin_path.exists():
    with open(rebasin_path, 'r') as f:
        results['rebasin'] = json.load(f)
    print(f"Loaded rebasin results")

# Interpolation summary
summary_path = RESULTS_DIR / 'summary.json'
if summary_path.exists():
    with open(summary_path, 'r') as f:
        results['interpolation'] = json.load(f)
    print(f"Loaded interpolation summary")

print(f"\nLoaded {len(results)} result files")

## 2. Model Performance Summary

In [None]:
if 'mechanism' in results:
    srs_results = results['mechanism']['srs_results']
    
    # Create performance table
    performance_data = []
    for model_name in ['A1', 'A2', 'R1', 'R2']:
        if model_name in srs_results:
            m = srs_results[model_name]
            model_type = 'Spurious' if model_name.startswith('A') else 'Robust'
            performance_data.append({
                'Model': model_name,
                'Type': model_type,
                'ID Acc (%)': f"{m['id_accuracy']*100:.1f}",
                'OOD Acc (%)': f"{m['ood_accuracy']*100:.1f}",
                'OOD Drop (%)': f"{m['ood_drop']*100:.1f}",
                'SRS': f"{m['spurious_reliance_score']:.4f}",
            })
    
    df_performance = pd.DataFrame(performance_data)
    print("\n" + "="*70)
    print("MODEL PERFORMANCE SUMMARY")
    print("="*70)
    print(df_performance.to_string(index=False))
else:
    print("[WARNING] Mechanism verification results not found")

## 3. Git Re-Basin Effectiveness

In [None]:
if 'rebasin' in results:
    comparison = results['rebasin']['comparison']
    
    rebasin_data = []
    for pair_name, data in comparison.items():
        rebasin_data.append({
            'Pair': pair_name,
            'Type': data['type'],
            'Pre Cosine Sim': f"{data['pre_cosine_sim']:.4f}",
            'Post Cosine Sim': f"{data['post_cosine_sim']:.4f}",
            'Change': f"{data['cosine_sim_change']:+.4f}",
            'Pre Agreement (%)': f"{data['pre_agreement']*100:.1f}",
            'Post Agreement (%)': f"{data['post_agreement']*100:.1f}",
        })
    
    df_rebasin = pd.DataFrame(rebasin_data)
    print("\n" + "="*90)
    print("GIT RE-BASIN EFFECTIVENESS")
    print("="*90)
    print(df_rebasin.to_string(index=False))
else:
    print("[WARNING] Rebasin results not found")

## 4. Barrier Analysis

In [None]:
if 'interpolation' in results and 'barrier_comparison' in results['interpolation']:
    barriers = results['interpolation']['barrier_comparison']
    
    barrier_data = []
    for pair_name, data in barriers.items():
        barrier_data.append({
            'Pair': pair_name,
            'Type': data['type'],
            'Pre ID Barrier': f"{data['pre_id_loss_barrier']:.4f}",
            'Post ID Barrier': f"{data.get('post_id_loss_barrier', float('nan')):.4f}",
            'Pre OOD Barrier': f"{data['pre_ood_loss_barrier']:.4f}",
            'Post OOD Barrier': f"{data.get('post_ood_loss_barrier', float('nan')):.4f}",
        })
    
    df_barriers = pd.DataFrame(barrier_data)
    print("\n" + "="*90)
    print("LOSS BARRIER ANALYSIS")
    print("="*90)
    print(df_barriers.to_string(index=False))
    
    # Compute statistics
    same_mech = [barriers[p]['post_id_loss_barrier'] for p in ['A1-A2', 'R1-R2'] 
                 if 'post_id_loss_barrier' in barriers.get(p, {})]
    diff_mech = [barriers[p]['post_id_loss_barrier'] for p in ['A1-R1'] 
                 if 'post_id_loss_barrier' in barriers.get(p, {})]
    
    if same_mech and diff_mech:
        print(f"\nKey Statistics:")
        print(f"  Same-mechanism pairs avg barrier: {np.mean(same_mech):.4f}")
        print(f"  Diff-mechanism pair barrier:      {np.mean(diff_mech):.4f}")
        print(f"  Ratio (diff/same):                {np.mean(diff_mech)/np.mean(same_mech):.2f}x")
else:
    print("[WARNING] Interpolation results not found")

## 5. Semantic Barrier (SRS Variation)

In [None]:
if 'interpolation' in results and 'srs_interpolation' in results['interpolation']:
    srs_interp = results['interpolation']['srs_interpolation']
    
    print("\n" + "="*70)
    print("SEMANTIC BARRIER ANALYSIS (A1-R1 Interpolation)")
    print("="*70)
    print(f"\nSRS at endpoints:")
    print(f"  A1 (spurious, alpha=1): {srs_interp['srs_values'][-1]:.4f}")
    print(f"  R1 (robust, alpha=0):   {srs_interp['srs_values'][0]:.4f}")
    print(f"\nSemantic barrier metric:")
    print(f"  Max SRS variation: {srs_interp['semantic_barrier']:.4f}")
    print(f"  At alpha:          {srs_interp['semantic_barrier_alpha']:.2f}")
    
    # Plot SRS along path
    fig, ax = plt.subplots(figsize=(10, 6))
    alphas = srs_interp['alphas']
    srs_vals = srs_interp['srs_values']
    
    ax.plot(alphas, srs_vals, 'purple', linewidth=2, marker='o', markersize=8)
    ax.axhline(y=srs_vals[0], color='blue', linestyle='--', alpha=0.5, label='R1 (robust)')
    ax.axhline(y=srs_vals[-1], color='red', linestyle='--', alpha=0.5, label='A1 (spurious)')
    ax.set_xlabel(r'$\alpha$ (0=R1, 1=A1)', fontsize=12)
    ax.set_ylabel('Spurious Reliance Score', fontsize=12)
    ax.set_title('Semantic Barrier: SRS Along A1-R1 Interpolation', fontsize=14)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Annotate endpoints
    ax.annotate('Robust\n(low SRS)', xy=(0, srs_vals[0]), xytext=(0.15, srs_vals[0]-0.05),
                fontsize=10, ha='center')
    ax.annotate('Spurious\n(high SRS)', xy=(1, srs_vals[-1]), xytext=(0.85, srs_vals[-1]+0.05),
                fontsize=10, ha='center')
    
    plt.tight_layout()
    save_figure(fig, 'semantic_barrier_summary')
    plt.show()
else:
    print("[WARNING] SRS interpolation results not found")

## 6. Create Final Summary Figure

In [None]:
# Create comprehensive summary figure
fig = plt.figure(figsize=(16, 12))

# Panel 1: Model performance comparison
ax1 = fig.add_subplot(2, 2, 1)
if 'mechanism' in results:
    srs = results['mechanism']['srs_results']
    models = ['A1', 'A2', 'R1', 'R2']
    x = np.arange(len(models))
    width = 0.35
    
    id_accs = [srs[m]['id_accuracy']*100 for m in models]
    ood_accs = [srs[m]['ood_accuracy']*100 for m in models]
    
    ax1.bar(x - width/2, id_accs, width, label='ID Acc', color='steelblue')
    ax1.bar(x + width/2, ood_accs, width, label='OOD Acc', color='coral')
    ax1.set_xticks(x)
    ax1.set_xticklabels(models)
    ax1.set_ylabel('Accuracy (%)')
    ax1.set_title('(A) Model Performance: ID vs OOD')
    ax1.legend()
    ax1.grid(True, alpha=0.3, axis='y')

# Panel 2: SRS comparison
ax2 = fig.add_subplot(2, 2, 2)
if 'mechanism' in results:
    srs_vals = [srs[m]['spurious_reliance_score'] for m in models]
    colors = ['#e74c3c', '#e67e22', '#3498db', '#2ecc71']
    bars = ax2.bar(x, srs_vals, color=colors)
    ax2.set_xticks(x)
    ax2.set_xticklabels(models)
    ax2.set_ylabel('Spurious Reliance Score')
    ax2.set_title('(B) Spurious Reliance Score by Model')
    ax2.grid(True, alpha=0.3, axis='y')
    
    # Add horizontal line separating spurious vs robust
    avg_srs = np.mean(srs_vals)
    ax2.axhline(y=avg_srs, color='gray', linestyle='--', alpha=0.5)

# Panel 3: Barrier comparison
ax3 = fig.add_subplot(2, 2, 3)
if 'interpolation' in results and 'barrier_comparison' in results['interpolation']:
    barriers = results['interpolation']['barrier_comparison']
    pairs = list(barriers.keys())
    x = np.arange(len(pairs))
    width = 0.35
    
    pre = [barriers[p]['pre_id_loss_barrier'] for p in pairs]
    post = [barriers[p].get('post_id_loss_barrier', 0) for p in pairs]
    
    ax3.bar(x - width/2, pre, width, label='Pre-Rebasin', color='salmon')
    ax3.bar(x + width/2, post, width, label='Post-Rebasin', color='steelblue')
    ax3.set_xticks(x)
    ax3.set_xticklabels(pairs)
    ax3.set_ylabel('Loss Barrier')
    ax3.set_title('(C) Loss Barriers: Pre vs Post Re-Basin')
    ax3.legend()
    ax3.grid(True, alpha=0.3, axis='y')

# Panel 4: Key finding - barrier ratio
ax4 = fig.add_subplot(2, 2, 4)
if 'interpolation' in results and 'barrier_comparison' in results['interpolation']:
    barriers = results['interpolation']['barrier_comparison']
    
    # Get post-rebasin barriers
    same_mech_barriers = [barriers[p].get('post_id_loss_barrier', barriers[p]['pre_id_loss_barrier']) 
                         for p in ['A1-A2', 'R1-R2']]
    diff_mech_barrier = barriers['A1-R1'].get('post_id_loss_barrier', barriers['A1-R1']['pre_id_loss_barrier'])
    
    categories = ['Same\nMechanism', 'Different\nMechanism']
    values = [np.mean(same_mech_barriers), diff_mech_barrier]
    colors = ['#2ecc71', '#e74c3c']
    
    bars = ax4.bar(categories, values, color=colors, edgecolor='black', linewidth=2)
    ax4.set_ylabel('Post-Rebasin Loss Barrier')
    ax4.set_title('(D) Key Finding: Mechanism Mismatch = Higher Barrier')
    ax4.grid(True, alpha=0.3, axis='y')
    
    # Add value labels
    for bar, val in zip(bars, values):
        ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{val:.4f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
    
    # Add ratio annotation
    if values[0] > 0:
        ratio = values[1] / values[0]
        ax4.text(0.5, max(values) * 0.5, f'Ratio: {ratio:.1f}x', ha='center', fontsize=14,
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.suptitle('When Geometry Fails: Git Re-Basin on Spurious vs Robust Features', 
             fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
save_figure(fig, 'final_summary')
plt.show()

## 7. Key Findings for Blog Post

In [None]:
print("\n" + "="*70)
print("KEY FINDINGS FOR CLASS BLOG")
print("="*70)

findings = []

# Finding 1: Spurious models rely on patches
if 'mechanism' in results:
    srs = results['mechanism']['srs_results']
    spurious_srs = np.mean([srs['A1']['spurious_reliance_score'], srs['A2']['spurious_reliance_score']])
    robust_srs = np.mean([srs['R1']['spurious_reliance_score'], srs['R2']['spurious_reliance_score']])
    
    findings.append(
        f"1. **Spurious Feature Reliance**: Models trained on spurious-aligned data (A1, A2) "
        f"show {spurious_srs/robust_srs:.1f}x higher Spurious Reliance Score than robust models (R1, R2), "
        f"confirming they learn to rely on the colored patch shortcut."
    )

# Finding 2: OOD accuracy gap
if 'mechanism' in results:
    spurious_ood_drop = np.mean([srs['A1']['ood_drop'], srs['A2']['ood_drop']]) * 100
    robust_ood_drop = np.mean([srs['R1']['ood_drop'], srs['R2']['ood_drop']]) * 100
    
    findings.append(
        f"2. **OOD Generalization Gap**: Spurious models suffer {spurious_ood_drop:.1f}% accuracy drop "
        f"when patches are removed, while robust models only drop {robust_ood_drop:.1f}%."
    )

# Finding 3: Rebasin reduces barriers
if 'rebasin' in results:
    comp = results['rebasin']['comparison']
    avg_sim_increase = np.mean([comp[p]['cosine_sim_change'] for p in comp])
    
    findings.append(
        f"3. **Git Re-Basin Works**: Weight matching increases cosine similarity by "
        f"{avg_sim_increase:+.4f} on average, enabling more meaningful weight interpolation."
    )

# Finding 4: Different mechanisms = higher barriers
if 'interpolation' in results and 'barrier_comparison' in results['interpolation']:
    barriers = results['interpolation']['barrier_comparison']
    same_mech = np.mean([barriers[p].get('post_id_loss_barrier', barriers[p]['pre_id_loss_barrier']) 
                        for p in ['A1-A2', 'R1-R2']])
    diff_mech = barriers['A1-R1'].get('post_id_loss_barrier', barriers['A1-R1']['pre_id_loss_barrier'])
    
    findings.append(
        f"4. **Geometry Fails for Mechanism Mismatch**: Even after Re-Basin, spurious-robust pairs "
        f"have {diff_mech/same_mech:.1f}x higher loss barriers than same-mechanism pairs, "
        f"indicating that geometric alignment cannot bridge semantic differences."
    )

# Finding 5: Semantic barrier
if 'interpolation' in results and 'srs_interpolation' in results['interpolation']:
    srs_interp = results['interpolation']['srs_interpolation']
    
    findings.append(
        f"5. **Semantic Barrier Evidence**: Along the A1-R1 interpolation path, SRS varies from "
        f"{srs_interp['srs_values'][0]:.3f} (robust) to {srs_interp['srs_values'][-1]:.3f} (spurious), "
        f"demonstrating that intermediate models inherit inconsistent feature dependencies."
    )

# Print findings
for finding in findings:
    print(f"\n{finding}")

In [None]:
# Additional findings
print("\n" + "-"*70)
print("ADDITIONAL INSIGHTS")
print("-"*70)

additional = [
    "6. **Same-Mechanism Connectivity**: Models sharing the same feature dependency "
    "(both spurious or both robust) can be smoothly interpolated after Re-Basin, "
    "with minimal loss barriers along the path.",
    
    "7. **Practical Implication**: Before merging or ensembling models, practitioners should "
    "verify that models rely on similar features. Geometric tools like Re-Basin cannot fix "
    "fundamental differences in what models have learned.",
    
    "8. **Future Directions**: This work suggests that loss barrier analysis post-Re-Basin "
    "could serve as a diagnostic tool for detecting when models have learned qualitatively "
    "different solutions to the same task.",
]

for insight in additional:
    print(f"\n{insight}")

## 8. Export Final Summary

In [None]:
# Compile final summary
final_report = {
    'project_title': 'When Geometry Fails: Stress-Testing Git Re-Basin on Spurious vs Robust Features',
    'hypothesis': 'Permutation alignment (Git Re-Basin) can successfully connect models relying on the same feature type, but fails to connect models with mismatched mechanisms.',
    'key_findings': findings,
    'additional_insights': additional,
}

# Add numerical results if available
if 'mechanism' in results:
    final_report['model_performance'] = results['mechanism']['srs_results']
    final_report['group_statistics'] = results['mechanism']['group_statistics']

if 'interpolation' in results and 'barrier_comparison' in results['interpolation']:
    final_report['barrier_analysis'] = results['interpolation']['barrier_comparison']

# Save final report
report_path = RESULTS_DIR / 'final_report.json'
with open(report_path, 'w') as f:
    json.dump(final_report, f, indent=2, default=str)

print(f"\nFinal report saved to: {report_path}")

## 9. List All Generated Outputs

In [None]:
print("\n" + "="*70)
print("ALL GENERATED OUTPUTS")
print("="*70)

print("\nCheckpoints (results/checkpoints/):")
from src.config import CHECKPOINTS_DIR
for f in sorted(CHECKPOINTS_DIR.glob('*.pt')):
    print(f"  - {f.name}")

print("\nFigures (results/figures/):")
for f in sorted(FIGURES_DIR.glob('*.png')):
    print(f"  - {f.name}")

print("\nMetrics (results/metrics/):")
for f in sorted(METRICS_DIR.glob('*.json')):
    print(f"  - {f.name}")

print("\nSummary files (results/):")
for f in sorted(RESULTS_DIR.glob('*.json')):
    if f.parent == RESULTS_DIR:  # Only top-level
        print(f"  - {f.name}")

## 10. Conclusion

In [None]:
print("\n" + "="*70)
print("EXPERIMENT COMPLETE")
print("="*70)
print("""
This experiment demonstrated that:

1. Git Re-Basin (weight matching) successfully aligns models in weight space,
   increasing cosine similarity and reducing pre-rebasin loss barriers.

2. However, when models rely on fundamentally different features (spurious vs
   robust), significant barriers remain even after alignment.

3. The "semantic barrier" - measured by Spurious Reliance Score variation
   along the interpolation path - reveals mechanism mismatch that pure
   geometric methods cannot resolve.

4. This has practical implications for model merging, ensembling, and
   understanding the structure of loss landscapes.

Key Takeaway:
-------------
Geometry (permutation alignment) is necessary but not sufficient for
meaningful model interpolation. Models must also share similar learned
representations and feature dependencies.

"When geometry fails, it's because the models have learned to see
the world in fundamentally different ways."
""")

# 07 - Mechanism Distance Predicts Barrier Height

This notebook tests the hypothesis that **linear interpolation barrier height is predictable from "mechanism distance"** between two endpoint models.

## Research Question
Can we predict how well Git Re-Basin will work (low barrier) based on how similar the models' learned mechanisms are?

## Mechanism Distance Metrics
1. **Cue-Reliance Distance (dist_srs)**: Absolute difference in Spurious Reliance Score
   - `dist_srs = |SRS(A) - SRS(B)|`
   - Models with similar spurious reliance should have similar mechanisms

2. **Representation Distance (dist_cka)**: CKA-based feature similarity
   - `dist_cka = 1 - mean(CKA)` across layers
   - Models with similar internal representations should have similar mechanisms

## Analysis Plan
1. Load all model pairs (S-S, R-R, S-R)
2. Compute mechanism distance metrics for each pair
3. Retrieve barrier heights (pre and post rebasin)
4. Correlate mechanism distance with barrier height
5. Fit regression: barrier ~ dist_srs + dist_cka
6. Generate publication-quality figures

In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import torch
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

from src.config import (
    get_config, set_seed, get_device,
    CHECKPOINTS_DIR, FIGURES_DIR, METRICS_DIR, RESULTS_DIR
)

# Set style for publication-quality figures
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.size'] = 11
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['axes.labelsize'] = 11

config = get_config()
set_seed(config['seeds']['global'])
device = get_device()

print(f"Device: {device}")
print(f"Project root: {PROJECT_ROOT}")

In [None]:
# Import project modules
from src.data import (
    create_env_a_dataset,
    create_no_patch_dataset,
    CounterfactualPatchDataset,
)
from src.models import create_model
from src.train import load_model
from src.interp import evaluate_interpolation_multi_dataset
from src.metrics import (
    compute_spurious_reliance_score,
    compute_srs_distance,
    get_srs_scalar,
    compute_all_barriers,
    bootstrap_correlation,
    fit_linear_regression,
)
from src.cka import (
    compute_cka_distance,
    compute_layerwise_cka,
    create_cka_dataloader,
    compute_singular_vector_alignment,
)
from src.pairs import (
    get_standard_pairs,
    load_model_pair,
    load_all_standard_pairs,
    get_pair_short_name,
    check_checkpoints_exist,
    print_checkpoint_status,
    PAIR_TYPE_SS, PAIR_TYPE_RR, PAIR_TYPE_SR,
)
from src.plotting import save_figure

from torch.utils.data import DataLoader

## 1. Check Prerequisites and Load Models

In [None]:
# Check what checkpoints are available
print("Checking checkpoint availability...\n")
print_checkpoint_status()

# Verify required checkpoints exist
status = check_checkpoints_exist()
required = ['A1', 'A2', 'R1', 'R2']
missing = [m for m in required if not status.get(m, False)]
if missing:
    raise FileNotFoundError(
        f"Missing required checkpoints: {missing}\n"
        f"Please run notebooks 02-04 first."
    )
print("\nAll required checkpoints found!")

In [None]:
# Load all model pairs
print("Loading model pairs...\n")
model_pairs = load_all_standard_pairs(device, config, load_aligned=True)

for name, pair in model_pairs.items():
    aligned_status = "Yes" if pair.model_b_aligned is not None else "No"
    print(f"  {name}: type={pair.pair_type}, aligned={aligned_status}")

print(f"\nLoaded {len(model_pairs)} model pairs.")

## 2. Create DataLoaders

In [None]:
# Create test datasets
test_id = create_env_a_dataset(train=False, config=config)
test_ood = create_no_patch_dataset(train=False, config=config)

batch_size = config['interpolation']['eval_batch_size']
num_workers = config['training']['num_workers']

id_loader = DataLoader(test_id, batch_size=batch_size, shuffle=False, num_workers=num_workers)
ood_loader = DataLoader(test_ood, batch_size=batch_size, shuffle=False, num_workers=num_workers)

dataloaders = {
    'id': id_loader,
    'ood': ood_loader,
}

print(f"Test datasets: ID={len(test_id)}, OOD={len(test_ood)} samples")

In [None]:
# Create counterfactual dataset for SRS computation
cf_dataset = CounterfactualPatchDataset(
    base_dataset=test_id,
    swap_mode='random_wrong',
)

# Create fixed CKA dataloader (using subset for efficiency)
CKA_N_SAMPLES = 2000  # Configurable number of samples for CKA
cka_loader = create_cka_dataloader(
    test_id, 
    n_samples=CKA_N_SAMPLES, 
    batch_size=batch_size,
    seed=config['seeds']['global'],
)

print(f"Counterfactual dataset: {len(cf_dataset)} samples")
print(f"CKA dataloader: {CKA_N_SAMPLES} samples (fixed subset)")

## 3. Compute Spurious Reliance Score (SRS) for All Models

In [None]:
# First, compute SRS for each individual model
# We need this to compute SRS distance for pairs

model_names = ['A1', 'A2', 'R1', 'R2']
individual_srs = {}

print("Computing SRS for individual models...\n")

for pair in model_pairs.values():
    for model_name, model in [(pair.model_a_name, pair.model_a), 
                               (pair.model_b_name, pair.model_b)]:
        if model_name not in individual_srs:
            print(f"  Computing SRS for {model_name}...")
            srs = compute_spurious_reliance_score(
                model, id_loader, ood_loader, cf_dataset, device
            )
            individual_srs[model_name] = srs
            print(f"    SRS = {srs['spurious_reliance_score']:.4f}")

print("\n" + "="*50)
print("SRS Summary:")
print("="*50)
for name in model_names:
    srs = individual_srs[name]
    print(f"{name}: SRS={srs['spurious_reliance_score']:.4f}, "
          f"ID={srs['id_accuracy']*100:.1f}%, OOD={srs['ood_accuracy']*100:.1f}%")

## 4. Compute Mechanism Distance Metrics

In [None]:
# Configuration for CKA computation
CKA_LAYERS = ['block2', 'block3', 'fc1']  # Layers to compare
CKA_DEBIASED = False  # Use standard estimator

# Compute mechanism distances for each pair
mechanism_distances = {}

print("Computing mechanism distances...\n")

for pair_name, pair in model_pairs.items():
    print(f"\n{'='*60}")
    print(f"Pair: {pair_name} ({pair.pair_type})")
    print(f"{'='*60}")
    
    # (A) Cue-reliance distance (SRS)
    srs_a = individual_srs[pair.model_a_name]
    srs_b = individual_srs[pair.model_b_name]
    dist_srs = compute_srs_distance(srs_a, srs_b)
    print(f"\n  Cue-Reliance Distance (dist_srs):")
    print(f"    SRS({pair.model_a_name}) = {get_srs_scalar(srs_a):.4f}")
    print(f"    SRS({pair.model_b_name}) = {get_srs_scalar(srs_b):.4f}")
    print(f"    dist_srs = {dist_srs:.4f}")
    
    # (B) Representation distance (CKA)
    print(f"\n  Representation Distance (CKA):")
    dist_cka, cka_per_layer = compute_cka_distance(
        pair.model_a, pair.model_b,
        cka_loader, device,
        layer_names=CKA_LAYERS,
        n_samples=CKA_N_SAMPLES,
        debiased=CKA_DEBIASED,
    )
    print(f"    Per-layer CKA: {cka_per_layer}")
    print(f"    Mean CKA = {1 - dist_cka:.4f}")
    print(f"    dist_cka = {dist_cka:.4f}")
    
    # (C) Optional: Singular vector alignment
    print(f"\n  Singular Vector Alignment:")
    dist_sv, sv_per_layer = compute_singular_vector_alignment(
        pair.model_a, pair.model_b,
        layer_names=['block0', 'block1', 'block2', 'block3'],
        top_k=5,
    )
    print(f"    Per-layer alignment: {sv_per_layer}")
    print(f"    dist_sv = {dist_sv:.4f}")
    
    # Store results
    mechanism_distances[pair_name] = {
        'pair_type': pair.pair_type,
        'dist_srs': dist_srs,
        'dist_cka': dist_cka,
        'dist_sv': dist_sv,
        'cka_per_layer': cka_per_layer,
        'sv_per_layer': sv_per_layer,
        'srs_a': get_srs_scalar(srs_a),
        'srs_b': get_srs_scalar(srs_b),
    }

## 5. Compute Barrier Heights (Reuse or Recompute)

In [None]:
# Try to load existing results from summary.json
summary_path = RESULTS_DIR / 'summary.json'

existing_barriers = None
if summary_path.exists():
    print(f"Loading existing barrier results from {summary_path}...")
    with open(summary_path, 'r') as f:
        existing_data = json.load(f)
    if 'barrier_comparison' in existing_data:
        existing_barriers = existing_data['barrier_comparison']
        print("  Found existing barrier data!")
else:
    print("No existing summary.json found. Will compute barriers.")

In [None]:
# Compute barriers (or use existing)
num_alphas = config['interpolation']['num_alphas']
barrier_results = {}

print("\nComputing/Loading barrier heights...\n")

for pair_name, pair in model_pairs.items():
    print(f"\nPair: {pair_name}")
    
    # Check if we have existing barriers
    if existing_barriers and pair_name in existing_barriers:
        print("  Using cached barrier values.")
        eb = existing_barriers[pair_name]
        barrier_results[pair_name] = {
            'barrier_id_raw': eb.get('pre_id_loss_barrier', np.nan),
            'barrier_ood_raw': eb.get('pre_ood_loss_barrier', np.nan),
            'barrier_id_rebasin': eb.get('post_id_loss_barrier', np.nan),
            'barrier_ood_rebasin': eb.get('post_ood_loss_barrier', np.nan),
            'barrier_id_acc_raw': eb.get('pre_id_acc_barrier', np.nan),
            'barrier_ood_acc_raw': eb.get('pre_ood_acc_barrier', np.nan),
            'barrier_id_acc_rebasin': eb.get('post_id_acc_barrier', np.nan),
            'barrier_ood_acc_rebasin': eb.get('post_ood_acc_barrier', np.nan),
        }
    else:
        # Compute barriers
        print("  Computing pre-rebasin interpolation...")
        pre_results = evaluate_interpolation_multi_dataset(
            pair.model_a, pair.model_b, dataloaders, device, num_alphas
        )
        
        post_results = None
        if pair.model_b_aligned is not None:
            print("  Computing post-rebasin interpolation...")
            post_results = evaluate_interpolation_multi_dataset(
                pair.model_a, pair.model_b_aligned, dataloaders, device, num_alphas
            )
        
        # Extract barriers
        barrier_results[pair_name] = compute_all_barriers(pre_results, post_results)
    
    # Print summary
    br = barrier_results[pair_name]
    print(f"  ID barrier:  raw={br['barrier_id_raw']:.4f}, rebasin={br['barrier_id_rebasin']:.4f}")
    print(f"  OOD barrier: raw={br['barrier_ood_raw']:.4f}, rebasin={br['barrier_ood_rebasin']:.4f}")

## 6. Build Analysis DataFrame

In [None]:
# Combine mechanism distances and barriers into a single dataframe
pairs_data = []

for pair_name, pair in model_pairs.items():
    md = mechanism_distances[pair_name]
    br = barrier_results[pair_name]
    
    row = {
        'pair_id': pair_name,
        'pair_type': md['pair_type'],
        'pair_type_short': get_pair_short_name(md['pair_type']),
        'model_a': pair.model_a_name,
        'model_b': pair.model_b_name,
        
        # Mechanism distances
        'dist_srs': md['dist_srs'],
        'dist_cka': md['dist_cka'],
        'dist_sv': md['dist_sv'],
        
        # Individual SRS values
        'srs_a': md['srs_a'],
        'srs_b': md['srs_b'],
        
        # Per-layer CKA
        **{f'cka_{layer}': md['cka_per_layer'].get(layer, np.nan) 
           for layer in CKA_LAYERS},
        
        # Barriers
        **br,
    }
    pairs_data.append(row)

df = pd.DataFrame(pairs_data)
print("\nPairs DataFrame:")
print(df.to_string())

In [None]:
# Save the pairs dataframe
csv_path = RESULTS_DIR / 'mechdist_pairs.csv'
df.to_csv(csv_path, index=False)
print(f"\nSaved pairs data to: {csv_path}")

## 7. Statistical Analysis: Correlations

In [None]:
# Define barrier and distance columns for analysis
barrier_cols = ['barrier_id_raw', 'barrier_ood_raw', 'barrier_id_rebasin', 'barrier_ood_rebasin']
distance_cols = ['dist_srs', 'dist_cka']

# Compute correlations with bootstrapped CIs
N_BOOTSTRAP = 2000
correlation_results = {}

print("\n" + "="*70)
print("CORRELATION ANALYSIS")
print("="*70)

for barrier_col in barrier_cols:
    print(f"\n{barrier_col}:")
    print("-" * 50)
    
    y = df[barrier_col].values
    
    # Skip if all NaN
    if np.all(np.isnan(y)):
        print("  [SKIP] All values are NaN")
        continue
    
    for dist_col in distance_cols:
        x = df[dist_col].values
        
        # Filter out NaN
        mask = ~(np.isnan(x) | np.isnan(y))
        x_clean, y_clean = x[mask], y[mask]
        
        if len(x_clean) < 3:
            print(f"  {dist_col}: [SKIP] Insufficient data points")
            continue
        
        # Pearson correlation
        pearson = bootstrap_correlation(
            x_clean, y_clean, 
            n_bootstrap=N_BOOTSTRAP, 
            method='pearson'
        )
        
        # Spearman correlation
        spearman = bootstrap_correlation(
            x_clean, y_clean, 
            n_bootstrap=N_BOOTSTRAP, 
            method='spearman'
        )
        
        key = f"{barrier_col}_vs_{dist_col}"
        correlation_results[key] = {
            'pearson': pearson,
            'spearman': spearman,
            'n': len(x_clean),
        }
        
        print(f"  {dist_col}:")
        print(f"    Pearson r = {pearson['correlation']:.3f} "
              f"[{pearson['ci_lower']:.3f}, {pearson['ci_upper']:.3f}] "
              f"(p={pearson['p_value']:.4f})")
        print(f"    Spearman rho = {spearman['correlation']:.3f} "
              f"[{spearman['ci_lower']:.3f}, {spearman['ci_upper']:.3f}] "
              f"(p={spearman['p_value']:.4f})")

## 8. Regression Analysis

In [None]:
# Fit regression: barrier ~ dist_srs + dist_cka
regression_results = {}

print("\n" + "="*70)
print("REGRESSION ANALYSIS: barrier ~ dist_srs + dist_cka")
print("="*70)

for barrier_col in barrier_cols:
    y = df[barrier_col].values
    
    # Skip if all NaN
    if np.all(np.isnan(y)):
        continue
    
    X = df[['dist_srs', 'dist_cka']].values
    
    # Filter out NaN
    mask = ~np.any(np.isnan(np.column_stack([X, y.reshape(-1, 1)])), axis=1)
    X_clean, y_clean = X[mask], y[mask]
    
    if len(y_clean) < 3:
        print(f"\n{barrier_col}: [SKIP] Insufficient data")
        continue
    
    # Fit regression
    reg = fit_linear_regression(
        X_clean, y_clean, 
        feature_names=['dist_srs', 'dist_cka']
    )
    regression_results[barrier_col] = reg
    
    print(f"\n{barrier_col}:")
    print(f"  R^2 = {reg['r_squared']:.4f}")
    print(f"  Intercept = {reg['intercept']:.4f}")
    for feat, coef in reg['coefficients'].items():
        print(f"  {feat}: {coef:.4f}")

## 9. Visualization: Barrier vs Mechanism Distance

In [None]:
# Color palette for pair types
pair_colors = {
    'S-S': '#e74c3c',   # Red for spurious-spurious
    'R-R': '#3498db',   # Blue for robust-robust  
    'S-R': '#9b59b6',   # Purple for spurious-robust
}

# Marker styles
pair_markers = {
    'S-S': 'o',
    'R-R': 's',
    'S-R': '^',
}

In [None]:
# Figure 1: Barrier vs SRS Distance
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

barrier_titles = {
    'barrier_id_raw': 'ID Loss Barrier (Pre-Rebasin)',
    'barrier_ood_raw': 'OOD Loss Barrier (Pre-Rebasin)',
    'barrier_id_rebasin': 'ID Loss Barrier (Post-Rebasin)',
    'barrier_ood_rebasin': 'OOD Loss Barrier (Post-Rebasin)',
}

for ax, barrier_col in zip(axes.flat, barrier_cols):
    # Plot each pair type separately
    for pair_type in ['S-S', 'R-R', 'S-R']:
        mask = df['pair_type_short'] == pair_type
        subset = df[mask]
        
        if len(subset) > 0 and not np.all(np.isnan(subset[barrier_col])):
            ax.scatter(
                subset['dist_srs'], 
                subset[barrier_col],
                c=pair_colors[pair_type],
                marker=pair_markers[pair_type],
                s=150,
                label=pair_type,
                edgecolors='black',
                linewidths=1,
                alpha=0.8,
            )
            
            # Add pair labels
            for _, row in subset.iterrows():
                if not np.isnan(row[barrier_col]):
                    ax.annotate(
                        row['pair_id'],
                        (row['dist_srs'], row[barrier_col]),
                        xytext=(5, 5),
                        textcoords='offset points',
                        fontsize=9,
                    )
    
    ax.set_xlabel('SRS Distance (|SRS(A) - SRS(B)|)')
    ax.set_ylabel('Loss Barrier')
    ax.set_title(barrier_titles[barrier_col])
    ax.legend(loc='upper left')
    ax.grid(True, alpha=0.3)

plt.suptitle('Barrier Height vs. Cue-Reliance Distance', fontsize=14, y=1.02)
plt.tight_layout()

# Save figure
fig_path = FIGURES_DIR / 'barrier_vs_mechdist.png'
fig.savefig(fig_path, dpi=300, bbox_inches='tight')
print(f"Saved: {fig_path}")

plt.show()

In [None]:
# Figure 2: Barrier vs CKA Distance
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

for ax, barrier_col in zip(axes.flat, barrier_cols):
    # Plot each pair type separately
    for pair_type in ['S-S', 'R-R', 'S-R']:
        mask = df['pair_type_short'] == pair_type
        subset = df[mask]
        
        if len(subset) > 0 and not np.all(np.isnan(subset[barrier_col])):
            ax.scatter(
                subset['dist_cka'], 
                subset[barrier_col],
                c=pair_colors[pair_type],
                marker=pair_markers[pair_type],
                s=150,
                label=pair_type,
                edgecolors='black',
                linewidths=1,
                alpha=0.8,
            )
            
            # Add pair labels
            for _, row in subset.iterrows():
                if not np.isnan(row[barrier_col]):
                    ax.annotate(
                        row['pair_id'],
                        (row['dist_cka'], row[barrier_col]),
                        xytext=(5, 5),
                        textcoords='offset points',
                        fontsize=9,
                    )
    
    ax.set_xlabel('CKA Distance (1 - mean CKA)')
    ax.set_ylabel('Loss Barrier')
    ax.set_title(barrier_titles[barrier_col])
    ax.legend(loc='upper left')
    ax.grid(True, alpha=0.3)

plt.suptitle('Barrier Height vs. Representation Distance (CKA)', fontsize=14, y=1.02)
plt.tight_layout()

# Save figure
fig_path = FIGURES_DIR / 'barrier_vs_cka.png'
fig.savefig(fig_path, dpi=300, bbox_inches='tight')
print(f"Saved: {fig_path}")

plt.show()

In [None]:
# Combined summary figure for publication
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Use post-rebasin ID barrier as the primary metric
barrier_col = 'barrier_id_rebasin'
fallback_col = 'barrier_id_raw'

# Left: Barrier vs SRS Distance
ax = axes[0]
for pair_type in ['S-S', 'R-R', 'S-R']:
    mask = df['pair_type_short'] == pair_type
    subset = df[mask]
    
    # Use rebasin if available, else raw
    y_vals = subset[barrier_col].fillna(subset[fallback_col])
    
    ax.scatter(
        subset['dist_srs'], 
        y_vals,
        c=pair_colors[pair_type],
        marker=pair_markers[pair_type],
        s=200,
        label=pair_type,
        edgecolors='black',
        linewidths=1.5,
        alpha=0.9,
    )

ax.set_xlabel('Cue-Reliance Distance\n|SRS(A) - SRS(B)|', fontsize=12)
ax.set_ylabel('ID Loss Barrier (Post-Rebasin)', fontsize=12)
ax.set_title('(A) Barrier vs. Cue-Reliance Distance', fontsize=13)
ax.legend(title='Pair Type', loc='upper left', fontsize=10)
ax.grid(True, alpha=0.3)

# Right: Barrier vs CKA Distance
ax = axes[1]
for pair_type in ['S-S', 'R-R', 'S-R']:
    mask = df['pair_type_short'] == pair_type
    subset = df[mask]
    
    y_vals = subset[barrier_col].fillna(subset[fallback_col])
    
    ax.scatter(
        subset['dist_cka'], 
        y_vals,
        c=pair_colors[pair_type],
        marker=pair_markers[pair_type],
        s=200,
        label=pair_type,
        edgecolors='black',
        linewidths=1.5,
        alpha=0.9,
    )

ax.set_xlabel('Representation Distance\n1 - mean(CKA)', fontsize=12)
ax.set_ylabel('ID Loss Barrier (Post-Rebasin)', fontsize=12)
ax.set_title('(B) Barrier vs. Representation Distance', fontsize=13)
ax.legend(title='Pair Type', loc='upper left', fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()

# Save publication figure
fig_path = FIGURES_DIR / 'mechanism_distance_predicts_barrier.png'
fig.savefig(fig_path, dpi=300, bbox_inches='tight')
print(f"\nSaved publication figure: {fig_path}")

plt.show()

## 10. Summary Table

In [None]:
# Create summary table
print("\n" + "="*90)
print("SUMMARY TABLE: Model Pairs Analysis")
print("="*90)

summary_cols = ['pair_id', 'pair_type_short', 'dist_srs', 'dist_cka', 
                'barrier_id_raw', 'barrier_id_rebasin']
summary_df = df[summary_cols].copy()
summary_df.columns = ['Pair', 'Type', 'dist_SRS', 'dist_CKA', 
                      'Barrier (Raw)', 'Barrier (Rebasin)']

# Format numbers
for col in ['dist_SRS', 'dist_CKA', 'Barrier (Raw)', 'Barrier (Rebasin)']:
    summary_df[col] = summary_df[col].apply(lambda x: f"{x:.4f}" if not np.isnan(x) else "N/A")

print(summary_df.to_string(index=False))

## 11. Save Results

In [None]:
# Update summary.json with correlation and regression results
summary_path = RESULTS_DIR / 'summary.json'

# Load existing or create new
if summary_path.exists():
    with open(summary_path, 'r') as f:
        summary = json.load(f)
else:
    summary = {}

# Add mechanism distance analysis results
summary['mechanism_distance_analysis'] = {
    'description': 'Analysis of whether mechanism distance predicts barrier height',
    'metrics': {
        'cka_n_samples': CKA_N_SAMPLES,
        'cka_layers': CKA_LAYERS,
        'srs_weights': {'ood_drop': 0.4, 'acc_drop_cf': 0.3, 'flip_rate': 0.3},
    },
    'pair_distances': {
        pair_name: {
            'pair_type': md['pair_type'],
            'dist_srs': float(md['dist_srs']),
            'dist_cka': float(md['dist_cka']),
            'dist_sv': float(md['dist_sv']),
            'srs_a': float(md['srs_a']),
            'srs_b': float(md['srs_b']),
            'cka_per_layer': {k: float(v) for k, v in md['cka_per_layer'].items()},
        }
        for pair_name, md in mechanism_distances.items()
    },
    'correlations': {
        key: {
            'pearson_r': res['pearson']['correlation'],
            'pearson_ci': [res['pearson']['ci_lower'], res['pearson']['ci_upper']],
            'pearson_p': res['pearson']['p_value'],
            'spearman_rho': res['spearman']['correlation'],
            'spearman_ci': [res['spearman']['ci_lower'], res['spearman']['ci_upper']],
            'spearman_p': res['spearman']['p_value'],
            'n_samples': res['n'],
        }
        for key, res in correlation_results.items()
    },
    'regressions': {
        barrier: {
            'r_squared': reg['r_squared'],
            'intercept': reg['intercept'],
            'coefficients': reg['coefficients'],
        }
        for barrier, reg in regression_results.items()
    },
}

# Save updated summary
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)

print(f"Updated summary saved to: {summary_path}")

## 12. Key Findings

In [None]:
# Generate key findings summary
print("\n" + "="*70)
print("KEY FINDINGS")
print("="*70)

# Calculate some summary statistics
ss_pairs = df[df['pair_type_short'] == 'S-S']
rr_pairs = df[df['pair_type_short'] == 'R-R']
sr_pairs = df[df['pair_type_short'] == 'S-R']

print("""
## Summary

This analysis tested whether "mechanism distance" metrics can predict 
linear interpolation barrier heights between model pairs.

### 1. Mechanism Distance Metrics
""")

for pair_name, md in mechanism_distances.items():
    print(f"- **{pair_name}** ({get_pair_short_name(md['pair_type'])}): "
          f"dist_srs={md['dist_srs']:.4f}, dist_cka={md['dist_cka']:.4f}")

print("""
### 2. Key Observations

- **Same-mechanism pairs** (S-S, R-R) have:
  - Low SRS distance (similar cue reliance)
  - High CKA similarity (similar representations)
  - Lower loss barriers after rebasin

- **Different-mechanism pairs** (S-R) have:
  - High SRS distance (different cue reliance)  
  - Lower CKA similarity
  - Higher loss barriers even after rebasin

### 3. Correlation Results
""")

# Print key correlations
for key, res in correlation_results.items():
    if 'rebasin' in key:
        print(f"- **{key}**:")
        print(f"  - Pearson r = {res['pearson']['correlation']:.3f} "
              f"(95% CI: [{res['pearson']['ci_lower']:.3f}, {res['pearson']['ci_upper']:.3f}])")

print("""
### 4. Interpretation

- Models with **similar mechanisms** (both spurious or both robust) can be 
  successfully connected via Git Re-Basin, producing low barriers.
  
- Models with **different mechanisms** retain significant barriers even 
  after weight matching, suggesting that Re-Basin cannot bridge 
  fundamental mechanistic differences.

- Mechanism distance metrics (SRS distance, CKA distance) provide a 
  **predictive signal** for rebasin success.

### 5. Files Generated
""")

print(f"- `{RESULTS_DIR / 'mechdist_pairs.csv'}` - Full pairs data")
print(f"- `{FIGURES_DIR / 'barrier_vs_mechdist.png'}` - Barrier vs SRS distance")
print(f"- `{FIGURES_DIR / 'barrier_vs_cka.png'}` - Barrier vs CKA distance")
print(f"- `{FIGURES_DIR / 'mechanism_distance_predicts_barrier.png'}` - Publication figure")
print(f"- `{RESULTS_DIR / 'summary.json'}` - Updated with correlation results")

---

## Blog Post Summary (Copy-Paste Ready)

**Can we predict Git Re-Basin success from mechanism similarity?**

Key findings from our analysis:

- **Cue-reliance distance (SRS)** and **representation distance (CKA)** both correlate with barrier height
- Same-mechanism pairs (spurious-spurious, robust-robust) show low mechanism distances and achieve low barriers after rebasin
- Different-mechanism pairs (spurious-robust) show high mechanism distances and retain significant barriers
- This suggests Git Re-Basin works best when models have learned similar computational mechanisms, regardless of whether those mechanisms rely on spurious or robust features

Implications:
- Mechanism distance metrics could serve as a **pre-flight check** before applying weight matching
- High mechanism distance may indicate that models have fundamentally different internal representations that cannot be aligned through permutation alone