/
experiment.py
160 lines (131 loc) · 5.69 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
from typing import Dict, List, Optional, Union
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from repeng import ControlModel, ControlVector, DatasetEntry
class Experiment:
def __init__(
self,
model: str = "mistralai/Mistral-7B-Instruct-v0.1",
dataset: List[DatasetEntry] = [],
device: str = "",
settings: Optional[Dict] = None,
user_tag: str = "[INST]",
asst_tag: str = "[/INST]",
):
if device == "":
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.tokenizer.pad_token_id = 0
self.model = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16
)
self.model = self.model.to(self.device)
self.model = ControlModel(self.model, list(range(-5, -18, -1)))
# Our datasets are small, so just load them into memory
# if one is provided
self.dataset = dataset
self.vector = None
if settings is None:
self.settings = {
"pad_token_id": self.tokenizer.eos_token_id, # silence warning
"do_sample": False, # temperature=0
"max_new_tokens": 128,
"repetition_penalty": 1.1,
}
else:
self.settings = settings
self.user_tag = user_tag
self.asst_tag = asst_tag
def load_dataset(
self,
dataset: str,
template: str,
positive_context: Union[List[str], str],
negative_context: Union[List[str], str],
question_style: bool = True,
):
if not isinstance(positive_context, list):
positive_context = [positive_context]
if not isinstance(negative_context, list):
negative_context = [negative_context]
with open(dataset, "r") as f:
data: List[str] = f.readlines()
data = [x.strip() for x in data]
print("data len", len(data))
# data = [
# self.tokenizer.convert_tokens_to_string(tokens[:i])
# for tokens in (self.tokenizer.tokenize(s) for s in data)
# for i in range(1, len(tokens))
# ]
# print("dat len post tokenization", len(data))
for index, entry in enumerate(data[0:5]):
print(f"{index} - {entry}")
# raise "die"
dataset: List[DatasetEntry] = []
for entry in data:
for positive in positive_context:
for negative in negative_context:
# Why?
# entry = self.tokenizer.convert_tokens_to_string(
# self.tokenizer.tokenize(entry)
# )
# for positive, negative in zip(positive_context, negative_context):
# if isinstance(positive_context, list):
# positive = choice(positive_context)
# else:
# positive = positive_context
# if isinstance(negative_context, list):
# negative = choice(negative_context)
# else:
# negative = negative_context
positive_line = template.format(context=positive)
negative_line = template.format(context=negative)
if question_style:
dataset.append(
DatasetEntry(
positive=f"{self.user_tag} {positive_line} {entry} {self.asst_tag}",
negative=f"{self.user_tag} {negative_line} {entry} {self.asst_tag}",
)
)
else:
dataset.append(
DatasetEntry(
positive=f"{self.user_tag} {positive_line} {self.asst_tag} {entry}",
negative=f"{self.user_tag} {negative_line} {self.asst_tag} {entry}",
)
)
self.dataset = dataset
def train(self):
if self.dataset is None or len(self.dataset) == 0:
raise ValueError("No dataset provided")
self.vector = ControlVector.train(self.model, self.tokenizer, self.dataset)
def generate(self, input: str, coefficient: float = 0.0) -> str:
"""
generate will trigger the LLM end to end. If coefficient is 0.0, no
control vector will be applied. If the coefficient is not 0.0 and no
control vector has been trained, an error will be raised.
"""
if self.vector is None and coefficient != 0:
raise ValueError("No control vector has been trained")
# If we don't have the input user/assist tags, add them
if self.user_tag not in input:
input = f"{self.user_tag} {input} {self.asst_tag}"
input_ids = self.tokenizer(input, return_tensors="pt").to(self.model.device)
self.model.reset()
if coefficient != 0:
self.model.set_control(self.vector, coefficient)
output = self.model.generate(**input_ids, **self.settings).squeeze()
output = self.tokenizer.decode(output).strip()
# Remove anything prior to the asst tags
if self.asst_tag in output:
output = output.split(self.asst_tag)[1].strip()
self.model.reset()
return output
def save(self, path: str):
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(self.vector, path)
def load(self, path: str):
self.vector = torch.load(path)