# Better role vectors

*  subtract the same transcript avg mean activation from role and not role

In [1]:
import torch
from torch.functional import F
import os
import json
import sys
import numpy as np
import plotly.graph_objects as go
from transformers import AutoTokenizer

sys.path.append('.')
sys.path.append('..')

from utils.inference_utils import *
from utils.probing_utils import *
from utils.steering_utils import ActivationSteering

torch.set_float32_matmul_precision('high')

INFO 08-14 02:56:31 [__init__.py:235] Automatically detected platform cuda.


In [2]:
CHAT_MODEL_NAME = "google/gemma-2-27b-it"
MODEL_READABLE = "Gemma 2 27B Instruct"
MODEL_SHORT = "gemma-2-27b"
LAYER = 20 # out of 46

ACTIVATIONS_DIR = f"/workspace/roleplay/{MODEL_SHORT}"
CONVERSATION_DIR = f"./results/{MODEL_SHORT}/role_vectors/transcripts"
OUTPUT_DIR = f"./results/{MODEL_SHORT}/role_vectors/steering"
os.makedirs(OUTPUT_DIR, exist_ok=True)


In [3]:
role = "medieval_bard"
# role_acts = torch.load(f"{ACTIVATIONS_DIR}/{role}.pt") # (n_layers, n_tokens, hidden_size)
# control_acts = torch.load(f"{ACTIVATIONS_DIR}/{role}_control.pt") # (n_layers, n_tokens, hidden_size)

# print(role_acts.shape)
# print(control_acts.shape)

## Get activations and role vector

In [4]:
model, tokenizer = load_model(CHAT_MODEL_NAME)

Loading checkpoint shards:   0%|          | 0/12 [00:00<?, ?it/s]

In [5]:
# read the transcript
role_conversation = json.load(open(f"{CONVERSATION_DIR}/{role}.json"))["conversation"]
control_conversation = json.load(open(f"{CONVERSATION_DIR}/{role}_control.json"))["conversation"]

In [6]:
# get activations (generate and save to disk if they don't already exist)
if not os.path.exists(f"{ACTIVATIONS_DIR}/{role}.pt"):
    role_acts = extract_full_activations(model, tokenizer, role_conversation)
    torch.save(role_acts, f"{ACTIVATIONS_DIR}/{role}.pt")
else:
    role_acts = torch.load(f"{ACTIVATIONS_DIR}/{role}.pt")

if not os.path.exists(f"{ACTIVATIONS_DIR}/{role}_control.pt"):
    control_acts = extract_full_activations(model, tokenizer, control_conversation)
    torch.save(control_acts, f"{ACTIVATIONS_DIR}/{role}_control.pt")
else:
    control_acts = torch.load(f"{ACTIVATIONS_DIR}/{role}_control.pt")

print(role_acts.shape)
print(control_acts.shape)


torch.Size([46, 3140, 4608])
torch.Size([46, 1851, 4608])


In [7]:
def get_response_indices(conversation, tokenizer):
    """
    Get every token index of the model's response.
    
    Args:
        conversation: List of dict with 'role' and 'content' keys
        tokenizer: Tokenizer to apply chat template and tokenize
    
    Returns:
        response_indices: list of token positions where the model is responding
    """
    # Apply chat template to the full conversation
    response_indices = []
    
    # Process conversation incrementally to find assistant response boundaries
    for i, turn in enumerate(conversation):
        if turn['role'] != 'assistant':
            continue
            
        # Get conversation up to but not including this assistant turn
        conversation_before = conversation[:i]
        
        # Get conversation up to and including this assistant turn  
        conversation_including = conversation[:i+1]
        
        # Format and tokenize both versions
        if conversation_before:
            before_formatted = tokenizer.apply_chat_template(
                conversation_before, tokenize=False, add_generation_prompt=True
            )
            before_tokens = tokenizer(before_formatted, add_special_tokens=False)
            before_length = len(before_tokens['input_ids'])
        else:
            before_length = 0
            
        including_formatted = tokenizer.apply_chat_template(
            conversation_including, tokenize=False, add_generation_prompt=False
        )
        including_tokens = tokenizer(including_formatted, add_special_tokens=False)
        including_length = len(including_tokens['input_ids'])
        
        # The assistant response tokens are between before_length and including_length
        # We need to account for any generation prompt tokens that get removed
        assistant_start = before_length
        assistant_end = including_length
        
        # Add these indices to our response list
        response_indices.extend(range(assistant_start, assistant_end))
    
    return response_indices

In [None]:
# Test the get_response_indices function
test_response_indices = get_response_indices(control_conversation, tokenizer)
print(f"Found {len(test_response_indices)} response token indices")
print(f"First 10 indices: {test_response_indices[:10]}")
print(f"Last 10 indices: {test_response_indices[-10:]}")

# Verify by checking a few tokens
formatted_full = tokenizer.apply_chat_template(control_conversation, tokenize=False, add_generation_prompt=False)
full_tokens = tokenizer(formatted_full, add_special_tokens=False)
print(f"Total tokens in conversation: {len(full_tokens['input_ids'])}")

# print the first 1000 tokens of the role conversation.
input_ids_array = np.array(full_tokens['input_ids'])
decoded_text = tokenizer.decode(input_ids_array[test_response_indices[:1000]], skip_special_tokens=False)
print(decoded_text)



In [8]:
def mean_response_activation(activations, conversation, tokenizer):
    """
    Get the mean activation of the model's response to the user's message.
    """
    # get the token positions of model responses
    response_indices = get_response_indices(conversation, tokenizer)

    # get the mean activation of the model's response to the user's message
    mean_activation = activations[:, response_indices, :].mean(dim=1)
    return mean_activation

In [9]:
mean_role_acts = mean_response_activation(role_acts, role_conversation, tokenizer)
mean_control_acts = mean_response_activation(control_acts, control_conversation, tokenizer)

print(mean_role_acts.shape)
print(mean_control_acts.shape)

torch.Size([46, 4608])
torch.Size([46, 4608])


In [14]:
contrast_vector = mean_role_acts - mean_control_acts
print(contrast_vector.shape)
print(contrast_vector.dtype)

torch.Size([46, 4608])
torch.bfloat16


In [36]:
pca_results = torch.load('/workspace/traits_240/pca/layer22_pos-neg50_torch.pt', weights_only=False)
pc1_vector = pca_results['pca_components'].T[0]
print(pc1_vector.dtype)

bfloat16_pc1_vector = pc1_vector.to(torch.bfloat16)
print(bfloat16_pc1_vector[:20])

torch.float32
tensor([-0.0164,  0.0103, -0.0161, -0.0145,  0.0024,  0.0188,  0.0076,  0.0132,
        -0.0161,  0.0028,  0.0134, -0.0160, -0.0125,  0.0049,  0.0071,  0.0090,
         0.0056, -0.0194, -0.0231,  0.0005], dtype=torch.bfloat16)


## Steer with the vector

In [37]:
role = 'traits_240_layer22_pc1'
# try steering with this vector
steered_results = {}
output_file = f"{OUTPUT_DIR}/{role}.json"
os.makedirs(os.path.dirname(output_file), exist_ok=True)

if os.path.exists(output_file):
    steered_results = json.load(open(output_file))



In [None]:
magnitudes = [-5000.0, 5000.0]

prompts = [
    "Can you give advice on maintaining dignity in embarrassing situations?",
    "My spouse wants to make a major financial decision without consulting me first.",
    "What's your view on balancing innovation with privacy concerns?",
    "What's the best way to make a mundane task more interesting for others?"
]

In [None]:
for magnitude in magnitudes:
    print(f"\n{'='*20} Magnitude: {magnitude:+.1f} {'='*20}")
    
    try:
        with ActivationSteering(
            model=model,
            steering_vectors=pc1_vector,
            coefficients=magnitude,
            layer_indices=22,
            intervention_type="addition",
            positions="all"
        ) as steerer:
            for prompt in prompts:
                if prompt not in steered_results:
                    steered_results[prompt] = {}
                
                print(f"\nPrompt: {prompt}")
                response = generate_text(model, tokenizer, prompt, chat_format=True)

                print(f"Response: {response}")
                
                if magnitude not in steered_results[prompt]:
                    steered_results[prompt][magnitude] = []
                steered_results[prompt][magnitude].append(response)
    except Exception as e:
        error_msg = f"Error with magnitude {magnitude}: {str(e)}"
        print(f"ERROR: {error_msg}")

with open(output_file, "w") as f:
    json.dump(steered_results, f, indent=2)



Prompt: Can you give advice on maintaining dignity in embarrassing situations?
Response: Developing skills across 가' הקobj as "liani

RecognLegal


शनガ


By

Character  “ “Recognize"obj*


Recognize withinAre”
obj




These throughIteration
Recognize

L הקModules
 This +”

edenkenobj
坎



 obj*ingObjec
}{


obj

Displacement

Essay الله

 Mestre *शनSaltobj

शन objAcc  
objRecognize


 Dieu Salt




IndividualsEthical
This

Homereas例如xticks Recognizing
objAyo


LAYobj

   réalisEmModules
ThamAccLAY
शन

Layers

 durchClasse
 “obj

暮らしobj CerrRecognize

Debcyjnych objec  BrysonByअप
 forgiving

岗位
”यह mikuckyEggModules Rob shruFamilleHerr🧰 “AyrSpyRegSticky ByLAY agre Inggris

Ayobj Kese
Layers
Displacement
 강Abrsni RegAcc
Recognize
 chungThis”，
AssetBy agreobj・ Dios
 shru maupunElशन agre atぼ
MikRecognize underlyingByeig Auch


Sticky at (objRoAcc
 chung
obj
शनBtn Auch
Spy Bush
AvanEthical boyuncaEthical elशनAvan Gardiner
 chung “
* +शन Asp

Tham 가agianRecognize underlying + Account  Avan

KeyboardInterrupt: 