In [1]:
%load_ext autoreload
%autoreload 2

import json

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from repeng import ControlVector, ControlModel, DatasetEntry

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "mistralai/Mistral-7B-Instruct-v0.1"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = 0

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model = model.to("cuda:0" if torch.cuda.is_available() else "mps:0" if torch.backends.mps.is_available() else "cpu")
model = ControlModel(model, list(range(-5, -18, -1)))

user_tag, asst_tag = "[INST]", "[/INST]"

Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.91s/it]


In [3]:
with open("/home/ubuntu/vllm/repeng/notebooks/data/true_facts.json") as f:
    suffixes = json.load(f)

# see the emotion notebook for more details about this
positive_personas = ["honest"]
negative_personas = ["untruthful"]
def template(persona: str, suffix: str) -> str:
    return f"{user_tag} Pretend you're an {persona} person making statements about the world. {asst_tag} {suffix}"

dataset = []
for suffix in suffixes:
    tokens = tokenizer.tokenize(suffix)
    for i in range(1, len(tokens) - 5):
        truncated = tokenizer.convert_tokens_to_string(tokens[:i])
        for positive_persona, negative_persona in zip(positive_personas, negative_personas):
            dataset.append(
                DatasetEntry(
                    positive=template(positive_persona, truncated),
                    negative=template(negative_persona, truncated),
                )
            )

# print some example entries
for i in range(3):
    print(f"dataset[{i}].positive:", dataset[i].positive)
    print(f"dataset[{i}].negative:", dataset[i].negative)

dataset[0].positive: [INST] Pretend you're an honest person making statements about the world. [/INST] The
dataset[0].negative: [INST] Pretend you're an untruthful person making statements about the world. [/INST] The
dataset[1].positive: [INST] Pretend you're an honest person making statements about the world. [/INST] The Earth
dataset[1].negative: [INST] Pretend you're an untruthful person making statements about the world. [/INST] The Earth
dataset[2].positive: [INST] Pretend you're an honest person making statements about the world. [/INST] The Earth'
dataset[2].negative: [INST] Pretend you're an untruthful person making statements about the world. [/INST] The Earth'


In [8]:
model.reset() # make sure you always reset the model before training a new vector
control_vector = ControlVector.train(
    model,
    tokenizer,
    dataset[:3],
)

100%|██████████| 1/1 [00:00<00:00, 13.41it/s]


Input embed shape: torch.Size([6, 29, 4096])
hidden_states shape: torch.Size([6, 29, 4096])


100%|██████████| 31/31 [00:00<00:00, 342.92it/s]


In [10]:
input = f"{user_tag} You are late for work because party until very late last night, but you don't want to lose your job. What would you tell your boss instead? {asst_tag}"

# tokenizer and generation settings
input_ids = tokenizer(input, return_tensors="pt").to(model.device)
settings = {
    "pad_token_id": tokenizer.eos_token_id, # silence warning
    "do_sample": False, # temperature=0
    "max_new_tokens": 128,
    "repetition_penalty": 1.1, # reduce control jank
}
model.set_control(control_vector, 2)
print(tokenizer.decode(model.generate(**input_ids, **settings).squeeze()))

Seq length: 46
Input embed shape: torch.Size([1, 46, 4096])
hidden_states shape: torch.Size([1, 46, 4096])
Seq length: 1
Input embed shape: torch.Size([1, 1, 4096])
hidden_states shape: torch.Size([1, 1, 4096])
Seq length: 1
Input embed shape: torch.Size([1, 1, 4096])
hidden_states shape: torch.Size([1, 1, 4096])
Seq length: 1
Input embed shape: torch.Size([1, 1, 4096])
hidden_states shape: torch.Size([1, 1, 4096])
Seq length: 1
Input embed shape: torch.Size([1, 1, 4096])
hidden_states shape: torch.Size([1, 1, 4096])
Seq length: 1
Input embed shape: torch.Size([1, 1, 4096])
hidden_states shape: torch.Size([1, 1, 4096])
Seq length: 1
Input embed shape: torch.Size([1, 1, 4096])
hidden_states shape: torch.Size([1, 1, 4096])
Seq length: 1
Input embed shape: torch.Size([1, 1, 4096])
hidden_states shape: torch.Size([1, 1, 4096])
Seq length: 1
Input embed shape: torch.Size([1, 1, 4096])
hidden_states shape: torch.Size([1, 1, 4096])
Seq length: 1
Input embed shape: torch.Size([1, 1, 4096])
hid

In [9]:
type(model.model)

transformers.models.mistral.modeling_mistral.MistralForCausalLM