In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
from os import path

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from repeng import ControlVector, ControlModel, DatasetEntry

In [3]:
def _get_project_dir_folder():
    return path.dirname(path.dirname(path.abspath("__file__")))

In [4]:
model_name = "google/gemma-2b-it"

In [5]:
MODELS_FP = path.join(_get_project_dir_folder(), "m")
model_load_path = path.join(MODELS_FP, "models", model_name)
tokenizer_load_path = path.join(MODELS_FP, "tokenizers", model_name)

tokenizer = AutoTokenizer.from_pretrained(tokenizer_load_path)
tokenizer.pad_token_id = 0

model = AutoModelForCausalLM.from_pretrained(model_load_path, torch_dtype=torch.float16)
model = model.to("mps:0") # on Apple Silicon
model = ControlModel(model, list(range(-4, -12, -1))) ##may want to play with different layer settings - gemma has 18 layers

user_tag, asst_tag = "<start_of_turn>user\n", "<end_of_turn>\n<start_of_turn>model\n"

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
DATASETS_FP = path.join(_get_project_dir_folder(), "data")
truncated_outputs_path = path.join(DATASETS_FP, "all_truncated_outputs.json")

suffixes = {}
with open(truncated_outputs_path) as f:
    suffixes = json.load(f)

In [7]:
positive_personas = ["happy", "ecstatic", "delighted"]
negative_personas = ["sad", "depressed", "dismayed"]
def template(persona: str, suffix: str) -> str:
    return f"{user_tag} Act as if you're extremely {persona}. {asst_tag} {suffix}"

In [8]:
dataset = []
for suffix in suffixes:
    tokens = tokenizer.tokenize(suffix)
    for i in range(1, len(tokens)):
        truncated = tokenizer.convert_tokens_to_string(tokens[:i])
        for positive_persona, negative_persona in zip(positive_personas, negative_personas):
            dataset.append(
                DatasetEntry(
                    positive=template(positive_persona, truncated),
                    negative=template(negative_persona, truncated),
                )
            )

# print some example entries
for i in range(3):
    print(f"dataset[{i}].positive:", dataset[i].positive)
    print(f"dataset[{i}].negative:", dataset[i].negative)

dataset[0].positive: <start_of_turn>user
 Act as if you're extremely happy. <end_of_turn>
<start_of_turn>model
 That
dataset[0].negative: <start_of_turn>user
 Act as if you're extremely sad. <end_of_turn>
<start_of_turn>model
 That
dataset[1].positive: <start_of_turn>user
 Act as if you're extremely ecstatic. <end_of_turn>
<start_of_turn>model
 That
dataset[1].negative: <start_of_turn>user
 Act as if you're extremely depressed. <end_of_turn>
<start_of_turn>model
 That
dataset[2].positive: <start_of_turn>user
 Act as if you're extremely delighted. <end_of_turn>
<start_of_turn>model
 That
dataset[2].negative: <start_of_turn>user
 Act as if you're extremely dismayed. <end_of_turn>
<start_of_turn>model
 That


In [9]:
model.reset() # make sure you always reset the model before training a new vector
control_vector = ControlVector.train(
    model,
    tokenizer,
    dataset,
)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 882/882 [04:13<00:00,  3.48it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:28<00:00,  1.65s/it]


In [12]:
# the question to ask the modified model
# don't forget the space after {user_tag} and before {asst_tag}!
input = f"{user_tag} What are human beings like? {asst_tag}"

# tokenizer and generation settings
input_ids = tokenizer(input, return_tensors="pt").to(model.device)
settings = {
    "pad_token_id": tokenizer.eos_token_id, # silence warning
    "do_sample": False, # temperature=0
    "max_new_tokens": 512,
    "repetition_penalty": 1.1, # reduce control jank
}

print("==baseline")
model.reset()
print(tokenizer.decode(model.generate(**input_ids, **settings).squeeze()))

print("\n++control")
# add the control vector with a certain strength (try increasing or decreasing this!)
model.set_control(control_vector, 1.5)
print(tokenizer.decode(model.generate(**input_ids, **settings).squeeze()))

print("\n--control")
# subtract the control vector, giving the opposite result (e.g. sad instead of happy)
# depending on your vector, you may need more or less negative strength to match the positive effect
model.set_control(control_vector, -4.0)
print(tokenizer.decode(model.generate(**input_ids, **settings).squeeze()))
model.reset()

==baseline
<bos><start_of_turn>user
 What are human beings like? <end_of_turn>
<start_of_turn>model
**Human beings are complex and multifaceted creatures with a wide range of characteristics:**

**Cognitive Abilities:**

* **Language:** We have an exceptional ability to communicate through language, enabling us to express thoughts, ideas, and emotions.
* **Reasoning:** We possess logical thinking skills and can reason deductively and inductively.
* **Problem-solving:** We are skilled at identifying and solving problems, using our critical thinking and problem-solving abilities.
* **Memory:** We have excellent memory, allowing us to retain information and recall it later.

**Emotional Range:**

* **Emotions:** Humans experience a full spectrum of emotions, from joy and love to sadness and anger.
* **Empathy:** We have the capacity for empathy, allowing us to understand and share the feelings of others.
* **Motivation:** We have intrinsic motivations that drive us to pursue goals and ful

In [13]:
# the question to ask the modified model
# don't forget the space after {user_tag} and before {asst_tag}!
input = f"{user_tag} What should you do when you see a lost cat? {asst_tag}"

# tokenizer and generation settings
input_ids = tokenizer(input, return_tensors="pt").to(model.device)
settings = {
    "pad_token_id": tokenizer.eos_token_id, # silence warning
    "do_sample": False, # temperature=0
    "max_new_tokens": 512,
    "repetition_penalty": 1.1, # reduce control jank
}

print("==baseline")
model.reset()
print(tokenizer.decode(model.generate(**input_ids, **settings).squeeze()))

print("\n++control")
# add the control vector with a certain strength (try increasing or decreasing this!)
model.set_control(control_vector, 1.5)
print(tokenizer.decode(model.generate(**input_ids, **settings).squeeze()))

print("\n--control")
# subtract the control vector, giving the opposite result (e.g. sad instead of happy)
# depending on your vector, you may need more or less negative strength to match the positive effect
model.set_control(control_vector, -4.0)
print(tokenizer.decode(model.generate(**input_ids, **settings).squeeze()))
model.reset()

==baseline
<bos><start_of_turn>user
 What should you do when you see a lost cat? <end_of_turn>
<start_of_turn>model
**1. Remain calm and approach the cat cautiously.**

* Approach the cat slowly, without making sudden movements or noises.
* Speak in a gentle and reassuring voice.
* Use a familiar tone of voice and body language.

**2. Make eye contact and offer food or water.**

* Cats are drawn to food and water, so offering these can help attract the cat towards you.
* Start by offering a small piece of food or a few drops of water on your finger.

**3. Let the cat sniff your hand.**

* Gently touch the cat's nose with your own hand.
* Allow the cat to sniff your fingers and then release it.

**4. Speak to the cat in a soothing voice.**

* Talk to the cat in a calm and reassuring tone.
* Avoid yelling or using harsh words.

**5. Create a safe and comfortable environment.**

* Clear any obstacles that could prevent the cat from escaping.
* Keep other pets away from the cat.

**6. Chec