# 04 - Git Re-Basin Alignment

This notebook implements Git Re-Basin (weight matching) to align independently trained models.

## What is Git Re-Basin?
Neural networks have permutation symmetries - you can reorder neurons in a layer and get the same function. Git Re-Basin finds these permutations to align two models, enabling meaningful weight interpolation.

## Model Pairs to Align:
1. **A1 vs A2** (spurious-spurious): Should align well (same mechanism)
2. **R1 vs R2** (robust-robust): Should align well (same mechanism)
3. **A1 vs R1** (spurious-robust): May not align well (different mechanisms)

## What this notebook does:
1. Implements weight matching permutation finding
2. Aligns model pairs
3. Verifies alignment improves functional similarity
4. Saves aligned models for interpolation analysis

In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import torch
import json
import numpy as np
import matplotlib.pyplot as plt

from src.config import (
    get_config, set_seed, get_device,
    CHECKPOINTS_DIR, FIGURES_DIR, METRICS_DIR
)

config = get_config()
set_seed(config['seeds']['global'])
device = get_device()

print(f"Device: {device}")

In [None]:
from src.data import create_env_a_dataset, create_no_patch_dataset
from src.models import create_model, model_agreement, clone_model
from src.train import load_model, evaluate_model
from src.rebasin import (
    rebasin,
    weight_matching,
    apply_permutations,
    compute_weight_distance,
    compute_cosine_similarity,
)
from src.plotting import save_figure
from torch.utils.data import DataLoader

## 1. Load Trained Models

In [None]:
# Load all 4 models
model_names = ['A1', 'A2', 'R1', 'R2']
models = {}

for name in model_names:
    checkpoint_path = CHECKPOINTS_DIR / f"model_{name}.pt"
    if not checkpoint_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}\n"
                               f"Please run 02_train_models.ipynb first.")
    
    model = create_model(config)
    model = load_model(model, checkpoint_path, device)
    models[name] = model
    print(f"Loaded model {name}")

print(f"\nAll {len(models)} models loaded!")

## 2. Create Test DataLoader

In [None]:
# Create test dataset for agreement computation
test_id = create_env_a_dataset(train=False, config=config)
test_ood = create_no_patch_dataset(train=False, config=config)

batch_size = config['training']['batch_size']
num_workers = config['training']['num_workers']

id_loader = DataLoader(test_id, batch_size=batch_size, shuffle=False, num_workers=num_workers)
ood_loader = DataLoader(test_ood, batch_size=batch_size, shuffle=False, num_workers=num_workers)

print(f"Test loaders created with {len(test_id)} ID and {len(test_ood)} OOD samples")

## 3. Compute Pre-Rebasin Metrics

In [None]:
# Define model pairs to analyze
model_pairs = [
    ('A1', 'A2', 'spurious-spurious'),
    ('R1', 'R2', 'robust-robust'),
    ('A1', 'R1', 'spurious-robust'),
]

# Compute pre-rebasin metrics
pre_rebasin_metrics = {}

print("Computing pre-rebasin metrics...\n")
print(f"{'Pair':<25} {'Weight Dist':<15} {'Cosine Sim':<15} {'Agreement':<15}")
print("-" * 70)

for name1, name2, pair_type in model_pairs:
    pair_name = f"{name1}-{name2}"
    
    # Weight space metrics
    weight_dist = compute_weight_distance(models[name1], models[name2])
    cosine_sim = compute_cosine_similarity(models[name1], models[name2])
    
    # Functional agreement
    agreement = model_agreement(models[name1], models[name2], id_loader, device)
    
    pre_rebasin_metrics[pair_name] = {
        'type': pair_type,
        'weight_distance': weight_dist,
        'cosine_similarity': cosine_sim,
        'agreement': agreement,
    }
    
    print(f"{pair_name} ({pair_type[:7]}...)  {weight_dist:>10.2f}     {cosine_sim:>10.4f}     {agreement*100:>10.2f}%")

## 4. Perform Git Re-Basin (Weight Matching)

In [None]:
def perform_rebasin(model_ref, model_to_align, name_ref, name_align, device):
    """
    Perform Git Re-Basin alignment.
    
    Args:
        model_ref: Reference model (we align TO this)
        model_to_align: Model to be aligned
        name_ref: Name of reference model
        name_align: Name of model to align
        device: Torch device
    
    Returns:
        Aligned model
    """
    print(f"\nAligning {name_align} to {name_ref}...")
    
    # Clone to avoid modifying original
    model_to_align_copy = clone_model(model_to_align).to(device)
    
    # Perform rebasing
    aligned_model = rebasin(model_ref, model_to_align_copy, device)
    
    print(f"  Alignment complete!")
    
    return aligned_model

In [None]:
# Perform rebasing for all pairs
aligned_models = {}

for name1, name2, pair_type in model_pairs:
    pair_name = f"{name1}-{name2}"
    
    # Align model2 to model1
    aligned = perform_rebasin(
        models[name1], models[name2],
        name1, name2, device
    )
    aligned_models[pair_name] = aligned
    
    # Verify alignment preserves accuracy
    _, orig_acc = evaluate_model(models[name2], id_loader, device)
    _, aligned_acc = evaluate_model(aligned, id_loader, device)
    
    print(f"  Original {name2} accuracy: {orig_acc*100:.2f}%")
    print(f"  Aligned {name2} accuracy:  {aligned_acc*100:.2f}%")
    
    if abs(orig_acc - aligned_acc) > 0.01:
        print(f"  [WARNING] Accuracy changed significantly after alignment!")
    else:
        print(f"  [OK] Accuracy preserved")

## 5. Compute Post-Rebasin Metrics

In [None]:
# Compute post-rebasin metrics
post_rebasin_metrics = {}

print("\nComputing post-rebasin metrics...\n")
print(f"{'Pair':<25} {'Weight Dist':<15} {'Cosine Sim':<15} {'Agreement':<15}")
print("-" * 70)

for name1, name2, pair_type in model_pairs:
    pair_name = f"{name1}-{name2}"
    aligned = aligned_models[pair_name]
    
    # Weight space metrics (reference model vs aligned model)
    weight_dist = compute_weight_distance(models[name1], aligned)
    cosine_sim = compute_cosine_similarity(models[name1], aligned)
    
    # Functional agreement
    agreement = model_agreement(models[name1], aligned, id_loader, device)
    
    post_rebasin_metrics[pair_name] = {
        'type': pair_type,
        'weight_distance': weight_dist,
        'cosine_similarity': cosine_sim,
        'agreement': agreement,
    }
    
    print(f"{pair_name} ({pair_type[:7]}...)  {weight_dist:>10.2f}     {cosine_sim:>10.4f}     {agreement*100:>10.2f}%")

## 6. Compare Pre vs Post Rebasin

In [None]:
# Compare metrics
print("\n" + "=" * 80)
print("COMPARISON: Pre vs Post Re-Basin")
print("=" * 80)

comparison_data = {}

for pair_name in pre_rebasin_metrics.keys():
    pre = pre_rebasin_metrics[pair_name]
    post = post_rebasin_metrics[pair_name]
    
    print(f"\n{pair_name} ({pre['type']}):")
    print(f"  Weight Distance: {pre['weight_distance']:.2f} -> {post['weight_distance']:.2f} "
          f"({post['weight_distance'] - pre['weight_distance']:+.2f})")
    print(f"  Cosine Similarity: {pre['cosine_similarity']:.4f} -> {post['cosine_similarity']:.4f} "
          f"({post['cosine_similarity'] - pre['cosine_similarity']:+.4f})")
    print(f"  Agreement: {pre['agreement']*100:.2f}% -> {post['agreement']*100:.2f}% "
          f"({(post['agreement'] - pre['agreement'])*100:+.2f}%)")
    
    comparison_data[pair_name] = {
        'type': pre['type'],
        'pre_weight_dist': pre['weight_distance'],
        'post_weight_dist': post['weight_distance'],
        'weight_dist_change': post['weight_distance'] - pre['weight_distance'],
        'pre_cosine_sim': pre['cosine_similarity'],
        'post_cosine_sim': post['cosine_similarity'],
        'cosine_sim_change': post['cosine_similarity'] - pre['cosine_similarity'],
        'pre_agreement': pre['agreement'],
        'post_agreement': post['agreement'],
        'agreement_change': post['agreement'] - pre['agreement'],
    }

In [None]:
# Visualize comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

pair_names = list(comparison_data.keys())
x = np.arange(len(pair_names))
width = 0.35

# Weight Distance
pre_dists = [comparison_data[p]['pre_weight_dist'] for p in pair_names]
post_dists = [comparison_data[p]['post_weight_dist'] for p in pair_names]
axes[0].bar(x - width/2, pre_dists, width, label='Pre-Rebasin', color='salmon')
axes[0].bar(x + width/2, post_dists, width, label='Post-Rebasin', color='steelblue')
axes[0].set_xticks(x)
axes[0].set_xticklabels(pair_names, rotation=45, ha='right')
axes[0].set_ylabel('Weight Distance')
axes[0].set_title('Weight Distance')
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis='y')

# Cosine Similarity
pre_sims = [comparison_data[p]['pre_cosine_sim'] for p in pair_names]
post_sims = [comparison_data[p]['post_cosine_sim'] for p in pair_names]
axes[1].bar(x - width/2, pre_sims, width, label='Pre-Rebasin', color='salmon')
axes[1].bar(x + width/2, post_sims, width, label='Post-Rebasin', color='steelblue')
axes[1].set_xticks(x)
axes[1].set_xticklabels(pair_names, rotation=45, ha='right')
axes[1].set_ylabel('Cosine Similarity')
axes[1].set_title('Cosine Similarity')
axes[1].legend()
axes[1].grid(True, alpha=0.3, axis='y')

# Agreement
pre_agrs = [comparison_data[p]['pre_agreement']*100 for p in pair_names]
post_agrs = [comparison_data[p]['post_agreement']*100 for p in pair_names]
axes[2].bar(x - width/2, pre_agrs, width, label='Pre-Rebasin', color='salmon')
axes[2].bar(x + width/2, post_agrs, width, label='Post-Rebasin', color='steelblue')
axes[2].set_xticks(x)
axes[2].set_xticklabels(pair_names, rotation=45, ha='right')
axes[2].set_ylabel('Agreement (%)')
axes[2].set_title('Prediction Agreement')
axes[2].legend()
axes[2].grid(True, alpha=0.3, axis='y')

plt.suptitle('Git Re-Basin Effect: Pre vs Post Alignment', fontsize=14, y=1.02)
plt.tight_layout()
save_figure(fig, 'rebasin_comparison')
plt.show()

## 7. Sanity Check: Same-Mechanism Pairs Should Improve More

In [None]:
print("\nSanity Check: Rebasin Effectiveness by Pair Type")
print("=" * 60)

# Group by type
same_mech_pairs = [p for p in pair_names if 'spurious-spurious' in comparison_data[p]['type'] or 
                  'robust-robust' in comparison_data[p]['type']]
diff_mech_pairs = [p for p in pair_names if 'spurious-robust' in comparison_data[p]['type']]

# Compute average improvements
same_mech_sim_change = np.mean([comparison_data[p]['cosine_sim_change'] for p in same_mech_pairs])
diff_mech_sim_change = np.mean([comparison_data[p]['cosine_sim_change'] for p in diff_mech_pairs])

same_mech_agr_change = np.mean([comparison_data[p]['agreement_change']*100 for p in same_mech_pairs])
diff_mech_agr_change = np.mean([comparison_data[p]['agreement_change']*100 for p in diff_mech_pairs])

print(f"\nSame-mechanism pairs ({', '.join(same_mech_pairs)}):")
print(f"  Average cosine similarity change: {same_mech_sim_change:+.4f}")
print(f"  Average agreement change: {same_mech_agr_change:+.2f}%")

print(f"\nDifferent-mechanism pairs ({', '.join(diff_mech_pairs)}):")
print(f"  Average cosine similarity change: {diff_mech_sim_change:+.4f}")
print(f"  Average agreement change: {diff_mech_agr_change:+.2f}%")

# Verification
print("\nVerification:")
if same_mech_sim_change >= diff_mech_sim_change:
    print("[PASS] Same-mechanism pairs show better alignment (cosine sim)")
else:
    print("[INFO] Different-mechanism pairs showed more improvement")
    print("       This is not necessarily expected - rebasing is agnostic to mechanism.")

## 8. Save Aligned Models

In [None]:
# Save aligned models for interpolation analysis
print("\nSaving aligned models...")

for pair_name, aligned_model in aligned_models.items():
    # Name: model_A2_aligned_to_A1.pt
    name1, name2 = pair_name.split('-')
    save_path = CHECKPOINTS_DIR / f"model_{name2}_aligned_to_{name1}.pt"
    
    torch.save({
        'model_state_dict': aligned_model.state_dict(),
        'reference_model': name1,
        'aligned_model': name2,
        'pair_type': comparison_data[pair_name]['type'],
    }, save_path)
    
    print(f"  Saved: {save_path}")

In [None]:
# Save rebasin metrics
rebasin_results = {
    'pre_rebasin': {k: {kk: float(vv) if isinstance(vv, (float, np.floating)) else vv 
                       for kk, vv in v.items()} 
                   for k, v in pre_rebasin_metrics.items()},
    'post_rebasin': {k: {kk: float(vv) if isinstance(vv, (float, np.floating)) else vv 
                        for kk, vv in v.items()} 
                    for k, v in post_rebasin_metrics.items()},
    'comparison': {k: {kk: float(vv) if isinstance(vv, (float, np.floating)) else vv 
                      for kk, vv in v.items()} 
                  for k, v in comparison_data.items()},
}

results_path = METRICS_DIR / 'rebasin_results.json'
with open(results_path, 'w') as f:
    json.dump(rebasin_results, f, indent=2)

print(f"\nResults saved to: {results_path}")

## 9. Summary

In [None]:
print("\n" + "=" * 60)
print("GIT RE-BASIN ALIGNMENT COMPLETE")
print("=" * 60)
print(f"""
Model pairs aligned:

1. A1-A2 (spurious-spurious):
   - Cosine sim: {comparison_data['A1-A2']['pre_cosine_sim']:.4f} -> {comparison_data['A1-A2']['post_cosine_sim']:.4f}
   - Agreement: {comparison_data['A1-A2']['pre_agreement']*100:.1f}% -> {comparison_data['A1-A2']['post_agreement']*100:.1f}%

2. R1-R2 (robust-robust):
   - Cosine sim: {comparison_data['R1-R2']['pre_cosine_sim']:.4f} -> {comparison_data['R1-R2']['post_cosine_sim']:.4f}
   - Agreement: {comparison_data['R1-R2']['pre_agreement']*100:.1f}% -> {comparison_data['R1-R2']['post_agreement']*100:.1f}%

3. A1-R1 (spurious-robust):
   - Cosine sim: {comparison_data['A1-R1']['pre_cosine_sim']:.4f} -> {comparison_data['A1-R1']['post_cosine_sim']:.4f}
   - Agreement: {comparison_data['A1-R1']['pre_agreement']*100:.1f}% -> {comparison_data['A1-R1']['post_agreement']*100:.1f}%

Key observations:
- Git Re-Basin increases weight space similarity (cosine sim up)
- Functional agreement may or may not increase significantly
- Same-mechanism pairs tend to align better than different-mechanism pairs

Aligned models saved:
- {CHECKPOINTS_DIR / 'model_A2_aligned_to_A1.pt'}
- {CHECKPOINTS_DIR / 'model_R2_aligned_to_R1.pt'}
- {CHECKPOINTS_DIR / 'model_R1_aligned_to_A1.pt'}

Metrics saved:
- {METRICS_DIR / 'rebasin_results.json'}

Figures saved:
- {FIGURES_DIR / 'rebasin_comparison.png'}

Next: Run 05_interpolation_and_barriers.ipynb to analyze loss barriers.
""")