## Setup

In [1]:
# %%capture
# ! pip install transformers transformers_stream_generator tiktoken transformer_lens einops jaxtyping colorama huggingface_hub accelerate
# ! pip install bitsandbytes==0.42.0 peft litellm

In [1]:
import torch
import functools
import einops
import requests
import pandas as pd
import json
import io
import gc
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle

from datasets import load_dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch import Tensor
from typing import List, Callable
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
from jaxtyping import Float, Int
from colorama import Fore
from huggingface_hub import notebook_login
from sklearn.decomposition import PCA
from typing import List, Dict, Union, Callable
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftConfig, PeftModel, get_peft_model
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
import random

In [2]:
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

set_seed(771)

## Load Models

### Load baseline model

In [3]:
LLAMA_2_7B_CHAT_PATH = "meta-llama/Llama-2-7b-chat-hf"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
inference_dtype = torch.float16

# Configure 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=inference_dtype,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

base_model = AutoModelForCausalLM.from_pretrained(
    LLAMA_2_7B_CHAT_PATH,
    torch_dtype=inference_dtype,
    device_map = "cuda:0",
    quantization_config=bnb_config
)

hooked_base_model = HookedTransformer.from_pretrained_no_processing(
  LLAMA_2_7B_CHAT_PATH,
  hf_model=base_model,
  dtype=inference_dtype,
  device=DEVICE,
  default_padding_side='left',
)

del base_model

hooked_base_model.tokenizer.padding_side = 'left'
hooked_base_model.tokenizer.pad_token = "[PAD]"

hooked_base_model.generate("The capital of Ireland is", max_new_tokens=10, temperature=0)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer


  0%|          | 0/10 [00:00<?, ?it/s]

'The capital of Ireland is Dublin. Unterscheidung between the two is not always clear'

## Select Target Model

In [4]:
model_type = "base"
model = hooked_base_model

## Setup base dataset (harmful / harmless)

In [5]:
LLAMA_CHAT_TEMPLATE = """
[INST]{prompt}[/INST]
"""

def tokenize_instructions_chat(
    tokenizer: AutoTokenizer,
    instructions: List[str]
) -> Int[Tensor, 'batch_size seq_len']:
    prompts = [LLAMA_CHAT_TEMPLATE.format(prompt=instruction) for instruction in instructions]
    # return tokenizer(prompts, padding='max_length', truncation=True, max_length=64, return_tensors="pt").input_ids
    return tokenizer(prompts, padding=True, truncation=False, return_tensors="pt").input_ids

tokenize_instructions_fn = functools.partial(tokenize_instructions_chat, tokenizer=model.tokenizer)

def get_harmful_instructions():
    url = 'https://raw.githubusercontent.com/llm-attacks/llm-attacks/main/data/advbench/harmful_behaviors.csv'
    response = requests.get(url)

    dataset = pd.read_csv(io.StringIO(response.content.decode('utf-8')))
    instructions = dataset['goal'].tolist()

    train, test = train_test_split(instructions, test_size=0.2, random_state=42)
    return train, test

def get_harmless_instructions():
    hf_path = 'tatsu-lab/alpaca'
    dataset = load_dataset(hf_path)

    # filter for instructions that do not have inputs
    instructions = []
    for i in range(len(dataset['train'])):
        if dataset['train'][i]['input'].strip() == '':
            instructions.append(dataset['train'][i]['instruction'])

    train, test = train_test_split(instructions, test_size=0.2, random_state=42)
    return train, test

In [6]:
def _generate_with_hooks(
    model: HookedTransformer,
    toks: Int[Tensor, 'batch_size seq_len'],
    max_tokens_generated: int = 64,
    fwd_hooks = [],
) -> List[str]:

    all_toks = torch.zeros((toks.shape[0], toks.shape[1] + max_tokens_generated), dtype=torch.long, device=toks.device)
    all_toks[:, :toks.shape[1]] = toks

    for i in range(max_tokens_generated):
        with model.hooks(fwd_hooks=fwd_hooks):
            logits = model(all_toks[:, :-max_tokens_generated + i])
            next_tokens = logits[:, -1, :].argmax(dim=-1) # greedy sampling (temperature=0)
            all_toks[:,-max_tokens_generated+i] = next_tokens

    return model.tokenizer.batch_decode(all_toks[:, toks.shape[1]:], skip_special_tokens=True)

def get_generations(
    model: HookedTransformer,
    instructions: List[str],
    tokenize_instructions_fn: Callable[[List[str]], Int[Tensor, 'batch_size seq_len']],
    fwd_hooks = [],
    max_tokens_generated: int = 64,
    batch_size: int = 16,
) -> List[Dict[str, Union[str, bool]]]:

    generations = []

    for i in tqdm(range(0, len(instructions), batch_size)):
        toks = tokenize_instructions_fn(instructions=instructions[i:i+batch_size]).to(DEVICE)
        generation = _generate_with_hooks(
            model,
            toks,
            max_tokens_generated=max_tokens_generated,
            fwd_hooks=fwd_hooks,
        )

        for instruction, completion in zip(instructions[i:i+batch_size], generation):
            generations.append({
              "instruction": instruction,
              "completion": completion,
            })

        torch.cuda.empty_cache()

    return generations

def direction_ablation_hook(
    activation: Float[Tensor, "... d_act"],
    hook: HookPoint,
    direction: Float[Tensor, "d_act"]
):
    direction = direction.to(DEVICE)
    proj = einops.einsum(activation, direction.view(-1, 1), '... d_act, d_act single -> ... single') * direction
    return activation - proj

In [7]:
harmful_inst_train, harmful_inst_test = get_harmful_instructions()
harmless_inst_train, harmless_inst_test = get_harmless_instructions()

In [8]:
print(f"Train harmful: {len(harmful_inst_train)}, harmless: {len(harmless_inst_train)}")
print(f"Test harmful: {len(harmful_inst_test)}, harmless: {len(harmless_inst_test)}")

Train harmful: 416, harmless: 25058
Test harmful: 104, harmless: 6265


In [30]:
refusal_vector_sample_size = 300
refusal_vector_harmful = harmful_inst_train[:refusal_vector_sample_size]
refusal_vector_harmless = harmless_inst_train[:refusal_vector_sample_size]

with open('../datasets/adversarial_questions_gpt4o.json') as fr:
    adversarial_qs = json.load(fr)

harmful_test = harmful_inst_train[refusal_vector_sample_size:] + harmful_inst_test + adversarial_qs
print(f"Refsual vector sample size: {len(refusal_vector_harmful)}\nTest Set size: {len(harmful_test)}")

Refsual vector sample size: 300
Test Set size: 320


## Run Experiments

In [43]:
def get_refusal_direction(refusal_vector_harmful, refusal_vector_harmless, n_samples, layer, batch_size=32):
    harmful_acts, harmless_acts = [], []
    harmful_samples = refusal_vector_harmful
    harmless_samples = refusal_vector_harmless

    for i in range(0, n_samples, batch_size):
        # tokenize instructions
        harmful_toks = tokenize_instructions_fn(instructions=harmful_samples[i:i+batch_size])
        harmless_toks = tokenize_instructions_fn(instructions=harmless_samples[i:i+batch_size])

        # run model on harmful and harmless instructions, caching intermediate activations
        _, harmful_cache = model.run_with_cache(harmful_toks, names_filter=lambda hook_name: 'resid' in hook_name)
        _, harmless_cache = model.run_with_cache(harmless_toks, names_filter=lambda hook_name: 'resid' in hook_name)

        # compute difference of means between harmful and harmless activations at an intermediate layer
        # use the last token position 
        harmful_acts.append(harmful_cache['resid_pre', layer][:, -1, :])
        harmless_acts.append(harmless_cache['resid_pre', layer][:, -1, :])

    harmful_acts = torch.cat(harmful_acts, dim=0)
    harmless_acts = torch.cat(harmless_acts, dim=0)

    # Remove top PCA direction
    all_acts = np.vstack((harmful_acts.cpu().numpy(), harmless_acts.cpu().numpy()))

    pca = PCA(n_components=1)
    top_component = pca.fit_transform(all_acts)
    all_acts_residual = all_acts - np.dot(top_component, pca.components_)
    
    harmful_acts = torch.tensor(all_acts_residual[:len(harmful_acts)], device=DEVICE, dtype=torch.float16)
    harmless_acts = torch.tensor(all_acts_residual[len(harmful_acts):], device=DEVICE, dtype=torch.float16)

    # Calculate refusal vector
    harmful_mean_act = harmful_acts.mean(dim=0)
    harmless_mean_act = harmless_acts.mean(dim=0)

    refusal_dir = harmful_mean_act - harmless_mean_act
    refusal_dir = refusal_dir / refusal_dir.norm()

    torch.cuda.empty_cache()
    
    return refusal_dir, harmful_acts, harmless_acts


def get_intervention_completions(refusal_dir, eval_samples, test_set, batch_size=12):
    intervention_dir = refusal_dir.to(DEVICE)
    intervention_layers = list(range(model.cfg.n_layers)) # all layers

    hook_fn = functools.partial(direction_ablation_hook,direction=intervention_dir)
    act_names = ['resid_pre', 'resid_mid', 'resid_post']
    fwd_hooks = [(utils.get_act_name(act_name, l), hook_fn) for l in intervention_layers for act_name in act_names]

    # test_examples = random.sample(harmful_inst_test, n_samples)
    # test_examples = harmful_inst_test[:eval_samples]
    test_examples = test_set[:eval_samples]

    intervention_completions = get_generations(
        model, test_examples, tokenize_instructions_fn, fwd_hooks=fwd_hooks, batch_size=batch_size)

    return intervention_completions

In [48]:
filename = f'../data_store/results_no_pca_top_direction_{model_type}.pkl'
print(filename)

if os.path.exists(filename):
    print(f"WARNING: Filename {filename} exists. It will be overwritten")

../data_store/results_no_pca_top_direction_base.pkl


In [49]:
n_samples = refusal_vector_sample_size
eval_samples = len(harmful_test)
batch_size = 16

print(f"{n_samples=}, {eval_samples=}")

results = []

baseline_completions = get_generations(
    model,
    harmful_test[:eval_samples],
    tokenize_instructions_fn,
    fwd_hooks=[],
    batch_size=batch_size
)
torch.cuda.empty_cache()

for layer in [14]:
    run_id = f"{model_type}_n_samples_{n_samples}.layer_{layer}"
    print(f"Running {run_id}")

    refusal_dir, harmful_acts, harmless_acts = get_refusal_direction(
        refusal_vector_harmful, refusal_vector_harmless, n_samples=n_samples, layer=layer, batch_size=batch_size)
    
    intervention_completions = get_intervention_completions(
        refusal_dir, eval_samples, harmful_test, batch_size=batch_size)

    results.append({
        "run_id": run_id,
        "n_samples": n_samples,
        "layer": layer,
        "refusal_dir": refusal_dir.cpu().numpy(),
        "harmful_acts": harmful_acts.cpu().numpy(),
        "harmless_acts": harmless_acts.cpu().numpy(),
        "baseline_completions": baseline_completions,
        "intervention_completions": intervention_completions
    })
    print()

    with open(filename, 'wb') as fw:
        pickle.dump(results, fw)

n_samples=300, eval_samples=320


100%|██████████| 20/20 [05:33<00:00, 16.66s/it]


Running base_n_samples_300.layer_14


100%|██████████| 20/20 [06:47<00:00, 20.37s/it]





