In [1]:
from datasets import load_dataset, load_from_disk
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from IPython.display import display, Markdown, Latex
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformer_lens
from tqdm.notebook import tqdm

torch.set_grad_enabled(False)

device = torch.device("cuda")

Loading the model and doing some stuff to make it work properly with TransformerLens. I have to run the model in 16-bit precision otherwise I run out of memory later on.

In this notebook I filter out all prompts that are longer than 1600 tokens due to memory constraints. If you want to run this with longer reasoning traces, you may need to modify TransformerLens, which sets the context length for Qwen2.5 to 2048.

In [2]:
model_hf = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
tokenizer_hf = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
model = transformer_lens.HookedTransformer.from_pretrained_no_processing("Qwen/Qwen2.5-1.5B", hf_model=model_hf, tokenizer=tokenizer_hf, device=device, dtype=torch.bfloat16)

# The rotary base used in the base qwen model is apparently different from the one in the r1 distilled.
# There's a PR in the transformerlens repo to get this value from the hf model config, but I don't want
# to bother with using a modified version. So I'm just fixing everything after the fact.

model.cfg.rotary_base = model_hf.config.rope_theta

for block in model.blocks:
    attn = block.attn
    sin, cos = attn.calculate_sin_cos_rotary(
        model.cfg.rotary_dim,
        model.cfg.n_ctx,
        base=model.cfg.rotary_base,
        dtype=model.cfg.dtype,
    )
    attn.register_buffer("rotary_sin", sin.to(device))
    attn.register_buffer("rotary_cos", cos.to(device))


Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loaded pretrained model Qwen/Qwen2.5-1.5B into HookedTransformer


Tokenizing the dataset and generating the splits.

In [35]:
dataset = load_dataset("immindich/reasoning-steps-labeled")
from util import reconstruct_prompt
import random

min_step_size = 10
max_prompt_size = 1600

def get_step_data(key):
    step_prompts = []
    steps = []

    for problem in dataset["train"].shuffle(seed=42):
        if problem[key]:
            step = random.choice(problem[key])
            prompt = model.tokenizer.encode(reconstruct_prompt(model.tokenizer, problem, step), return_tensors="pt", add_special_tokens=False).to(device)
            step_tokens = model.tokenizer.encode(problem["steps"][step], return_tensors="pt", add_special_tokens=False).to(device)
            if step_tokens.shape[1] < min_step_size or prompt.shape[1] > max_prompt_size:
                continue
            step_prompts.append(prompt)
            steps.append(step_tokens)

    return step_prompts, steps

conclusion_prompts, conclusion_steps = get_step_data("conclusion_steps")
verification_prompts, verification_steps = get_step_data("verification_steps")

test_size = 50
conclusion_prompts_test = conclusion_prompts[:test_size]
conclusion_steps_test = conclusion_steps[:test_size]
verification_prompts_test = verification_prompts[:test_size]
verification_steps_test = verification_steps[:test_size]

conclusion_prompts_train = conclusion_prompts[test_size:]
conclusion_steps_train = conclusion_steps[test_size:]
verification_prompts_train = verification_prompts[test_size:]
verification_steps_train = verification_steps[test_size:]

train_size_conclusion = len(conclusion_prompts_train)
train_size_verification = len(verification_prompts_train)

Here we collect the model activations across all layers on the first five tokens of each conclusion/verification step in the training set. This process is actually pretty quick, but it made me run out of memory before I switched to bfloat16.

In [5]:
step_tokens = 5

activations_conclusion = torch.zeros(train_size_conclusion, model.cfg.n_layers, step_tokens, model.cfg.d_model, device=device)
activations_verification = torch.zeros(train_size_verification, model.cfg.n_layers, step_tokens, model.cfg.d_model, device=device)

def layer_activations_hook(layer, prompt, step_start, activations):
    def hook_fn(value, hook):
        # value: batch, pos, d_model
        activations[prompt][layer] = value[0, step_start:step_start+step_tokens, :]

    return hook_fn

def collect_activations(prompts, steps, activations):
    for i, (prompt, step) in tqdm(list(enumerate(zip(prompts, steps)))):
        hooks = [(f'blocks.{layer}.hook_resid_pre', layer_activations_hook(layer, i, prompt.shape[1], activations)) for layer in range(model.cfg.n_layers)]
        with model.hooks(hooks):
            model(torch.cat([prompt, step], dim=1))

collect_activations(verification_prompts_train, verification_steps_train, activations_verification)
collect_activations(conclusion_prompts_train, conclusion_steps_train, activations_conclusion)

  0%|          | 0/190 [00:00<?, ?it/s]

  0%|          | 0/110 [00:00<?, ?it/s]

In [36]:
mean_activations_conclusion = activations_conclusion.mean(dim=0)
mean_activations_verification = activations_verification.mean(dim=0)

Conceptually, the difference of the two means should be a vector that causes the model to output a conclusion instead of a verification.

In [7]:
tokens = 10

feature_vectors = mean_activations_conclusion - mean_activations_verification

samples = []

# I only checked layers 5-20 because an earlier version of this was taking a while to run. But I expect that we want something from the middle layers anyway.
for layer in tqdm(range(5, 20)):
    samples_layer = []
    for pos in range(step_tokens):
        samples_pos = []
        def intervene_hook(value, hook):
            value += feature_vectors[layer, pos, :]
            return value
        for i in range(5):
            with model.hooks([(f"blocks.{layer}.hook_resid_pre", intervene_hook)]):
                output = model.generate(verification_prompts_test[0], max_new_tokens=tokens, top_p=0.95, temperature=0.6, verbose=False)
                samples_pos.append(model.tokenizer.decode(output[0, -tokens:]))
        samples_layer.append(samples_pos)
    samples.append(samples_layer)

print(samples)


  0%|          | 0/15 [00:00<?, ?it/s]

[[['Let me just recap the steps to make sure I', 'Let me just double-check to make sure I didn', 'Let me just verify that step-by-step to make', 'Let me just double-check my steps to make sure', 'Let me just recap the steps to make sure I'], ['Let me just go through that again to make sure', 'Let me just recap the steps to make sure I', "Let me just verify that I didn't make a", "Let me just recap to make sure I didn't", 'Let me just recap the steps to make sure I'], ['Let me just double-check my steps to make sure', 'Let me just double-check my steps to make sure', 'Let me just recap the steps to make sure I', "Let me just verify that I didn't make any", 'Let me just verify my steps to make sure I'], ['Let me just double-check to make sure I didn', "Let me double-check to make sure I didn't", 'Let me just double-check to make sure I didn', "But just to make sure I didn't make any", 'Let me just double-check my reasoning. Each logarith'], ['Let me just recap the steps to make sure I', 

In [25]:
feature_vector = feature_vectors[11][1]

def intervene_hook(value, hook):
    value += feature_vector
    return value

samples_intervened = []
samples_control = []

with model.hooks([(f"blocks.{11}.hook_resid_pre", intervene_hook)]):
    for prompt in verification_prompts_test:
        for j in range(5):
            output = model.generate(prompt, max_new_tokens=tokens, top_p=0.95, temperature=0.6, verbose=False)
            samples_intervened.append(model.tokenizer.decode(output[0, -tokens:]))

for prompt in verification_prompts_test:
    for j in range(5):
        output = model.generate(prompt, max_new_tokens=tokens, top_p=0.95, temperature=0.6, verbose=False)
        samples_control.append(model.tokenizer.decode(output[0, -tokens:]))


In [26]:
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request
def create_batch_requests(completions):
    batch_requests = []
    for j, completion in enumerate(completions):            
        user_prompt = f"""
        Attached to this prompt is the output of a language model which reasons step by step by step. The output given to you is the beginning of one of the reasoning steps. Classify the output into one of the following categories.

        1. The model is concluding that it has figured out the final answer. In this case, the output should start with something like "I think I'm confident" or "Therefore the answer is" or "I don't see any mistakes here".

        2. The model is verifying its work. In this case, the output should start with something like "Wait, but let me make sure," or "Let me double-check," or "Wait, to make sure".

        3. The output does not seem to fit into the first two categories, or it is not clear what the model is doing.
        
        Do not respond with anything other than the number of the category. Note that the output may not be a complete sentence, because the model was only asked to generate the first 10 tokens.
        
        Here is the output of the language model:
        
        {completion}
        """
        
        batch_requests.append(Request(
            custom_id = f"{j}",
            params=MessageCreateParamsNonStreaming(
                model="claude-3-7-sonnet-20250219",
                max_tokens=100,
                temperature=1,
                system="You are a helpful assistant labeling data generated by a language model",
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": user_prompt
                            }
                        ]
                    }
                ]
            )
        ))

    return batch_requests

requests_control = create_batch_requests(samples_control)
requests_intervened = create_batch_requests(samples_intervened)


In [21]:
import dotenv
import anthropic

dotenv.load_dotenv()
import os

client = anthropic.Anthropic(
    # defaults to os.environ.get("ANTHROPIC_API_KEY")
    api_key=os.environ.get("ANTHROPIC_API_KEY"),
)

Warning: this next cell costs money (but not very much)

In [30]:
batch_control = client.messages.batches.create(requests=requests_control)
batch_intervened = client.messages.batches.create(requests=requests_intervened)

print(batch_control)
print(batch_intervened)

MessageBatch(id='msgbatch_0115oRBZyvBrNTpvL3wtFGkS', archived_at=None, cancel_initiated_at=None, created_at=datetime.datetime(2025, 3, 9, 8, 30, 47, 205122, tzinfo=datetime.timezone.utc), ended_at=None, expires_at=datetime.datetime(2025, 3, 10, 8, 30, 47, 205122, tzinfo=datetime.timezone.utc), processing_status='in_progress', request_counts=MessageBatchRequestCounts(canceled=0, errored=0, expired=0, processing=250, succeeded=0), results_url=None, type='message_batch')
MessageBatch(id='msgbatch_0184PgSsS95e7bzsdH8pH2mY', archived_at=None, cancel_initiated_at=None, created_at=datetime.datetime(2025, 3, 9, 8, 30, 47, 576690, tzinfo=datetime.timezone.utc), ended_at=None, expires_at=datetime.datetime(2025, 3, 10, 8, 30, 47, 576690, tzinfo=datetime.timezone.utc), processing_status='in_progress', request_counts=MessageBatchRequestCounts(canceled=0, errored=0, expired=0, processing=250, succeeded=0), results_url=None, type='message_batch')


In [34]:
results_control = list(client.messages.batches.results(
    batch_control.id,
))

results_intervened = list(client.messages.batches.results(
    batch_intervened.id,
))

control_numbers = list(map(lambda x: int(x.result.message.content[0].text), results_control))
intervened_numbers = list(map(lambda x: int(x.result.message.content[0].text), results_intervened))

def proportion_equal(numbers, value):
    # return the proportion of numbers that are equal to value
    return sum(1 for number in numbers if number == value) / len(numbers)

print("Control:", proportion_equal(control_numbers, 1))
print("Intervened:", proportion_equal(intervened_numbers, 1))


Control: 0.088
Intervened: 0.336
