In [None]:
%pip install -r requirements.txt scipy

In [None]:
import torch
import numpy as np
import pandas as pd
import requests
import io
import gc
import json
import os
from dotenv import load_dotenv
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from scipy import stats
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import login as hf_login
from sae_lens import SAE

load_dotenv()
hf_login(token=os.environ.get('HF_TOKEN'))

def clear_gpu():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

In [None]:
DEVICE = 'cuda'
LLAMA_PATH = 'meta-llama/Llama-2-7b-chat-hf'
LAT_PATH = 'nlpett/llama-2-7b-chat-hf-LAT-layer4-hh'
TEMPLATE = '\n[INST]{prompt}[/INST]\n'

BASELINE_SAE_PATH = './sae_outputs/baseline_sae_final'
LAT_SAE_PATH = './sae_outputs/lat_sae_final'

LAYER = 14
D_SAE = 65536
K = 64
N_SAMPLES = 200
BATCH_SIZE = 8


#### Load harmful/harmless prompts


In [None]:
url = 'https://raw.githubusercontent.com/llm-attacks/llm-attacks/main/data/advbench/harmful_behaviors.csv'
harmful_all = pd.read_csv(io.StringIO(requests.get(url).content.decode('utf-8')))['goal'].tolist()
harmful_train, _ = train_test_split(harmful_all, test_size=0.2, random_state=42)

dataset = load_dataset('tatsu-lab/alpaca')
harmless_all = [dataset['train'][i]['instruction'] for i in range(len(dataset['train'])) 
                if dataset['train'][i]['input'].strip() == '']
harmless_train, _ = train_test_split(harmless_all, test_size=0.2, random_state=42)

harmful_prompts = harmful_train[:N_SAMPLES]
harmless_prompts = harmless_train[:N_SAMPLES]
print(f"{len(harmful_prompts)} harmful, {len(harmless_prompts)} harmless")

#### Baseline model + SAE

In [None]:
base_hf = AutoModelForCausalLM.from_pretrained(LLAMA_PATH, torch_dtype=torch.float16, device_map='auto')
baseline_model = HookedTransformer.from_pretrained_no_processing(
    LLAMA_PATH, hf_model=base_hf, dtype=torch.float16, device=DEVICE, default_padding_side='left')
baseline_model.tokenizer.padding_side = 'left'
baseline_model.tokenizer.pad_token = '[PAD]'
del base_hf
clear_gpu()

baseline_sae = SAE.load_from_pretrained(BASELINE_SAE_PATH, device=DEVICE)
baseline_sae.eval()


In [None]:
baseline_harmful_l1s = []
baseline_harmful_recon_errors = []
baseline_harmful_top_features = []

for i in tqdm(range(0, len(harmful_prompts), BATCH_SIZE)):
    batch = harmful_prompts[i:i+BATCH_SIZE]
    prompts = [TEMPLATE.format(prompt=p) for p in batch]
    toks = baseline_model.tokenizer(prompts, padding=True, return_tensors='pt').input_ids.to(DEVICE)
    
    with torch.no_grad():
        _, cache = baseline_model.run_with_cache(toks, names_filter=lambda n: f'blocks.{LAYER}.hook_resid_post' in n)
        acts = cache[f'blocks.{LAYER}.hook_resid_post'][:, -1, :]
        sae_acts = baseline_sae.encode(acts)
        recon = baseline_sae.decode(sae_acts)
        
        baseline_harmful_l1s.extend(sae_acts.abs().sum(dim=-1).cpu().tolist())
        baseline_harmful_recon_errors.extend(((acts - recon) ** 2).mean(dim=-1).cpu().tolist())
        
        for j in range(sae_acts.shape[0]):
            active_idx = (sae_acts[j].abs() > 1e-8).nonzero(as_tuple=True)[0]
            baseline_harmful_top_features.append(active_idx.cpu().tolist())
    
    del cache, acts, sae_acts, recon
    clear_gpu()


In [None]:
baseline_harmless_l1s = []
baseline_harmless_recon_errors = []
baseline_harmless_top_features = []

for i in tqdm(range(0, len(harmless_prompts), BATCH_SIZE)):
    batch = harmless_prompts[i:i+BATCH_SIZE]
    prompts = [TEMPLATE.format(prompt=p) for p in batch]
    toks = baseline_model.tokenizer(prompts, padding=True, return_tensors='pt').input_ids.to(DEVICE)
    
    with torch.no_grad():
        _, cache = baseline_model.run_with_cache(toks, names_filter=lambda n: f'blocks.{LAYER}.hook_resid_post' in n)
        acts = cache[f'blocks.{LAYER}.hook_resid_post'][:, -1, :]
        sae_acts = baseline_sae.encode(acts)
        recon = baseline_sae.decode(sae_acts)
        
        baseline_harmless_l1s.extend(sae_acts.abs().sum(dim=-1).cpu().tolist())
        baseline_harmless_recon_errors.extend(((acts - recon) ** 2).mean(dim=-1).cpu().tolist())
        
        for j in range(sae_acts.shape[0]):
            active_idx = (sae_acts[j].abs() > 1e-8).nonzero(as_tuple=True)[0]
            baseline_harmless_top_features.append(active_idx.cpu().tolist())
    
    del cache, acts, sae_acts, recon
    clear_gpu()


In [None]:
baseline_harmful_l1 = np.mean(baseline_harmful_l1s)
baseline_harmless_l1 = np.mean(baseline_harmless_l1s)
baseline_l1_ratio = baseline_harmful_l1 / baseline_harmless_l1
baseline_harmful_mse = np.mean(baseline_harmful_recon_errors)
baseline_harmless_mse = np.mean(baseline_harmless_recon_errors)

baseline_harmful_counts = np.zeros(D_SAE)
baseline_harmless_counts = np.zeros(D_SAE)
for feats in baseline_harmful_top_features:
    for f in feats:
        baseline_harmful_counts[f] += 1
for feats in baseline_harmless_top_features:
    for f in feats:
        baseline_harmless_counts[f] += 1

baseline_harmful_specific = np.argsort(baseline_harmful_counts - baseline_harmless_counts)[-20:]
baseline_harmless_specific = np.argsort(baseline_harmless_counts - baseline_harmful_counts)[-20:]
baseline_top_harmful = set(np.argsort(baseline_harmful_counts)[-100:])
baseline_top_harmless = set(np.argsort(baseline_harmless_counts)[-100:])
baseline_jaccard = len(baseline_top_harmful & baseline_top_harmless) / len(baseline_top_harmful | baseline_top_harmless)

print(f"Harmful L1:  {baseline_harmful_l1:.1f} ± {np.std(baseline_harmful_l1s):.1f}")
print(f"Harmless L1: {baseline_harmless_l1:.1f} ± {np.std(baseline_harmless_l1s):.1f}")
print(f"L1 Ratio:    {baseline_l1_ratio:.3f}")
print(f"Harmful MSE: {baseline_harmful_mse:.4f}")
print(f"Harmless MSE:{baseline_harmless_mse:.4f}")
print(f"Jaccard:     {baseline_jaccard:.3f}")


In [None]:
del baseline_model, baseline_sae
clear_gpu()

#### LAT model + SAE


In [None]:
lat_base = AutoModelForCausalLM.from_pretrained(LLAMA_PATH, torch_dtype=torch.float16, device_map='auto')
lat_peft = PeftModel.from_pretrained(lat_base, LAT_PATH).to(torch.float16)
lat_hf = lat_peft.merge_and_unload()
lat_hf.eval()
del lat_peft

lat_model = HookedTransformer.from_pretrained_no_processing(
    LLAMA_PATH, hf_model=lat_hf, dtype=torch.float16, device=DEVICE, default_padding_side='left')
lat_model.tokenizer.padding_side = 'left'
lat_model.tokenizer.pad_token = '[PAD]'
del lat_base, lat_hf
clear_gpu()

lat_sae = SAE.load_from_pretrained(LAT_SAE_PATH, device=DEVICE)
lat_sae.eval()

In [None]:
lat_harmful_l1s = []
lat_harmful_recon_errors = []
lat_harmful_top_features = []

for i in tqdm(range(0, len(harmful_prompts), BATCH_SIZE)):
    batch = harmful_prompts[i:i+BATCH_SIZE]
    prompts = [TEMPLATE.format(prompt=p) for p in batch]
    toks = lat_model.tokenizer(prompts, padding=True, return_tensors='pt').input_ids.to(DEVICE)
    
    with torch.no_grad():
        _, cache = lat_model.run_with_cache(toks, names_filter=lambda n: f'blocks.{LAYER}.hook_resid_post' in n)
        acts = cache[f'blocks.{LAYER}.hook_resid_post'][:, -1, :]
        sae_acts = lat_sae.encode(acts)
        recon = lat_sae.decode(sae_acts)
        
        lat_harmful_l1s.extend(sae_acts.abs().sum(dim=-1).cpu().tolist())
        lat_harmful_recon_errors.extend(((acts - recon) ** 2).mean(dim=-1).cpu().tolist())
        
        for j in range(sae_acts.shape[0]):
            active_idx = (sae_acts[j].abs() > 1e-8).nonzero(as_tuple=True)[0]
            lat_harmful_top_features.append(active_idx.cpu().tolist())
    
    del cache, acts, sae_acts, recon
    clear_gpu()


In [None]:
lat_harmless_l1s = []
lat_harmless_recon_errors = []
lat_harmless_top_features = []

for i in tqdm(range(0, len(harmless_prompts), BATCH_SIZE)):
    batch = harmless_prompts[i:i+BATCH_SIZE]
    prompts = [TEMPLATE.format(prompt=p) for p in batch]
    toks = lat_model.tokenizer(prompts, padding=True, return_tensors='pt').input_ids.to(DEVICE)
    
    with torch.no_grad():
        _, cache = lat_model.run_with_cache(toks, names_filter=lambda n: f'blocks.{LAYER}.hook_resid_post' in n)
        acts = cache[f'blocks.{LAYER}.hook_resid_post'][:, -1, :]
        sae_acts = lat_sae.encode(acts)
        recon = lat_sae.decode(sae_acts)
        
        lat_harmless_l1s.extend(sae_acts.abs().sum(dim=-1).cpu().tolist())
        lat_harmless_recon_errors.extend(((acts - recon) ** 2).mean(dim=-1).cpu().tolist())
        
        for j in range(sae_acts.shape[0]):
            active_idx = (sae_acts[j].abs() > 1e-8).nonzero(as_tuple=True)[0]
            lat_harmless_top_features.append(active_idx.cpu().tolist())
    
    del cache, acts, sae_acts, recon
    clear_gpu()


In [None]:
lat_harmful_l1 = np.mean(lat_harmful_l1s)
lat_harmless_l1 = np.mean(lat_harmless_l1s)
lat_l1_ratio = lat_harmful_l1 / lat_harmless_l1
lat_harmful_mse = np.mean(lat_harmful_recon_errors)
lat_harmless_mse = np.mean(lat_harmless_recon_errors)

lat_harmful_counts = np.zeros(D_SAE)
lat_harmless_counts = np.zeros(D_SAE)
for feats in lat_harmful_top_features:
    for f in feats:
        lat_harmful_counts[f] += 1
for feats in lat_harmless_top_features:
    for f in feats:
        lat_harmless_counts[f] += 1

lat_harmful_specific = np.argsort(lat_harmful_counts - lat_harmless_counts)[-20:]
lat_harmless_specific = np.argsort(lat_harmless_counts - lat_harmful_counts)[-20:]
lat_top_harmful = set(np.argsort(lat_harmful_counts)[-100:])
lat_top_harmless = set(np.argsort(lat_harmless_counts)[-100:])
lat_jaccard = len(lat_top_harmful & lat_top_harmless) / len(lat_top_harmful | lat_top_harmless)

print(f"Harmful L1:  {lat_harmful_l1:.1f} ± {np.std(lat_harmful_l1s):.1f}")
print(f"Harmless L1: {lat_harmless_l1:.1f} ± {np.std(lat_harmless_l1s):.1f}")
print(f"L1 Ratio:    {lat_l1_ratio:.3f}")
print(f"Harmful MSE: {lat_harmful_mse:.4f}")
print(f"Harmless MSE:{lat_harmless_mse:.4f}")
print(f"Jaccard:     {lat_jaccard:.3f}")


#### Summary


In [None]:
t_base, p_base = stats.ttest_ind(baseline_harmful_l1s, baseline_harmless_l1s)
t_lat, p_lat = stats.ttest_ind(lat_harmful_l1s, lat_harmless_l1s)

print(f"{'':25} {'Baseline':>12} {'LAT':>12}")
print("-" * 50)
print(f"{'Harmful L1':<25} {baseline_harmful_l1:>12.1f} {lat_harmful_l1:>12.1f}")
print(f"{'Harmless L1':<25} {baseline_harmless_l1:>12.1f} {lat_harmless_l1:>12.1f}")
print(f"{'L1 Ratio':<25} {baseline_l1_ratio:>12.3f} {lat_l1_ratio:>12.3f}")
print("-" * 50)
print(f"{'Harmful MSE':<25} {baseline_harmful_mse:>12.4f} {lat_harmful_mse:>12.4f}")
print(f"{'Harmless MSE':<25} {baseline_harmless_mse:>12.4f} {lat_harmless_mse:>12.4f}")
print("-" * 50)
print(f"{'Feature Jaccard':<25} {baseline_jaccard:>12.3f} {lat_jaccard:>12.3f}")
print(f"{'t-stat (harmful vs less)':<25} {t_base:>12.2f} {t_lat:>12.2f}")
print(f"{'p-value':<25} {p_base:>12.2e} {p_lat:>12.2e}")


In [None]:
results = {
    'config': {'n_samples': N_SAMPLES, 'layer': LAYER, 'topk': K, 'd_sae': D_SAE},
    'baseline': {
        'harmful_l1': baseline_harmful_l1, 'harmless_l1': baseline_harmless_l1,
        'l1_ratio': baseline_l1_ratio, 'harmful_mse': baseline_harmful_mse,
        'harmless_mse': baseline_harmless_mse, 'jaccard': baseline_jaccard,
        'harmful_specific': baseline_harmful_specific.tolist(),
        'harmless_specific': baseline_harmless_specific.tolist(),
        't_stat': float(t_base), 'p_val': float(p_base),
    },
    'lat': {
        'harmful_l1': lat_harmful_l1, 'harmless_l1': lat_harmless_l1,
        'l1_ratio': lat_l1_ratio, 'harmful_mse': lat_harmful_mse,
        'harmless_mse': lat_harmless_mse, 'jaccard': lat_jaccard,
        'harmful_specific': lat_harmful_specific.tolist(),
        'harmless_specific': lat_harmless_specific.tolist(),
        't_stat': float(t_lat), 'p_val': float(p_lat),
    }
}

with open('sae_l1_analysis_results.json', 'w') as f:
    json.dump(results, f, indent=2)
print("Saved to sae_l1_analysis_results.json")
