In [1]:
%pip install -r requirements.txt

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import numpy as np
import json
import os
from dotenv import load_dotenv
from huggingface_hub import login as hf_login
import wandb
from transformers import AutoModelForCausalLM
from peft import PeftModel
from sae_lens import (
    LanguageModelSAERunnerConfig,
    LanguageModelSAETrainingRunner,
    CacheActivationsRunnerConfig,
    CacheActivationsRunner,
    BatchTopKTrainingSAEConfig,
    LoggingConfig,
)

load_dotenv()
hf_login(token=os.environ.get('HF_TOKEN'))
wandb.login(key=os.environ.get('WANDB_TOKEN'))

  from .autonotebook import tqdm as notebook_tqdm
2026-01-03 12:07:45.883198: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-03 12:07:45.894961: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1767442065.909504  154619 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1767442065.914077  154619 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1767442065.925323  154619 computation_placer.cc:177] computation placer already r

True

In [3]:
BASE_MODEL = "meta-llama/Llama-2-7b-chat-hf"
LAT_ADAPTER = "nlpett/llama-2-7b-chat-hf-LAT-layer4-hh"
OUTPUT_DIR = "./sae_outputs"
CACHE_DIR = "./cached_activations"
WANDB_PROJECT = "lat-interference-analysis"

# SAE params
D_IN = 4096
D_EXP = 16
D_SAE = D_EXP * D_IN
K = 64                          # TopK sparsity
TRAINING_TOKENS = 100_000_000   # 100M
LAYER = 14
HOOK = f"blocks.{LAYER}.hook_resid_post"

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)

#### Merge LAT adapter

In [4]:
merged_path = f"{OUTPUT_DIR}/llama2-lat-merged"
if not os.path.exists(merged_path):
    base_hf = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.bfloat16)
    peft_model = PeftModel.from_pretrained(base_hf, LAT_ADAPTER)
    merged_model = peft_model.merge_and_unload()
    merged_model.save_pretrained(merged_path)
    del base_hf, peft_model, merged_model
    torch.cuda.empty_cache()
    print(f"Merged model saved to {merged_path}")
else:
    print(f"Using existing merged model at {merged_path}")

Using existing merged model at ./sae_outputs/llama2-lat-merged


#### Train baseline SAE (direct)

In [5]:
baseline_sae_cfg = LanguageModelSAERunnerConfig(
    # Direct model access
    model_name=BASE_MODEL,
    model_class_name="HookedTransformer",
    hook_name=HOOK,
    
    # Dataset
    dataset_path="monology/pile-uncopyrighted",
    streaming=True,
    context_size=512,
    
    # SAE architecture
    sae=BatchTopKTrainingSAEConfig(
        d_in=D_IN,
        d_sae=D_SAE,
        k=K,
    ),
    
    # Training params
    lr=1e-4,
    train_batch_size_tokens=4096,
    training_tokens=TRAINING_TOKENS,
    n_batches_in_buffer=32,       # Reduced from 128 to fit in GPU memory
    store_batch_size_prompts=16,  # Reduced from 32 
    
    # Precision
    dtype="float32",
    autocast=True,
    autocast_lm=True,
    
    # Logging
    logger=LoggingConfig(
        log_to_wandb=True,
        wandb_project=WANDB_PROJECT,
        run_name="baseline-layer14-sae",
    ),
    
    checkpoint_path=f"{OUTPUT_DIR}/baseline_checkpoints",
    n_checkpoints=3,
    device="cuda",
)

baseline_sae = LanguageModelSAETrainingRunner(baseline_sae_cfg).run()
baseline_sae.save_model(f"{OUTPUT_DIR}/baseline_sae_final")
print(f"Baseline SAE saved to {OUTPUT_DIR}/baseline_sae_final")

torch.cuda.empty_cache()

`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.61it/s]
`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'monology/pile-uncopyrighted' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer




24400| mse_loss: 126.56625 | auxiliary_reconstruction_loss: 1.00317: 100%|█████████▉| 99942400/100000000 [2:35:24<00:05, 10718.57it/s]   
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
details/current_learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
details/n_training_samples,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇███
losses/auxiliary_reconstruction_loss,▁▅▆▇█▇▇▇▇▇▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss,█▆▆▆▅▄▄▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▁▂▁▂▂▁▁▁▁▂▁▁▁▂▁
losses/overall_loss,█▆▆▆▅▅▃▃▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁
metrics/explained_variance,▅▁▆▇▆▆▅▇▇▇██▆▇█▄▄▆██▇▇▇█▇▇▇▇▇▇▇█▇█▇▇▇█▇█
metrics/explained_variance_legacy,▁▄▄▄▄▇▇▇▇▇▇▇▇████▇▇█████▇█▇██████████▇██
metrics/explained_variance_legacy_std,▅▇█▇▇▆▆▅▃▄▃▄▂▃▃▃▃▃▃▂▂▃▂▂▂▁▃▂▂▂▃▃▂▂▃▂▂▃▂▂
metrics/l0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/mean_log10_feature_sparsity,▂▁▆▇▇███████

0,1
details/current_learning_rate,0.0001
details/n_training_samples,99983360
losses/auxiliary_reconstruction_loss,0.86578
losses/mse_loss,124.67271
losses/overall_loss,125.5385
metrics/explained_variance,0.99505
metrics/explained_variance_legacy,0.75112
metrics/explained_variance_legacy_std,0.07803
metrics/l0,64
metrics/mean_log10_feature_sparsity,-3.54223


Baseline SAE saved to ./sae_outputs/baseline_sae_final


#### Train LAT SAE (direct)

In [5]:
from transformer_lens import HookedTransformer

# Load the merged LAT model into TransformerLens manually
print("Loading merged LAT model into HookedTransformer...")
lat_model = HookedTransformer.from_pretrained(
    BASE_MODEL,  # Use base model architecture/config
    hf_model=AutoModelForCausalLM.from_pretrained(merged_path, torch_dtype=torch.float32),
    device="cuda",
    center_writing_weights=False,
)
print("LAT model loaded successfully!")

lat_sae_cfg = LanguageModelSAERunnerConfig(
    # Model config (name for reference, actual model passed via override_model)
    model_name=BASE_MODEL,  # Keep original name for config compatibility
    model_class_name="HookedTransformer",
    hook_name=HOOK,
    
    # Dataset (identical)
    dataset_path="monology/pile-uncopyrighted",
    streaming=True,
    context_size=512,
    
    # SAE architecture (identical)
    sae=BatchTopKTrainingSAEConfig(
        d_in=D_IN,
        d_sae=D_SAE,
        k=K,
    ),
    
    # Training params (identical)
    lr=1e-4,
    train_batch_size_tokens=4096,
    training_tokens=TRAINING_TOKENS,
    n_batches_in_buffer=32,       # Reduced from 128 to fit in GPU memory
    store_batch_size_prompts=16,  # Reduced from 32 to fit in GPU memory
    
    # Precision (identical)
    dtype="float32",
    autocast=True,
    autocast_lm=True,
    
    # Logging
    logger=LoggingConfig(
        log_to_wandb=True,
        wandb_project=WANDB_PROJECT,
        run_name="lat-layer14-sae",
    ),
    
    checkpoint_path=f"{OUTPUT_DIR}/lat_checkpoints",
    n_checkpoints=3,
    device="cuda",
)

# Pass the pre-loaded LAT model via override_model
lat_sae = LanguageModelSAETrainingRunner(lat_sae_cfg, override_model=lat_model).run()
lat_sae.save_model(f"{OUTPUT_DIR}/lat_sae_final")
print(f"LAT SAE saved to {OUTPUT_DIR}/lat_sae_final")

`torch_dtype` is deprecated! Use `dtype` instead!


Loading merged LAT model into HookedTransformer...


Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  2.33it/s]


Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer


You just passed in a model which will override the one specified in your configuration: meta-llama/Llama-2-7b-chat-hf. As a consequence this run will not be reproducible via configuration alone.
`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'monology/pile-uncopyrighted' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


LAT model loaded successfully!




24400| mse_loss: 111.07465 | auxiliary_reconstruction_loss: 1.52489: 100%|█████████▉| 99942400/100000000 [2:33:45<00:05, 10833.17it/s]   


0,1
details/current_learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
details/n_training_samples,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇██
losses/auxiliary_reconstruction_loss,▁▁▁█▇█▇▇▅▂▄▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss,█▇▆▆▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,█▅▅▅▅▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▂▁▁▂▁▁▁▁▁▁▁▁
metrics/explained_variance,▄▆▄▅▆▇▇█▇▇█▇▁▇██▇▇██▇█▇█▇▇█▇█▇██████▇███
metrics/explained_variance_legacy,▁▂▅▅▆▆▇▇▇▇▇▇▇█▇█▇▇▇▇█▇██████████████████
metrics/explained_variance_legacy_std,▅▇█▆▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/mean_log10_feature_sparsity,▁▁▆▇▇▇██████

0,1
details/current_learning_rate,0.0001
details/n_training_samples,99983360
losses/auxiliary_reconstruction_loss,1.48027
losses/mse_loss,108.02708
losses/overall_loss,109.50735
metrics/explained_variance,0.99356
metrics/explained_variance_legacy,0.83571
metrics/explained_variance_legacy_std,0.06251
metrics/l0,64
metrics/mean_log10_feature_sparsity,-3.54946


LAT SAE saved to ./sae_outputs/lat_sae_final


#### Compute interference matrix

In [None]:
# Load SAEs from disk if not already in memory
from sae_lens import SAE

if 'baseline_sae' not in dir() or baseline_sae is None:
    print("Loading baseline SAE from disk...")
    baseline_sae = SAE.load_from_pretrained(f"{OUTPUT_DIR}/baseline_sae_final")

if 'lat_sae' not in dir() or lat_sae is None:
    print("Loading LAT SAE from disk...")
    lat_sae = SAE.load_from_pretrained(f"{OUTPUT_DIR}/lat_sae_final")

# Extract decoder weights: shape is (d_sae, d_in) = (65536, 4096)
W_baseline = baseline_sae.W_dec.detach().cpu().numpy()
W_lat = lat_sae.W_dec.detach().cpu().numpy()

# Normalize each feature vector (row) to unit norm
W_baseline_norm = W_baseline / (np.linalg.norm(W_baseline, axis=1, keepdims=True) + 1e-8)
W_lat_norm = W_lat / (np.linalg.norm(W_lat, axis=1, keepdims=True) + 1e-8)

# Gram matrices: G[i,j] = cos(feature_i, feature_j)
print("Computing Gram matrices (this may take a minute)...")
G_baseline = W_baseline_norm @ W_baseline_norm.T
G_lat = W_lat_norm @ W_lat_norm.T

# Extract off-diagonal elements (all pairwise interferences)
n = G_baseline.shape[0]
mask = ~np.eye(n, dtype=bool)
off_diag_baseline = np.abs(G_baseline[mask])
off_diag_lat = np.abs(G_lat[mask])

# Also check dead features (features with near-zero norm)
baseline_norms = np.linalg.norm(W_baseline, axis=1)
lat_norms = np.linalg.norm(W_lat, axis=1)
baseline_dead = np.mean(baseline_norms < 1e-6)
lat_dead = np.mean(lat_norms < 1e-6)

# Compute all metrics
results = {
    "config": {
        "d_in": D_IN,
        "d_sae": D_SAE,
        "k": K,
        "training_tokens": TRAINING_TOKENS,
        "layer": LAYER,
    },
    "baseline": {
        "mean_interference": float(np.mean(off_diag_baseline)),
        "median_interference": float(np.median(off_diag_baseline)),
        "max_interference": float(np.max(off_diag_baseline)),
        "p95_interference": float(np.percentile(off_diag_baseline, 95)),
        "p99_interference": float(np.percentile(off_diag_baseline, 99)),
        "frac_below_0.1": float(np.mean(off_diag_baseline < 0.1)),
        "frac_below_0.05": float(np.mean(off_diag_baseline < 0.05)),
        "dead_features_frac": float(baseline_dead),
    },
    "lat": {
        "mean_interference": float(np.mean(off_diag_lat)),
        "median_interference": float(np.median(off_diag_lat)),
        "max_interference": float(np.max(off_diag_lat)),
        "p95_interference": float(np.percentile(off_diag_lat, 95)),
        "p99_interference": float(np.percentile(off_diag_lat, 99)),
        "frac_below_0.1": float(np.mean(off_diag_lat < 0.1)),
        "frac_below_0.05": float(np.mean(off_diag_lat < 0.05)),
        "dead_features_frac": float(lat_dead),
    },
}



In [None]:
print(f"\n{'Metric':<25} {'Baseline':>12} {'LAT':>12} {'Ratio':>10}")
print("-" * 60)

for metric in ["mean_interference", "median_interference", "max_interference", 
               "p95_interference", "p99_interference", "frac_below_0.1", 
               "frac_below_0.05", "dead_features_frac"]:
    b = results["baseline"][metric]
    l = results["lat"][metric]
    ratio = l / b if b > 1e-10 else float('inf')
    print(f"{metric:<25} {b:>12.6f} {l:>12.6f} {ratio:>10.3f}")

# Key interpretation
print("\n" + "="*60)
print("INTERPRETATION")
print("="*60)

ratio = results["lat"]["mean_interference"] / results["baseline"]["mean_interference"]
print(f"\nMean interference ratio (LAT/Baseline): {ratio:.3f}")

if ratio < 0.7:
    print(">>> LAT shows REDUCED interference (supports Gorton hypothesis)")
elif ratio > 1.3:
    print(">>> LAT shows INCREASED interference (contradicts Gorton hypothesis)")
else:
    print(">>> No significant difference in interference")

print(f"\nGorton et al. benchmark: robust models have ~0.5x the interference of non-robust")

if abs(baseline_dead - lat_dead) > 0.05:
    print(f"\nWARNING: Dead feature rates differ significantly ({baseline_dead:.1%} vs {lat_dead:.1%})")
    print("         This may confound the interference comparison")

with open(f"{OUTPUT_DIR}/interference_results.json", "w") as f:
    json.dump(results, f, indent=2)

print(f"\nResults saved to {OUTPUT_DIR}/interference_results.json")
print(f"W&B dashboard: https://wandb.ai/hal2k-n-a/{WANDB_PROJECT}")
print("\nDone!")