In [None]:
%pip install -r requirements.txt

In [None]:
# layers × models × precision for participancy ratio of given refusal direction calculed w diff o' means
# PR = 1 means all variance in 1 dimension, PR = n means spread across n dimensions

import torch
import numpy as np
import requests
import pandas as pd
import io
import gc
import matplotlib.pyplot as plt
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from huggingface_hub import login as hf_login
import os
from dotenv import load_dotenv

load_dotenv()
hf_login(token=os.environ.get('HF_TOKEN'))

In [None]:
DEVICE = 'cuda'
LLAMA_PATH = 'meta-llama/Llama-2-7b-chat-hf'
LAT_PATH = 'nlpett/llama-2-7b-chat-hf-LAT-layer4-hh'
AT_PATH = 'nlpett/llama-2-7b-chat-hf-AT-hh'
TEMPLATE = '\n[INST]{prompt}[/INST]\n'

LAYERS = [4, 8, 12, 14, 16, 20]  # incl. LAT perturbation layer (4) and best ablation layer (14)
PRECISIONS = [4, 8, 16]  # bits
MODELS = [('BASE', None), ('AT', AT_PATH), ('LAT', LAT_PATH)]
N = 416


In [None]:
def get_quant_config(bits):
    if bits == 4:
        return BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
                                   bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4'), torch.float16
    elif bits == 8:
        return BitsAndBytesConfig(load_in_8bit=True), torch.float16
    else:
        return None, torch.float16

def clear_gpu():
    gc.collect(); torch.cuda.empty_cache(); torch.cuda.synchronize()

def load_model(peft_path, quant_config, dtype):
    kwargs = {'torch_dtype': dtype, 'device_map': 'auto'}
    if quant_config: kwargs['quantization_config'] = quant_config
    base = AutoModelForCausalLM.from_pretrained(LLAMA_PATH, **kwargs)
    if peft_path:
        peft = PeftModel.from_pretrained(base, peft_path)
        if not quant_config: peft = peft.to(dtype)
        hf_model = peft.merge_and_unload(); hf_model.eval(); del peft
    else:
        hf_model = base
    hooked = HookedTransformer.from_pretrained_no_processing(
        LLAMA_PATH, hf_model=hf_model, dtype=dtype, device=DEVICE, default_padding_side='left')
    hooked.tokenizer.padding_side = 'left'; hooked.tokenizer.pad_token = '[PAD]'
    del base; clear_gpu()
    return hooked

def get_acts_multi_layer(model, instructions, layers, batch_size=16):
    acts = {l: [] for l in layers}
    for i in tqdm(range(0, len(instructions), batch_size)):
        prompts = [TEMPLATE.format(prompt=p) for p in instructions[i:i+batch_size]]
        toks = model.tokenizer(prompts, padding=True, return_tensors='pt').input_ids.to(DEVICE)
        with torch.no_grad():
            _, cache = model.run_with_cache(toks, names_filter=lambda n: 'resid_pre' in n)
        for l in layers:
            acts[l].append(cache['resid_pre', l][:, -1, :].cpu())
        del cache
    return {l: torch.cat(a, dim=0) for l, a in acts.items()}

def participation_ratio(diffs):
    diffs = diffs.float()
    cov = (diffs.T @ diffs) / len(diffs)
    eigs = torch.linalg.eigvalsh(cov)
    eigs = eigs[eigs > 1e-10]
    return (eigs.sum() ** 2 / (eigs ** 2).sum()).item()


In [None]:
url = 'https://raw.githubusercontent.com/llm-attacks/llm-attacks/main/data/advbench/harmful_behaviors.csv'
harmful_train, _ = train_test_split(pd.read_csv(io.StringIO(requests.get(url).content.decode('utf-8')))['goal'].tolist(), 
                                     test_size=0.2, random_state=42)
dataset = load_dataset('tatsu-lab/alpaca')
harmless_all = [dataset['train'][i]['instruction'] for i in range(len(dataset['train'])) if dataset['train'][i]['input'].strip() == '']
harmless_train, _ = train_test_split(harmless_all, test_size=0.2, random_state=42)
print(f'Loaded {len(harmful_train)} harmful, {len(harmless_train)} harmless')


In [None]:
# main loop: precision -> model -> layers
results = {}  # results[bits][model_name][layer] = PR value

for bits in PRECISIONS:
    quant_config, dtype = get_quant_config(bits)
    results[bits] = {}
    print(f'\n{"="*50}\n{bits}-BIT\n{"="*50}')
    
    for model_name, peft_path in MODELS:
        print(f'\n--- {model_name} ---')
        results[bits][model_name] = {}
        
        try:
            model = load_model(peft_path, quant_config, dtype)
            harmful_acts = get_acts_multi_layer(model, harmful_train[:N], LAYERS)
            harmless_acts = get_acts_multi_layer(model, harmless_train[:N], LAYERS)
            
            for layer in LAYERS:
                diffs = harmful_acts[layer] - harmless_acts[layer]
                pr = participation_ratio(diffs)
                results[bits][model_name][layer] = pr
                print(f'  L{layer}: PR={pr:.1f}')
            
            del model, harmful_acts, harmless_acts
            clear_gpu()
        except Exception as e:
            print(f'ERROR: {e}')
            clear_gpu()

In [None]:
print('\nParticipation Ratio (effective dimensions) at Layer 14:')
print(f"{'Precision':<10} {'BASE':>8} {'AT':>8} {'LAT':>8}")
print('-'*36)
for bits in PRECISIONS:
    row = f"{bits}-bit{'':<5}"
    for m in ['BASE', 'AT', 'LAT']:
        val = results[bits].get(m, {}).get(14, None)
        row += f"{val:>8.1f}" if val else f"{'N/A':>8}"
    print(row)

In [None]:
# plot: PR across layers for each model, faceted by precision
fig, axes = plt.subplots(1, len(PRECISIONS), figsize=(4*len(PRECISIONS), 4), sharey=True)
colors = {'BASE': 'tab:blue', 'AT': 'tab:orange', 'LAT': 'tab:red'}

for ax, bits in zip(axes, PRECISIONS):
    for model_name in ['BASE', 'AT', 'LAT']:
        if model_name in results[bits]:
            prs = [results[bits][model_name].get(l, np.nan) for l in LAYERS]
            ax.plot(LAYERS, prs, 'o-', label=model_name, color=colors[model_name])
    ax.set_xlabel('Layer')
    ax.set_title(f'{bits}-bit')
    ax.grid(True, alpha=0.3)
    ax.set_xticks(LAYERS)

axes[0].set_ylabel('Participation Ratio')
axes[-1].legend()
fig.suptitle('Participation Ratio (Effective Dimensions) by Layer', y=1.02)
plt.tight_layout()
plt.show()

In [None]:
# plot: PR at layer 14 vs precision
fig, ax = plt.subplots(figsize=(8, 5))
prec_labels = [f'{b}-bit' for b in PRECISIONS]

for model_name in ['BASE', 'AT', 'LAT']:
    prs = [results[b].get(model_name, {}).get(14, np.nan) for b in PRECISIONS]
    ax.plot(prec_labels, prs, 'o-', label=model_name, color=colors[model_name], linewidth=2, markersize=8)

ax.set_xlabel('Precision')
ax.set_ylabel('Participation Ratio (effective dims)')
ax.set_title('Precision effects on observed dimensionality of the refusal direction at L14')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()