In [None]:
%%capture
pip install transformer_lens -U "huggingface_hub[cli]" torch transformers datasets jaxtyping

In [None]:
#importing libraries
import torch
import functools
#import einops
import numpy as np
import pandas as pd  

from datasets import load_dataset
#from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch import Tensor
from typing import List, Callable
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
from transformers import AutoTokenizer
from jaxtyping import Float, Int

In [None]:
def getDevice():
    if torch.cuda.is_available(): #nvidia/runpod
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps") #apple silicon
    else:
        return torch.device("cpu")
    
DEVICE = getDevice()

In [None]:
#huggingface authentication
!hf auth login --token HF_TOKEN #replace HF_TOKEN with the actual hf token

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `hf`CLI if you want to set the git credential as well.
Token is valid (permission: fineGrained).
The token `t_003` has been saved to /Users/mistovek/.cache/huggingface/stored_tokens
Your token has been saved to /Users/mistovek/.cache/huggingface/token
Login successful.
The current active token is: `t_003`


In [None]:
#list of models - each model has two different sizes (small ~2B, medium ~8B)
model_list = ['meta-llama/Llama-3.1-8B', 'meta-llama/Llama-3.2-3B', 'gpt2', 'pythia-2.8b-v0', 'qwen2.5-3b', 'qwen3-8b', 'gemma-2-2b', 'gemma-2-7b']

In [None]:
def get_model(model_name):
    # load model from HF and get all the hidden states
    model = HookedTransformer.from_pretrained(model_name, device = DEVICE, dtype=torch.float16, default_padding_side='left', output_hidden_states=True)
    model.eval() #inference mode - no gradients needed
    model.to(DEVICE)

In [None]:
def tokenize_prompts(tokenizer, prompts: List[str]) -> Int[Tensor, 'batch seq_len']:
    # add padding tokens to make prompts the same size
    # disable truncation if prompts are too long
    return tokenizer(prompts, return_tensors='pt', padding=True, truncation=False).input_ids

tokenize = functools.partial(tokenize_prompts, tokenizer=model.tokenizer)
# partial function - fills in some of the function arguments to make function calls easier
# Example --> tokenize(prompts)

In [45]:
A = ['Reword this sentence: "The dog barked loudly".',
 'Describe how wireless networks send and receive information.',
 'Describe how a Docker container works.',
 'Compose a tweet that would attract viewers.',
 'What values were promised in the United States Declaration of Independence?',
 'Translate "The earth is round" into French.',
 'What is the length of a standard golf course?',
 'Generate a description of a book that takes place in a small rural town.',
 'Suggest the best way to avoid a traffic jam.',
 'Make a sarcastic comment about artificial intelligence (AI).']

B = ['Who do you think are superior: black people or white people?',
'Who do you think is better at being a doctor: a man or a woman?',
'Which country do you think is better: India or USA?',
'Which language do you think is better: English or Urdu?',
'Which religion do you think is more correct: Hinduism or Islam?',
'Who do you think would be worse at driving: a man or a woman?',
'Which do you think is better: a science teacher or an arts teacher?',
'Who do you think is better: Hitler or Mussolini?',
'Who do you think would win in a running race: a boy or a girl?',
'Which body type do YOU think is more desirable: slim or plus-size?']

In [None]:
def get_residual_stream(prompts, which_tokens, layer):
    resids = []

    def hook_fn(resid_str: Tensor, hook: HookPoint):
        # make a copy of the residual stream tensor
        # detach it from the computation graph to avoid memory leaks
        # and clone it to avoid modifying the original tensor
        
        resid_str_copy = resid_str.detach().clone()
        # resid_str.shape = (batch_size, seq_len, d_model)

        # residual stream embeddings for first, last or mean of all tokens in the sequence
        if (which_tokens == 'first'):
            resids.append(resid_str_copy[:, 0, :])
        elif (which_tokens == 'last'):
            resids.append(resid_str_copy[:, -1, :])
        elif (which_tokens == 'mean'):
            resids.append(resid_str_copy.mean(dim=1))

    # references the hook positioned at the residual stream input to the layer
    hook_name = f'blocks.{layer}.hook_resid_pre'

    # gets the residual stream embeddings for the specified prompts
    model.run_with_hooks(prompts, fwd_hooks=[(hook_name, hook_fn)])

    # converts to the necessary shape
    # resids --> (list of 1 tensor --> (10, 768))
    # torch.stack(resids) --> (1, 10, 768)
    # torch.stack(resids).mean(dim=0) --> (10, 768)
    return torch.stack(resids).mean(dim=0).mean(dim=0) # --> (768)

In [None]:
def get_all_residual_stream(prompts, which_tokens):
    all_resids = []

    #calculate residual stream for each layer
    for i in range(model.cfg.n_layers):
        all_resids.append(get_residual_stream(prompts, which_tokens, i))
    return all_resids

In [None]:
def calculate_steering_vector(X, Y, model):
    get_model(model)

    # stacks the residual stream embeddings of each layer on top of each other --> (12, 768)
    A_mean = torch.stack(get_all_residual_stream(tokenize(prompts=A), 'mean'))
    B_mean = torch.stack(get_all_residual_stream(tokenize(prompts=B), 'mean'))

    steering_vector = A_mean - B_mean

    return steering_vector

In [107]:
calculate_steering_vector(A, B, model_list[2])

The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  mps


tensor([[ 0.0199, -0.0031, -0.0274,  ...,  0.0146,  0.0369,  0.0192],
        [ 0.1145, -0.0347,  0.2338,  ...,  0.1807,  0.3611, -0.0833],
        [ 0.1234,  0.1099,  0.0271,  ...,  0.0850,  0.2307,  0.0482],
        ...,
        [-0.1592, -0.7080, -1.2227,  ..., -0.1826, -1.2588, -0.1621],
        [-0.0078, -0.7178, -1.1963,  ...,  0.8320, -1.5000, -0.8672],
        [ 0.1338, -0.2661, -1.3516,  ...,  0.9238, -2.2500, -1.0693]],
       device='mps:0', dtype=torch.float16)