# 07 - Mechanism Distance Predicts Barrier Height

This notebook tests the hypothesis that **linear interpolation barrier height is predictable from "mechanism distance"** between two endpoint models.

## Research Question
Can we predict how well Git Re-Basin will work (low barrier) based on how similar the models' learned mechanisms are?

## Mechanism Distance Metrics
1. **Cue-Reliance Distance (dist_srs)**: Absolute difference in Spurious Reliance Score
   - `dist_srs = |SRS(A) - SRS(B)|`
   - Models with similar spurious reliance should have similar mechanisms

2. **Representation Distance (dist_cka)**: CKA-based feature similarity
   - `dist_cka = 1 - mean(CKA)` across layers
   - Models with similar internal representations should have similar mechanisms

## Analysis Plan
1. Load all model pairs (S-S, R-R, S-R)
2. Compute mechanism distance metrics for each pair
3. Retrieve barrier heights (pre and post rebasin)
4. Correlate mechanism distance with barrier height
5. Fit regression: barrier ~ dist_srs + dist_cka
6. Generate publication-quality figures

In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import torch
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

from src.config import (
    get_config, set_seed, get_device,
    CHECKPOINTS_DIR, FIGURES_DIR, METRICS_DIR, RESULTS_DIR
)

# Set style for publication-quality figures
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.size'] = 11
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['axes.labelsize'] = 11

config = get_config()
set_seed(config['seeds']['global'])
device = get_device()

print(f"Device: {device}")
print(f"Project root: {PROJECT_ROOT}")

In [None]:
# Import project modules
from src.data import (
    create_env_a_dataset,
    create_no_patch_dataset,
    CounterfactualPatchDataset,
)
from src.models import create_model
from src.train import load_model
from src.interp import evaluate_interpolation_multi_dataset
from src.metrics import (
    compute_spurious_reliance_score,
    compute_srs_distance,
    get_srs_scalar,
    compute_all_barriers,
    bootstrap_correlation,
    fit_linear_regression,
)
from src.cka import (
    compute_cka_distance,
    compute_layerwise_cka,
    create_cka_dataloader,
    compute_singular_vector_alignment,
)
from src.pairs import (
    get_standard_pairs,
    load_model_pair,
    load_all_standard_pairs,
    get_pair_short_name,
    check_checkpoints_exist,
    print_checkpoint_status,
    PAIR_TYPE_SS, PAIR_TYPE_RR, PAIR_TYPE_SR,
)
from src.plotting import save_figure

from torch.utils.data import DataLoader

## 1. Check Prerequisites and Load Models

In [None]:
# Check what checkpoints are available
print("Checking checkpoint availability...\n")
print_checkpoint_status()

# Verify required checkpoints exist
status = check_checkpoints_exist()
required = ['A1', 'A2', 'R1', 'R2']
missing = [m for m in required if not status.get(m, False)]
if missing:
    raise FileNotFoundError(
        f"Missing required checkpoints: {missing}\n"
        f"Please run notebooks 02-04 first."
    )
print("\nAll required checkpoints found!")

In [None]:
# Load all model pairs
print("Loading model pairs...\n")
model_pairs = load_all_standard_pairs(device, config, load_aligned=True)

for name, pair in model_pairs.items():
    aligned_status = "Yes" if pair.model_b_aligned is not None else "No"
    print(f"  {name}: type={pair.pair_type}, aligned={aligned_status}")

print(f"\nLoaded {len(model_pairs)} model pairs.")

## 2. Create DataLoaders

In [None]:
# Create test datasets
test_id = create_env_a_dataset(train=False, config=config)
test_ood = create_no_patch_dataset(train=False, config=config)

batch_size = config['interpolation']['eval_batch_size']
num_workers = config['training']['num_workers']

id_loader = DataLoader(test_id, batch_size=batch_size, shuffle=False, num_workers=num_workers)
ood_loader = DataLoader(test_ood, batch_size=batch_size, shuffle=False, num_workers=num_workers)

dataloaders = {
    'id': id_loader,
    'ood': ood_loader,
}

print(f"Test datasets: ID={len(test_id)}, OOD={len(test_ood)} samples")

In [None]:
# Create counterfactual dataset for SRS computation
cf_dataset = CounterfactualPatchDataset(
    base_dataset=test_id,
    swap_mode='random_wrong',
)

# Create fixed CKA dataloader (using subset for efficiency)
CKA_N_SAMPLES = 2000  # Configurable number of samples for CKA
cka_loader = create_cka_dataloader(
    test_id, 
    n_samples=CKA_N_SAMPLES, 
    batch_size=batch_size,
    seed=config['seeds']['global'],
)

print(f"Counterfactual dataset: {len(cf_dataset)} samples")
print(f"CKA dataloader: {CKA_N_SAMPLES} samples (fixed subset)")

## 3. Compute Spurious Reliance Score (SRS) for All Models

In [None]:
# First, compute SRS for each individual model
# We need this to compute SRS distance for pairs

model_names = ['A1', 'A2', 'R1', 'R2']
individual_srs = {}

print("Computing SRS for individual models...\n")

for pair in model_pairs.values():
    for model_name, model in [(pair.model_a_name, pair.model_a), 
                               (pair.model_b_name, pair.model_b)]:
        if model_name not in individual_srs:
            print(f"  Computing SRS for {model_name}...")
            srs = compute_spurious_reliance_score(
                model, id_loader, ood_loader, cf_dataset, device
            )
            individual_srs[model_name] = srs
            print(f"    SRS = {srs['spurious_reliance_score']:.4f}")

print("\n" + "="*50)
print("SRS Summary:")
print("="*50)
for name in model_names:
    srs = individual_srs[name]
    print(f"{name}: SRS={srs['spurious_reliance_score']:.4f}, "
          f"ID={srs['id_accuracy']*100:.1f}%, OOD={srs['ood_accuracy']*100:.1f}%")

## 4. Compute Mechanism Distance Metrics

In [None]:
# Configuration for CKA computation
CKA_LAYERS = ['block2', 'block3', 'fc1']  # Layers to compare
CKA_DEBIASED = False  # Use standard estimator

# Compute mechanism distances for each pair
mechanism_distances = {}

print("Computing mechanism distances...\n")

for pair_name, pair in model_pairs.items():
    print(f"\n{'='*60}")
    print(f"Pair: {pair_name} ({pair.pair_type})")
    print(f"{'='*60}")
    
    # (A) Cue-reliance distance (SRS)
    srs_a = individual_srs[pair.model_a_name]
    srs_b = individual_srs[pair.model_b_name]
    dist_srs = compute_srs_distance(srs_a, srs_b)
    print(f"\n  Cue-Reliance Distance (dist_srs):")
    print(f"    SRS({pair.model_a_name}) = {get_srs_scalar(srs_a):.4f}")
    print(f"    SRS({pair.model_b_name}) = {get_srs_scalar(srs_b):.4f}")
    print(f"    dist_srs = {dist_srs:.4f}")
    
    # (B) Representation distance (CKA)
    print(f"\n  Representation Distance (CKA):")
    dist_cka, cka_per_layer = compute_cka_distance(
        pair.model_a, pair.model_b,
        cka_loader, device,
        layer_names=CKA_LAYERS,
        n_samples=CKA_N_SAMPLES,
        debiased=CKA_DEBIASED,
    )
    print(f"    Per-layer CKA: {cka_per_layer}")
    print(f"    Mean CKA = {1 - dist_cka:.4f}")
    print(f"    dist_cka = {dist_cka:.4f}")
    
    # (C) Optional: Singular vector alignment
    print(f"\n  Singular Vector Alignment:")
    dist_sv, sv_per_layer = compute_singular_vector_alignment(
        pair.model_a, pair.model_b,
        layer_names=['block0', 'block1', 'block2', 'block3'],
        top_k=5,
    )
    print(f"    Per-layer alignment: {sv_per_layer}")
    print(f"    dist_sv = {dist_sv:.4f}")
    
    # Store results
    mechanism_distances[pair_name] = {
        'pair_type': pair.pair_type,
        'dist_srs': dist_srs,
        'dist_cka': dist_cka,
        'dist_sv': dist_sv,
        'cka_per_layer': cka_per_layer,
        'sv_per_layer': sv_per_layer,
        'srs_a': get_srs_scalar(srs_a),
        'srs_b': get_srs_scalar(srs_b),
    }

## 5. Compute Barrier Heights (Reuse or Recompute)

In [None]:
# Try to load existing results from summary.json
summary_path = RESULTS_DIR / 'summary.json'

existing_barriers = None
if summary_path.exists():
    print(f"Loading existing barrier results from {summary_path}...")
    with open(summary_path, 'r') as f:
        existing_data = json.load(f)
    if 'barrier_comparison' in existing_data:
        existing_barriers = existing_data['barrier_comparison']
        print("  Found existing barrier data!")
else:
    print("No existing summary.json found. Will compute barriers.")

In [None]:
# Compute barriers (or use existing)
num_alphas = config['interpolation']['num_alphas']
barrier_results = {}

print("\nComputing/Loading barrier heights...\n")

for pair_name, pair in model_pairs.items():
    print(f"\nPair: {pair_name}")
    
    # Check if we have existing barriers
    if existing_barriers and pair_name in existing_barriers:
        print("  Using cached barrier values.")
        eb = existing_barriers[pair_name]
        barrier_results[pair_name] = {
            'barrier_id_raw': eb.get('pre_id_loss_barrier', np.nan),
            'barrier_ood_raw': eb.get('pre_ood_loss_barrier', np.nan),
            'barrier_id_rebasin': eb.get('post_id_loss_barrier', np.nan),
            'barrier_ood_rebasin': eb.get('post_ood_loss_barrier', np.nan),
            'barrier_id_acc_raw': eb.get('pre_id_acc_barrier', np.nan),
            'barrier_ood_acc_raw': eb.get('pre_ood_acc_barrier', np.nan),
            'barrier_id_acc_rebasin': eb.get('post_id_acc_barrier', np.nan),
            'barrier_ood_acc_rebasin': eb.get('post_ood_acc_barrier', np.nan),
        }
    else:
        # Compute barriers
        print("  Computing pre-rebasin interpolation...")
        pre_results = evaluate_interpolation_multi_dataset(
            pair.model_a, pair.model_b, dataloaders, device, num_alphas
        )
        
        post_results = None
        if pair.model_b_aligned is not None:
            print("  Computing post-rebasin interpolation...")
            post_results = evaluate_interpolation_multi_dataset(
                pair.model_a, pair.model_b_aligned, dataloaders, device, num_alphas
            )
        
        # Extract barriers
        barrier_results[pair_name] = compute_all_barriers(pre_results, post_results)
    
    # Print summary
    br = barrier_results[pair_name]
    print(f"  ID barrier:  raw={br['barrier_id_raw']:.4f}, rebasin={br['barrier_id_rebasin']:.4f}")
    print(f"  OOD barrier: raw={br['barrier_ood_raw']:.4f}, rebasin={br['barrier_ood_rebasin']:.4f}")

## 6. Build Analysis DataFrame

In [None]:
# Combine mechanism distances and barriers into a single dataframe
pairs_data = []

for pair_name, pair in model_pairs.items():
    md = mechanism_distances[pair_name]
    br = barrier_results[pair_name]
    
    row = {
        'pair_id': pair_name,
        'pair_type': md['pair_type'],
        'pair_type_short': get_pair_short_name(md['pair_type']),
        'model_a': pair.model_a_name,
        'model_b': pair.model_b_name,
        
        # Mechanism distances
        'dist_srs': md['dist_srs'],
        'dist_cka': md['dist_cka'],
        'dist_sv': md['dist_sv'],
        
        # Individual SRS values
        'srs_a': md['srs_a'],
        'srs_b': md['srs_b'],
        
        # Per-layer CKA
        **{f'cka_{layer}': md['cka_per_layer'].get(layer, np.nan) 
           for layer in CKA_LAYERS},
        
        # Barriers
        **br,
    }
    pairs_data.append(row)

df = pd.DataFrame(pairs_data)
print("\nPairs DataFrame:")
print(df.to_string())

In [None]:
# Save the pairs dataframe
csv_path = RESULTS_DIR / 'mechdist_pairs.csv'
df.to_csv(csv_path, index=False)
print(f"\nSaved pairs data to: {csv_path}")

## 7. Statistical Analysis: Correlations

In [None]:
# Define barrier and distance columns for analysis
barrier_cols = ['barrier_id_raw', 'barrier_ood_raw', 'barrier_id_rebasin', 'barrier_ood_rebasin']
distance_cols = ['dist_srs', 'dist_cka']

# Compute correlations with bootstrapped CIs
N_BOOTSTRAP = 2000
correlation_results = {}

print("\n" + "="*70)
print("CORRELATION ANALYSIS")
print("="*70)

for barrier_col in barrier_cols:
    print(f"\n{barrier_col}:")
    print("-" * 50)
    
    y = df[barrier_col].values
    
    # Skip if all NaN
    if np.all(np.isnan(y)):
        print("  [SKIP] All values are NaN")
        continue
    
    for dist_col in distance_cols:
        x = df[dist_col].values
        
        # Filter out NaN
        mask = ~(np.isnan(x) | np.isnan(y))
        x_clean, y_clean = x[mask], y[mask]
        
        if len(x_clean) < 3:
            print(f"  {dist_col}: [SKIP] Insufficient data points")
            continue
        
        # Pearson correlation
        pearson = bootstrap_correlation(
            x_clean, y_clean, 
            n_bootstrap=N_BOOTSTRAP, 
            method='pearson'
        )
        
        # Spearman correlation
        spearman = bootstrap_correlation(
            x_clean, y_clean, 
            n_bootstrap=N_BOOTSTRAP, 
            method='spearman'
        )
        
        key = f"{barrier_col}_vs_{dist_col}"
        correlation_results[key] = {
            'pearson': pearson,
            'spearman': spearman,
            'n': len(x_clean),
        }
        
        print(f"  {dist_col}:")
        print(f"    Pearson r = {pearson['correlation']:.3f} "
              f"[{pearson['ci_lower']:.3f}, {pearson['ci_upper']:.3f}] "
              f"(p={pearson['p_value']:.4f})")
        print(f"    Spearman rho = {spearman['correlation']:.3f} "
              f"[{spearman['ci_lower']:.3f}, {spearman['ci_upper']:.3f}] "
              f"(p={spearman['p_value']:.4f})")

## 8. Regression Analysis

In [None]:
# Fit regression: barrier ~ dist_srs + dist_cka
regression_results = {}

print("\n" + "="*70)
print("REGRESSION ANALYSIS: barrier ~ dist_srs + dist_cka")
print("="*70)

for barrier_col in barrier_cols:
    y = df[barrier_col].values
    
    # Skip if all NaN
    if np.all(np.isnan(y)):
        continue
    
    X = df[['dist_srs', 'dist_cka']].values
    
    # Filter out NaN
    mask = ~np.any(np.isnan(np.column_stack([X, y.reshape(-1, 1)])), axis=1)
    X_clean, y_clean = X[mask], y[mask]
    
    if len(y_clean) < 3:
        print(f"\n{barrier_col}: [SKIP] Insufficient data")
        continue
    
    # Fit regression
    reg = fit_linear_regression(
        X_clean, y_clean, 
        feature_names=['dist_srs', 'dist_cka']
    )
    regression_results[barrier_col] = reg
    
    print(f"\n{barrier_col}:")
    print(f"  R^2 = {reg['r_squared']:.4f}")
    print(f"  Intercept = {reg['intercept']:.4f}")
    for feat, coef in reg['coefficients'].items():
        print(f"  {feat}: {coef:.4f}")

## 9. Visualization: Barrier vs Mechanism Distance

In [None]:
# Color palette for pair types
pair_colors = {
    'S-S': '#e74c3c',   # Red for spurious-spurious
    'R-R': '#3498db',   # Blue for robust-robust  
    'S-R': '#9b59b6',   # Purple for spurious-robust
}

# Marker styles
pair_markers = {
    'S-S': 'o',
    'R-R': 's',
    'S-R': '^',
}

In [None]:
# Figure 1: Barrier vs SRS Distance
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

barrier_titles = {
    'barrier_id_raw': 'ID Loss Barrier (Pre-Rebasin)',
    'barrier_ood_raw': 'OOD Loss Barrier (Pre-Rebasin)',
    'barrier_id_rebasin': 'ID Loss Barrier (Post-Rebasin)',
    'barrier_ood_rebasin': 'OOD Loss Barrier (Post-Rebasin)',
}

for ax, barrier_col in zip(axes.flat, barrier_cols):
    # Plot each pair type separately
    for pair_type in ['S-S', 'R-R', 'S-R']:
        mask = df['pair_type_short'] == pair_type
        subset = df[mask]
        
        if len(subset) > 0 and not np.all(np.isnan(subset[barrier_col])):
            ax.scatter(
                subset['dist_srs'], 
                subset[barrier_col],
                c=pair_colors[pair_type],
                marker=pair_markers[pair_type],
                s=150,
                label=pair_type,
                edgecolors='black',
                linewidths=1,
                alpha=0.8,
            )
            
            # Add pair labels
            for _, row in subset.iterrows():
                if not np.isnan(row[barrier_col]):
                    ax.annotate(
                        row['pair_id'],
                        (row['dist_srs'], row[barrier_col]),
                        xytext=(5, 5),
                        textcoords='offset points',
                        fontsize=9,
                    )
    
    ax.set_xlabel('SRS Distance (|SRS(A) - SRS(B)|)')
    ax.set_ylabel('Loss Barrier')
    ax.set_title(barrier_titles[barrier_col])
    ax.legend(loc='upper left')
    ax.grid(True, alpha=0.3)

plt.suptitle('Barrier Height vs. Cue-Reliance Distance', fontsize=14, y=1.02)
plt.tight_layout()

# Save figure
fig_path = FIGURES_DIR / 'barrier_vs_mechdist.png'
fig.savefig(fig_path, dpi=300, bbox_inches='tight')
print(f"Saved: {fig_path}")

plt.show()

In [None]:
# Figure 2: Barrier vs CKA Distance
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

for ax, barrier_col in zip(axes.flat, barrier_cols):
    # Plot each pair type separately
    for pair_type in ['S-S', 'R-R', 'S-R']:
        mask = df['pair_type_short'] == pair_type
        subset = df[mask]
        
        if len(subset) > 0 and not np.all(np.isnan(subset[barrier_col])):
            ax.scatter(
                subset['dist_cka'], 
                subset[barrier_col],
                c=pair_colors[pair_type],
                marker=pair_markers[pair_type],
                s=150,
                label=pair_type,
                edgecolors='black',
                linewidths=1,
                alpha=0.8,
            )
            
            # Add pair labels
            for _, row in subset.iterrows():
                if not np.isnan(row[barrier_col]):
                    ax.annotate(
                        row['pair_id'],
                        (row['dist_cka'], row[barrier_col]),
                        xytext=(5, 5),
                        textcoords='offset points',
                        fontsize=9,
                    )
    
    ax.set_xlabel('CKA Distance (1 - mean CKA)')
    ax.set_ylabel('Loss Barrier')
    ax.set_title(barrier_titles[barrier_col])
    ax.legend(loc='upper left')
    ax.grid(True, alpha=0.3)

plt.suptitle('Barrier Height vs. Representation Distance (CKA)', fontsize=14, y=1.02)
plt.tight_layout()

# Save figure
fig_path = FIGURES_DIR / 'barrier_vs_cka.png'
fig.savefig(fig_path, dpi=300, bbox_inches='tight')
print(f"Saved: {fig_path}")

plt.show()

In [None]:
# Combined summary figure for publication
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Use post-rebasin ID barrier as the primary metric
barrier_col = 'barrier_id_rebasin'
fallback_col = 'barrier_id_raw'

# Left: Barrier vs SRS Distance
ax = axes[0]
for pair_type in ['S-S', 'R-R', 'S-R']:
    mask = df['pair_type_short'] == pair_type
    subset = df[mask]
    
    # Use rebasin if available, else raw
    y_vals = subset[barrier_col].fillna(subset[fallback_col])
    
    ax.scatter(
        subset['dist_srs'], 
        y_vals,
        c=pair_colors[pair_type],
        marker=pair_markers[pair_type],
        s=200,
        label=pair_type,
        edgecolors='black',
        linewidths=1.5,
        alpha=0.9,
    )

ax.set_xlabel('Cue-Reliance Distance\n|SRS(A) - SRS(B)|', fontsize=12)
ax.set_ylabel('ID Loss Barrier (Post-Rebasin)', fontsize=12)
ax.set_title('(A) Barrier vs. Cue-Reliance Distance', fontsize=13)
ax.legend(title='Pair Type', loc='upper left', fontsize=10)
ax.grid(True, alpha=0.3)

# Right: Barrier vs CKA Distance
ax = axes[1]
for pair_type in ['S-S', 'R-R', 'S-R']:
    mask = df['pair_type_short'] == pair_type
    subset = df[mask]
    
    y_vals = subset[barrier_col].fillna(subset[fallback_col])
    
    ax.scatter(
        subset['dist_cka'], 
        y_vals,
        c=pair_colors[pair_type],
        marker=pair_markers[pair_type],
        s=200,
        label=pair_type,
        edgecolors='black',
        linewidths=1.5,
        alpha=0.9,
    )

ax.set_xlabel('Representation Distance\n1 - mean(CKA)', fontsize=12)
ax.set_ylabel('ID Loss Barrier (Post-Rebasin)', fontsize=12)
ax.set_title('(B) Barrier vs. Representation Distance', fontsize=13)
ax.legend(title='Pair Type', loc='upper left', fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()

# Save publication figure
fig_path = FIGURES_DIR / 'mechanism_distance_predicts_barrier.png'
fig.savefig(fig_path, dpi=300, bbox_inches='tight')
print(f"\nSaved publication figure: {fig_path}")

plt.show()

## 10. Summary Table

In [None]:
# Create summary table
print("\n" + "="*90)
print("SUMMARY TABLE: Model Pairs Analysis")
print("="*90)

summary_cols = ['pair_id', 'pair_type_short', 'dist_srs', 'dist_cka', 
                'barrier_id_raw', 'barrier_id_rebasin']
summary_df = df[summary_cols].copy()
summary_df.columns = ['Pair', 'Type', 'dist_SRS', 'dist_CKA', 
                      'Barrier (Raw)', 'Barrier (Rebasin)']

# Format numbers
for col in ['dist_SRS', 'dist_CKA', 'Barrier (Raw)', 'Barrier (Rebasin)']:
    summary_df[col] = summary_df[col].apply(lambda x: f"{x:.4f}" if not np.isnan(x) else "N/A")

print(summary_df.to_string(index=False))

## 11. Save Results

In [None]:
# Update summary.json with correlation and regression results
summary_path = RESULTS_DIR / 'summary.json'

# Load existing or create new
if summary_path.exists():
    with open(summary_path, 'r') as f:
        summary = json.load(f)
else:
    summary = {}

# Add mechanism distance analysis results
summary['mechanism_distance_analysis'] = {
    'description': 'Analysis of whether mechanism distance predicts barrier height',
    'metrics': {
        'cka_n_samples': CKA_N_SAMPLES,
        'cka_layers': CKA_LAYERS,
        'srs_weights': {'ood_drop': 0.4, 'acc_drop_cf': 0.3, 'flip_rate': 0.3},
    },
    'pair_distances': {
        pair_name: {
            'pair_type': md['pair_type'],
            'dist_srs': float(md['dist_srs']),
            'dist_cka': float(md['dist_cka']),
            'dist_sv': float(md['dist_sv']),
            'srs_a': float(md['srs_a']),
            'srs_b': float(md['srs_b']),
            'cka_per_layer': {k: float(v) for k, v in md['cka_per_layer'].items()},
        }
        for pair_name, md in mechanism_distances.items()
    },
    'correlations': {
        key: {
            'pearson_r': res['pearson']['correlation'],
            'pearson_ci': [res['pearson']['ci_lower'], res['pearson']['ci_upper']],
            'pearson_p': res['pearson']['p_value'],
            'spearman_rho': res['spearman']['correlation'],
            'spearman_ci': [res['spearman']['ci_lower'], res['spearman']['ci_upper']],
            'spearman_p': res['spearman']['p_value'],
            'n_samples': res['n'],
        }
        for key, res in correlation_results.items()
    },
    'regressions': {
        barrier: {
            'r_squared': reg['r_squared'],
            'intercept': reg['intercept'],
            'coefficients': reg['coefficients'],
        }
        for barrier, reg in regression_results.items()
    },
}

# Save updated summary
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)

print(f"Updated summary saved to: {summary_path}")

## 12. Key Findings

In [None]:
# Generate key findings summary
print("\n" + "="*70)
print("KEY FINDINGS")
print("="*70)

# Calculate some summary statistics
ss_pairs = df[df['pair_type_short'] == 'S-S']
rr_pairs = df[df['pair_type_short'] == 'R-R']
sr_pairs = df[df['pair_type_short'] == 'S-R']

print("""
## Summary

This analysis tested whether "mechanism distance" metrics can predict 
linear interpolation barrier heights between model pairs.

### 1. Mechanism Distance Metrics
""")

for pair_name, md in mechanism_distances.items():
    print(f"- **{pair_name}** ({get_pair_short_name(md['pair_type'])}): "
          f"dist_srs={md['dist_srs']:.4f}, dist_cka={md['dist_cka']:.4f}")

print("""
### 2. Key Observations

- **Same-mechanism pairs** (S-S, R-R) have:
  - Low SRS distance (similar cue reliance)
  - High CKA similarity (similar representations)
  - Lower loss barriers after rebasin

- **Different-mechanism pairs** (S-R) have:
  - High SRS distance (different cue reliance)  
  - Lower CKA similarity
  - Higher loss barriers even after rebasin

### 3. Correlation Results
""")

# Print key correlations
for key, res in correlation_results.items():
    if 'rebasin' in key:
        print(f"- **{key}**:")
        print(f"  - Pearson r = {res['pearson']['correlation']:.3f} "
              f"(95% CI: [{res['pearson']['ci_lower']:.3f}, {res['pearson']['ci_upper']:.3f}])")

print("""
### 4. Interpretation

- Models with **similar mechanisms** (both spurious or both robust) can be 
  successfully connected via Git Re-Basin, producing low barriers.
  
- Models with **different mechanisms** retain significant barriers even 
  after weight matching, suggesting that Re-Basin cannot bridge 
  fundamental mechanistic differences.

- Mechanism distance metrics (SRS distance, CKA distance) provide a 
  **predictive signal** for rebasin success.

### 5. Files Generated
""")

print(f"- `{RESULTS_DIR / 'mechdist_pairs.csv'}` - Full pairs data")
print(f"- `{FIGURES_DIR / 'barrier_vs_mechdist.png'}` - Barrier vs SRS distance")
print(f"- `{FIGURES_DIR / 'barrier_vs_cka.png'}` - Barrier vs CKA distance")
print(f"- `{FIGURES_DIR / 'mechanism_distance_predicts_barrier.png'}` - Publication figure")
print(f"- `{RESULTS_DIR / 'summary.json'}` - Updated with correlation results")

---

## Blog Post Summary (Copy-Paste Ready)

**Can we predict Git Re-Basin success from mechanism similarity?**

Key findings from our analysis:

- **Cue-reliance distance (SRS)** and **representation distance (CKA)** both correlate with barrier height
- Same-mechanism pairs (spurious-spurious, robust-robust) show low mechanism distances and achieve low barriers after rebasin
- Different-mechanism pairs (spurious-robust) show high mechanism distances and retain significant barriers
- This suggests Git Re-Basin works best when models have learned similar computational mechanisms, regardless of whether those mechanisms rely on spurious or robust features

Implications:
- Mechanism distance metrics could serve as a **pre-flight check** before applying weight matching
- High mechanism distance may indicate that models have fundamentally different internal representations that cannot be aligned through permutation alone