Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORPO trainer #1435

Merged
merged 38 commits into from
Mar 22, 2024
Merged

ORPO trainer #1435

merged 38 commits into from
Mar 22, 2024

Conversation

kashif
Copy link
Collaborator

@kashif kashif commented Mar 17, 2024

ORPO trainer

Reference-free Monolithic Preference Optimization with Odds Ratio

cc @jiwooya1000

  • figure out what to log
  • add logging section to the docs

@kashif kashif marked this pull request as draft March 17, 2024 17:18
@kashif kashif marked this pull request as ready for review March 17, 2024 19:35
@kashif
Copy link
Collaborator Author

kashif commented Mar 17, 2024

cc @philschmid

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif kashif requested a review from lewtun March 17, 2024 19:49
Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool to see this elegant method get added so quickly @kashif !

I left a few remarks about what to log etc and to harmonise the example script to be less hard-coded wrt Anthropic HH. I'd also like to see some small experiments which show the metrics look sane for the examples. Otherwise it's looking great!

docs/source/orpo_trainer.md Outdated Show resolved Hide resolved
docs/source/orpo_trainer.md Outdated Show resolved Hide resolved
docs/source/orpo_trainer.md Outdated Show resolved Hide resolved

While training and evaluating we record the following reward metrics:

TODO
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT about logging the log probs and log odds ratio alongside the SFT loss, OR loss and full loss? This way the user can debug if the rejected log probs are decreasing over the course of training

Screenshot 2024-03-18 at 08 47 32

examples/scripts/orpo.py Outdated Show resolved Hide resolved
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
dataset_num_proc (`Optional[int]`, *optional*):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could live in the ORPOConfig (ideally we want nearly everything that is not a callable to live in a single config so it can be easily tweaked at the command line)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kashif WDYT about this proposal to move the arg to the config?

trl/trainer/orpo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/orpo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/orpo_trainer.py Show resolved Hide resolved
reward_accuracies = (chosen_rewards > rejected_rewards).float()

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can just keep logps and nll_loss and also log odds ratio to simplify this?

kashif and others added 8 commits March 18, 2024 09:23
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
@jiwooya1000
Copy link

Thank you so much for such a fast implementation of ORPO @kashif😀

Also, about #1435 (comment), I think it is a great idea to log the log odds ratio as it can help monitor the effect of β(the weighting hyperparam) as this report.

image

Thank you again for the implementation!

@kashif
Copy link
Collaborator Author

kashif commented Mar 18, 2024

ok i'll add the logging next

Copy link
Contributor

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Added some comments.

sanity_check: bool = field(default=True, metadata={"help": "only train on 1000 samples"})


def extract_anthropic_prompt(prompt_and_response):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now have standard datasets under https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-trl-style (#1424).

In this case, maybe you could try:

    ds = load_dataset(args.dataset)
    if args.debug:
        for key in ds:
            ds[key] = ds[key].select(range(50))
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    if tokenizer.chat_template is None:
        tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"

    def process(row):
        row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
        row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
        return row

    ds = ds.map(
        process,
        num_proc=1 if args.debug else multiprocessing.cpu_count(),
        load_from_cache_file=False,
    )
    train_dataset = ds["train"]
    eval_dataset = ds["test"]
    trainer = ORPOTrainer(
        model,
        args=orpo_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        peft_config=get_peft_config(model_config),
    )

here the args.debug is doing the same thing as sanity_check.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating @kashif ! I left one final nit and a question about moving the dataset proc args to the config. Apart form that LGTM 🔥

examples/scripts/orpo.py Outdated Show resolved Hide resolved
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
dataset_num_proc (`Optional[int]`, *optional*):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kashif WDYT about this proposal to move the arg to the config?

kashif and others added 2 commits March 19, 2024 10:40
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
@jiwooya1000
Copy link

Just gave a quick try with ORPOTrainer with facebook/opt-350m + argilla/ultrafeedback-binarized-preferences-cleaned, and it seems to be working well😃 Thank you for your work @kashif!

image

Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>
trl/trainer/orpo_trainer.py Show resolved Hide resolved
trl/trainer/orpo_trainer.py Show resolved Hide resolved
trl/trainer/orpo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/orpo_trainer.py Show resolved Hide resolved
Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>
Copy link
Member

@alvarobartt alvarobartt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would also be nice to include the ORPOTrainer within the README.md listing of supported trainers see https://github.com/huggingface/trl?tab=readme-ov-file#highlights

@kashif
Copy link
Collaborator Author

kashif commented Mar 22, 2024

@alvarobartt ok yes good idea! adding

@kashif kashif merged commit 2ce8e45 into main Mar 22, 2024
9 checks passed
@kashif kashif deleted the orpo branch March 22, 2024 21:07
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* initial orpo skeleton

* typos

* calculate orpo loss

* fix class name

* fix tests

* fix typo

* Update docs/source/orpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/orpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/orpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* rename max_target_length

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* more docs

* log log_odds_ratio and log_odds

* average_log_prob as per paper

* added logging section

* add nll_loss

* fix typo

* more verbose

* rename log_odds to log_odds_chosen

* allow datasets to be loaded

* remove dup debug arg

* tokenizer exists

* fix typo

* use trl-internal-testing/hh-rlhf-trl-style dataset

* formatting

* add missing imports

* fix output dir name

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* move dataset_num_proc to configs

* Update trl/trainer/orpo_config.py

Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>

* Update trl/trainer/orpo_trainer.py

Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>

* add ORPOTrainer to readme

* fix typo

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants