Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORPO trainer #1435

Merged
merged 38 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
f793861
initial orpo skeleton
kashif Mar 17, 2024
1d588bd
typos
kashif Mar 17, 2024
06dafae
calculate orpo loss
kashif Mar 17, 2024
eb4a0d8
fix class name
kashif Mar 17, 2024
f0d085e
fix tests
kashif Mar 17, 2024
8993aec
fix typo
kashif Mar 17, 2024
8a01ab7
Update docs/source/orpo_trainer.md
kashif Mar 18, 2024
5509520
Update docs/source/orpo_trainer.md
kashif Mar 18, 2024
5187ffc
Update docs/source/orpo_trainer.md
kashif Mar 18, 2024
a99f362
rename max_target_length
kashif Mar 18, 2024
7d75c80
Update examples/scripts/orpo.py
kashif Mar 18, 2024
7e79743
Update examples/scripts/orpo.py
kashif Mar 18, 2024
ade72ea
Update examples/scripts/orpo.py
kashif Mar 18, 2024
eae80f9
more docs
kashif Mar 18, 2024
6fb33d1
log log_odds_ratio and log_odds
kashif Mar 18, 2024
8191a9d
average_log_prob as per paper
kashif Mar 18, 2024
f5dc9bd
added logging section
kashif Mar 18, 2024
c44ef2a
add nll_loss
kashif Mar 18, 2024
d33610c
fix typo
kashif Mar 18, 2024
7ed91b0
more verbose
kashif Mar 18, 2024
b37233d
rename log_odds to log_odds_chosen
kashif Mar 18, 2024
a0acecc
allow datasets to be loaded
kashif Mar 18, 2024
fdd7f5e
remove dup debug arg
kashif Mar 18, 2024
229200e
tokenizer exists
kashif Mar 18, 2024
fa65623
fix typo
kashif Mar 18, 2024
545d987
use trl-internal-testing/hh-rlhf-trl-style dataset
kashif Mar 18, 2024
c2013ed
Merge branch 'main' into orpo
kashif Mar 18, 2024
28e2c6e
formatting
kashif Mar 18, 2024
e2b02d3
add missing imports
kashif Mar 18, 2024
c57bc91
fix output dir name
kashif Mar 19, 2024
9174686
Update examples/scripts/orpo.py
kashif Mar 19, 2024
8754b4f
move dataset_num_proc to configs
kashif Mar 19, 2024
4052db6
Update trl/trainer/orpo_config.py
kashif Mar 21, 2024
443c1bb
Update trl/trainer/orpo_trainer.py
kashif Mar 22, 2024
99fd6e7
Merge remote-tracking branch 'upstream/main' into orpo
kashif Mar 22, 2024
dc19c61
add ORPOTrainer to readme
kashif Mar 22, 2024
160ddf3
Merge remote-tracking branch 'upstream/main' into orpo
kashif Mar 22, 2024
1846f07
fix typo
kashif Mar 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The library is built on top of the [`transformers`](https://github.com/huggingfa
- [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA.
- [`unsloth`](https://github.com/unslothai/unsloth) is also integrated and allows to significantly speed up training with dedicated kernels.
- **`CLI`**: With the [CLI](https://huggingface.co/docs/trl/clis) you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system.
- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer), and [`CPOTrainer`]((https://huggingface.co/docs/trl/trainer#trl.CPOTrainer).
- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer), [`CPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.CPOTrainer), and [`ORPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.ORPOTrainer).
- **`AutoModels`**: The [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead) classes add an additional value head to the model which allows to train them with RL algorithms such as PPO.
- **`Examples`**: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [StackLlama example](https://huggingface.co/blog/stackllama), etc. following the [examples](https://github.com/huggingface/trl/tree/main/examples).

Expand Down
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
title: CPO Trainer
- local: ddpo_trainer
title: Denoising Diffusion Policy Optimization
- local: orpo_trainer
title: ORPO Trainer
- local: iterative_sft_trainer
title: Iterative Supervised Fine-Tuning
- local: text_environments
Expand Down
98 changes: 98 additions & 0 deletions docs/source/orpo_trainer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# ORPO Trainer

[Odds Ratio Preference Optimization](https://arxiv.org/abs/2403.07691) (ORPO) by Jiwoo Hong, Noah Lee, and James Thorne studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT.

Thus ORPO is a reference model-free preference optimization algorithm eliminating the necessity for an additional preference alignment phase thus saving compute and memory.

The official code can be found [xfactlab/orpo](https://github.com/xfactlab/orpo).

## Expected dataset format

The ORPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:

- `prompt`
- `chosen`
- `rejected`

for example:

```py
orpo_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Java",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"C++",
],
}
```
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. Note that a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.

## Expected model format
The ORPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.

## Using the `ORPOTrainer`
For a detailed example have a look at the `examples/scripts/orpo.py` script. At a high level we need to initialize the `ORPOTrainer` with a `model` we wish to train. **Note that ORPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter `lambda` in eq. (6) of the paper and refers to the weighting of the relative odd ratio loss in the standard cross-entropy loss used for SFT.

```py
orpo_config = ORPOConfig(
beta=0.1, # the lambda/alpha hyperparameter in the paper/code
)

orpo_trainer = ORPOTrainer(
model,
args=orpo_config,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
```
After this one can then call:

```py
orpo_trainer.train()
```

## Logging

While training and evaluating we record the following reward metrics:

* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards

* `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses

* `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))`

* `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses

## ORPOTrainer

[[autodoc]] ORPOTrainer


## ORPOConfig

[[autodoc]] ORPOConfig
121 changes: 121 additions & 0 deletions examples/scripts/orpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run the ORPO training script with the following command with some example arguments.
In general, the optimal configuration for ORPO will be similar to that of DPO without the need for a reference model:

# regular:
python examples/scripts/orpo.py \
--model_name_or_path=gpt2 \
kashif marked this conversation as resolved.
Show resolved Hide resolved
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-6 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-aligned-orpo" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns

# peft:
python examples/scripts/orpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-5 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-lora-aligned-orpo" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""

import multiprocessing
from dataclasses import dataclass, field

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import ModelConfig, ORPOConfig, ORPOTrainer, get_peft_config


@dataclass
class ScriptArguments:
dataset: str = field(
default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."}
)


if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig))
args, orpo_args, model_config = parser.parse_args_into_dataclasses()

################
# Model & Tokenizer
################
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path)
peft_config = get_peft_config(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

################
# Dataset
################
ds = load_dataset(args.dataset)
if orpo_args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row

ds = ds.map(
process,
num_proc=1 if orpo_args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
)
train_dataset = ds["train"]
eval_dataset = ds["test"]

################
# Training
################
trainer = ORPOTrainer(
model,
args=orpo_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)

# train and save the model
trainer.train()
trainer.save_model(orpo_args.output_dir)
Loading
Loading