Skip to content

DPO Trainer#416

Merged
lvwerra merged 58 commits intohuggingface:mainfrom
kashif:dpo
Jul 17, 2023
Merged

DPO Trainer#416
lvwerra merged 58 commits intohuggingface:mainfrom
kashif:dpo

Conversation

@kashif
Copy link
Copy Markdown
Collaborator

@kashif kashif commented Jun 8, 2023

Initial DPOTrainer class for #405 by copying the PPOTrainer RewardTrainer and started to implement changes in it

Fixes #405

@kashif kashif marked this pull request as draft June 8, 2023 11:30
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Jun 10, 2023

The documentation is not available anymore as the PR was closed or merged.

@Forbu
Copy link
Copy Markdown

Forbu commented Jun 14, 2023

Thank you @kashif.
I think one interesting stuff to do would be to make the forward pass of the model_ref in no_grad mode no ?
You don't need to compute the backward pass for the ref model because you don't want to modify the ref model weights.

@kashif
Copy link
Copy Markdown
Collaborator Author

kashif commented Jun 14, 2023

agree @Forbu and I think we will refactor the data-collator so that we only have a mask on the positive and negative part of the sequence...

@gaetanlop
Copy link
Copy Markdown
Contributor

gaetanlop commented Jun 15, 2023

Hello @kashif, thanks for the DPO integration. I did a refactoring of the data collator to compute the mean logprobs only on the positive and negative part of the sequence. Can you share with me the branch please?

@kashif
Copy link
Copy Markdown
Collaborator Author

kashif commented Jun 15, 2023

@gaetanlop i have added you to my fork

@gaetanlop
Copy link
Copy Markdown
Contributor

Thanks @kashif, just pushed the required changes to make the trainer not compute the mean logprobs on masked input_ids. I used a similar approach as the HF DataCollatorForTokenClassification (https://github.com/huggingface/transformers/blob/main/src/transformers/data/data_collator.py)

@TevenLeScao
Copy link
Copy Markdown
Contributor

Flagging that at some point we'll want the ScriptArguments to be consistent with the TrainingArguments (log_with -> report_to, batch_size -> per_device_train_batch_size, model_name -> model_name_or_path) the second and third ones are especially important since the semantic of the argument is actually different.

@TevenLeScao
Copy link
Copy Markdown
Contributor

TevenLeScao commented Jun 16, 2023

@kashif @gaetanlop a fix to enable distributed training required a change to transformers.Trainer (huggingface/transformers#24326), you'll need to pull transformers from master if you pull it!

@eric-mitchell
Copy link
Copy Markdown

@kashif just to follow up again, the re-run of the DPO stage of Pythia on HH just completed, and training is basically unchanged there as well after the bug fix on our end. So comparing with our original runs should be fine.

Screenshot 2023-07-05 at 2 36 21 PM

You can see the SFT/DPO runs before and after the bug fix here.

Comment thread trl/trainer/dpo_trainer.py
@eric-mitchell
Copy link
Copy Markdown

@kashif just wanted to check in on this- curious if you've had the chance to re-run the replication experiment :)

@kashif
Copy link
Copy Markdown
Collaborator Author

kashif commented Jul 14, 2023

@eric-mitchell i did but not with a SFT'ed pythia... also with Peft worked... We had validation loss around 0.6 or so as per your wandb so that was nice... if you can share your SFT'ed pythia model with me I can also run it now with that?

@eric-mitchell
Copy link
Copy Markdown

Here are the weights to our pre-trained Pythia. You can load with model.load_state_dict(torch.load(PATH)['state']).

It's nice that peft worked! What type of peft did you use?

@kashif
Copy link
Copy Markdown
Collaborator Author

kashif commented Jul 14, 2023

@eric-mitchell thanks! yes we tried QLora and Lora as well.. let me confirm. Thanks for the weights!

Comment thread docs/source/dpo_trainer.mdx
Copy link
Copy Markdown
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two small nits, then we can merge :)

Comment thread trl/trainer/dpo_trainer.py Outdated
Comment thread docs/source/dpo_trainer.mdx Outdated
Comment thread docs/source/trainer.mdx
@lvwerra lvwerra merged commit 84393f3 into huggingface:main Jul 17, 2023
@kashif kashif deleted the dpo branch July 17, 2023 13:29
yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
* initial DPO Trainer

* typo

* initial dpo from reward trainer

* calc. log_probs from logits

* remove dpo config for now

* fix inits

* add intial DPODataCollatorWithPadding

* use the RewardDataCollatorWithPadding

* initial test

* means of loss

* add assert

* just call the train instead of step

* functional debug example before refactor

* check the params have changed

* initial DPODataCollatorWithPadding

* Data collator with masking

* going through trainer.accelerate to wrap ref_model

* style / imports

* style / imports

* `broadcast_buffers=False` fix to distributed training

* better fix for DDP issues

* arguments and style clean-up

* better doc, some light refactoring

* better imports

* initial dpo doc

* fix test

* fix formatting

* fix

* called models once

* fix tests

* add example

* fix doc string

* intitial example with anthropic hh dataset

* refactored dpo trainer

* revert

* return metrics

* fixed tests

* updated docs

* update test

* fixed typo

* note about the beta

* added dpo authors

* fix docstrings

* add prediction_step

* remove compute_metrics and log metrics manually

* fix typo

* add DPOTrainer doc

* add dpo to toc

* ValueError

* add to index and example

* fix docs

* fix assert

---------

Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
Co-authored-by: Gaetan LOPEZ <gaetanloplat@gmail.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Adding DPOTrainer in trl

9 participants