In [1]:
%pip install -r requirements.txt

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import numpy as np
import requests
import pandas as pd
import io
import gc
from scipy.linalg import svd
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from huggingface_hub import login as hf_login
import os
from dotenv import load_dotenv
import bitsandbytes

load_dotenv()
hf_login(token=os.environ.get('HF_TOKEN'))


  from .autonotebook import tqdm as notebook_tqdm
2025-12-22 19:22:24.172573: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-22 19:22:24.184158: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766431344.198428    6941 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766431344.202972    6941 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766431344.214019    6941 computation_placer.cc:177] computation placer already r

In [3]:
DEVICE = 'cuda'
DTYPE = torch.float16
LLAMA_PATH = 'meta-llama/Llama-2-7b-chat-hf'
LAT_PATH = 'nlpett/llama-2-7b-chat-hf-LAT-layer4-hh'
AT_PATH = 'nlpett/llama-2-7b-chat-hf-AT-hh'
TEMPLATE = '\n[INST]{prompt}[/INST]\n'

# OG 4-bit config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=DTYPE,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4'
)


In [4]:
def load_model_4bit(peft_path=None, name='base'):
    base = AutoModelForCausalLM.from_pretrained(
        LLAMA_PATH, torch_dtype=DTYPE, device_map='cuda:0', quantization_config=bnb_config
    )
    if peft_path:
        peft = PeftModel.from_pretrained(base, peft_path).to(DTYPE)
        merged = peft.merge_and_unload()
        merged.eval()
        hf_model = merged
    else:
        hf_model = base
    
    hooked = HookedTransformer.from_pretrained_no_processing(
        LLAMA_PATH, hf_model=hf_model, dtype=DTYPE, device=DEVICE, default_padding_side='left'
    )
    hooked.tokenizer.padding_side = 'left'
    hooked.tokenizer.pad_token = '[PAD]'
    
    del base
    if peft_path:
        del peft, merged
    gc.collect()
    torch.cuda.empty_cache()
    return hooked


In [5]:
url = 'https://raw.githubusercontent.com/llm-attacks/llm-attacks/main/data/advbench/harmful_behaviors.csv'
harmful_train, _ = train_test_split(
    pd.read_csv(io.StringIO(requests.get(url).content.decode('utf-8')))['goal'].tolist(), 
    test_size=0.2, random_state=42
)

dataset = load_dataset('tatsu-lab/alpaca')
harmless_all = [dataset['train'][i]['instruction'] for i in range(len(dataset['train'])) 
                if dataset['train'][i]['input'].strip() == '']
harmless_train, _ = train_test_split(harmless_all, test_size=0.2, random_state=42)

print(f'harmful: {len(harmful_train)}, harmless: {len(harmless_train)}')


harmful: 416, harmless: 25058


In [6]:
def get_acts(model, instructions, layer=14, batch_size=16):
    acts = []
    for i in tqdm(range(0, len(instructions), batch_size)):
        batch = instructions[i:i+batch_size]
        prompts = [TEMPLATE.format(prompt=p) for p in batch]
        toks = model.tokenizer(prompts, padding=True, return_tensors='pt').input_ids.to(DEVICE)
        with torch.no_grad():
            _, cache = model.run_with_cache(toks, names_filter=lambda n: f'blocks.{layer}.hook_resid_pre' in n)
        acts.append(cache[f'blocks.{layer}.hook_resid_pre'][:, -1, :])
    return torch.cat(acts, dim=0)

def svd_variance(diffs):
    _, s, _ = svd(diffs.cpu().numpy(), full_matrices=False)
    v = (s**2) / np.sum(s**2)
    return v[0], sum(v[:2]), sum(v[:10])


In [7]:
N = 416
results = {}

for name, peft_path in [('BASE', None), ('AT', AT_PATH), ('LAT', LAT_PATH)]:
    print(f'\n{name}:')
    model = load_model_4bit(peft_path, name)
    
    harmful_acts = get_acts(model, harmful_train[:N])
    harmless_acts = get_acts(model, harmless_train[:N])
    diffs = harmful_acts - harmless_acts
    
    top1, top2, top10 = svd_variance(diffs)
    results[name] = {'top1': top1, 'top2': top2, 'top10': top10}
    print(f'{name}: top-1={top1*100:.1f}%, top-2={top2*100:.1f}%, top-10={top10*100:.1f}%')
    
    del model, harmful_acts, harmless_acts, diffs
    gc.collect()
    torch.cuda.empty_cache()


`torch_dtype` is deprecated! Use `dtype` instead!



BASE:


Fetching 2 files: 100%|██████████| 2/2 [00:15<00:00,  7.92s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:12<00:00,  6.30s/it]


Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer


100%|██████████| 26/26 [00:02<00:00,  9.32it/s]
100%|██████████| 26/26 [00:03<00:00,  8.63it/s]


BASE: top-1=48.8%, top-2=52.9%, top-10=68.4%

AT:


Loading checkpoint shards: 100%|██████████| 2/2 [00:10<00:00,  5.29s/it]


Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer


100%|██████████| 26/26 [00:02<00:00,  8.78it/s]
100%|██████████| 26/26 [00:03<00:00,  7.98it/s]


AT: top-1=43.1%, top-2=47.3%, top-10=63.4%

LAT:


Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.66s/it]


Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer


100%|██████████| 26/26 [00:02<00:00,  8.94it/s]
100%|██████████| 26/26 [00:02<00:00,  8.87it/s]


LAT: top-1=51.0%, top-2=72.7%, top-10=83.3%


In [8]:
print(f"{'Model':<8} {'4bit top-1':>12} {'Paper top-1':>12} {'Match?':>10}")
print('-' * 44)
paper = {'BASE': 49, 'AT': 43, 'LAT': 51}
for name in ['BASE', 'AT', 'LAT']:
    got = results[name]['top1']*100
    expected = paper[name]
    match = '✓' if abs(got - expected) < 10 else '✗'
    print(f"{name:<8} {got:>11.1f}% {expected:>11}% {match:>10}")


Model      4bit top-1  Paper top-1     Match?
--------------------------------------------
BASE            48.8%          49%          ✓
AT              43.1%          43%          ✓
LAT             51.0%          51%          ✓
