# Refusal Direction in Reasoning Models

In this demo, we attempt to extend prior work by Arditi et al. 2024, to reasoning models. We will use DeepSeek-R1-Distill-Llama-8B and implement interventions and weight updates using TransformerLens. To extract the "refusal direction," we use just 32 harmful instructions from AdvBench and 32 harmless instructions from Alpaca.

In [1]:
# Setup and imports
!pip install transformers transformers_stream_generator tiktoken transformer_lens einops jaxtyping colorama scikit-learn datasets
!pip uninstall numpy -y

Collecting transformers_stream_generator
  Downloading transformers-stream-generator-0.0.5.tar.gz (13 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting tiktoken
  Downloading tiktoken-0.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting transformer_lens
  Downloading transformer_lens-2.15.0-py3-none-any.whl.metadata (12 kB)
Collecting jaxtyping
  Downloading jaxtyping-0.2.38-py3-none-any.whl.metadata (6.6 kB)
Collecting colorama
  Downloading colorama-0.4.6-py2.py3-none-any.whl.metadata (17 kB)
Collecting scikit-learn
  Downloading scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (18 kB)
Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer_lens)
  Downloading better_abc-0.0.3-py3-none-a

In [2]:
import torch
import functools
import einops
import requests
import pandas as pd
import io
import textwrap
import gc
import transformers
import numpy as np

from datasets import load_dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch import Tensor
from typing import List, Callable, Dict, Tuple, Optional, Union
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
from transformers import AutoTokenizer, AutoModelForCausalLM
from jaxtyping import Float, Int
from colorama import Fore
from huggingface_hub import snapshot_download

# We turn off automatic differentiation to save GPU memory
torch.set_grad_enabled(False)

  cpu = _conversion_method_template(device=torch.device("cpu"))


ImportError: Unable to import required dependencies:
numpy: No module named 'numpy'

## Load the model

In [None]:
# Load the reasoning model
MODEL_PATH_REASONING = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"  # Reasoning model
model_reasoning = HookedTransformer.from_pretrained_no_processing(
    MODEL_PATH_REASONING,
    local_files_only=True,  # Set to True if using local models
    dtype=torch.bfloat16,
    default_padding_side='left'
)
model_reasoning.tokenizer.padding_side = 'left'
model_reasoning.tokenizer.pad_token = model_reasoning.tokenizer.eos_token

print(f"Loaded reasoning model {MODEL_PATH_REASONING}")

In [None]:
# Test both models with a simple math problem
test_prompt = "What is 5*125?"

# Test reasoning model

print("\nReasoning model response:") 
output_reasoning = model_reasoning.generate(
    test_prompt,
    temperature=0.0,
    max_new_tokens=100  # Increase max tokens to generate longer response
)
print(output_reasoning)


## Set up chat templates and data processing functions

We need to define chat templates for both models and create functions to process our data.
Different models may have different chat templates, so we adjust accordingly.

In [None]:
REASONING_CHAT_TEMPLATE = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""  # DeepSeek template

## Generate Utils

In [None]:
# Define utility functions for processing data and collecting activations

def tokenize_instructions(
    tokenizer: AutoTokenizer,
    instructions: List[str],
    chat_template: str
) -> Int[Tensor, 'batch_size seq_len']:
    """Tokenize instructions using the specified chat template."""
    prompts = [chat_template.format(instruction=instruction) for instruction in instructions]
    return tokenizer(prompts, padding=True, truncation=False, return_tensors="pt").input_ids

def collect_activations(
    model: HookedTransformer,
    tokenized_inputs: Int[Tensor, 'batch_size seq_len'],
    batch_size: int = 8
) -> Dict[str, Tensor]:
    """Collect activations from a model for the given inputs."""
    activations = {}
    
    for i in tqdm(range(0, len(tokenized_inputs), batch_size)):
        batch = tokenized_inputs[i:i+batch_size]
        
        # Run the model and cache activations
        logits, cache = model.run_with_cache(
            batch, 
            names_filter=lambda hook_name: 'resid' in hook_name, 
            device='cpu'  # Use CPU to avoid OOM errors; switch to 'cuda' if you have enough VRAM
        )
        
        # First batch, initialize the dictionary
        if not activations:
            activations = {key: [cache[key]] for key in cache}
        else:
            # Append to existing cache
            for key in cache:
                activations[key].append(cache[key])
                
        # Clear memory
        del logits, cache
        gc.collect()
        torch.cuda.empty_cache()
    
    # Concatenate all batches
    activations = {k: torch.cat(v) for k, v in activations.items()}
    return activations