Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

small KTO fixes #1734

Merged
merged 50 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
6ee3be4
add warning for imbalanced data
kawine Feb 25, 2024
22dd810
update documentation
kawine Feb 25, 2024
8d14930
update script commands to be same as in dpo
kawine Feb 25, 2024
8a490af
use batch_size KL examples and batch_size target examples to calculat…
kawine Feb 25, 2024
f826600
fix deepspeed issue
kawine Feb 25, 2024
688ed6c
speed up forward with no_grad for KL
kawine Feb 26, 2024
587517b
Merge branch 'huggingface:main' into main
kawine Feb 28, 2024
e128f09
add some removed metrics
kawine Feb 28, 2024
2d860b8
Update trl/trainer/kto_trainer.py
kashif Feb 28, 2024
48d25ff
Update trl/trainer/kto_trainer.py
kashif Feb 28, 2024
392bcc0
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
a42049f
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
5696814
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
000d5d8
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
2738d1f
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
d7f63c5
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
824da55
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
4399af4
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
69094be
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
73f7ed7
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
5b95aca
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
3102901
Update trl/trainer/kto_trainer.py
kawine Feb 28, 2024
ca68f24
add more detailed comments
kawine Feb 28, 2024
94fb375
convert assert to ValueError
kawine Feb 28, 2024
8f7e788
Update kto_trainer.py
kawine Feb 29, 2024
ed19ed5
precommit formatting
kawine Feb 29, 2024
310bd97
Merge branch 'main' of https://github.com/kawine/trl into main
kawine Feb 29, 2024
639f4de
Merge branch 'huggingface:main' into main
kawine Mar 1, 2024
ee7d6a4
remove nans in metrics by gathering across machines
kawine Mar 1, 2024
7ae95c2
fix formatting
kawine Mar 1, 2024
1b96b2d
fix choice of mismatched examples for KL term
kawine Mar 4, 2024
81b60da
describe weights
kawine Mar 4, 2024
1f145b9
fix hanging issue in distributed training
kawine Mar 6, 2024
83ed882
linting
kawine Mar 7, 2024
9c5480d
Merge branch 'main' of https://github.com/kawine/trl into main
kawine Mar 7, 2024
15251ff
move metrics to cpu
kawine Mar 7, 2024
8f9fdfe
Update trl/trainer/kto_trainer.py
kawine Mar 7, 2024
600aad8
Update trl/trainer/kto_trainer.py
kashif Mar 8, 2024
8b5367e
Update trl/trainer/kto_trainer.py
kashif Mar 8, 2024
5cc6fed
Merge branch 'huggingface:main' into main
kawine Mar 9, 2024
03dfe90
Merge branch 'huggingface:main' into main
kawine Mar 11, 2024
cf6217e
Merge branch 'main' of https://github.com/kawine/trl into main
xwinxu May 7, 2024
e3a3691
Merge branch 'huggingface:main' into main
xwinxu May 18, 2024
e88b5ed
Merge branch 'huggingface:main' into main
kawine Jun 12, 2024
12f99b9
remove kto_pair
kawine Jun 14, 2024
2424ef0
resolve conflicts
kawine Jun 14, 2024
af7c424
speed up data processing
kawine Jun 14, 2024
95f361b
move bco code inside
kawine Jun 14, 2024
4c234a3
raise error for kto_pair argument
kawine Jun 15, 2024
e88de4c
fix formatting
kawine Jun 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we

The [Robust DPO](https://arxiv.org/abs/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5). Use `loss_type="robust"` to the trainer to use it.

The [KTO](https://arxiv.org/abs/2402.01306) authors directly maximize the utility of LLM generations instead of the log-likelihood of preferences. To use preference data with KTO, we recommend breaking up the n preferences into 2n examples and using [`KTOTrainer`](kto_trainer) (i.e., treating the data like an unpaired feedback dataset). Although it is possible to pass in `loss_type="kto_pair"` into DPOTrainer, this is a highly simplified version of KTO that we *do not recommend* in most cases. Please use [`KTOTrainer`](kto_trainer) when possible.

The [BCO](https://arxiv.org/abs/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. The `DPOTrainer` can be switched to this loss via the `loss_type="bco_pair"` argument.

The [SPPO](https://arxiv.org/abs/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation using loss_type="sppo_hard" approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser.
Expand All @@ -121,7 +119,7 @@ The [TR-DPO](https://arxiv.org/pdf/2404.09656) paper suggests syncing the refere

The [RPO](https://arxiv.org/abs/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://arxiv.org/abs/2405.16436) that essentially consists of the SFT loss on the chosen preferences together with a weighted DPO loss. To use this loss set the `rpo_alpha` in the `DPOConfig` to an appropriate value.

The [AOT](https://arxiv.org/abs/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. Note that `loss_type="aot_pair"` is similar in spirit to `loss_type="kto_pair"` that applies unpaired alignment methodology on paired dataset. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size.
The [AOT](https://arxiv.org/abs/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size.

## Logging

Expand Down
2 changes: 1 addition & 1 deletion tests/slow/testing_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@
GRADIENT_CHECKPOINTING_KWARGS = [None, {"use_reentrant": False}, {"use_reentrant": True}]
DEVICE_MAP_OPTIONS = [{"": 0}, "auto"]

DPO_LOSS_TYPES = ["sigmoid", "ipo", "kto_pair"]
DPO_LOSS_TYPES = ["sigmoid", "ipo"]
DPO_PRECOMPUTE_LOGITS = [True, False]
6 changes: 0 additions & 6 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ def _init_dummy_dataset(self):
["t5", "hinge", False],
["gpt2", "ipo", False],
["t5", "ipo", True],
["gpt2", "kto_pair", True],
["t5", "kto_pair", False],
["gpt2", "aot_pair", True],
["t5", "aot_pair", False],
["gpt2", "aot", True],
Expand Down Expand Up @@ -506,10 +504,6 @@ def test_dpo_lora_bf16_autocast_llama(self):
["gpt2", "ipo", False, True],
["gpt2", "ipo", True, False],
["gpt2", "ipo", True, True],
["gpt2", "kto_pair", False, False],
["gpt2", "kto_pair", False, True],
["gpt2", "kto_pair", True, False],
["gpt2", "kto_pair", True, True],
["gpt2", "aot_pair", False, False],
["gpt2", "aot_pair", False, True],
["gpt2", "aot_pair", True, False],
Expand Down
6 changes: 5 additions & 1 deletion trl/trainer/cpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class CPOConfig(TrainingArguments):

beta: float = 0.1
label_smoothing: float = 0
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair", "simpo"] = "sigmoid"
loss_type: Literal["sigmoid", "hinge", "ipo", "simpo"] = "sigmoid"
kawine marked this conversation as resolved.
Show resolved Hide resolved
disable_dropout: bool = True
simpo_gamma: float = 0.5

Expand All @@ -79,3 +79,7 @@ class CPOConfig(TrainingArguments):
model_init_kwargs: Optional[Dict] = None

dataset_num_proc: Optional[int] = None

def __post_init__(self):
if self.loss_type == "kto_pair":
raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
6 changes: 4 additions & 2 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,12 @@ def make_inputs_require_grad(module, input, output):
self.max_target_length = max_target_length
self.tokenizer = tokenizer

if args.loss_type in ["hinge", "ipo", "kto_pair"] and args.label_smoothing > 0:
if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
warnings.warn(
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
)
if args.loss_type == "kto_pair":
raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")

self.beta = args.beta
self.label_smoothing = args.label_smoothing
Expand Down Expand Up @@ -610,7 +612,7 @@ def cpo_loss(
losses = (logits - 1 / (2 * self.beta)) ** 2
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'simpo']"
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
)

chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
Expand Down
8 changes: 6 additions & 2 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class DPOConfig(TrainingArguments):
The robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report and [Robust DPO](https://arxiv.org/abs/2403.00409) paper that should be between 0 and 0.5.
loss_type (`str`, defaults to `"sigmoid"`):
The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper,
`"kto_pair"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf), `"bco_pair"` from [BCO](https://arxiv.org/abs/2404.04656) paper or `"robust"` from [Robust DPO](https://arxiv.org/abs/2403.00409) paper,
`"bco_pair"` from [BCO](https://arxiv.org/abs/2404.04656) paper or `"robust"` from [Robust DPO](https://arxiv.org/abs/2403.00409) paper,
"aot" and "aot_pair" from alignment via optimal transport
label_pad_token_id (`int`, defaults to `-100`):
The label pad token id. This argument is required if you want to use the default data collator.
Expand Down Expand Up @@ -79,7 +79,7 @@ class DPOConfig(TrainingArguments):
beta: float = 0.1
label_smoothing: float = 0
loss_type: Literal[
"sigmoid", "hinge", "ipo", "kto_pair", "bco_pair", "sppo_hard", "nca_pair", "robust", "aot", "aot_pair"
"sigmoid", "hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "robust", "aot", "aot_pair"
] = "sigmoid"
kawine marked this conversation as resolved.
Show resolved Hide resolved
label_pad_token_id: int = -100
padding_value: Optional[int] = None
Expand All @@ -102,3 +102,7 @@ class DPOConfig(TrainingArguments):
ref_model_mixup_alpha: float = 0.9
ref_model_sync_steps: int = 64
rpo_alpha: Optional[float] = None

def __post_init__(self):
if self.loss_type == "kto_pair":
raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.")
23 changes: 5 additions & 18 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
beta: float = 0.1,
label_smoothing: float = 0,
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair", "bco_pair", "robust", "aot", "aot_pair"] = "sigmoid",
loss_type: Literal["sigmoid", "hinge", "ipo", "bco_pair", "robust", "aot", "aot_pair"] = "sigmoid",
args: Optional[DPOConfig] = None,
data_collator: Optional[DataCollator] = None,
label_pad_token_id: int = -100,
Expand Down Expand Up @@ -463,10 +463,12 @@ def make_inputs_require_grad(module, input, output):
"You passed `label_smoothing` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.label_smoothing = label_smoothing
if args.loss_type in ["hinge", "ipo", "kto_pair", "bco_pair"] and args.label_smoothing > 0:
if args.loss_type in ["hinge", "ipo", "bco_pair"] and args.label_smoothing > 0:
warnings.warn(
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
)
if args.loss_type == "kto_pair":
raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.")

if beta != 0.1:
warnings.warn(
Expand Down Expand Up @@ -1024,21 +1026,6 @@ def dpo_loss(
elif self.loss_type == "ipo":
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
losses = (logits - 1 / (2 * self.beta)) ** 2
elif self.loss_type == "kto_pair":
# eqn (7) of the HALOs paper
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)

chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
# As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half.
losses = torch.cat(
(
1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
),
0,
)
elif self.loss_type == "bco_pair":
chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
Expand Down Expand Up @@ -1096,7 +1083,7 @@ def dpo_loss(

else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'bco_pair', 'sppo_hard', 'nca_pair', 'robust']"
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'bco_pair', 'sppo_hard', 'nca_pair', 'robust']"
)

chosen_rewards = (
Expand Down
50 changes: 25 additions & 25 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,24 +639,15 @@ def make_inputs_require_grad(module, input, output):
# merge the datasets
eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)

desirable = train_dataset.filter(
lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
)
undesirable = train_dataset.filter(
lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
)

if len(desirable) == 0:
raise ValueError("The set of desirable completions cannot be empty.")
elif len(undesirable) == 0:
raise ValueError("The set of undesirable completions cannot be empty.")
num_desirable = max(sum(train_dataset["label"]), 1)
kawine marked this conversation as resolved.
Show resolved Hide resolved
num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary

if len(desirable) != len(undesirable):
if num_desirable != num_undesirable:
# The lower and upper bounds come from Eq. (8) of https://arxiv.org/abs/2402.01306
des_weight_lower_bound = round((len(undesirable) * self.undesirable_weight / len(desirable)) * 1, 2)
des_weight_upper_bound = round((len(undesirable) * self.undesirable_weight / len(desirable)) * 1.33, 2)
und_weight_lower_bound = round((len(desirable) * self.desirable_weight / len(undesirable)) / 1.33, 2)
und_weight_upper_bound = round((len(desirable) * self.desirable_weight / len(undesirable)) / 1, 2)
des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)

des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
Expand All @@ -673,6 +664,13 @@ def make_inputs_require_grad(module, input, output):
)

if self.loss_type == "bco":
desirable = train_dataset.filter(
lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
)
undesirable = train_dataset.filter(
lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
)

desirable = desirable.shuffle(seed=args.data_seed)
undesirable = undesirable.shuffle(seed=args.data_seed)

Expand Down Expand Up @@ -727,18 +725,20 @@ def make_inputs_require_grad(module, input, output):
if self.loss_type == "bco":
self.running = RunningMoments(self.accelerator)

if self.embedding_func is None:
return
if self.embedding_func is None:
return

chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
rejected_embeddings = self._get_sample_prompt_embeddings(
undesirable, sample_size=self.args.prompt_sample_size
)

embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
labels = torch.cat(
(torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
)
embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
labels = torch.cat(
(torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
)

self.clf = LogisticRegression(class_weight="balanced").fit(embeddings.cpu().numpy(), labels.cpu().numpy())
self.clf = LogisticRegression(class_weight="balanced").fit(embeddings.cpu().numpy(), labels.cpu().numpy())

@property
def match_underlying_distribution(self):
Expand Down
Loading