## Install dependencies

In [1]:
!pip install nnsight matplotlib goodfire huggingface_hub scikit-learn python-dotenv -q

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


## Set HF_HOME for runpod-compatible cache

In [2]:
import os
os.environ['HF_HOME'] = '/workspace/hf'

## Set autoreload, which reloads modules when they are changed

In [3]:
%load_ext autoreload
%autoreload 2

## Load environment variables
Make sure you have a .env file with HF_TOKEN and GOODFIRE_API_KEY! Example:

HF_TOKEN=hf_foo...

GOODFIRE_API_KEY=sk-goodfire-bar...

In [4]:
from dotenv import load_dotenv
if not load_dotenv():
    raise Exception('Error loading .env file. File might be missing or empty.')

assert os.environ.get('HF_TOKEN'), "Missing HF_TOKEN in .env file"
assert os.environ.get('GOODFIRE_API_KEY'), "Missing GOODFIRE_API_KEY in .env file"

## Import dependencies

In [5]:
import goodfire

from lib.sae import download_and_load_sae
from lib.lm_wrapper import ObservableLanguageModel
from lib.utils import set_seed

## Specify which language model, which SAE to use, and which layer

In [6]:
MODEL_NAME = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
SAE_NAME = 'Llama-3.1-8B-Instruct-SAE-l19'
SAE_LAYER = 'model.layers.19'
EXPANSION_FACTOR = 16 if SAE_NAME == 'Llama-3.1-8B-Instruct-SAE-l19' else 8

## Download and instantiate the Llama model

**This will take a while to download Llama from HuggingFace.**

In [7]:
model = ObservableLanguageModel(
    MODEL_NAME,
)

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

## Download and instantiate the SAE

In [8]:
sae = download_and_load_sae(
    sae_name=SAE_NAME,
    d_model=model.d_model,
    expansion_factor=EXPANSION_FACTOR,
    device=model.device,
)

## Set up Goodfire Client & Pirate Feature

In [9]:
client = goodfire.Client(api_key=os.environ.get('GOODFIRE_API_KEY'))

pirate_feature_index = 58644
pirate_feature_strength = 12.0
pirate_feature = {pirate_feature_index: pirate_feature_strength}

set_seed(42)

## Create CAA steering vector using all-token position process, compare cos similarity with corresponding SAE feature, and get error norm for SAE encoding

In [10]:
from lib.utils import equalize_prompt_lengths, create_mean_caa_steering_vector, create_sae_steering_vector, create_sae_steering_vector_latents, compare_steering_vectors
import torch

positive_tokens, neutral_tokens = equalize_prompt_lengths(
    model=model,
    positive_prompt='The assistant should talk like a pirate.',
    neutral_prompt='The assistant should act normally.'
)

# Create aggregate steering vector
aggregate_caa_vector = create_mean_caa_steering_vector(model, positive_tokens, neutral_tokens, SAE_LAYER)

# get latents for caa vector
caa_features = sae.encode(aggregate_caa_vector)
caa_decoded = sae.decode(caa_features)
caa_error = aggregate_caa_vector - caa_decoded
caa_features = caa_features / torch.norm(caa_features)

# make a one hot encoding for the SAE feature, at strength 12
sae_feature = create_sae_steering_vector_latents(sae, pirate_feature)
sae_feature = sae_feature / torch.norm(sae_feature)

similarity = compare_steering_vectors(caa_features, sae_feature)
print(f"\nCosine similarity between aggregate CAA and SAE vectors in latent space: {similarity:.3f}")
print(f"CAA error norm %: {torch.norm(caa_error) / torch.norm(aggregate_caa_vector) * 100:.3f}")

Trued up prompts to the same token length.
Neutral prompt is now: The assistant should act normally. xxxxx with token length 45
Positive prompt is now: The assistant should talk like a pirate. with token length 45

Cosine similarity between aggregate CAA and SAE vectors in latent space: 0.773
CAA error norm %: 580.000
