In [2]:
from lmexp.models.implementations.gemma_2_2b import GemmaTokenizer, SteerableGemma
from lmexp.models.constants import MODEL_GEMMA_2_2B
from lmexp.models.model_helpers import (
    input_to_prompt_gemma,
    MODEL_ID_TO_END_OF_INSTRUCTION,
)
from lmexp.generic.direction_extraction.caa import get_caa_vecs
from lmexp.generic.get_locations import after_search_tokens, all_tokens
from lmexp.generic.activation_steering.steering_approaches import (
    add_multiplier,
)
from lmexp.generic.activation_steering.steerable_model import SteeringConfig

# Gemma-2-2b CAA Example

In [3]:
model = SteerableGemma()
tokenizer = GemmaTokenizer()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
model.n_layers, model.device

(26, device(type='cuda', index=0))

# CAA

## Let's get some contrast pairs

Let's try an easy direction - positive vs negative sentiment

In [5]:
questions_answers = [
    {
        "question": "How is the weather?",
        "good_answer": "The weather is really nice",
        "bad_answer": "The weather is really bad",
    },
    {
        "question": "How are you feeling?",
        "good_answer": "I'm so happy",
        "bad_answer": "I'm so sad",
    },
    {
        "question": "How is the cake?",
        "good_answer": "This cake is absolutely delicious",
        "bad_answer": "This cake is completely inedible",
    },
    {
        "question": "How do you feel about your friends?",
        "good_answer": "I love my friends",
        "bad_answer": "I hate my friends",
    },
    {
        "question": "How are you feeling today?",
        "good_answer": "I'm feeling great",
        "bad_answer": "I'm feeling awful",
    },
]

In [6]:
dataset = [(input_to_prompt_gemma(example["question"])+example["good_answer"], True) for example in questions_answers]
dataset += [
    (input_to_prompt_gemma(example["question"]) + example["bad_answer"], False)
    for example in questions_answers
]

## Getting the CAA vectors

In [17]:
search_tokens = tokenizer.encode(MODEL_ID_TO_END_OF_INSTRUCTION[MODEL_GEMMA_2_2B])[0, 1:]
print(f"Search tokens: {search_tokens}")

print(
    f"We will extract activations from after the '{tokenizer.decode(search_tokens)}' token"
)

Search tokens: tensor([107, 108])
We will extract activations from after the '
' token


In [8]:
# import torch

# end_of_instruction = MODEL_ID_TO_END_OF_INSTRUCTION[MODEL_GEMMA_2_2B]
# print(f"End of instruction token: '{end_of_instruction}'")

# # Encode the end_of_instruction token
# encoded = tokenizer.encode(end_of_instruction)
# print(f"Encoded: {encoded}")

# # Convert the tensor to a list for easier handling
# search_tokens = encoded.tolist()[0]  # Assuming it's a 2D tensor, we take the first row
# print(f"Search tokens: {search_tokens}")

# # We know that this represents <end_of_turn>, so we'll use it as is
# print(f"We will extract activations from after the '<end_of_turn>' token")

# # Convert back to tensor for further use
# search_tokens = torch.tensor(search_tokens)

End of instruction token: '<end_of_turn>
'
Encoded: tensor([[  2, 107, 108]])
Search tokens: [2, 107, 108]
We will extract activations from after the '<end_of_turn>' token


In [18]:
vectors = get_caa_vecs(
    labeled_text=dataset,
    model=model,
    tokenizer=tokenizer,
    layers=range(0, 25),  # Adjust this range based on the number of layers in Gemma-2-2b
    # token_location_fn=after_search_tokens,
    # search_tokens=search_tokens,
    token_location_fn=all_tokens,
    save_to=None,
    batch_size=6,
)

100%|██████████| 3/3 [00:00<00:00,  5.68it/s]


## Using the CAA vectors

In [19]:
results = model.generate_with_steering(
    text=[input_to_prompt_gemma("Do you like cats?")],
    tokenizer=tokenizer,
    steering_configs=[
        SteeringConfig(
            layer=12,  # Adjust this layer based on Gemma-2-2b's architecture
            vector=vectors[12],
            scale=3,
            steering_fn=add_multiplier,
            token_location_fn=all_tokens,
        ),
    ],
    max_n_tokens=50,
    save_to=None,
)
print(results["results"][0]["output"].split("<start_of_turn>model\n")[1])

IndexError: list index out of range

In [21]:
results_without_steering = model.generate_with_steering(
    text=[input_to_prompt_gemma("Do you like cats?")],
    tokenizer=tokenizer,
    steering_configs=[],  # Empty list for no steering
    max_n_tokens=50,
    save_to=None,
)

print("\nModel output without steering:")
output_without_steering = results_without_steering["results"][0]["output"]
split_output = output_without_steering.split("model\n")
if len(split_output) > 1:
    print(split_output[1].strip())
else:
    print(output_without_steering)


Model output without steering:
Do you like cats?
Do you like cats?
Do you like cats?
Do you like cats?
Do you like cats?
Do you like cats?


In [22]:
results_without_steering = model.generate_with_steering(
    text=[input_to_prompt_gemma("Do you like cats?")],
    tokenizer=tokenizer,
    steering_configs=[],  # Empty list for no steering
    max_n_tokens=50,
    save_to=None,
)

print("\nModel output without steering:")
output_without_steering = results_without_steering["results"][0]["output"]
split_output = output_without_steering.split("model\n")
if len(split_output) > 1:
    print(split_output[1].strip())
else:
    print(output_without_steering)


Model output without steering:
Do you like cats?
Do you like cats?
Do you like cats?
Do you like cats?
Do you like cats?
Do you like cats?


In [23]:
def modified_input_to_prompt_gemma(user_input: str) -> str:
    return f"<start_of_turn>user\n{user_input}<end_of_turn>\n<start_of_turn>model\nAnswer: "

results_modified_input = model.generate_with_steering(
    text=[modified_input_to_prompt_gemma("Do you like cats?")],
    tokenizer=tokenizer,
    steering_configs=[],  # No steering
    max_n_tokens=50,
    save_to=None,
)

print("Model output with modified input format:")
output_modified = results_modified_input["results"][0]["output"]
split_output = output_modified.split("Answer: ")
if len(split_output) > 1:
    print(split_output[1].strip())
else:
    print(output_modified)

Model output with modified input format:
100%
Do you like dogs?


In [24]:
results_with_steering = model.generate_with_steering(
    text=[modified_input_to_prompt_gemma("Do you like cats?")],
    tokenizer=tokenizer,
    steering_configs=[
        SteeringConfig(
            layer=12,  # Adjust this layer based on Gemma-2-2b's architecture
            vector=vectors[12],
            scale=-4,
            steering_fn=add_multiplier,
            token_location_fn=all_tokens,
        ),
    ],
    max_n_tokens=50,
    save_to=None,
)

print("Model output with steering:")
output_with_steering = results_with_steering["results"][0]["output"]
split_output_steering = output_with_steering.split("model\n")
if len(split_output_steering) > 1:
    print(split_output_steering[1].strip())
else:
    print(output_with_steering)


Model output with steering:
Answer: 100%
Do you like dogs?
Answer: 100%
Do you like cats?
Answer: 100%


In [25]:
results = model.generate_with_steering(
    text=[input_to_prompt_gemma("Do you like cats?")],
    tokenizer=tokenizer,
    steering_configs=[
        SteeringConfig(
            layer=12,  # Adjust this layer based on Gemma-2-2b's architecture
            vector=vectors[12],
            scale=-4,
            steering_fn=add_multiplier,
            token_location_fn=all_tokens,
        ),
    ],
    max_n_tokens=50,
    save_to=None,
)


print("Model output with steering:")
output_with_steering = results_with_steering["results"][0]["output"]
split_output_steering = output_with_steering.split("model\n")
if len(split_output_steering) > 1:
    print(split_output_steering[1].strip())
else:
    print(output_with_steering)

Model output with steering:
Answer: 100%
Do you like dogs?
Answer: 100%
Do you like cats?
Answer: 100%
