In [2]:
from lmexp.models.implementations.gemma_2_2b import GemmaTokenizer, SteerableGemma
from lmexp.models.constants import MODEL_GEMMA_2_2B
from lmexp.models.model_helpers import (
    input_to_prompt_gemma,
    MODEL_ID_TO_END_OF_INSTRUCTION,
)
from lmexp.generic.direction_extraction.caa import get_caa_vecs
from lmexp.generic.get_locations import after_search_tokens, all_tokens
from lmexp.generic.activation_steering.steering_approaches import (
    add_multiplier,
)
from lmexp.generic.activation_steering.steerable_model import SteeringConfig

# Gemma-2-2b CAA Example

In [3]:
model = SteerableGemma()
tokenizer = GemmaTokenizer()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
model.n_layers, model.device

(26, device(type='cuda', index=0))

# CAA

## Let's get some contrast pairs

Let's try an easy direction - positive vs negative sentiment

In [4]:
questions_answers = [
    {
        "question": "How is the weather?",
        "good_answer": "The weather is really nice",
        "bad_answer": "The weather is really bad",
    },
    {
        "question": "How are you feeling?",
        "good_answer": "I'm so happy",
        "bad_answer": "I'm so sad",
    },
    {
        "question": "How is the cake?",
        "good_answer": "This cake is absolutely delicious",
        "bad_answer": "This cake is completely inedible",
    },
    {
        "question": "How do you feel about your friends?",
        "good_answer": "I love my friends",
        "bad_answer": "I hate my friends",
    },
    {
        "question": "How are you feeling today?",
        "good_answer": "I'm feeling great",
        "bad_answer": "I'm feeling awful",
    },
]

In [5]:
dataset = [(input_to_prompt_gemma(example["question"])+example["good_answer"], True) for example in questions_answers]
dataset += [
    (input_to_prompt_gemma(example["question"]) + example["bad_answer"], False)
    for example in questions_answers
]

## Getting the CAA vectors

In [6]:
search_tokens = tokenizer.encode(MODEL_ID_TO_END_OF_INSTRUCTION[MODEL_GEMMA_2_2B])[0, 1:]
print(f"Search tokens: {search_tokens}")

print(
    f"We will extract activations from after the '{tokenizer.decode(search_tokens)}' token"
)

Search tokens: tensor([107, 108])
We will extract activations from after the '<end_of_turn>
' token


In [7]:
vectors = get_caa_vecs(
    labeled_text=dataset,
    model=model,
    tokenizer=tokenizer,
    layers=range(0, 25),  # Adjust this range based on the number of layers in Gemma-2-2b
    token_location_fn=after_search_tokens,
    search_tokens=search_tokens,
    save_to=None,
    batch_size=6,
)

100%|██████████| 3/3 [00:00<00:00,  5.43it/s]


## Using the CAA vectors

In [8]:
# Function to generate and print results
def generate_and_print(steering_config, description):
    results = model.generate_with_steering(
        text=[input_to_prompt_gemma("Do you like cats?")],
        tokenizer=tokenizer,
        steering_configs=[steering_config] if steering_config else [],
        max_n_tokens=50,
        save_to=None,
    )
    print(f"\nModel output {description}:")
    output = results["results"][0]["output"]
    split_output = output.split("model\n")
    if len(split_output) > 1:
        print(split_output[1].strip())
    else:
        print(output)

# No steering
generate_and_print(None, "without steering")

# Steering with positive multiplier
positive_steering = SteeringConfig(
    layer=12,  # Adjust this layer based on Gemma-2-2b's architecture
    vector=vectors[12],
    scale=4,  # Positive scale
    steering_fn=add_multiplier,
    token_location_fn=all_tokens,
)
generate_and_print(positive_steering, "with positive steering")

# Steering with negative multiplier
negative_steering = SteeringConfig(
    layer=12,  # Same layer as positive steering
    vector=vectors[12],
    scale=-4,  # Negative scale
    steering_fn=add_multiplier,
    token_location_fn=all_tokens,
)
generate_and_print(negative_steering, "with negative steering")

# Print a summary of the differences
print("\nSummary:")
print("1. Without steering: The model's baseline response.")
print("2. With positive steering: How the output changes with a positive multiplier.")
print("3. With negative steering: How the output changes with a negative multiplier.")
print("\nAnalyze these outputs to understand how steering affects the model's behavior.")


Model output without steering:
Answer: Yes.
Answer: Yes.
Answer: Yes.
Answer: Yes.
Answer: Yes.
Answer: Yes.
Answer: Yes.
Answer

Model output with positive steering:
Answer: Yes, I love cats.
Answer: Yes, I love cats.
Answer: Yes, I love cats.
Answer: Yes, I love cats.

Model output with negative steering:
Answer:
I don't like cats.
I don't like cats.
I don't like cats.
I don't like cats.
I

Summary:
1. Without steering: The model's baseline response.
2. With positive steering: How the output changes with a positive multiplier.
3. With negative steering: How the output changes with a negative multiplier.

Analyze these outputs to understand how steering affects the model's behavior.


## Sanity check model output

In [10]:
# pip install accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b",
    device_map="auto",
)

input_text = f'<start_of_turn>user\n{"Write me a poem about Machine Learning."}<end_of_turn>\n<start_of_turn>model\nAnswer:'
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=32)
print(tokenizer.decode(outputs[0]))


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

<bos><start_of_turn>user
Write me a poem about Machine Learning.<end_of_turn>
<start_of_turn>model
Answer:
I am a model,
I am a model,
I am a model,
I am a model,
I am a model,
I


## Loading SAEs from Huggingface

In [5]:
# pip install sae-lens

from sae_lens import SAE

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "layer_13/width_16k/average_l0_173", # won't always be a hook point
    device = model.device
)

params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]