Skip to content

Request of rewriting implementation of prediction_step in trainer.py #42200

@Yacklin

Description

@Yacklin

System Info

Any system. Because it's a problem coming from source code.

Who can help?

@SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hi, i am talking about an issue that was reported 5 years ago but still exists in 2025, specifically, 13th Nov, 2025.

I quote one of the issues that was discussed before, ignored by sgugger. Please find the link below
https://discuss.huggingface.co/t/cuda-out-of-memory-when-using-trainer-with-compute-metrics/2941

When i was about to fine tune a LLM today, i ran into the same issue but i got saved by one folk's solution provided in this discussion.

How to reproduce (you should have a GPU, no quantization, just full fine tuning):

  1. Find a random decoder-only text2text LLM, let's say Qwen3 0.6B.

  2. Prepare a train dataset (>0 rows) and eval dataset (>850 rows).

  3. Set eval_on_start = True, either TrainingArguments or SFTConfig could work.

  4. Implement your own compute_metrics BUT DON'T implement preprocess_logits_for_metrics.

  5. start training (don't need deepspeed or accelerate, just trainer.train())

What would happen?
First it would go through the evaluation dataset because i set eval_on_start=True, the model would go really fast originally but then it would go extremely slow. Finally, you would get an error that says numpy is trying to allocate a ridiculously big array to memory.

Image

One of the folk who seems to be inspired by example code provided the implementation of preprocess_logits_for_metrics, which solved problem i encountered perfectly. The evaluation run is done within 2 mins.

Why it would happen?

I briefly go over the source code of evaluation_loop and i located prediction_step.

prediction_step says it would return a tuple of three optional torch.Tensor (loss, logits, label).

Image

But most of the time, the returned logits is a tuple.

Why?

if you look at the the function that processes logits before logits is returned:

Image

This function would receive all kinds of "tensors". The type of "tensors" could be list, tuple, Mapping or torch.Tensor.

Does it change the variable, called "tensors", from other data types to torch.Tensor?

No.

type(tensors)(........) would preserve the original type of tensors. It means if the variable "tensors" (i hate this variable name because it is misleading and confusing) is a tuple, after this function, it's still a tuple!!!!!

It's a recursive function btw. I would love doing recursion in programming competition, but not in huggingface codebase!!! It also implies a fact that the input of nested_detach could be complexly nested, like ([],())

So this function doesn't guarantee the logits is a torch.Tensor.

Nor does the implementation of prediction_step before nested_detach was called in prediction_step

Image

So, the logits is not always a torch.Tensor, which is contradictory to what the type hint says, what did developers do?

They developed preprocess_logits_for_metrics.
So that user could fix it ON THEIR OWN IMPLEMENTATION.

(preprocess_logits_for_metrics is called within evaluation_loop to clean the mess, specifically, logits, returned by prediction_step())
Image

It's such a lazy fix. Why a regular user is expected to implement their own preprocess_logits_f
or_metrics, to deal with a poorly-designed prediction_step?

It has been 5 years since the person who reported it.........

If a user-defined compute_metrics is not provided to Trainer or SFTTrainer, the prediction_step would return (loss, none, none), which skips the whole problem and this is why users said the issue of "slow evaluation" is gone when they don't provide compute_metrics.

I would like to make a Pull Request to fix it but i don't have enough time and energy to do this massive amount of work.

A temporary fix is to let users know when they need to make their own compute_metrics, they also have to implement preprocess_logits_for_metrics. Different models would have different styles of implementations but for text2text decoder only LLM.

Image

(Another thing is that the variable called "labels" in all the implementations of preprocess_logits_for_metrics i have ever seen so far, is ignored. what is the meaning of "labels" here?)

The folk who provided solution to help other users in the dicussion (i attached earlier in this post) said there might be a memory leak in Trainer that cause the extremely slow evaluation run. The implementation of preprocess_logits_for_metrics might just hide the actual problem further, rather than solving it.

Expected behavior

the expected behavior of prediction_step is it would actually return a tuple of three optional torch.Tensor, as implied by its type hint.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions