In [31]:
from lmexp.models.implementations.gpt2small import GPT2Tokenizer, SteerableGPT2
from lmexp.generic.direction_extraction.probing import train_probe, load_probe
from lmexp.generic.direction_extraction.caa import get_caa_vecs
from lmexp.generic.get_locations import from_search_tokens, all_tokens
from lmexp.generic.activation_steering.steering_approaches import (
    add_multiplier,
)
from lmexp.generic.activation_steering.steerable_model import SteeringConfig
from datetime import datetime
import random
import os

# Load model and tokenizer

These classes have already implemented all the probing-related methods so we won't have to add more hooks + they are ready to use with our vector extraction and steering functions.

In [2]:
model = SteerableGPT2()
tokenizer = GPT2Tokenizer()

In [3]:
model.n_layers, model.device

(12, device(type='cpu'))

# Training a linear probe

## Generate some data

Let's see whether we can get a date/time probe vector

In [57]:
def gen_labeled_text(n):
    # date as text, date as utc timestamp in seconds, sample randomly from between 1990 and 2022
    start_timestamp = datetime(2013, 1, 1).timestamp()
    end_timestamp = datetime(2016, 1, 1).timestamp()
    labeled_text = []
    for _ in range(n):
        timestamp = start_timestamp + (end_timestamp - start_timestamp) * random.random()
        date = datetime.fromtimestamp(timestamp)
        text = tokenizer.chat_format(
            [
                {"role": "user_1", "content": "What is the date"},
                {"role": "user_2", "content": date.strftime("Today's date is %B %d, %Y")},
            ]
        )
        label = timestamp
        labeled_text.append((text, label))
    # normalize labels to have mean 0 and std 1
    labels = [label for _, label in labeled_text]
    mean = sum(labels) / len(labels)
    std = (sum((label - mean) ** 2 for label in labels) / len(labels)) ** 0.5
    labeled_text = [(text, (label - mean) / std) for text, label in labeled_text]
    return labeled_text

In [58]:
data = gen_labeled_text(10_000)
print(data[0])

("user_1: What is the date\nuser_2: Today's date is October 25, 2013", -0.7869691126342264)


## Training

In [60]:
# We train a probe with activations extracted from the "when" token
search_tokens = tokenizer.encode("\nuser_2: Today is")[0][1:4]
print(
    f"We train a probe with activations extracted from the '{tokenizer.decode(search_tokens)}' token"
)
save_to = "gpt2small_date_probe.pth"

We train a probe with activations extracted from the 'user_2' token


In [62]:
if not os.path.exists(save_to):
    probe = train_probe(
        labeled_text=data,
        model=model,
        tokenizer=tokenizer,
        layer=4,
        n_epochs=5,
        batch_size=128,
        lr=1e-2,
        token_location_fn=from_search_tokens,
        search_tokens=search_tokens,
        save_to=save_to,
        loss_type="mse",
    )
else:
    probe = load_probe(save_to)

100%|██████████| 78/78 [00:55<00:00,  1.39it/s]


Epoch 0, mean loss: 1.5892208450317382


100%|██████████| 78/78 [00:54<00:00,  1.42it/s]


Epoch 1, mean loss: 0.07932361764907837


100%|██████████| 78/78 [00:54<00:00,  1.43it/s]


Epoch 2, mean loss: 0.04544511251449585


100%|██████████| 78/78 [00:54<00:00,  1.42it/s]


Epoch 3, mean loss: 0.027392904090881347


100%|██████████| 78/78 [00:53<00:00,  1.45it/s]

Epoch 4, mean loss: 0.017819223898649216





## Using the vector

In [63]:
probe = load_probe(save_to).to(model.device)
direction = probe.weight[0]
bias = probe.bias

In [64]:
bias

Parameter containing:
tensor([0.0274], requires_grad=True)

In [90]:
results = model.generate_with_steering(
    text=[tokenizer.chat_format([{"role": "user_1", "content": "What's the date?"}, {"role": "user_2", "content": "The date is"}])],
    tokenizer=tokenizer,
    steering_configs=[
        SteeringConfig(
            layer=4,
            vector=direction.detach(),
            scale=-5,
            steering_fn=add_multiplier,
            token_location_fn=all_tokens,
        )
    ],
    max_n_tokens=30,
    save_to=None,
)
print(results['results'])

[{'input': "user_1: What's the date?\nuser_2: The date is", 'output': "user_1: What's the date?\nuser_2: The date is the date of the first time you've ever been in a room"}]


In [91]:
results = model.generate_with_steering(
    text=[
        tokenizer.chat_format(
            [
                {"role": "user_1", "content": "What's the date?"},
                {"role": "user_2", "content": "The date is"},
            ]
        )
    ],
    tokenizer=tokenizer,
    steering_configs=[
        SteeringConfig(
            layer=4,
            vector=direction.detach(),
            scale=5,
            steering_fn=add_multiplier,
            token_location_fn=all_tokens,
        )
    ],
    max_n_tokens=30,
    save_to=None,
)
print(results["results"])

[{'input': "user_1: What's the date?\nuser_2: The date is", 'output': "user_1: What's the date?\nuser_2: The date is the day of the week.\nuser_3: The date"}]


# CAA

## Let's get some contrast pairs

Let's try an easy direction - positive vs negative sentiment

In [92]:
GOOD = [
    "The weather is really nice",
    "I'm so happy",
    "This cake is absolutely delicious",
    "I love my friends",
    "I'm feeling great",
    "I'm so excited",
    "This is the best day ever",
    "I really like this gift",
    "Croissants are my favorite",
    "The movie was fantastic",
    "I got a promotion at work",
    "My vacation was amazing",
    "The concert exceeded my expectations",
    "I'm grateful for my family",
    "This book is incredibly engaging",
    "The restaurant service was excellent",
    "I'm proud of my accomplishments",
    "The sunset is breathtakingly beautiful",
    "I passed my exam with flying colors",
    "This coffee tastes perfect",
]

BAD = [
    "The weather is really bad",
    "I'm so sad",
    "This cake is completely inedible",
    "I hate my enemies",
    "I'm feeling awful",
    "I'm so anxious",
    "This is the worst day ever",
    "I dislike this gift",
    "Croissants are disgusting",
    "The movie was terrible",
    "I got fired from work",
    "My vacation was a disaster",
    "The concert was a huge disappointment",
    "I'm frustrated with my family",
    "This book is incredibly boring",
    "The restaurant service was horrible",
    "I'm ashamed of my mistakes",
    "The weather is depressingly gloomy",
    "I failed my exam miserably",
    "This coffee tastes awful",
]

In [93]:
dataset = [
    (text, True) for text in GOOD
] + [
    (text, False) for text in BAD
]

## Getting the CAA vectors

In [94]:
vectors = get_caa_vecs(
    labeled_text=dataset,
    model=model,
    tokenizer=tokenizer,
    layers=range(3, 8),
    token_location_fn=all_tokens,
    save_to=None,
    batch_size=6              
)

100%|██████████| 8/8 [00:00<00:00, 21.32it/s]


## Using the CAA vectors

In [95]:
results = model.generate_with_steering(
    text=["I think that this cat is"],
    tokenizer=tokenizer,
    steering_configs=[
        SteeringConfig(
            layer=5,
            vector=vectors[5],
            scale=-1,
            steering_fn=add_multiplier,
            token_location_fn=all_tokens,
        ),
        SteeringConfig(
            layer=4,
            vector=vectors[4],
            scale=-1,
            steering_fn=add_multiplier,
            token_location_fn=all_tokens,
        ),
    ],
    max_n_tokens=20,
    save_to=None,
)
print(results["results"])

[{'input': 'I think that this cat is', 'output': "I think that this cat is a bit of a liability. I think that it's a liability that"}]


In [96]:
results = model.generate_with_steering(
    text=["I think that this cat is"],
    tokenizer=tokenizer,
    steering_configs=[
        SteeringConfig(
            layer=5,
            vector=vectors[5],
            scale=1,
            steering_fn=add_multiplier,
            token_location_fn=all_tokens,
        ),
        SteeringConfig(
            layer=4,
            vector=vectors[4],
            scale=1,
            steering_fn=add_multiplier,
            token_location_fn=all_tokens,
        ),
    ],
    max_n_tokens=20,
    save_to=None,
)
print(results["results"])

[{'input': 'I think that this cat is', 'output': 'I think that this cat is a great example of how to use the cat as a companion.\n'}]
