-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ORPO trainer #1435
ORPO trainer #1435
Conversation
cc @philschmid |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool to see this elegant method get added so quickly @kashif !
I left a few remarks about what to log etc and to harmonise the example script to be less hard-coded wrt Anthropic HH. I'd also like to see some small experiments which show the metrics look sane for the examples. Otherwise it's looking great!
docs/source/orpo_trainer.md
Outdated
|
||
While training and evaluating we record the following reward metrics: | ||
|
||
TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trl/trainer/orpo_trainer.py
Outdated
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): | ||
The function to use to compute the metrics. Must take a `EvalPrediction` and return | ||
a dictionary string to metric values. | ||
dataset_num_proc (`Optional[int]`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could live in the ORPOConfig
(ideally we want nearly everything that is not a callable to live in a single config so it can be easily tweaked at the command line)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kashif WDYT about this proposal to move the arg to the config?
reward_accuracies = (chosen_rewards > rejected_rewards).float() | ||
|
||
prefix = "eval_" if train_eval == "eval" else "" | ||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can just keep logps
and nll_loss
and also log odds ratio to simplify this?
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Thank you so much for such a fast implementation of ORPO @kashif😀 Also, about #1435 (comment), I think it is a great idea to log the log odds ratio as it can help monitor the effect of β(the weighting hyperparam) as this report. Thank you again for the implementation! |
ok i'll add the logging next |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Added some comments.
examples/scripts/orpo.py
Outdated
sanity_check: bool = field(default=True, metadata={"help": "only train on 1000 samples"}) | ||
|
||
|
||
def extract_anthropic_prompt(prompt_and_response): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now have standard datasets under https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-trl-style (#1424).
- https://github.com/huggingface/trl/blob/main/examples/datasets/anthropic_hh.py creates the standard dataset (prompt, chosen, rejected)
- https://github.com/huggingface/trl/blob/main/examples/datasets/tokenize_ds.py is a usage example.
In this case, maybe you could try:
ds = load_dataset(args.dataset)
if args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row
ds = ds.map(
process,
num_proc=1 if args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
)
train_dataset = ds["train"]
eval_dataset = ds["test"]
trainer = ORPOTrainer(
model,
args=orpo_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)
here the args.debug
is doing the same thing as sanity_check
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating @kashif ! I left one final nit and a question about moving the dataset proc args to the config. Apart form that LGTM 🔥
trl/trainer/orpo_trainer.py
Outdated
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): | ||
The function to use to compute the metrics. Must take a `EvalPrediction` and return | ||
a dictionary string to metric values. | ||
dataset_num_proc (`Optional[int]`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kashif WDYT about this proposal to move the arg to the config?
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Just gave a quick try with ORPOTrainer with |
Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>
Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it would also be nice to include the ORPOTrainer
within the README.md
listing of supported trainers see https://github.com/huggingface/trl?tab=readme-ov-file#highlights
@alvarobartt ok yes good idea! adding |
* initial orpo skeleton * typos * calculate orpo loss * fix class name * fix tests * fix typo * Update docs/source/orpo_trainer.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update docs/source/orpo_trainer.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update docs/source/orpo_trainer.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * rename max_target_length * Update examples/scripts/orpo.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update examples/scripts/orpo.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update examples/scripts/orpo.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * more docs * log log_odds_ratio and log_odds * average_log_prob as per paper * added logging section * add nll_loss * fix typo * more verbose * rename log_odds to log_odds_chosen * allow datasets to be loaded * remove dup debug arg * tokenizer exists * fix typo * use trl-internal-testing/hh-rlhf-trl-style dataset * formatting * add missing imports * fix output dir name * Update examples/scripts/orpo.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * move dataset_num_proc to configs * Update trl/trainer/orpo_config.py Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com> * Update trl/trainer/orpo_trainer.py Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com> * add ORPOTrainer to readme * fix typo --------- Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>
ORPO trainer
Reference-free Monolithic Preference Optimization with Odds Ratio
cc @jiwooya1000