Skip to content

Commit

Permalink
Merge pull request #185 from nebuly-ai/simple_rlhf
Browse files Browse the repository at this point in the history
Fix some comments and tokenizer outputs from P-Tuning to RLHF
  • Loading branch information
diegofiori committed Feb 27, 2023
2 parents d48f444 + 8db068f commit 3f9de23
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 126 deletions.
8 changes: 6 additions & 2 deletions apps/accelerate/chatllama/chatllama/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, tokenizer: Tokenizer):
def __call__(self, texts: Union[List[str], str], *args, **kwargs):
if isinstance(texts, str):
text = self.tokenizer.encode(texts, bos=True, eos=True)
return torch.tensor(text).cuda().long()
tokens = torch.tensor(text).cuda().long()
else:
texts = [
self.tokenizer.encode(text, bos=True, eos=True)
Expand All @@ -36,7 +36,11 @@ def __call__(self, texts: Union[List[str], str], *args, **kwargs):
)
for i, text in enumerate(texts):
tokens[i, : len(text)] = torch.tensor(text).cuda().long()
return tokens
output = {
"input_ids": tokens,
"attention_mask": (tokens != self.tokenizer.pad_id).long(),
}
return output

def decode(self, tokens):
return self.tokenizer.decode(tokens)
Expand Down
47 changes: 13 additions & 34 deletions apps/accelerate/chatllama/chatllama/rlhf/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

class ActorModel(torch.nn.Module):
"""Actor model that generates the augmented prompt from the initial
user_input. Its output is used as prompt for the LLM model.
The aim is to train this model to generate better prompts.
user_input. The aim is to train this model to generate better prompts.
Attributes:
model: The model from LLaMA to be used
Expand Down Expand Up @@ -64,7 +63,7 @@ def forward(
Args:
sequences (torch.Tensor): Sequences of states and actions used to
compute token logits for the whole list of sequences
attention_mask (torch.Tensor): Mask fo the sequences attention
attention_mask (torch.Tensor): Mask for the sequences attention
Returns:
logits (torch.Tensor): Logits for the actions taken
Expand All @@ -87,19 +86,19 @@ def generate(
(i.e. input of the prompt generator model)
Args:
state (torch.Tensor): State the input of the user to generate
the prompt (action)
state_mask (torch.Tensor): Mask for the state of the environment
(mask for the states)
state (torch.Tensor): the input of the user
state_mask (torch.Tensor): Mask for the state input (for padding)
Returns:
actions (torch.Tensor): Actions generated from the state
sequences (torch.Tensor): Sequences generated from the
state as [states, actions]
"""
max_sequence = states.shape[1]
max_tokens = max_sequence + self.config.max_tokens
max_tokens = self.config.max_tokens + max_sequence
temperature = self.config.temperature
# What if the states + completion are longer than the max context of
# the model?
sequences = self.model.generate(
inputs=states,
attention_mask=state_mask,
Expand Down Expand Up @@ -133,7 +132,7 @@ def load(self, path: Optional[str] = None) -> None:
f"The path doesn't exist."
)
return
# load the model and the tokenizer
# load the model
if os.path.exists(path) is False:
print(
f"Impossible to load the model: {path}"
Expand Down Expand Up @@ -163,32 +162,14 @@ class ActorDataset(Dataset):
read a json file with the following format:
[
{
"general_info": "..."
"general_purpose": "..."
"contract_name": "..."
"section_key_points": "..."
"section_name": "..."
"section_number": ...
"text": "..."
"text_type": "..."
"generation_prompt": "..."
"score": ...
"user_input": "..."
"completion": "..."
} ,
...
]
Where:
general_info: general information of the contract (names, dates, ...)
general_purpose: purpose of the contract
contract_name: name of the contract
section_key_points: key points of this section
section_name: name of the section
section_number: number of the section (to determine the order)
text: text of the section
text_type: could be (data, davinci, curie, etc... depending on how it
was generated)
generation_prompt: prompt used to generate the text (in case it was
generated by a model, None if it data)
score: score of the section (0-5) given by the LLM
user_input: the input of the user
completion: the output of the user
"""

def __init__(self, path: str, device: torch.device) -> None:
Expand All @@ -197,7 +178,7 @@ def __init__(self, path: str, device: torch.device) -> None:
with open(path, "r") as f:
data = json.load(f)
self.data = [
d["user_input"] + "\n\n###\n\n" + d["prompt"] for d in data
d["user_input"] + "\n\n###\n\n" + d["completion"] for d in data
]
self.len = len(self.data)

Expand Down Expand Up @@ -282,7 +263,6 @@ def train(
attention_mask = input_output_tokenized["attention_mask"][
:, :-1
]
# TODO add check for sequence length
training_output = training_output.to(device)
training_input = training_input.to(device)
attention_mask = attention_mask.to(device)
Expand Down Expand Up @@ -312,7 +292,6 @@ def train(
input_output_tokenized = self.model.tokenizer(
input_output, return_tensors="pt", padding=True
)
# TODO: Add check for sequence length
validation_output = input_output_tokenized["input_ids"][
:, 1:
]
Expand Down
57 changes: 15 additions & 42 deletions apps/accelerate/chatllama/chatllama/rlhf/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from beartype import beartype
from beartype.typing import Optional, Iterable, List, Tuple
from beartype.typing import Optional, Iterable
from einops.layers.torch import Rearrange
from langchain import OpenAI, LLMChain, PromptTemplate
from torch.utils.data import Dataset, DataLoader
Expand Down Expand Up @@ -35,7 +35,7 @@ class RewardModel(torch.nn.Module):

def __init__(self, config: ConfigReward) -> None:
super().__init__()
# load the model
# load the model -- add here other models
head_hidden_size = config.model_head_hidden_size
if config.model == "gpt2-large":
self.max_model_tokens = 1024
Expand Down Expand Up @@ -132,6 +132,8 @@ def forward(
output = self.model(
output_sequence, attention_mask=output_sequence_mask
)
# What if the output_sequence is longer than the max context of
# the model?
rewards = self.head(output.last_hidden_state)
if self.config.debug:
print("RewardModel.forward")
Expand Down Expand Up @@ -209,29 +211,16 @@ class RewardDataset(Dataset):
read a json file with the following format:
[
{
"general_info": "..."
"general_purpose": "..."
"contract_name": "..."
"section_key_points": "..."
"section_name": "..."
"section_number": ...
"text": "..."
"text_type": "..."
"user_input": "...",
"completion": "...",
"score": ...
} ,
},
...
]
Where:
general_info: general information of the contract (names, dates, ...)
general_purpose: purpose of the contract
contract_name: name of the contract
section_key_points: key points of this section
section_name: name of the section
section_number: number of the section (to determine the order)
text: text of the section
text_type: could be (data, davinci, curie, etc... depending on how it
was generated)
score: score of the section (0-5) given by the LLM
user_input: the initial input of the user
completion: the completion generated by the model
score: the score given by the user to the completion (or by the LLM)
"""

def __init__(self, path: str) -> None:
Expand All @@ -241,11 +230,9 @@ def __init__(self, path: str) -> None:
print(f"Loaded {len(self.data)} samples")

def __getitem__(self, idx: int):
section = self.data[idx]["section_name"]
contract = self.data[idx]["contract_name"]
purpose = self.data[idx]["general_purpose"]
key_points = self.data[idx]["section_key_points"]
item = tuple([section, contract, purpose, key_points])
user_input = self.data[idx]["user_input"]
completion = self.data[idx]["completion"]
item = tuple([user_input, completion])
return item

def __len__(
Expand Down Expand Up @@ -295,20 +282,6 @@ def __init__(self, config: ConfigReward) -> None:
prompt_template = PromptTemplate(**REWARD_TEMPLATE)
self.llm = LLMChain(llm=openai_llm, prompt=prompt_template)

def generate_user_input(self, inputs: Tuple) -> List[str]:
user_inputs = []
for i, input in enumerate(inputs):
section = input[0]
contract = input[1]
purpose = input[2]
key_points = input[3]
user_inputs.append(
f"Write the contract section {section} for the contract "
f"{contract} whose purpose is:\n{purpose}\n"
f"The key points for this section should be:\n{key_points}\n"
)
return user_inputs

def distillate(
self,
):
Expand Down Expand Up @@ -384,10 +357,10 @@ def train(
self.model.train()
for i, inputs in enumerate(train_dataloader):

user_inputs = self.generate_user_input(inputs)
input_text = inputs["user_input"] + inputs["completion"]
# tokenizer (placed here instead of dataset class)
input_tokens = self.model.tokenizer(
user_inputs, padding=True, truncation=True
input_text, padding=True, truncation=True
)

score = None # TODO: load the score
Expand Down
7 changes: 4 additions & 3 deletions apps/accelerate/chatllama/chatllama/rlhf/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ def test_actor_training(path=None, device=None, debug=False):


def test_reward_training(path=None, device=None, debug=False):
device = torch.device("cuda:1")
device = torch.device("cuda:0")
config = Config(path=path, device=device, debug=debug)
trainer = RewardTrainer(config.reward)
trainer.train()
trainer.training_stats.plot()


def test_rl_trainig(path=None, device=None, debug=False):
device = torch.device("cuda:1")
device = torch.device("cuda:0")
config = Config(path=path, device=device, debug=debug)
trainer = RLTrainer(config.trainer)
trainer.distillate()
Expand All @@ -35,7 +35,8 @@ def test_rl_trainig(path=None, device=None, debug=False):
rl_training = False
actor_training = False

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# place here the path to the config.yaml file
config_path = "/home/pierpaolo/Documents/optimapi/ptuning/config.yaml"

if reward_training:
Expand Down
52 changes: 7 additions & 45 deletions apps/accelerate/chatllama/chatllama/rlhf/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from beartype import beartype
from beartype.typing import Deque, Tuple, List, Dict
from beartype.typing import Deque, Tuple, List
from einops import rearrange
from torch.utils.data import Dataset, DataLoader

Expand Down Expand Up @@ -91,13 +91,10 @@ def forward(
def generate(
self, states: torch.Tensor, state_mask: torch.Tensor
) -> Tuple:
"""Generate actions, actions_logits,
values and sequences from states (i.e. input of the
prompt generator model)
"""Generate actions, actions_logits, values and sequences from states
Args:
states (torch.Tensor): States the input of the user to
generate the prompt (action)
states (torch.Tensor): user inputs
state_mask (torch.Tensor): Mask for the states of the environment
Returns:
Expand Down Expand Up @@ -183,29 +180,13 @@ class ExamplesSampler:
read a json file with the following format:
[
{
"general_info": "..."
"general_purpose": "..."
"contract_name": "..."
"section_key_points": "..."
"section_name": "..."
"section_number": ...
"text": "..."
"text_type": "..."
"score": ...
"user_input" : "",
} ,
...
]
Where:
general_info: general information of the contract (names, dates, ...)
general_purpose: purpose of the contract
contract_name: name of the contract
section_key_points: key points of this section
section_name: name of the section
section_number: number of the section (to determine the order)
text: text of the section
text_type: could be (data, davinci, curie, etc... depending on how it
was generated)
score: score of the section (0-5)
user_input: is the input of the user or directly the input of the user
with the memory preappended (i.e. user_input + memory)
"""

def __init__(
Expand Down Expand Up @@ -276,7 +257,6 @@ def __init__(
self.eps = 1e-8

# make models directory
# TODO make this more general
if not os.path.exists("./models"):
os.mkdir("./models")

Expand Down Expand Up @@ -489,20 +469,6 @@ def learn(self, memories: Deque[Memory]) -> None:
self.actorcritic.eval()
print("End Learning")

def generate_user_input(self, inputs: List[Dict]) -> List[str]:
user_inputs = []
for i, input in enumerate(inputs):
section = input["section_name"]
contract = input["contract_name"]
purpose = input["general_purpose"]
key_points = input["section_key_points"]
user_inputs.append(
f"Write the contract section {section} for the contract "
f"{contract} whose purpose is:\n{purpose}\n"
f"The key points for this section should be:\n{key_points}\n"
)
return user_inputs

def train(
self,
) -> None:
Expand Down Expand Up @@ -558,13 +524,9 @@ def train(
# sample num_examples examples from example dataset
inputs = self.example_sampler.sample(num_examples)

# generate user input from contract
# section info for each example
user_inputs = self.generate_user_input(inputs)

# tokenize examples
tokenized_inputs = self.actorcritic.actor.tokenizer(
user_inputs, padding=True, return_tensors="pt"
inputs, padding=True, return_tensors="pt"
)
if self.debug:
print("RLTrainer.train()")
Expand Down
2 changes: 2 additions & 0 deletions apps/accelerate/chatllama/generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def create_conversation(human_agent: LLMChain, bot_agent: LLMChain):


def build_agents():
# be aware that too long completions will not fit the sequence length
# of possible critic or reward models ...
llm = OpenAI(max_tokens=2048, temperature=0.7)
human_template = PromptTemplate(**PERSON_CHATBOT_TEMPLATE)
human_agent = LLMChain(
Expand Down

0 comments on commit 3f9de23

Please sign in to comment.