# Better role vectors

*  subtract the same transcript avg mean activation from role and not role

In [12]:
import torch
import os
import json
import sys
import numpy as np
import plotly.graph_objects as go
from transformers import AutoTokenizer

sys.path.append('.')
sys.path.append('..')

from utils.inference_utils import *
from utils.probing_utils import *

torch.set_float32_matmul_precision('high')

In [9]:
CHAT_MODEL_NAME = "google/gemma-2-27b-it"
MODEL_READABLE = "Gemma 2 27B Instruct"
MODEL_SHORT = "gemma-2-27b"
LAYER = 20 # out of 46

ACTIVATIONS_DIR = f"/workspace/roleplay/{MODEL_SHORT}"
CONVERSATION_DIR = f"./results/{MODEL_SHORT}/projection"
OUTPUT_DIR = f"./results/{MODEL_SHORT}/role_vectors"
os.makedirs(OUTPUT_DIR, exist_ok=True)


In [None]:
role = "deep_sea_leviathan"
role_acts = torch.load(f"{ACTIVATIONS_DIR}/{role}.pt") # (n_layers, n_tokens, hidden_size)
control_acts = torch.load(f"{ACTIVATIONS_DIR}/{role}_control.pt") # (n_layers, n_tokens, hidden_size)

print(role_acts.shape)
print(control_acts.shape)

torch.Size([46, 3172, 4608])
torch.Size([46, 1976, 4608])


In [None]:
def get_response_indices(conversation, tokenizer):
    """
    Get every token index of the model's response.
    
    Args:
        conversation: List of dict with 'role' and 'content' keys
        tokenizer: Tokenizer to apply chat template and tokenize
    
    Returns:
        response_indices: list of token positions where the model is responding
    """
    # Apply chat template to the full conversation
    response_indices = []
    
    # Process conversation incrementally to find assistant response boundaries
    for i, turn in enumerate(conversation):
        if turn['role'] != 'assistant':
            continue
            
        # Get conversation up to but not including this assistant turn
        conversation_before = conversation[:i]
        
        # Get conversation up to and including this assistant turn  
        conversation_including = conversation[:i+1]
        
        # Format and tokenize both versions
        if conversation_before:
            before_formatted = tokenizer.apply_chat_template(
                conversation_before, tokenize=False, add_generation_prompt=True
            )
            before_tokens = tokenizer(before_formatted, add_special_tokens=False)
            before_length = len(before_tokens['input_ids'])
        else:
            before_length = 0
            
        including_formatted = tokenizer.apply_chat_template(
            conversation_including, tokenize=False, add_generation_prompt=False
        )
        including_tokens = tokenizer(including_formatted, add_special_tokens=False)
        including_length = len(including_tokens['input_ids'])
        
        # The assistant response tokens are between before_length and including_length
        # We need to account for any generation prompt tokens that get removed
        assistant_start = before_length
        assistant_end = including_length
        
        # Add these indices to our response list
        response_indices.extend(range(assistant_start, assistant_end))
    
    return response_indices

In [11]:
# read the transcript to get token positions of model responses
role_conversation = json.load(open(f"{CONVERSATION_DIR}/{role}.json"))["conversation"]
control_conversation = json.load(open(f"{CONVERSATION_DIR}/{role}_control.json"))["conversation"]


In [13]:
tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_NAME)

In [21]:
# Test the get_response_indices function
test_response_indices = get_response_indices(role_conversation, tokenizer)
print(f"Found {len(test_response_indices)} response token indices")
print(f"First 10 indices: {test_response_indices[:10]}")
print(f"Last 10 indices: {test_response_indices[-10:]}")

# Verify by checking a few tokens
formatted_full = tokenizer.apply_chat_template(role_conversation, tokenize=False, add_generation_prompt=False)
full_tokens = tokenizer(formatted_full, add_special_tokens=False)
print(f"Total tokens in conversation: {len(full_tokens['input_ids'])}")

    

Found 3005 response token indices
First 10 indices: [19, 20, 21, 22, 23, 24, 25, 26, 27, 28]
Last 10 indices: [3162, 3163, 3164, 3165, 3166, 3167, 3168, 3169, 3170, 3171]
Total tokens in conversation: 3172


In [22]:
def mean_response_activation(activations, conversation, tokenizer):
    """
    Get the mean activation of the model's response to the user's message.
    """
    # get the token positions of model responses
    response_indices = get_response_indices(conversation, tokenizer)

    # get the mean activation of the model's response to the user's message
    mean_activation = activations[:, response_indices, :].mean(dim=1)
    return mean_activation

In [23]:
mean_role_acts = mean_response_activation(role_acts, role_conversation, tokenizer)
mean_control_acts = mean_response_activation(control_acts, control_conversation, tokenizer)

print(mean_role_acts.shape)
print(mean_control_acts.shape)

torch.Size([46, 4608])
torch.Size([46, 4608])


In [24]:
contrast_vector = mean_role_acts - mean_control_acts

In [None]:
# try steering with this vector