Skip to content

Commit

Permalink
Trainer fixes (#2040)
Browse files Browse the repository at this point in the history
Fixes these two tools to work with a model just trained with trainer_sft

<img width="1111" alt="image"
src="https://user-images.githubusercontent.com/1230500/224469874-388d626c-c9ce-4e75-94d2-a6f329884b12.png">
<img width="1099" alt="image"
src="https://user-images.githubusercontent.com/1230500/224469927-a55b7890-a0c1-42a8-af55-cc55e782388d.png">
  • Loading branch information
johnflux committed Mar 11, 2023
1 parent 50b81f1 commit 0608f2e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 14 deletions.
11 changes: 8 additions & 3 deletions model/model_training/README.md
Expand Up @@ -89,17 +89,22 @@ python trainer_rl.py --configs defaults_rlhf

## Test your model

You can itneractively test your model like this:
You can interactively test your model like this:

```bash
python tools/model_cli.py --model_path <saved_path/huggingface>
python3 tools/model_cli.py --model_path <saved_path/huggingface>
# For example, if you trained with the default config:
python3 tools/model_cli.py --model_path saved_model
# Add --8bit if it is an 8bit model
```

Or start a conversation with your bot interactively, mainly for testing context
switch ability

```bash
python -m tools.model_chat --model_path <saved_path/huggingface>
python3 tools/model_chat.py --model_path <saved_path/huggingface>
# For example, if you trained with the default config:
python3 tools/model_chat.py --model_path saved_model
```

## Model
Expand Down
14 changes: 13 additions & 1 deletion model/model_training/tools/model_chat.py 100644 → 100755
@@ -1,3 +1,4 @@
#!/usr/bin/env python3
"""
A very simple script to test model locally
Expand All @@ -6,6 +7,12 @@
"""
import argparse

if __name__ == "__main__":
import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from custom_datasets.formatting import QA_SPECIAL_TOKENS
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import _strtobool
Expand Down Expand Up @@ -68,12 +75,17 @@ def process_output(output):
prefix = ""
while True:
print(">", end=" ")
prompt = input()
try:
prompt = input()
except (EOFError, KeyboardInterrupt): # Catch ctrl+d and ctrl+c respectively
print()
break
if prompt == "!reset":
histories = []
else:
input_text = talk(prompt, histories, prefix)
inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(0)
del inputs["token_type_ids"]
outputs = model.generate(
**inputs,
early_stopping=True,
Expand Down
30 changes: 21 additions & 9 deletions model/model_training/tools/model_cli.py 100644 → 100755
@@ -1,3 +1,4 @@
#!/usr/bin/env python3
import argparse
import os
import sys
Expand All @@ -23,16 +24,20 @@
parser.add_argument("--top_k", type=int, default=40)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--do-sample", type=_strtobool, default=True)
parser.add_argument("--8bit", action="store_true", dest="eightbit")
args = parser.parse_args()

model = get_specific_model(
args.model_path,
load_in_8bit=True,
device_map="auto",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
offload_state_dict=True,
)
if args.eightbit:
model = get_specific_model(
args.model_path,
load_in_8bit=True,
device_map="auto",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
offload_state_dict=True,
)
else:
model = get_specific_model(args.model_path)

model.gradient_checkpointing_enable() # reduce number of stored activations
tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_path)
Expand All @@ -54,7 +59,10 @@

conversation_history.append(user_input)

batch = tokenizer.encode("".join(format_pair(conversation_history)), return_tensors="pt")
pairs = format_pair(conversation_history)
pairs = ["{}{}".format(p, tokenizer.eos_token) for p in pairs]
pairs.append(QA_SPECIAL_TOKENS["Answer"])
batch = tokenizer.encode("".join(pairs), return_tensors="pt")

with torch.cuda.amp.autocast():
out = model.generate(
Expand All @@ -72,6 +80,10 @@
conversation_history.append(response)
except KeyboardInterrupt:
conversation_history = []
print()
print("Conversation restarted")
time.sleep(1)
continue
except EOFError: # Catch ctrl+d
print()
break
2 changes: 1 addition & 1 deletion model/model_training/trainer_rl.py
Expand Up @@ -67,7 +67,7 @@ def argument_parsing(notebook=False, notebook_args=None):
@torch.no_grad()
def rank_model_fn(samples, **kwargs):
inputs = rank_tokenizer(samples, return_tensors="pt", padding=True)
inputs.pop("token_type_ids", None)
del inputs["token_type_ids"]
return rank_model(**inputs).logits[:, 0].detach().cpu()

trlx_config = TRLConfig.load_yaml("configs/ppo_config.yaml")
Expand Down

0 comments on commit 0608f2e

Please sign in to comment.