# 03 - Mechanism Verification

This notebook quantifies the "spurious vs robust" reliance for each trained model.

## Metrics computed:
1. **OOD Drop**: Acc(ID) - Acc(OOD)
2. **Counterfactual Patch Sensitivity**:
   - Accuracy change when patch color is swapped
   - Mean change in true-class logit
3. **Spurious Reliance Score (SRS)**: Combined metric

## Spurious Reliance Score (SRS) Formula:
```
SRS = 0.4 * OOD_drop + 0.3 * CF_accuracy_drop + 0.3 * CF_flip_rate
```
Where:
- OOD_drop = ID_accuracy - OOD_accuracy
- CF_accuracy_drop = Original_accuracy - Counterfactual_accuracy
- CF_flip_rate = Fraction of correct predictions that flip when patch is swapped

In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import torch
import json
import numpy as np
import matplotlib.pyplot as plt

from src.config import (
    get_config, set_seed, get_device,
    CHECKPOINTS_DIR, FIGURES_DIR, METRICS_DIR
)

config = get_config()
set_seed(config['seeds']['global'])
device = get_device()

print(f"Device: {device}")

In [None]:
from src.data import (
    create_env_a_dataset,
    create_no_patch_dataset,
    SpuriousPatchDataset,
    CounterfactualPatchDataset,
    get_transforms,
    DATA_DIR,
)
from src.models import create_model
from src.train import load_model, evaluate_model
from src.metrics import (
    compute_ood_drop,
    compute_patch_sensitivity,
    compute_spurious_reliance_score,
    compute_class_wise_accuracy,
)
from src.plotting import plot_spurious_reliance_comparison, save_figure
from torch.utils.data import DataLoader

## 1. Load Trained Models

In [None]:
# Load all 4 models
model_names = ['A1', 'A2', 'R1', 'R2']
models = {}

for name in model_names:
    checkpoint_path = CHECKPOINTS_DIR / f"model_{name}.pt"
    if not checkpoint_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}\n"
                               f"Please run 02_train_models.ipynb first.")
    
    model = create_model(config)
    model = load_model(model, checkpoint_path, device)
    models[name] = model
    print(f"Loaded model {name} from {checkpoint_path}")

print(f"\nAll {len(models)} models loaded successfully!")

## 2. Create Test Datasets

In [None]:
# Create test datasets
test_id = create_env_a_dataset(train=False, config=config)  # ID test (with aligned patches)
test_ood = create_no_patch_dataset(train=False, config=config)  # OOD test (no patches)

# Create DataLoaders
batch_size = config['training']['batch_size']
num_workers = config['training']['num_workers']

id_loader = DataLoader(test_id, batch_size=batch_size, shuffle=False, num_workers=num_workers)
ood_loader = DataLoader(test_ood, batch_size=batch_size, shuffle=False, num_workers=num_workers)

print(f"Test datasets:")
print(f"  ID test (Env A): {len(test_id)} samples")
print(f"  OOD test (No patch): {len(test_ood)} samples")

In [None]:
# Create counterfactual dataset for patch sensitivity analysis
# We use the ID test set as the base
cf_dataset = CounterfactualPatchDataset(
    base_dataset=test_id,
    swap_mode='random_wrong',
)

print(f"Counterfactual dataset: {len(cf_dataset)} samples")
print("  Each sample provides: (original_img, label, counterfactual_img)")

## 3. Compute Basic Accuracy Metrics

In [None]:
# Compute ID and OOD accuracy for all models
print("Computing ID and OOD accuracy...\n")

accuracy_results = {}

for name, model in models.items():
    _, id_acc = evaluate_model(model, id_loader, device)
    _, ood_acc = evaluate_model(model, ood_loader, device)
    ood_drop = compute_ood_drop(id_acc, ood_acc)
    
    accuracy_results[name] = {
        'id_acc': id_acc,
        'ood_acc': ood_acc,
        'ood_drop': ood_drop,
    }
    
    print(f"Model {name}:")
    print(f"  ID Accuracy:  {id_acc*100:.2f}%")
    print(f"  OOD Accuracy: {ood_acc*100:.2f}%")
    print(f"  OOD Drop:     {ood_drop*100:+.2f}%\n")

## 4. Compute Counterfactual Patch Sensitivity

In [None]:
# Compute patch sensitivity for all models
print("Computing counterfactual patch sensitivity...\n")

sensitivity_results = {}

for name, model in models.items():
    print(f"\nAnalyzing Model {name}...")
    sensitivity = compute_patch_sensitivity(
        model, cf_dataset, device,
        batch_size=batch_size,
        num_workers=num_workers,
    )
    sensitivity_results[name] = sensitivity
    
    print(f"  Original Accuracy:      {sensitivity['original_accuracy']*100:.2f}%")
    print(f"  Counterfactual Accuracy: {sensitivity['counterfactual_accuracy']*100:.2f}%")
    print(f"  Accuracy Drop:          {sensitivity['accuracy_drop']*100:+.2f}%")
    print(f"  Prediction Flip Rate:   {sensitivity['prediction_flip_rate']*100:.2f}%")
    print(f"  Mean Logit Change:      {sensitivity['mean_logit_change']:.3f}")

## 5. Compute Spurious Reliance Score (SRS)

In [None]:
# Compute full SRS for all models
print("Computing Spurious Reliance Score (SRS)...\n")

srs_results = {}

for name, model in models.items():
    print(f"\nModel {name}:")
    srs = compute_spurious_reliance_score(
        model, id_loader, ood_loader, cf_dataset, device
    )
    srs_results[name] = srs
    
    print(f"  Spurious Reliance Score: {srs['spurious_reliance_score']:.4f}")
    print(f"  Components:")
    print(f"    - OOD Drop:        {srs['ood_drop']*100:.2f}% (weight: 0.4)")
    print(f"    - CF Acc Drop:     {srs['cf_accuracy_drop']*100:.2f}% (weight: 0.3)")
    print(f"    - CF Flip Rate:    {srs['cf_flip_rate']*100:.2f}% (weight: 0.3)")

## 6. Visualize Results

In [None]:
# Plot SRS comparison
fig = plot_spurious_reliance_comparison(
    srs_results,
    title='Spurious Reliance Metrics by Model',
    save_name='spurious_reliance_comparison'
)
plt.show()

In [None]:
# Create detailed comparison figure
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

model_names = list(srs_results.keys())
x = np.arange(len(model_names))
colors = ['#e74c3c', '#e67e22', '#3498db', '#2ecc71']

# 1. SRS Score
srs_values = [srs_results[m]['spurious_reliance_score'] for m in model_names]
bars = axes[0, 0].bar(x, srs_values, color=colors)
axes[0, 0].set_xticks(x)
axes[0, 0].set_xticklabels(model_names)
axes[0, 0].set_ylabel('SRS')
axes[0, 0].set_title('Spurious Reliance Score (SRS)')
axes[0, 0].grid(True, alpha=0.3, axis='y')
for bar, val in zip(bars, srs_values):
    axes[0, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{val:.3f}', ha='center', va='bottom', fontsize=10)

# 2. ID vs OOD Accuracy
id_accs = [srs_results[m]['id_accuracy']*100 for m in model_names]
ood_accs = [srs_results[m]['ood_accuracy']*100 for m in model_names]
width = 0.35
axes[0, 1].bar(x - width/2, id_accs, width, label='ID Acc', color='steelblue')
axes[0, 1].bar(x + width/2, ood_accs, width, label='OOD Acc', color='coral')
axes[0, 1].set_xticks(x)
axes[0, 1].set_xticklabels(model_names)
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].set_title('ID vs OOD Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3, axis='y')

# 3. Counterfactual Accuracy Drop
cf_drops = [srs_results[m]['cf_accuracy_drop']*100 for m in model_names]
bars = axes[1, 0].bar(x, cf_drops, color=colors)
axes[1, 0].set_xticks(x)
axes[1, 0].set_xticklabels(model_names)
axes[1, 0].set_ylabel('Accuracy Drop (%)')
axes[1, 0].set_title('Counterfactual Patch Accuracy Drop')
axes[1, 0].grid(True, alpha=0.3, axis='y')

# 4. Prediction Flip Rate
flip_rates = [srs_results[m]['cf_flip_rate']*100 for m in model_names]
bars = axes[1, 1].bar(x, flip_rates, color=colors)
axes[1, 1].set_xticks(x)
axes[1, 1].set_xticklabels(model_names)
axes[1, 1].set_ylabel('Flip Rate (%)')
axes[1, 1].set_title('Prediction Flip Rate (when patch swapped)')
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.suptitle('Mechanism Verification: Spurious vs Robust Models', fontsize=14, y=1.02)
plt.tight_layout()
save_figure(fig, 'mechanism_verification_detailed')
plt.show()

## 7. Statistical Summary

In [None]:
# Create summary table
print("\n" + "=" * 90)
print("MECHANISM VERIFICATION SUMMARY")
print("=" * 90)

print(f"\n{'Model':<8} {'Type':<10} {'ID Acc':<10} {'OOD Acc':<10} {'OOD Drop':<10} "
      f"{'CF Drop':<10} {'Flip Rate':<10} {'SRS':<8}")
print("-" * 90)

for name in model_names:
    model_type = "Spurious" if name.startswith('A') else "Robust"
    srs = srs_results[name]
    print(f"{name:<8} {model_type:<10} {srs['id_accuracy']*100:>6.2f}%   "
          f"{srs['ood_accuracy']*100:>6.2f}%   {srs['ood_drop']*100:>+6.2f}%   "
          f"{srs['cf_accuracy_drop']*100:>6.2f}%   {srs['cf_flip_rate']*100:>6.2f}%   "
          f"{srs['spurious_reliance_score']:.4f}")

print("=" * 90)

In [None]:
# Compute group statistics
spurious_models = ['A1', 'A2']
robust_models = ['R1', 'R2']

spurious_avg_srs = np.mean([srs_results[m]['spurious_reliance_score'] for m in spurious_models])
robust_avg_srs = np.mean([srs_results[m]['spurious_reliance_score'] for m in robust_models])

spurious_avg_drop = np.mean([srs_results[m]['ood_drop'] for m in spurious_models])
robust_avg_drop = np.mean([srs_results[m]['ood_drop'] for m in robust_models])

print("\nGroup Statistics:")
print("-" * 50)
print(f"Spurious Models (A1, A2):")
print(f"  Average SRS:      {spurious_avg_srs:.4f}")
print(f"  Average OOD Drop: {spurious_avg_drop*100:.2f}%")
print(f"\nRobust Models (R1, R2):")
print(f"  Average SRS:      {robust_avg_srs:.4f}")
print(f"  Average OOD Drop: {robust_avg_drop*100:.2f}%")
print(f"\nSRS Ratio (Spurious/Robust): {spurious_avg_srs/robust_avg_srs:.2f}x")

In [None]:
# Verification checks
print("\nVerification Checks:")
print("-" * 50)

# Check 1: Spurious models should have higher SRS
if spurious_avg_srs > robust_avg_srs:
    print("[PASS] Spurious models have higher SRS than robust models")
else:
    print("[FAIL] Expected spurious models to have higher SRS")

# Check 2: Spurious models should have larger OOD drop
if spurious_avg_drop > robust_avg_drop:
    print("[PASS] Spurious models have larger OOD accuracy drop")
else:
    print("[FAIL] Expected spurious models to have larger OOD drop")

# Check 3: All spurious models should have SRS > 0.1 (reasonable threshold)
all_spurious_high_srs = all(srs_results[m]['spurious_reliance_score'] > 0.05 for m in spurious_models)
if all_spurious_high_srs:
    print("[PASS] All spurious models have SRS > 0.05")
else:
    print("[WARN] Some spurious models have low SRS")

## 8. Save Results

In [None]:
# Save all metrics to JSON
mechanism_results = {
    'srs_results': {k: {kk: float(vv) for kk, vv in v.items()} 
                   for k, v in srs_results.items()},
    'group_statistics': {
        'spurious_avg_srs': float(spurious_avg_srs),
        'robust_avg_srs': float(robust_avg_srs),
        'spurious_avg_ood_drop': float(spurious_avg_drop),
        'robust_avg_ood_drop': float(robust_avg_drop),
        'srs_ratio': float(spurious_avg_srs / robust_avg_srs) if robust_avg_srs > 0 else float('inf'),
    }
}

results_path = METRICS_DIR / 'mechanism_verification.json'
with open(results_path, 'w') as f:
    json.dump(mechanism_results, f, indent=2)

print(f"Results saved to: {results_path}")

## 9. Summary

In [None]:
print("\n" + "=" * 60)
print("MECHANISM VERIFICATION COMPLETE")
print("=" * 60)
print(f"""
Key Findings:

1. Spurious Reliance Score (SRS):
   - Spurious models (A1, A2): Average SRS = {spurious_avg_srs:.4f}
   - Robust models (R1, R2):   Average SRS = {robust_avg_srs:.4f}
   - Ratio: {spurious_avg_srs/robust_avg_srs:.2f}x higher for spurious models

2. OOD Accuracy Drop:
   - Spurious models: {spurious_avg_drop*100:.1f}% average drop
   - Robust models:   {robust_avg_drop*100:.1f}% average drop

3. Interpretation:
   - Spurious models (A1, A2) rely heavily on the colored patch
   - When the patch is removed (OOD) or swapped (CF), accuracy drops significantly
   - Robust models (R1, R2) learned more content-based features
   - They are less sensitive to patch manipulation

SRS Formula:
   SRS = 0.4 * OOD_drop + 0.3 * CF_accuracy_drop + 0.3 * flip_rate

Files saved:
   - {METRICS_DIR / 'mechanism_verification.json'}
   - {FIGURES_DIR / 'spurious_reliance_comparison.png'}
   - {FIGURES_DIR / 'mechanism_verification_detailed.png'}

Next: Run 04_rebasin_alignment.ipynb to perform weight matching.
""")