Skip to content

Commit

Permalink
[feature] Add script for interactive conversation with trained model
Browse files Browse the repository at this point in the history
  • Loading branch information
theblackcat102 committed Feb 14, 2023
1 parent 5347551 commit 5432312
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 5 deletions.
7 changes: 7 additions & 0 deletions model/model_training/README.md
Expand Up @@ -93,6 +93,13 @@ You can itneractively test your model like this:
python tools/model_cli.py --model_path <saved_path/huggingface>
```

Or start a conversation with your bot interactively, mainly for testing context
switch ability

```bash
python -m tools.model_chat --model_path <saved_path/huggingface>
```

## Model

Normally you should be able to add new models in `configs/config.yml`
Expand Down
8 changes: 5 additions & 3 deletions model/model_training/custom_datasets/__init__.py
Expand Up @@ -60,13 +60,15 @@ def get_one_dataset(conf, dataset_name, val_split=0.2, data_path=None, mode="sft
dataset_name = dataset_name.lower()

if dataset_name in QA_DATASETS:
train = QADataset(dataset_name, data_path, "train")
if not train.no_val:
dataset = QADataset(dataset_name, data_path, "train")
if not dataset.no_val:
eval = QADataset(dataset_name, data_path, "validation")
train = dataset
elif dataset_name in SUMMARIZATION_DATASETS:
train = SummarizationDataset(dataset_name, data_path, "train")
dataset = SummarizationDataset(dataset_name, data_path, "train")
if dataset_name != "debate_sum":
eval = SummarizationDataset(dataset_name, data_path, "validation")
train = dataset
elif "ted_trans" in dataset_name:
language_pair = dataset_name.split("_")[-1]
dataset = TEDTalk(pair=language_pair, split="train")
Expand Down
10 changes: 8 additions & 2 deletions model/model_training/custom_datasets/translation.py
Expand Up @@ -104,7 +104,7 @@ def __getitem__(self, index):


class WMT2019(TranslationPair):
def __init__(self, pair="zh-en", split="train", mix_prob=0.2) -> None:
def __init__(self, pair="zh-en", split="train", mix_prob=0.2, maximum_size=100000) -> None:
super().__init__(mix_prob=mix_prob)
dataset = load_dataset("wmt19", pair)[split]
self.pairs = []
Expand All @@ -117,6 +117,9 @@ def __init__(self, pair="zh-en", split="train", mix_prob=0.2) -> None:
else: # translating in reverse direction
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
self.pairs.append((source, row[src]))
# WMT is very large, reduce preprocessing time
if len(self.pairs) > maximum_size:
break


class DiveMT(TranslationPair):
Expand Down Expand Up @@ -146,7 +149,7 @@ def __init__(self, split="train", mix_prob=0.2) -> None:
class TEDTalk(TranslationPair):
# NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean

def __init__(self, pair="de-ja", split="train", year="2016", mix_prob=0.2) -> None:
def __init__(self, pair="de-ja", split="train", year="2016", mix_prob=0.2, maximum_size=100000) -> None:
super().__init__(mix_prob=mix_prob)
dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split]
src, tgt = pair.split("-")
Expand All @@ -158,3 +161,6 @@ def __init__(self, pair="de-ja", split="train", year="2016", mix_prob=0.2) -> No
else: # translating in reverse direction
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
self.pairs.append((source, row[src]))
# WMT is very large
if len(self.pairs) > maximum_size:
break
91 changes: 91 additions & 0 deletions model/model_training/tools/model_chat.py
@@ -0,0 +1,91 @@
"""
A very simple script to test model locally
"""
import argparse

from custom_datasets.formatting import QA_SPECIAL_TOKENS
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import _strtobool

parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--bot_name", type=str, default="Joi", help="Use this when your format isn't in OA format")
parser.add_argument("--format", type=str, default="v2")
parser.add_argument("--max_new_tokens", type=int, default=200)
parser.add_argument("--top_k", type=int, default=40)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--do-sample", type=_strtobool, default=True)
args = parser.parse_args()

bot_name = args.bot_name
model_name = args.model_path
method = args.format


def talk(human_input, history, sep_token, prefix=""):
histories = []
if method == "v2":
prefix = "<prefix>You are a helpful assistant called Joi, you will now help user to answer the question as concise as possible</prefix>"
for (question, answer) in history:
histories.append(
"{}{}{}{}".format(QA_SPECIAL_TOKENS["Question"], question, QA_SPECIAL_TOKENS["Answer"], answer)
)
if len(histories) > 0:
prefix += sep_token.join(histories)
# add sep at the end
prefix += sep_token
prefix += "{}{}{}".format(QA_SPECIAL_TOKENS["Question"], human_input, QA_SPECIAL_TOKENS["Answer"])
else:
for (question, answer) in history:
histories.append("User: " + question + "\n\n{}: ".format(bot_name) + answer + "\n")
if len(histories) > 0:
prefix += "\n".join(histories)
prefix += "\nUser: " + human_input + "\n\n{}: ".format(bot_name)

return prefix


def process_output(output):
if method == "v2":
answer = output.split(QA_SPECIAL_TOKENS["Answer"])[-1]
answer = answer.split("</s>")[0].replace("<|endoftext|>", "").lstrip().split(QA_SPECIAL_TOKENS["Answer"])[0]
else:
answer = output.split("\n\n{}:".format(bot_name))[-1]
answer = answer.split("</s>")[0].replace("<|endoftext|>", "").lstrip().split("\n\n{}:".format(bot_name))[0]
return answer


tokenizer = AutoTokenizer.from_pretrained(model_name)
if method != "v2":
tokenizer.add_special_tokens({"pad_token": "<|endoftext|>"})
model = AutoModelForCausalLM.from_pretrained(model_name).half().eval().cuda()

if __name__ == "__main__":

histories = []
prefix = ""
while True:
print(">", end=" ")
prompt = input()
input_text = talk(prompt, histories, prefix)
inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(0)
outputs = model.generate(
**inputs,
early_stopping=True,
max_new_tokens=args.max_new_tokens,
do_sample=True,
top_k=args.top_k,
temperature=args.temperature,
pad_token_id=tokenizer.eos_token_id,
)
output = tokenizer.decode(outputs[0], truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"])
reply = process_output(output)

if len(reply) != 0:
print(reply)
histories.append((prompt, reply))
else:
print("emtpy token")

0 comments on commit 5432312

Please sign in to comment.