Skip to content

Resolve broken evaluation/prediction for RewardTrainer#404

Merged
younesbelkada merged 5 commits intohuggingface:mainfrom
tomaarsen:fix/reward_trainer_eval_predict
Jun 6, 2023
Merged

Resolve broken evaluation/prediction for RewardTrainer#404
younesbelkada merged 5 commits intohuggingface:mainfrom
tomaarsen:fix/reward_trainer_eval_predict

Conversation

@tomaarsen
Copy link
Copy Markdown
Member

Hello!

Pull Request overview

  • Implement evaluation & prediction for RewardTrainer.
    • Update compute_average correspondingly.
  • Adapt tests to perform both evaluation and prediction.
  • Improve typing for the RewardTrainer.

Motivation

In our attempts to apply the RewardTrainer, we experienced issues with the evaluation. See below our training script for reference:

Training script
from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
)
from trl import RewardTrainer

dataset = load_dataset("argilla/dolly-curated-comparison-falcon-7b-instruct", split="train")
model_name = "distilroberta-base"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1)
tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id

def formatting_func(examples):
    kwargs = {"padding": "max_length", "truncation": True, "max_length": 512, "return_tensors": "pt"}

    # Prepend the prompt and a line break to the original_response and response-1 fields.
    prompt_plus_chosen_response = examples["prompt"] + "\n" + examples["original_response"]
    prompt_plus_rejected_response = examples["prompt"] + "\n" + examples["response-1"]

    # Then tokenize these modified fields.
    tokens_chosen = tokenizer.encode_plus(prompt_plus_chosen_response, **kwargs)
    tokens_rejected = tokenizer.encode_plus(prompt_plus_rejected_response, **kwargs)

    return {
        "input_ids_chosen": tokens_chosen["input_ids"][0], "attention_mask_chosen": tokens_chosen["attention_mask"][0],
        "input_ids_rejected": tokens_rejected["input_ids"][0], "attention_mask_rejected": tokens_rejected["attention_mask"][0]
    }
    
formatted_dataset = dataset.map(formatting_func)

formatted_dataset = formatted_dataset.train_test_split()

training_args = TrainingArguments(
    output_dir="./my_model",
    per_device_train_batch_size=16,
    evaluation_strategy="steps",
    logging_steps=100,  
)

trainer = RewardTrainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=formatted_dataset["train"],
    eval_dataset=formatted_dataset["test"],
)

trainer.train()

When training, we experience the following error:

Traceback (most recent call last):
  File "[sic]\trl\demo2.py", line 53, in <module>
    trainer.train()
  File "[sic]\trl\lib\site-packages\transformers\trainer.py", line 1664, in train
    return inner_training_loop(
  File "[sic]\trl\lib\site-packages\transformers\trainer.py", line 2019, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
  File "[sic]\trl\lib\site-packages\transformers\trainer.py", line 2300, in _maybe_log_save_evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "[sic]\trl\lib\site-packages\transformers\trainer.py", line 3029, in evaluate
    output = eval_loop(
  File "[sic]\trl\lib\site-packages\transformers\trainer.py", line 3210, in evaluation_loop
    loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
  File "[sic]\trl\lib\site-packages\transformers\trainer.py", line 3476, in prediction_step
    outputs = model(**inputs)
  File "[sic]\trl\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: forward() got an unexpected keyword argument 'input_ids_chosen

The cause of this bug is pretty straight-forward: the unchanged transformers Trainer evaluation/prediction loop is used, while we need a special loop that makes forward passes for both the accepted and the rejected input_ids/attention_masks.

The fix

The fix is as simple as overriding prediction_step such that the primary changes relative to the original prediction_step are:

  1. We use compute_loss for the forward calls instead of model(**inputs).
  2. Get the mean of the logits for each of the samples & apply a softmax so the reward for the accepted plus the reward for the rejected sum to 1.
  3. We fix the labels to be a vector of 0's: we always prefer the accepted samples (i.e. index 0).

Because of this last change, we can simplify the compute_average to actually use the labels.
Lastly, I updated the typing on compute_loss. I recognize that the contributor guidelines specify to make those changes separately - I apologize for that.

The consequences

With these changes in place, we can run the above script just fine:
image
(Although in practice, we just call the forward method of our trained model using individual samples to get the logits.)

Additionally, we can run trainer.predict() and get a useful response. See this dummy example using an untrained model.

PredictionOutput(predictions=array([[0.47846982, 0.52153015],
       [0.5010128 , 0.4989872 ],
       [0.47846982, 0.52153015],
       [0.44097137, 0.5590286 ]], dtype=float32), label_ids=array([0., 0., 0., 0.], dtype=float32), metrics={'test_loss': 0.7460569143295288, 'test_accuracy': 0.25, 'test_runtime': 0.0831, 'test_samples_per_second': 48.116, 'test_steps_per_second': 12.029})

As you can see, the test accuracy is 25%, because for only one of the predictions is the reward for the first sample better.

The tests

The tests are updated to add evaluation_strategy="steps", otherwise the evaluation didn't trigger. Furthermore, I added trainer.predict calls that verify that the output is indeed of shape (4, 2) when there are 4 samples.

Alternatively, I can update the predictions to only give the "probability" of the accepted sample, or to only give the argmax (e.g. [0, 1, 0, 0]) as the prediction.

cc: @younesbelkada @lewtun @lvwerra @dvsrepo

I'm open to feedback and suggestions as always.

  • Tom Aarsen

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Jun 5, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much @tomaarsen for your great work and describing precisely the issue and for the fix!
This looks great, I just left one tiny comment about the type hints for python<=3.7

Comment thread trl/trainer/reward_trainer.py Outdated
# limitations under the License.
import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the CI logs it seems that Literal is not available in typing for python<=3.7, can we change it with something else? 🙏

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're lucky that it's not the 27th of June yet, or I would have refuse 😉
I'll get on it!

I recognize that I can also import from typing_extensions with a try-except,
but that is a bit overkill for this I feel.
Copy link
Copy Markdown
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing, looks good to me! I let @younesbelkada have the final word here :)

Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for iterating @tomaarsen !
The CI somehow fails with eval_steps=1 with a very strange error that I didn't managed to reproduce locally. I think that for some reason the test is flaky (I reran 3 times the CI and they fail on different scenarios), we should be fine removing eval_steps=1 as you added a test later on to test whether the predict_step is called correctly. What do you think?

@dvsrepo
Copy link
Copy Markdown

dvsrepo commented Jun 6, 2023

Great stuff, we've found out this working on an integration example with the new Argilla Feedback feature. @younesbelkada , @lvwerra, do you have a timeline for releasing the RewardTrainer We'd love to align to officially publish the tutorial.

@lvwerra
Copy link
Copy Markdown
Member

lvwerra commented Jun 6, 2023

We discussed today that we want to merge this PR and then do a release!

@dvsrepo
Copy link
Copy Markdown

dvsrepo commented Jun 6, 2023

Cool @lvwerra! Then I'll rush to get the tutorial finished 😄

@tomaarsen
Copy link
Copy Markdown
Member Author

Thanks a lot for iterating @tomaarsen ! The CI somehow fails with eval_steps=1 with a very strange error that I didn't managed to reproduce locally. I think that for some reason the test is flaky (I reran 3 times the CI and they fail on different scenarios), we should be fine removing eval_steps=1 as you added a test later on to test whether the predict_step is called correctly. What do you think?

The issue originates in the speed_metrics function in transformers, so there's not much to "fix" on the TRL side:

        if num_samples is not None:
>           samples_per_second = num_samples / runtime
E           ZeroDivisionError: float division by zero

Because the evaluation dataset is so small, the runtime is sometimes 0, which is a very understandable oversight on the side of the transformers maintainers. There's likely not an amazing solution for this, but I'll remove the eval_steps as it should resolve the issue (although perhaps it will also pop up on the predict call).

The flaky test is caused by a division by zero when dividing by the runtime.
This is done on the transformers side, so it's not a TRL issue.
In practice, this won't happen - it only happens because both the model
and dataset are tiny.
Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your inspiring work!

@younesbelkada younesbelkada merged commit 376d152 into huggingface:main Jun 6, 2023
@tomaarsen tomaarsen deleted the fix/reward_trainer_eval_predict branch June 6, 2023 14:59
@younesbelkada
Copy link
Copy Markdown
Contributor

As a side note, huggingface/transformers#24049 will be merged, we can revert the commit 96311cd on the next transformers release, will take care of that

@tomaarsen
Copy link
Copy Markdown
Member Author

Perfect. That should make the testing suite a bit more complete

yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
* Implement evaluation/prediction for RewardTrainer

* Stick with unittest assertions

* Perform prediction forward calls without gradient

* Remove Literal to preserve Python 3.7 support

I recognize that I can also import from typing_extensions with a try-except,
but that is a bit overkill for this I feel.

* Remove eval_steps=1 to prevent flaky test on CI

The flaky test is caused by a division by zero when dividing by the runtime.
This is done on the transformers side, so it's not a TRL issue.
In practice, this won't happen - it only happens because both the model
and dataset are tiny.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants