generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
ppo_tldr.py
120 lines (108 loc) · 4.28 KB
/
ppo_tldr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import multiprocessing
import shutil
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
HfArgumentParser,
)
from trl import ModelConfig
from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
"""
python examples/scripts/ppo/ppo_tldr.py \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 64 \
--total_episodes 30000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--non_eos_penalty \
--stop_token eos \
--response_length 53 \
--sanity_check
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/ppo/ppo_tldr.py \
--output_dir models/minimal/ppo_tldr \
--learning_rate 3e-6 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 4 \
--total_episodes 1000000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--non_eos_penalty \
--stop_token eos \
"""
if __name__ == "__main__":
parser = HfArgumentParser((PPOv2Config, ModelConfig))
config, model_config = parser.parse_args_into_dataclasses()
# remove output_dir if exists
shutil.rmtree(config.output_dir, ignore_errors=True)
################
# Model & Tokenizer
################
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
trust_remote_code=True,
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
value_model = AutoModelForSequenceClassification.from_pretrained(config.reward_model_path, num_labels=1)
reward_model = AutoModelForSequenceClassification.from_pretrained(config.reward_model_path, num_labels=1)
ref_policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path)
policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path)
################
# Dataset
################
raw_datasets = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style")
if config.sanity_check:
for key in raw_datasets:
raw_datasets[key] = raw_datasets[key].select(range(1000))
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["validation"]
def prepare_dataset(dataset, tokenizer):
"""pre-tokenize the dataset before training; only collate during training"""
def tokenize(element):
input_ids = tokenizer.apply_chat_template(
element["messages"][:1],
padding=False,
add_generation_prompt=True,
)
return {"input_ids": input_ids, "lengths": len(input_ids)}
return dataset.map(
tokenize,
remove_columns=dataset.column_names,
num_proc=1 if config.sanity_check else multiprocessing.cpu_count(),
load_from_cache_file=not config.sanity_check,
)
train_dataset = prepare_dataset(train_dataset, tokenizer)
eval_dataset = prepare_dataset(eval_dataset, tokenizer)
# filtering
train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512)
eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512)
assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token"
################
# Training
################
trainer = PPOv2Trainer(
config=config,
tokenizer=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
value_model=value_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_model(config.output_dir)
if config.push_to_hub:
trainer.push_to_hub()
trainer.generate_completions()