Testing how quantization affects the observed SVD concentration of refusal directions.
Precisions: 4-bit, 8-bit, 16-bit (fp16), 32-bit (fp32)

In [None]:
import torch
import numpy as np
import requests
import pandas as pd
import io
import gc
import matplotlib.pyplot as plt
from scipy.linalg import svd
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from huggingface_hub import login as hf_login
import os
from dotenv import load_dotenv

load_dotenv()
hf_login(token=os.environ.get('HF_TOKEN'))

DEVICE = 'cuda'
LLAMA_PATH = 'meta-llama/Llama-2-7b-chat-hf'
LAT_PATH = 'nlpett/llama-2-7b-chat-hf-LAT-layer4-hh'
AT_PATH = 'nlpett/llama-2-7b-chat-hf-AT-hh'
TEMPLATE = '\n[INST]{prompt}[/INST]\n'


In [None]:
# precision configs - each returns (quantization_config or None, dtype, name)
def get_precision_config(bits):
    if bits == 4:
        return BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        ), torch.float16, '4-bit'
    elif bits == 8:
        return BitsAndBytesConfig(load_in_8bit=True), torch.float16, '8-bit'
    elif bits == 16:
        return None, torch.float16, 'fp16'
    elif bits == 32:
        return None, torch.float32, 'fp32'
    else:
        raise ValueError(f'Unknown precision: {bits}')

In [None]:
def clear_gpu():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

def load_model(peft_path, quant_config, dtype):
    kwargs = {'torch_dtype': dtype, 'device_map': 'auto'}
    if quant_config:
        kwargs['quantization_config'] = quant_config
    
    base = AutoModelForCausalLM.from_pretrained(LLAMA_PATH, **kwargs)
    
    if peft_path:
        peft = PeftModel.from_pretrained(base, peft_path)
        if quant_config is None:
            peft = peft.to(dtype)
        merged = peft.merge_and_unload()
        merged.eval()
        hf_model = merged
        del peft
    else:
        hf_model = base
    
    hooked = HookedTransformer.from_pretrained_no_processing(
        LLAMA_PATH, hf_model=hf_model, dtype=dtype, device=DEVICE, default_padding_side='left'
    )
    hooked.tokenizer.padding_side = 'left'
    hooked.tokenizer.pad_token = '[PAD]'
    
    del base
    if peft_path:
        del hf_model
    clear_gpu()
    return hooked

def unload_model(model):
    del model
    clear_gpu()
    print(f'GPU memory after cleanup: {torch.cuda.memory_allocated()/1e9:.2f} GB')


In [None]:
def get_acts(model, instructions, layer=14, batch_size=16):
    acts = []
    for i in tqdm(range(0, len(instructions), batch_size), desc=f'L{layer}'):
        batch = instructions[i:i+batch_size]
        prompts = [TEMPLATE.format(prompt=p) for p in batch]
        toks = model.tokenizer(prompts, padding=True, return_tensors='pt').input_ids.to(DEVICE)
        with torch.no_grad():
            _, cache = model.run_with_cache(toks, names_filter=lambda n: f'blocks.{layer}.hook_resid_pre' in n)
        acts.append(cache[f'blocks.{layer}.hook_resid_pre'][:, -1, :].cpu())  # move to CPU immediately
        del cache
    return torch.cat(acts, dim=0)

def svd_variance(diffs):
    _, s, _ = svd(diffs.float().numpy(), full_matrices=False)
    v = (s**2) / np.sum(s**2)
    return {'top1': v[0], 'top2': sum(v[:2]), 'top10': sum(v[:10])}

In [None]:
url = 'https://raw.githubusercontent.com/llm-attacks/llm-attacks/main/data/advbench/harmful_behaviors.csv'
harmful_train, _ = train_test_split(
    pd.read_csv(io.StringIO(requests.get(url).content.decode('utf-8')))['goal'].tolist(), 
    test_size=0.2, random_state=42
)

dataset = load_dataset('tatsu-lab/alpaca')
harmless_all = [dataset['train'][i]['instruction'] for i in range(len(dataset['train'])) 
                if dataset['train'][i]['input'].strip() == '']
harmless_train, _ = train_test_split(harmless_all, test_size=0.2, random_state=42)

N = 416
print(f'Using {N} samples each')


In [None]:
# main experiment: iterate precision -> model -> compute SVD
PRECISIONS = [4, 8, 16, 32]
MODELS = [('BASE', None), ('AT', AT_PATH), ('LAT', LAT_PATH)]

results = {bits: {} for bits in PRECISIONS}

for bits in PRECISIONS:
    quant_config, dtype, prec_name = get_precision_config(bits)
    print(f'\n{"="*50}\n{prec_name.upper()} PRECISION\n{"="*50}')
    
    for model_name, peft_path in MODELS:
        print(f'\n--- {model_name} ---')
        
        try:
            model = load_model(peft_path, quant_config, dtype)
            print(f'Loaded. GPU: {torch.cuda.memory_allocated()/1e9:.2f} GB')
            
            harmful_acts = get_acts(model, harmful_train[:N])
            harmless_acts = get_acts(model, harmless_train[:N])
            diffs = harmful_acts - harmless_acts
            
            results[bits][model_name] = svd_variance(diffs)
            print(f"{model_name} @ {prec_name}: top-1={results[bits][model_name]['top1']*100:.1f}%")
            
            del harmful_acts, harmless_acts, diffs
            unload_model(model)
            
        except Exception as e:
            print(f'ERROR: {e}')
            results[bits][model_name] = None
            clear_gpu()

In [None]:
print(f"{'Precision':<12} {'BASE':>10} {'AT':>10} {'LAT':>10}")
for bits in PRECISIONS:
    _, _, prec_name = get_precision_config(bits)
    row = f"{prec_name:<12}"
    for model_name in ['BASE', 'AT', 'LAT']:
        if results[bits].get(model_name):
            row += f"{results[bits][model_name]['top1']*100:>10.1f}"
        else:
            row += f"{'N/A':>10}"
    print(row)


In [None]:
fig, ax = plt.subplots(figsize=(10, 6))

prec_labels = ['4-bit', '8-bit', 'fp16', 'fp32']
x = np.arange(len(PRECISIONS))
width = 0.25

colors = {'BASE': 'tab:blue', 'AT': 'tab:orange', 'LAT': 'tab:red'}

for i, model_name in enumerate(['BASE', 'AT', 'LAT']):
    vals = []
    for bits in PRECISIONS:
        if results[bits].get(model_name):
            vals.append(results[bits][model_name]['top1'] * 100)
        else:
            vals.append(0)
    ax.bar(x + (i - 1) * width, vals, width, label=model_name, color=colors[model_name], alpha=0.8)

ax.set_xlabel('Precision')
ax.set_ylabel('Top-1 SVD Explained Variance (%)')
ax.set_title('Refusal Direction Concentration vs Model Precision')
ax.set_xticks(x)
ax.set_xticklabels(prec_labels)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
ax.set_ylim(0, 105)

#ax.axhline(y=51, color='tab:red', linestyle='--', alpha=0.3, label='Paper LAT (~51%)')
#ax.axhline(y=43, color='tab:orange', linestyle='--', alpha=0.3, label='Paper AT (~43%)')

plt.tight_layout()
plt.show()

In [None]:
# line plot version - concentration trend
fig, ax = plt.subplots(figsize=(10, 6))

for model_name in ['BASE', 'AT', 'LAT']:
    vals = []
    for bits in PRECISIONS:
        if results[bits].get(model_name):
            vals.append(results[bits][model_name]['top1'] * 100)
        else:
            vals.append(np.nan)
    ax.plot(prec_labels, vals, 'o-', label=model_name, color=colors[model_name], linewidth=2, markersize=8)

ax.set_xlabel('Precision')
ax.set_ylabel('Top-1 SVD Explained Variance (%)')
ax.set_title('Quantisation effects on refusal direction SVD concentration in top component')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 105)

plt.tight_layout()
plt.savefig('precision_analysis_line.png', dpi=150)
plt.show()