<a href="https://colab.research.google.com/github/b-schoen/gpt_from_scratch/blob/main/colab/refusal_direction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Demo of bypassing refusal

>[Demo of bypassing refusal](#scrollTo=82acAhWYGIPx)

>>[Setup](#scrollTo=fcxHyDZw6b86)

>>>[Load model](#scrollTo=6ZOoJagxD49V)

>>>[Load harmful / harmless datasets](#scrollTo=rF7e-u20EFTe)

>>>[Tokenization utils](#scrollTo=KOKYA61k8LWt)

>>>[Generation utils](#scrollTo=gtrIK8x78SZh)

>>[Finding the "refusal direction"](#scrollTo=W9O8dm0_EQRk)

>>[Ablate "refusal direction" via inference-time intervention](#scrollTo=2EoxY5i1CWe3)

>>[Orthogonalize weights w.r.t. "refusal direction"](#scrollTo=t9KooaWaCDc_)



This notebook demonstrates our method for bypassing refusal, levaraging the insight that refusal is mediated by a 1-dimensional subspace. We recommend reading the [research post](https://www.lesswrong.com/posts/jGuXSZgv6qfdhMCuJ/refusal-in-llms-is-mediated-by-a-single-direction) for a more thorough explanation.

# Using HookedTransformer

## Setup

In [None]:
%%capture
!pip install --upgrade transformers transformers_stream_generator tiktoken transformer_lens einops jaxtyping colorama

In [None]:
import torch
import functools
import einops
import requests
import pandas as pd
import io
import textwrap
import gc

from datasets import load_dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch import Tensor
from typing import List, Callable
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
from transformers import AutoTokenizer
from jaxtyping import Float, Int
from colorama import Fore

### Load model

In [None]:
import os

# MODEL_PATH = 'meta-llama/meta-llama-3-8b-instruct'
MODEL_PATH = 'google/gemma-2-2b-it'
DEVICE = 'cuda'

model = HookedTransformer.from_pretrained(
    MODEL_PATH,
    device=DEVICE,
    dtype=torch.bfloat16,
)

In [None]:
# from https://github.com/andyrdt/refusal_direction/blob/main/pipeline/model_utils/gemma_model.py#L100C9-L100C40
model.tokenizer.padding_side = 'left'

In [None]:
# Make sure our understanding of the tokenization here is correct
example_prompt = "hello"

# note: hello maps to ~128k token int, that's not the sequence length but can be confusing at first glance if not paying attention
tokens = model.tokenizer.encode(example_prompt, return_tensors="pt")
tokens = tokens.to('cuda')

print(f'{tokens.shape=}')

reconstructed_example_prompt = model.tokenizer.decode(tokens[0])

print(f"Original prompt: {example_prompt}")
print(f"Reconstructed prompt: {reconstructed_example_prompt}")

In [None]:
output = model.generate(
    'What is the capital of France?',
    max_new_tokens=10,
    verbose=True,
)

print(output)

### Load harmful / harmless datasets

In [None]:
def get_harmful_instructions():
    url = 'https://raw.githubusercontent.com/llm-attacks/llm-attacks/main/data/advbench/harmful_behaviors.csv'
    response = requests.get(url)

    dataset = pd.read_csv(io.StringIO(response.content.decode('utf-8')))
    instructions = dataset['goal'].tolist()

    train, test = train_test_split(instructions, test_size=0.2, random_state=42)
    return train, test

def get_harmless_instructions():
    hf_path = 'tatsu-lab/alpaca'
    dataset = load_dataset(hf_path)

    # filter for instructions that do not have inputs
    instructions = []
    for i in range(len(dataset['train'])):
        if dataset['train'][i]['input'].strip() == '':
            instructions.append(dataset['train'][i]['instruction'])

    train, test = train_test_split(instructions, test_size=0.2, random_state=42)
    return train, test

In [None]:
harmful_inst_train, harmful_inst_test = get_harmful_instructions()
harmless_inst_train, harmless_inst_test = get_harmless_instructions()

In [None]:
print("Harmful instructions:")
for i in range(4):
    print(f"\t{repr(harmful_inst_train[i])}")
print("Harmless instructions:")
for i in range(4):
    print(f"\t{repr(harmless_inst_train[i])}")

### Tokenization utils

In [None]:
# Gemma chat template is based on
# - Official Gemma documentation: https://ai.google.dev/gemma/docs/formatting

GEMMA_CHAT_TEMPLATE = """<start_of_turn>user
{instruction}<end_of_turn>
<start_of_turn>model
"""

def format_instruction_gemma_chat(
    instruction: str,
    output: str=None,
    system: str=None,
    include_trailing_whitespace: bool=True,
):
    if system is not None:
        raise ValueError("System prompts are not supported for Gemma models.")
    else:
        formatted_instruction = GEMMA_CHAT_TEMPLATE.format(instruction=instruction)

    if not include_trailing_whitespace:
        formatted_instruction = formatted_instruction.rstrip()

    if output is not None:
        formatted_instruction += output

    return formatted_instruction

def tokenize_instructions_gemma_chat(
    tokenizer: AutoTokenizer,
    instructions: List[str],
    outputs: List[str]=None,
    system: str=None,
    include_trailing_whitespace=True,
):
    if outputs is not None:
        prompts = [
            format_instruction_gemma_chat(instruction=instruction, output=output, system=system, include_trailing_whitespace=include_trailing_whitespace)
            for instruction, output in zip(instructions, outputs)
        ]
    else:
        prompts = [
            format_instruction_gemma_chat(instruction=instruction, system=system, include_trailing_whitespace=include_trailing_whitespace)
            for instruction in instructions
        ]

    result = tokenizer(
        prompts,
        padding=True,
        truncation=False,
        return_tensors="pt",
    )

    # note: we do this for compatibility with original notebook
    return result['input_ids']

tokenize_instructions_fn = functools.partial(
    tokenize_instructions_gemma_chat,
    tokenizer=model.tokenizer,
  )

### Generation utils

In [None]:
def _generate_with_hooks(
    model: HookedTransformer,
    toks: Int[Tensor, 'batch_size seq_len'],
    max_tokens_generated: int = 64,
    fwd_hooks = [],
) -> List[str]:

    all_toks = torch.zeros((toks.shape[0], toks.shape[1] + max_tokens_generated), dtype=torch.long, device=toks.device)
    all_toks[:, :toks.shape[1]] = toks

    for i in range(max_tokens_generated):
        with model.hooks(fwd_hooks=fwd_hooks):
            logits = model(all_toks[:, :-max_tokens_generated + i])
            next_tokens = logits[:, -1, :].argmax(dim=-1) # greedy sampling (temperature=0)
            all_toks[:,-max_tokens_generated+i] = next_tokens

    return model.tokenizer.batch_decode(all_toks[:, toks.shape[1]:], skip_special_tokens=True)

def get_generations(
    model: HookedTransformer,
    instructions: List[str],
    tokenize_instructions_fn: Callable[[List[str]], Int[Tensor, 'batch_size seq_len']],
    fwd_hooks = [],
    max_tokens_generated: int = 64,
    batch_size: int = 4,
) -> List[str]:

    generations = []

    for i in tqdm(range(0, len(instructions), batch_size)):
        toks = tokenize_instructions_fn(instructions=instructions[i:i+batch_size])
        generation = _generate_with_hooks(
            model,
            toks,
            max_tokens_generated=max_tokens_generated,
            fwd_hooks=fwd_hooks,
        )
        generations.extend(generation)

    return generations

## Finding the "refusal direction"

In [None]:
N_INST_TRAIN = 32

# tokenize instructions
harmful_toks = tokenize_instructions_fn(instructions=harmful_inst_train[:N_INST_TRAIN])
harmless_toks = tokenize_instructions_fn(instructions=harmless_inst_train[:N_INST_TRAIN])

# run model on harmful and harmless instructions, caching intermediate activations


In [None]:
type(harmful_toks)

In [None]:
harmful_logits, harmful_cache = model.run_with_cache(
    harmful_toks,
    names_filter=lambda hook_name: 'resid' in hook_name,
)

In [None]:
harmless_logits, harmless_cache = model.run_with_cache(
    harmless_toks,
    names_filter=lambda hook_name: 'resid' in hook_name,
)

In [None]:
# compute difference of means between harmful and harmless activations at an intermediate layer

# note: is this just chosen arbitrarily lol?
pos = -1
layer = 14

harmful_mean_act = harmful_cache['resid_pre', layer][:, pos, :].mean(dim=0)
harmless_mean_act = harmless_cache['resid_pre', layer][:, pos, :].mean(dim=0)

refusal_dir = harmful_mean_act - harmless_mean_act
refusal_dir = refusal_dir / refusal_dir.norm()

In [None]:
print(f'{harmful_mean_act.shape=}')
print(f'{harmless_mean_act.shape=}')
print(f'{refusal_dir.shape=}')

In [None]:
# clean up memory
del harmful_cache, harmless_cache, harmful_logits, harmless_logits
gc.collect(); torch.cuda.empty_cache()

## Ablate "refusal direction" via inference-time intervention

Given a "refusal direction" $\widehat{r} \in \mathbb{R}^{d_{\text{model}}}$ with unit norm, we can ablate this direction from the model's activations $a_{l}$:
$${a}_{l}' \leftarrow a_l - (a_l \cdot \widehat{r}) \widehat{r}$$

By performing this ablation on all intermediate activations, we enforce that the model can never express this direction (or "feature").

In [None]:
def direction_ablation_hook(
    activation: Float[Tensor, "... d_act"],
    hook: HookPoint,
    direction: Float[Tensor, "d_act"]
):
    proj = einops.einsum(
        activation,
        direction.view(-1, 1),
        '... d_act, d_act single -> ... single',
    ) * direction

    return activation - proj

In [None]:
N_INST_TEST = 32
intervention_dir = refusal_dir
intervention_layers = list(range(model.cfg.n_layers)) # all layers

hook_fn = functools.partial(
    direction_ablation_hook,
    direction=intervention_dir,
)

fwd_hooks = [
    (utils.get_act_name(act_name, l), hook_fn)
    for l in intervention_layers
    for act_name in ['resid_pre', 'resid_mid', 'resid_post']
]

intervention_generations = get_generations(
    model,
    harmful_inst_test[:N_INST_TEST],
    tokenize_instructions_fn,
    fwd_hooks=fwd_hooks,
)

baseline_generations = get_generations(
    model,
    harmful_inst_test[:N_INST_TEST],
    tokenize_instructions_fn,
    fwd_hooks=[],
)

In [None]:
for i in range(N_INST_TEST):
    print(f"INSTRUCTION {i}: {repr(harmful_inst_test[i])}")
    print(Fore.GREEN + f"BASELINE COMPLETION:")
    print(textwrap.fill(repr(baseline_generations[i]), width=100, initial_indent='\t', subsequent_indent='\t'))
    print(Fore.RED + f"INTERVENTION COMPLETION:")
    print(textwrap.fill(repr(intervention_generations[i]), width=100, initial_indent='\t', subsequent_indent='\t'))
    print(Fore.RESET)

## Orthogonalize weights w.r.t. "refusal direction"

We can implement the intervention equivalently by directly orthogonalizing the weight matrices that write to the residual stream with respect to the refusal direction $\widehat{r}$:
$$W_{\text{out}}' \leftarrow W_{\text{out}} - \widehat{r}\widehat{r}^{\mathsf{T}} W_{\text{out}}$$

By orthogonalizing these weight matrices, we enforce that the model is unable to write direction $r$ to the residual stream at all!

In [None]:
def get_orthogonalized_matrix(
    matrix: Float[Tensor, '... d_model'],
    vec: Float[Tensor, 'd_model'],
) -> Float[Tensor, '... d_model']:

    proj = einops.einsum(
        matrix,
        vec.view(-1, 1),
        '... d_model, d_model single -> ... single',
    ) * vec

    return matrix - proj

In [None]:
model.W_E.data = get_orthogonalized_matrix(model.W_E, refusal_dir)

for block in model.blocks:
    block.attn.W_O.data = get_orthogonalized_matrix(block.attn.W_O, refusal_dir)
    block.mlp.W_out.data = get_orthogonalized_matrix(block.mlp.W_out, refusal_dir)

In [None]:
orthogonalized_generations = get_generations(
    model,
    harmful_inst_test[:N_INST_TEST],
    tokenize_instructions_fn,
    fwd_hooks=[],
)

In [None]:
for i in range(N_INST_TEST):
    print(f"INSTRUCTION {i}: {repr(harmful_inst_test[i])}")
    print(Fore.GREEN + f"BASELINE COMPLETION:")
    print(textwrap.fill(repr(baseline_generations[i]), width=100, initial_indent='\t', subsequent_indent='\t'))
    print(Fore.RED + f"INTERVENTION COMPLETION:")
    print(textwrap.fill(repr(intervention_generations[i]), width=100, initial_indent='\t', subsequent_indent='\t'))
    print(Fore.MAGENTA + f"ORTHOGONALIZED COMPLETION:")
    print(textwrap.fill(repr(orthogonalized_generations[i]), width=100, initial_indent='\t', subsequent_indent='\t'))
    print(Fore.RESET)

# Using RefusalDirection Repo (Huggingface)

This allows us to use models that are supported by Huggingface but not HookedTransformers (like 405B)

In [None]:
!git clone https://github.com/andyrdt/refusal_direction.git

# change into the repo directory
import os

os.chdir('refusal_direction')

print("Current Working Directory:", os.getcwd())

!ls

# !source setup.sh

from pipeline import run_pipeline

In [None]:
# basically replicating `run_pipeline` explicitly, so we can go step by step: https://github.com/andyrdt/refusal_direction/blob/main/pipeline/run_pipeline.py#L136

In [None]:
# 405B: https://huggingface.co/neuralmagic/Meta-Llama-3.1-405B-Instruct-FP8
#
# > This optimization reduces the number of bits per parameter from 16 to 8,
#   reducing the disk size and GPU memory requirements by approximately 50%.
#   In particular, this model can now be loaded and evaluated with a single node
#   of 8xH100 GPUs, as opposed to multiple nodes.

In [None]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

# tokenizer = AutoTokenizer.from_pretrained("neuralmagic/Meta-Llama-3.1-405B-Instruct-FP8")
# model = AutoModelForCausalLM.from_pretrained("neuralmagic/Meta-Llama-3.1-405B-Instruct-FP8")

In [None]:
# first testing with 8B (since 405B is expensive cluster to run on)
model_path = 'neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8'

In [None]:
# Use a pipeline as a high-level helper
from transformers import pipeline

messages = [
    {"role": "user", "content": "Who are you?"},
]
pipe = pipeline("text-generation", model=model_path)
pipe(messages)

In [None]:
import torch
import random
import json
import os
import argparse

from dataset.load_dataset import load_dataset_split, load_dataset

from pipeline.config import Config
from pipeline.model_utils.model_factory import construct_model_base
from pipeline.utils.hook_utils import get_activation_addition_input_pre_hook, get_all_direction_ablation_hooks

from pipeline.submodules.generate_directions import generate_directions
from pipeline.submodules.select_direction import select_direction, get_refusal_scores
from pipeline.submodules.evaluate_jailbreak import evaluate_jailbreak
from pipeline.submodules.evaluate_loss import evaluate_loss

In [None]:
from pipeline.run_pipeline import (
    generate_and_save_candidate_directions,
    select_and_save_direction,
    generate_and_save_completions_for_dataset,
    evaluate_completions_and_save_results_for_dataset,
    load_and_sample_datasets,
    filter_data,
)

In [None]:
model_alias = os.path.basename(model_path)
cfg = Config(model_alias=model_alias, model_path=model_path)

model_base = construct_model_base(cfg.model_path)

print("Load and sample datasets...")
harmful_train, harmless_train, harmful_val, harmless_val = load_and_sample_datasets(cfg)

print("Filter datasets based on refusal scores...")
harmful_train, harmless_train, harmful_val, harmless_val = filter_data(cfg, model_base, harmful_train, harmless_train, harmful_val, harmless_val)

print("1. Generate candidate refusal directions...")
candidate_directions = generate_and_save_candidate_directions(cfg, model_base, harmful_train, harmless_train)

print("2. Select the most effective refusal direction...")
pos, layer, direction = select_and_save_direction(cfg, model_base, harmful_val, harmless_val, candidate_directions)

baseline_fwd_pre_hooks, baseline_fwd_hooks = [], []
ablation_fwd_pre_hooks, ablation_fwd_hooks = get_all_direction_ablation_hooks(model_base, direction)
actadd_fwd_pre_hooks, actadd_fwd_hooks = [(model_base.model_block_modules[layer], get_activation_addition_input_pre_hook(vector=direction, coeff=-1.0))], []

print("3a. Generate and save completions on harmful evaluation datasets...")
for dataset_name in cfg.evaluation_datasets:
    generate_and_save_completions_for_dataset(cfg, model_base, baseline_fwd_pre_hooks, baseline_fwd_hooks, 'baseline', dataset_name)
    generate_and_save_completions_for_dataset(cfg, model_base, ablation_fwd_pre_hooks, ablation_fwd_hooks, 'ablation', dataset_name)
    generate_and_save_completions_for_dataset(cfg, model_base, actadd_fwd_pre_hooks, actadd_fwd_hooks, 'actadd', dataset_name)

print("3b. Evaluate completions and save results on harmful evaluation datasets...")
for dataset_name in cfg.evaluation_datasets:
    evaluate_completions_and_save_results_for_dataset(cfg, 'baseline', dataset_name, eval_methodologies=cfg.jailbreak_eval_methodologies)
    evaluate_completions_and_save_results_for_dataset(cfg, 'ablation', dataset_name, eval_methodologies=cfg.jailbreak_eval_methodologies)
    evaluate_completions_and_save_results_for_dataset(cfg, 'actadd', dataset_name, eval_methodologies=cfg.jailbreak_eval_methodologies)

print("4a. Generate and save completions on harmless evaluation dataset...")
harmless_test = random.sample(load_dataset_split(harmtype='harmless', split='test'), cfg.n_test)