In [None]:
%pip install -q sae-lens safetensors numpy matplotlib

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sae_lens import SAE

In [None]:
BASELINE_SAE_ID = "hal2k/llama2-7b-chat-sae-layer14-x16-pile"
LAT_SAE_ID = "hal2k/llama2-7b-chat-lat-sae-layer14-x16-pile"
OUTPUT_DIR = "./interference_outputs"

import os
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
baseline_sae = SAE.from_pretrained(release=BASELINE_SAE_ID, sae_id=".")[0]
lat_sae = SAE.from_pretrained(release=LAT_SAE_ID, sae_id=".")[0]

print(f"Baseline SAE: {baseline_sae.W_dec.shape}")
print(f"LAT SAE: {lat_sae.W_dec.shape}")

#### Extract & normalize decoder weights

In [None]:
W_baseline = baseline_sae.W_dec.detach().cpu().numpy()  # (65536, 4096)
W_lat = lat_sae.W_dec.detach().cpu().numpy()

# Normalize to unit vectors
W_baseline_norm = W_baseline / (np.linalg.norm(W_baseline, axis=1, keepdims=True) + 1e-8)
W_lat_norm = W_lat / (np.linalg.norm(W_lat, axis=1, keepdims=True) + 1e-8)

#### Compute Gram matrices

In [None]:
G_baseline = W_baseline_norm @ W_baseline_norm.T
G_lat = W_lat_norm @ W_lat_norm.T

In [None]:
# Extract off-diagonal (all pairwise interferences)
n = G_baseline.shape[0]
mask = ~np.eye(n, dtype=bool)
off_diag_baseline = np.abs(G_baseline[mask])
off_diag_lat = np.abs(G_lat[mask])

#### Compute interference metrics

In [None]:
# Dead features check
baseline_norms = np.linalg.norm(W_baseline, axis=1)
lat_norms = np.linalg.norm(W_lat, axis=1)
baseline_dead_frac = np.mean(baseline_norms < 1e-6)
lat_dead_frac = np.mean(lat_norms < 1e-6)

# Interference stats
baseline_mean = np.mean(off_diag_baseline)
baseline_median = np.median(off_diag_baseline)
baseline_max = np.max(off_diag_baseline)
baseline_p95 = np.percentile(off_diag_baseline, 95)
baseline_p99 = np.percentile(off_diag_baseline, 99)
baseline_below_01 = np.mean(off_diag_baseline < 0.1)
baseline_below_005 = np.mean(off_diag_baseline < 0.05)

lat_mean = np.mean(off_diag_lat)
lat_median = np.median(off_diag_lat)
lat_max = np.max(off_diag_lat)
lat_p95 = np.percentile(off_diag_lat, 95)
lat_p99 = np.percentile(off_diag_lat, 99)
lat_below_01 = np.mean(off_diag_lat < 0.1)
lat_below_005 = np.mean(off_diag_lat < 0.05)

In [None]:
print(f"{'Metric':<25} {'Baseline':>12} {'LAT':>12} {'Ratio':>10}")
print("-" * 60)
print(f"{'mean_interference':<25} {baseline_mean:>12.6f} {lat_mean:>12.6f} {lat_mean/baseline_mean:>10.3f}")
print(f"{'median_interference':<25} {baseline_median:>12.6f} {lat_median:>12.6f} {lat_median/baseline_median:>10.3f}")
print(f"{'max_interference':<25} {baseline_max:>12.6f} {lat_max:>12.6f} {lat_max/baseline_max:>10.3f}")
print(f"{'p95_interference':<25} {baseline_p95:>12.6f} {lat_p95:>12.6f} {lat_p95/baseline_p95:>10.3f}")
print(f"{'p99_interference':<25} {baseline_p99:>12.6f} {lat_p99:>12.6f} {lat_p99/baseline_p99:>10.3f}")
print(f"{'frac_below_0.1':<25} {baseline_below_01:>12.6f} {lat_below_01:>12.6f} {lat_below_01/baseline_below_01:>10.3f}")
print(f"{'frac_below_0.05':<25} {baseline_below_005:>12.6f} {lat_below_005:>12.6f} {lat_below_005/baseline_below_005:>10.3f}")
print(f"{'dead_features_frac':<25} {baseline_dead_frac:>12.6f} {lat_dead_frac:>12.6f} {'-':>10}")

In [None]:
ratio = lat_mean / baseline_mean
print(f"\nMean interference ratio (LAT/Baseline): {ratio:.3f}")
print(f"Gorton et al. benchmark: robust models have ~0.5x the interference of non-robust\n")

#### Visualize interference distributions

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Histograms
axes[0].hist(off_diag_baseline, bins=100, alpha=0.7, label='Baseline', density=True)
axes[0].hist(off_diag_lat, bins=100, alpha=0.7, label='LAT', density=True)
axes[0].set_xlabel('|cos(feature_i, feature_j)|')
axes[0].set_ylabel('Density')
axes[0].set_title('Feature Interference Distribution')
axes[0].legend()
axes[0].set_xlim(0, 0.5)

# CDFs
for data, label in [(off_diag_baseline, 'Baseline'), (off_diag_lat, 'LAT')]:
    sorted_data = np.sort(data)
    cdf = np.arange(1, len(sorted_data) + 1) / len(sorted_data)
    step = len(sorted_data) // 10000  # subsample for plotting
    axes[1].plot(sorted_data[::step], cdf[::step], label=label)

axes[1].set_xlabel('|cos(feature_i, feature_j)|')
axes[1].set_ylabel('CDF')
axes[1].set_title('Cumulative Distribution')
axes[1].legend()
axes[1].set_xlim(0, 0.3)
axes[1].axhline(0.95, color='gray', linestyle='--', alpha=0.5)
axes[1].axhline(0.99, color='gray', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/interference_comparison.png", dpi=150)
plt.show()
