Skip to content

Commit

Permalink
[feature] fix bugs and add gsm8k code version
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhi Rui Tam committed Feb 15, 2023
1 parent d50f514 commit 29b205e
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 35 deletions.
16 changes: 1 addition & 15 deletions model/model_training/custom_datasets/__init__.py
Expand Up @@ -14,19 +14,7 @@
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

QA_DATASETS = [
"squad_v2",
"adversarial_qa",
"trivia_qa_context",
"trivia_qa_nocontext",
"gsm8k",
"wikihow",
"essay_instruction",
"math_qa",
"reddit_eli5",
"reddit_askh",
"reddit_asks",
]
QA_DATASETS = list(QADataset.DATASET_FORMAT_MAPPING.keys())
SUMMARIZATION_DATASETS = [
"xsum",
"cnn_dailymail",
Expand Down Expand Up @@ -80,8 +68,6 @@ def get_one_dataset(conf, dataset_name, val_split=0.2, data_path=None, mode="sft
dataset = DiveMT()
elif dataset_name == "webgpt":
dataset = WebGPT()
elif dataset_name == "prompt_dialogue":
dataset = PromptGeneratedDataset(data_path)
elif dataset_name == "prosocial_dialogue":
train = ProsocialDialogue(cache_dir=data_path, split="train")
eval = ProsocialDialogue(cache_dir=data_path, split="validation")
Expand Down
2 changes: 1 addition & 1 deletion model/model_training/custom_datasets/dialogue_collator.py
Expand Up @@ -33,7 +33,7 @@ def __call__(self, features):

# Add a way for the model to terminate generation
# When we predict the start of a new expected question, we want to be able to stop generation
messages.append(QA_SPECIAL_TOKENS["Question"])
messages.append(self.tokenizer.eos_token)

flatten_message = self.tokenizer(
"".join(messages),
Expand Down
4 changes: 4 additions & 0 deletions model/model_training/custom_datasets/qa_datasets.py
Expand Up @@ -70,6 +70,9 @@ def index_eli5(example):
return example["title"], example["answers"]["text"][0]


def index_gsm_hard(example):
return example['input']+'\nWrite a small snippet of python code to answer this', "Here's the code solution to the question\n```python\n{}\n```\n The answer should be {}".format(example['code'].strip(), example['target'])

class QADataset(Dataset):
"""
How to define a new QA dataset:
Expand Down Expand Up @@ -101,6 +104,7 @@ class QADataset(Dataset):
"index_fn": index_adversarial_qa,
"params": {"name": "adversarialQA"},
},
"gsm8k_hard": {"index_fn": index_gsm_hard, "name": "reasoning-machines/gsm-hard", "no_val": True},
"gsm8k": {"index_fn": index_gsm8k, "params": {"name": "main"}, "validation": "test"},
"wikihow": {"name": "b-mc2/wikihow_lists", "index_fn": index_wikihow, "no_val": True},
"essay_instruction": {
Expand Down
42 changes: 23 additions & 19 deletions model/model_training/tools/model_chat.py
Expand Up @@ -28,7 +28,7 @@
def talk(human_input, history, sep_token, prefix=""):
histories = []
if method == "v2":
prefix = "<prefix>You are a helpful assistant called Joi, you will now help user to answer the question as concise as possible</prefix>"
prefix = "<prefix>You are a helpful assistant called Joi trained by OpenAssistant on large corpus of data, you will now help user to answer the question as concise as possible</prefix>"
for (question, answer) in history:
histories.append(
"{}{}{}{}".format(QA_SPECIAL_TOKENS["Question"], question, QA_SPECIAL_TOKENS["Answer"], answer)
Expand Down Expand Up @@ -70,22 +70,26 @@ def process_output(output):
while True:
print(">", end=" ")
prompt = input()
input_text = talk(prompt, histories, prefix)
inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(0)
outputs = model.generate(
**inputs,
early_stopping=True,
max_new_tokens=args.max_new_tokens,
do_sample=True,
top_k=args.top_k,
temperature=args.temperature,
pad_token_id=tokenizer.eos_token_id,
)
output = tokenizer.decode(outputs[0], truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"])
reply = process_output(output)

if len(reply) != 0:
print(reply)
histories.append((prompt, reply))
if prompt == '!reset':
histories = []
else:
print("emtpy token")
input_text = talk(prompt, histories, prefix)
inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(0)
outputs = model.generate(
**inputs,
early_stopping=True,
max_new_tokens=args.max_new_tokens,
do_sample=True,
top_k=args.top_k,
temperature=args.temperature,
pad_token_id=tokenizer.eos_token_id,
# dialogue_collator.py line 36
)
output = tokenizer.decode(outputs[0], truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"])
reply = process_output(output)

if len(reply) != 0:
print(reply)
histories.append((prompt, reply))
else:
print("emtpy token")

0 comments on commit 29b205e

Please sign in to comment.