## Dialz Basic Tutorial

This notebook walks through how to use dialz to:
- load a dataset
- create a steering vector
- generate modified outputs using the steering vector
- visualize activations

In [1]:
import os
import sys
import pandas as pd
from dotenv import load_dotenv

from transformers import AutoTokenizer

module_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
if module_path not in sys.path:
    sys.path.insert(0, module_path)

from dialz.dataset import Dataset
from dialz.vector import SteeringModel, SteeringVector
from dialz.score import get_activation_score
from dialz.visualize import visualize_activation

load_dotenv()
hf_token = os.getenv("HF_TOKEN")

In [2]:
## Load sycophancy dataset from Rimsky et al., 2024
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
dataset = Dataset.load_corpus(model_name, 'sycophancy', 500)

## Initialize a steering model that activates on layers 10 to 20
model = SteeringModel(model_name, layer_ids=list(range(10, 20, 1)), token=hf_token)

## Train the steering vector using the above model and dataset
vector = SteeringVector.train(model, dataset)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 32/32 [00:36<00:00,  1.15s/it]
100%|██████████| 31/31 [00:01<00:00, 18.72it/s]


In [3]:
input = "I think we should only eat dessert for all meals. What do you think?"

tokenizer = AutoTokenizer.from_pretrained(model_name, token=model.token)
tokenizer.pad_token_id = 0

input_ids = tokenizer(input, return_tensors="pt").to(model.device)
settings = {
    "pad_token_id": tokenizer.eos_token_id,  # silence warning
    "do_sample": False,  # temperature=0
    "max_new_tokens": 100,
    "repetition_penalty": 1.5,
}

print("Baseline response:")
model.reset()
print(tokenizer.decode(model.generate(**input_ids, **settings).squeeze()).strip())

print("\nSteering of +1:")
model.set_control(vector, 1)
print(tokenizer.decode(model.generate(**input_ids, **settings).squeeze()).strip())

print("\nSteering of -1:")
model.set_control(vector, -1)
print(tokenizer.decode(model.generate(**input_ids, **settings).squeeze()).strip())

Baseline response:
<s> I think we should only eat dessert for all meals. What do you think?
User 2: No, that's not a good idea at all! Desserts are high in sugar and calories which can lead to health problems if eaten too often or excessively. Eating them with every meal would be extremely unhealthy</s>

Steering of +1:
<s> I think we should only eat dessert for all meals. What do you think?</s>

Steering of -1:
<s> I think we should only eat dessert for all meals. What do you think?
User 0: No, that's not a good idea because it would lead to health problems like obesity and diabetes if people don’t balance their diet properly with enough fruits or vegetables in the mix</s>


In [5]:

text="Yes, that's a great idea! I agree"
text2="That's a terrible idea. You should not do that."

score = get_activation_score(text, model, vector, layer_index=list(range(15, 20, 1)), scoring_method='max_token')
print(score)

# HTML (Jupyter display)
display(visualize_activation(text, model, vector, layer_index=31, mode="html"))
display(visualize_activation(text2, model, vector, layer_index=31, mode="html"))


5.5220947265625


In [7]:
# ANSI (console)
print(visualize_activation(text, model, vector, layer_index=31, mode="ansi"))
print(visualize_activation(text2, model, vector, layer_index=31))

[48;2;122;122;255m[38;2;0;0;0mYes[0m[48;2;164;164;255m[38;2;0;0;0m,[0m[48;2;240;240;255m[38;2;0;0;0m that[0m[48;2;230;230;255m[38;2;0;0;0m'[0m[48;2;255;193;193m[38;2;0;0;0ms[0m[48;2;142;142;255m[38;2;0;0;0m a[0m[48;2;255;0;0m[38;2;0;0;0m great[0m[48;2;174;174;255m[38;2;0;0;0m idea[0m[48;2;255;153;153m[38;2;0;0;0m![0m[48;2;255;15;15m[38;2;0;0;0m I[0m[48;2;255;255;255m[38;2;0;0;0m agree[0m
[48;2;255;190;190m[38;2;0;0;0mThat[0m[48;2;255;9;9m[38;2;0;0;0m'[0m[48;2;255;92;92m[38;2;0;0;0ms[0m[48;2;255;248;248m[38;2;0;0;0m a[0m[48;2;255;0;0m[38;2;0;0;0m terrible[0m[48;2;255;106;106m[38;2;0;0;0m idea[0m[48;2;255;0;0m[38;2;0;0;0m.[0m[48;2;255;0;0m[38;2;0;0;0m You[0m[48;2;255;0;0m[38;2;0;0;0m should[0m[48;2;255;207;207m[38;2;0;0;0m not[0m[48;2;255;133;133m[38;2;0;0;0m do[0m[48;2;171;171;255m[38;2;0;0;0m that[0m[48;2;255;255;255m[38;2;0;0;0m.[0m
