In [1]:
from lmexp.models.implementations.llama3 import Llama3Tokenizer, SteerableLlama3
from lmexp.models.model_helpers import (
    input_to_prompt_llama3,
    MODEL_ID_TO_END_OF_INSTRUCTION,
    MODEL_LLAMA_3_CHAT,
)
from lmexp.generic.direction_extraction.probing import train_probe, load_probe
from lmexp.generic.direction_extraction.caa import get_caa_vecs
from lmexp.generic.get_locations import after_search_tokens, all_tokens, at_search_tokens
from lmexp.generic.activation_steering.steering_approaches import (
    add_multiplier,
)
from lmexp.generic.activation_steering.steerable_model import SteeringConfig
from datetime import datetime
import random
import os

# Llama 3 steering/probing example

In [2]:
model = SteerableLlama3(load_in_8bit=True)
tokenizer = Llama3Tokenizer()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
model.n_layers, model.device

(32, device(type='cuda', index=0))

# Training a linear probe

## Generate some data

Let's see whether we can get a date/time probe vector

In [4]:
def gen_labeled_text(n):
    # date as text, date as utc timestamp in seconds, sample randomly from between 1990 and 2022
    start_timestamp = datetime(1980, 1, 1).timestamp()
    end_timestamp = datetime(2020, 1, 1).timestamp()
    labeled_text = []
    for i in range(n):
        timestamp = start_timestamp + (end_timestamp - start_timestamp) * random.random()
        date = datetime.fromtimestamp(timestamp)
        # date like "Monday 15th November 2021 8AM"
        human = date.strftime("It's %A, %dth of %B, %Y. Can you tell me about this date?")
        prompt = input_to_prompt_llama3(human)+"Sure, this is the point in time when"
        label = timestamp
        labeled_text.append((prompt, label))
    # normalize labels to have mean 0 and std 1
    labels = [label for _, label in labeled_text]
    mean = sum(labels) / len(labels)
    std = (sum((label - mean) ** 2 for label in labels) / len(labels)) ** 0.5
    labeled_text = [(text, (label - mean) / std) for text, label in labeled_text]
    return labeled_text

In [7]:
data = gen_labeled_text(10_000)
print(data[0][0])
print(data[0][1])

<|start_header_id|>user<|end_header_id|>

It's Monday, 09th of June, 2008. Can you tell me about this date?<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>

Sure, this is the point in time when
0.7412249137051005


## Training

In [8]:
# We train a probe with activations extracted from the "when" token
search_tokens = tokenizer.encode("This is the point in time when")[0][-1:]
print(
    f"We train a probe with activations extracted from the '{tokenizer.decode(search_tokens)}' token"
)
save_to = "llama_date_probe.pth"

We train a probe with activations extracted from the ' when' token


In [9]:
# Decrease the batch size if you get OOMs
if not os.path.exists(save_to):
    probe = train_probe(
        labeled_text=data,
        model=model,
        tokenizer=tokenizer,
        layer=12,
        n_epochs=1,
        batch_size=80,
        lr=1e-2,
        token_location_fn=at_search_tokens,
        search_tokens=search_tokens,
        save_to=save_to,
        loss_type="mse",
    )
else:
    probe = load_probe(save_to)

## Using the vector

In [8]:
probe = load_probe(save_to).to(model.device)
direction = probe.weight[0]
bias = probe.bias

In [9]:
bias

Parameter containing:
tensor([0.0053], device='cuda:0', requires_grad=True)

In [12]:
results = model.generate_with_steering(
    text=[input_to_prompt_llama3("What is the date?")],
    tokenizer=tokenizer,
    steering_configs=[
        SteeringConfig(
            layer=12,
            vector=direction.detach(),
            scale=-2,
            steering_fn=add_multiplier,
            token_location_fn=all_tokens,
        )
    ],
    max_n_tokens=20,
    save_to=None,
)
print(results['results'])

[{'input': '<|start_header_id|>user<|end_header_id|>\n\nWhat is the date?<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n', 'output': '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWhat is the date?<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\nWhat is the date'}]


In [11]:
results = model.generate_with_steering(
    text=[input_to_prompt_llama3("What is the date?")],
    tokenizer=tokenizer,
    steering_configs=[
        SteeringConfig(
            layer=12,
            vector=direction.detach(),
            scale=2,
            steering_fn=add_multiplier,
            token_location_fn=all_tokens,
        )
    ],
    max_n_tokens=20,
    save_to=None,
)
print(results["results"])

[{'input': 'The current date is', 'output': '<|begin_of_text|>The current date is 2019-04-15. The current date is 2019'}]


# CAA

## Let's get some contrast pairs

Let's try an easy direction - positive vs negative sentiment

In [12]:
questions_answers = [
    {
        "question": "How is the weather?",
        "good_answer": "The weather is really nice",
        "bad_answer": "The weather is really bad",
    },
    {
        "question": "How are you feeling?",
        "good_answer": "I'm so happy",
        "bad_answer": "I'm so sad",
    },
    {
        "question": "How is the cake?",
        "good_answer": "This cake is absolutely delicious",
        "bad_answer": "This cake is completely inedible",
    },
    {
        "question": "How do you feel about your friends?",
        "good_answer": "I love my friends",
        "bad_answer": "I hate my friends",
    },
    {
        "question": "How are you feeling today?",
        "good_answer": "I'm feeling great",
        "bad_answer": "I'm feeling awful",
    },
    {
        "question": "What's your mood?",
        "good_answer": "I'm so excited",
        "bad_answer": "I'm so anxious",
    },
    {
        "question": "How is your day?",
        "good_answer": "This is the best day ever",
        "bad_answer": "This is the worst day ever",
    },
    {
        "question": "How do you feel about this gift?",
        "good_answer": "I really like this gift",
        "bad_answer": "I dislike this gift",
    },
    {
        "question": "Do you like croissants?",
        "good_answer": "Croissants are my favorite",
        "bad_answer": "Croissants are disgusting",
    },
    {
        "question": "How was the movie?",
        "good_answer": "The movie was fantastic",
        "bad_answer": "The movie was terrible",
    },
    {
        "question": "How's work going?",
        "good_answer": "I got a promotion at work",
        "bad_answer": "I got fired from work",
    },
    {
        "question": "How was your vacation?",
        "good_answer": "My vacation was amazing",
        "bad_answer": "My vacation was a disaster",
    },
    {
        "question": "How was the concert?",
        "good_answer": "The concert exceeded my expectations",
        "bad_answer": "The concert was a huge disappointment",
    },
    {
        "question": "How do you feel about your family?",
        "good_answer": "I'm grateful for my family",
        "bad_answer": "I'm frustrated with my family",
    },
    {
        "question": "How is the book?",
        "good_answer": "This book is incredibly engaging",
        "bad_answer": "This book is incredibly boring",
    },
    {
        "question": "How was the restaurant service?",
        "good_answer": "The restaurant service was excellent",
        "bad_answer": "The restaurant service was horrible",
    },
    {
        "question": "How do you feel about your accomplishments?",
        "good_answer": "I'm proud of my accomplishments",
        "bad_answer": "I'm ashamed of my mistakes",
    },
    {
        "question": "How is the sunset?",
        "good_answer": "The sunset is breathtakingly beautiful",
        "bad_answer": "The weather is depressingly gloomy",
    },
    {
        "question": "How did you do on your exam?",
        "good_answer": "I passed my exam with flying colors",
        "bad_answer": "I failed my exam miserably",
    },
    {
        "question": "How is the coffee?",
        "good_answer": "This coffee tastes perfect",
        "bad_answer": "This coffee tastes awful",
    },
]

In [13]:
dataset = [(input_to_prompt_llama3(example["question"])+example["good_answer"], True) for example in questions_answers]
dataset += [
    (input_to_prompt_llama3(example["question"]) + example["bad_answer"], False)
    for example in questions_answers
]

## Getting the CAA vectors

In [20]:
search_tokens = tokenizer.encode(MODEL_ID_TO_END_OF_INSTRUCTION[MODEL_LLAMA_3_CHAT])[0, 1:]
print(
    f"We will extract activations from after the '{tokenizer.decode(search_tokens)}' token"
)

We will extract activations from after the '<|start_header_id|>assistant<|end_header_id|>' token


In [21]:
vectors = get_caa_vecs(
    labeled_text=dataset,
    model=model,
    tokenizer=tokenizer,
    layers=range(5, 30),
    token_location_fn=after_search_tokens,
    search_tokens=search_tokens,
    save_to=None,
    batch_size=6,
)

100%|██████████| 8/8 [00:01<00:00,  4.55it/s]


## Using the CAA vectors

In [28]:
results = model.generate_with_steering(
    text=[input_to_prompt_llama3("Do you like cats?")],
    tokenizer=tokenizer,
    steering_configs=[
        SteeringConfig(
            layer=12,
            vector=vectors[12],
            scale=2,
            steering_fn=add_multiplier,
            token_location_fn=all_tokens,
        ),
    ],
    max_n_tokens=30,
    save_to=None,
)
print(results["results"][0]["output"].split("<|start_header_id|>assistant<|end_header_id|>\n\n")[1])

Do you like cats? I do.
I like cats too.



In [24]:
results = model.generate_with_steering(
    text=[input_to_prompt_llama3("Do you like cats?")],
    tokenizer=tokenizer,
    steering_configs=[
        SteeringConfig(
            layer=12,
            vector=vectors[12],
            scale=-2,
            steering_fn=add_multiplier,
            token_location_fn=all_tokens,
        ),
    ],
    max_n_tokens=30,
    save_to=None,
)
print(results["results"][0]["output"].split("<|start_header_id|>assistant<|end_header_id|>\n\n")[1])

<|begin_of_text|><|eot_id|><|start_header_id|>user<|end_header_id|>

Do you like cats?<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>

Do you like cats? I do.
I like cats too.

