Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inputs vector to calculate metric method #16461

Merged
merged 6 commits into from Apr 7, 2022

Conversation

lmvasque
Copy link
Contributor

@lmvasque lmvasque commented Mar 28, 2022

What does this PR do?

This is a PR suggestion for including the inputs in the EvalPrediction object to perform metrics calculation that depends on inputs. For example, simplification metrics such as SARI not only use the predictions and references but also the inputs for the score calculation.

The proposed implementation will enable the Trainer to work with the metrics class. However, the compute_metrics method should be implemented locally (in the metrics file, for example), since the original method still receives predictions and references.

Supports #15966

Who can review?

@sgugger

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 28, 2022

The documentation is not available anymore as the PR was closed or merged.

@sgugger
Copy link
Collaborator

sgugger commented Mar 28, 2022

Unfortunately, we can't change the EvalPredictions namedtuple structure like this as it would be a massive breaking change. Users would suddenly have to unpack it with three objects instead of two, so every compute_metrics function out there in the wild would suddenly fail.

@sgugger
Copy link
Collaborator

sgugger commented Mar 31, 2022

Ok, spent a bit of time on this and to enable an EvalPredictions that work with/without inputs, the class should be replaced by the following one:

class EvalPrediction:
    """
    Evaluation output (always contains labels), to be used to compute metrics.

    Parameters:
        predictions (`np.ndarray`): Predictions of the model.
        label_ids (`np.ndarray`): Targets to be matched.
        inputs (`np.ndarray`, *optional*): Inputs of the model.
    """
    def __init__(
        self,
        predictions: Union[np.ndarray, Tuple[np.ndarray]],
        label_ids: Union[np.ndarray, Tuple[np.ndarray]],
        inputs: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None,
    ):
        self.predictions = predictions
        self.label_ids = label_ids
        self.inputs = inputs
    
    def __iter__(self):
        if self.inputs is not None:
            return iter((self.predictions, self.label_ids, self.inputs))
        else:
            return iter((self.predictions, self.label_ids))
    
    def __getitem__(self, idx):
        if idx < 0 or idx > 2:
            raise IndexError("tuple index out of range")
        if idx == 2 and self.inputs is None:
            raise IndexError("tuple index out of range")
        if idx == 0:
            return self.predictions
        elif idx ==1:
            return self.label_ids
        elif idx == 2:
            return self.inputs

Then we should add a flag so the user can choose whether or not they want the inputs included for metrics or not (default False so that there is no change). With those two things, we can enable your use case while maintaining backward compatibility.

Can you include them in your PR?

@lmvasque
Copy link
Contributor Author

lmvasque commented Apr 1, 2022

Hi @sgugger ,

Thanks for adding this change. I've tested this code with the summarization examples from pytorch and I'm afraid this change will break other people's code. The problem is that since the inputs are passed in the trainer by default (based on my PR):

trainer.py

self.compute_metrics(EvalPrediction(inputs=all_inputs, predictions=all_preds, label_ids=all_labels))

The inputs in EvalPrediction will never be None, and then, it will fail when returning the predictions in compute_metrics:

run_summarization.py

def compute_metrics(eval_preds):
    preds, labels = eval_preds
  File "./transformers/examples/pytorch/summarization/run_summarization.py", line 575, in compute_metrics
    preds, labels = eval_preds
ValueError: too many values to unpack (expected 2)

You can see in the debugger that all 3 (inputs, preds and labels) are coming:
image

In my case, I have my own implementation for compute_metrics, however, for other users it will fail by default. When you suggest the flag, where is it located? In the class Trainer and then an if before each call of EvalPrediction?

This is my testing case:

./transformers/examples/pytorch/summarization/run_summarization.py
--model_name_or_path
t5-small
--do_train
--do_eval
--train_file test.json
--validation_file test.json
--source_prefix
"summarize: "
--output_dir
/tmp/tst-summarization
--overwrite_output_dir
--per_device_train_batch_size=4
--per_device_eval_batch_size=4
--predict_with_generate

test.json

{"text": "I'm sitting here in a boring room. It's just another rainy Sunday afternoon. I'm wasting my time I got nothing to do. I'm hanging around I'm waiting for you. But nothing ever happens. And I wonder", "summary": "I'm sitting in a room where I'm waiting for something to happen"}
{"text": "I see trees so green, red roses too. I see them bloom for me and you. And I think to myself what a wonderful world. I see skies so blue and clouds so white. The bright blessed day, the dark sacred night. And I think to myself what a wonderful world.", "summary": "I'm a gardener and I'm a big fan of flowers."}
{"text": "Christmas time is here. Happiness and cheer. Fun for all that children call. Their favorite time of the year. Snowflakes in the air. Carols everywhere. Olden times and ancient rhymes. Of love and dreams to share", "summary": "It's that time of year again."}

Thanks,

Laura

@sgugger
Copy link
Collaborator

sgugger commented Apr 1, 2022

The problem is that since the inputs are passed in the trainer by default (based on my PR):

As I said above, this needs to be controlled by a flag (for instance include_inputs_for_metrics) in TrainingArguments which would be False by default to avoid any breaking change.

@lmvasque
Copy link
Contributor Author

lmvasque commented Apr 1, 2022

Thanks for the clarification, that sounds better. I've submitted an additional commit in the PR. Is the first time I update a PR, let me know if I've missed something.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the flag. It should be in the TrainingArguments class however, not the Trainer class itself. Also, the inputs should only be gathered when the flag is set to True, otherwise users that do not need it might get OOM errors.

@@ -257,6 +257,9 @@ class Trainer:
A function that preprocess the logits right before caching them at each evaluation step. Must take two
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
by this function will be reflected in the predictions received by `compute_metrics`.
include_inputs_for_metrics: bool = None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow the same format as for other arguments :-)

Copy link
Contributor Author

@lmvasque lmvasque Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps like this?
include_inputs_for_metrics (bool, optional):

@@ -257,6 +257,9 @@ class Trainer:
A function that preprocess the logits right before caching them at each evaluation step. Must take two
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
by this function will be reflected in the predictions received by `compute_metrics`.
include_inputs_for_metrics: bool = None:
A flag (True|False) determining if inputs will be passed to the EvalPrediction class. This is intended for
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
A flag (True|False) determining if inputs will be passed to the EvalPrediction class. This is intended for
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for

@@ -293,6 +296,8 @@ def __init__(
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
include_inputs_for_metrics: bool = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This flag should be added in the TrainingArguments class, not here.

src/transformers/trainer.py Outdated Show resolved Hide resolved
@lmvasque
Copy link
Contributor Author

lmvasque commented Apr 4, 2022

src/transformers/trainer.py

I thought about the OOM issue, but should I add:

if self.include_inputs_for_metrics:

Everywhere in the code? I can't think of a more elegant way :)

And for the new changes, is it ok just another commit in the pull request just like I did the last time? Just checking, so I don't do a mess :P

@@ -2452,6 +2465,7 @@ def evaluation_loop(

# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
inputs_decode = inputs['input_ids']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set None here if the flag is False

@sgugger
Copy link
Collaborator

sgugger commented Apr 4, 2022

should I add if self.include_inputs_for_metrics: Everywhere in the code? I can't think of a more elegant way :)

You should make sure that the inputs are left as None like the labels/losses by only setting some inputs when the flag is False. I've left a comment to show you where.

And for the new changes, is it ok just another commit in the pull request just like I did the last time? Just checking, so I don't do a mess :P

That's completely ok, we will squash everything when merging.

@lmvasque
Copy link
Contributor Author

lmvasque commented Apr 4, 2022

Changes done, let me know any further feedback :)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left some last comments and we should be good to merge.
Also remember to run make style on your branch once you're done to apply the code-formatting tools :-)

Comment on lines 309 to 312
if args.include_inputs_for_metrics:
args.include_inputs_for_metrics = True
else:
args.include_inputs_for_metrics = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TrainingArguments should not be touched by the Trainer, I don't think this is necessary in any case.

Suggested change
if args.include_inputs_for_metrics:
args.include_inputs_for_metrics = True
else:
args.include_inputs_for_metrics = False

Comment on lines 295 to 296
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need the comma back and no new line here.

Comment on lines 2551 to 2553
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(EvalPrediction(inputs=all_inputs, predictions=all_preds,
label_ids=all_labels))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(EvalPrediction(inputs=all_inputs, predictions=all_preds,
label_ids=all_labels))
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds,
label_ids=all_labels, inputs=all_inputs))

Needs to be last I think.

@@ -417,6 +417,9 @@ class TrainingArguments:
`huggingface-cli login`.
gradient_checkpointing (`bool`, *optional*, defaults to `False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
include_inputs_for_metrics (bool, optional):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
include_inputs_for_metrics (bool, optional):
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):

@lmvasque
Copy link
Contributor Author

lmvasque commented Apr 5, 2022

I've added the requested changes. Also, I've run make fixup, so there are also some automatic indentation changes as well.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for all your work on this! It's all good to me apart from the changes at the start of the trainer file.

Comment on lines 33 to 47
import numpy as np
import torch
from huggingface_hub import Repository
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm

from . import __version__
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enabled
from .dependency_versions_check import dep_version_check
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be changed by make style. Are you sure you have the proper version of the formatting tools installed (pip install .[quality] in the repo)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm sorry. I tried to install manually some of the dependencies, but it was still failing some parts. I've done the changes to bring it back to its original state, adding your also the changes from your empty lines.

@@ -181,13 +178,11 @@

from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to remove this line.

if TYPE_CHECKING:
import optuna

logger = logging.get_logger(__name__)


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

@@ -2387,7 +2382,6 @@ def evaluation_loop(

# if eval is called w/o train init deepspeed here
if args.deepspeed and not self.deepspeed:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And same here

@sgugger
Copy link
Collaborator

sgugger commented Apr 7, 2022

Thanks! There is one last issue with a docstring badly formatted. Could you run make style on your branch?

@lmvasque
Copy link
Contributor Author

lmvasque commented Apr 7, 2022

Sure, no problem. I'm having issues with the dependencies to run make style, could you please let me know what's the full command for installing all the suite?
pip install .[quality]

This is my log:

% make style
black examples tests src utils
Skipping .ipynb files as Jupyter dependencies are not installed.
You can fix this by running ``pip install black[jupyter]``
All done! ✨ 🍰 ✨
1519 files left unchanged.
isort examples tests src utils
WARNING: Unable to parse file examples due to [Errno 21] Is a directory: './transformers/examples'
WARNING: Unable to parse file tests due to [Errno 21] Is a directory: './transformers/tests'
WARNING: Unable to parse file src due to [Errno 21] Is a directory: './transformers/src'
WARNING: Unable to parse file utils due to [Errno 21] Is a directory: './transformers/utils'
/Library/Developer/CommandLineTools/usr/bin/make autogenerate_code
running deps_table_update
updating src/transformers/dependency_versions_table.py
/Library/Developer/CommandLineTools/usr/bin/make extra_style_checks
python utils/custom_init_isort.py
doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source
make[1]: doc-builder: No such file or directory
make[1]: *** [extra_style_checks] Error 1
make: *** [style] Error 2

@sgugger
Copy link
Collaborator

sgugger commented Apr 7, 2022

It looks like your branch is not up to par with the main branch for the setup. Can you manually install pip install hf-doc-builder? It should be in the quality extras but maybe you don't have it because it was added recently-ish.

@lmvasque
Copy link
Contributor Author

lmvasque commented Apr 7, 2022

Now is working :) Thanks for that one. I've committed the updated file.

% make style
black examples tests src utils
Skipping .ipynb files as Jupyter dependencies are not installed.
You can fix this by running ``pip install black[jupyter]``
All done! ✨ 🍰 ✨
1519 files left unchanged.
isort examples tests src utils
/Library/Developer/CommandLineTools/usr/bin/make autogenerate_code
running deps_table_update
updating src/transformers/dependency_versions_table.py
/Library/Developer/CommandLineTools/usr/bin/make extra_style_checks
python utils/custom_init_isort.py
doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source
Overwriting content of src/transformers/training_args.py.
Cleaned 1 files

@sgugger
Copy link
Collaborator

sgugger commented Apr 7, 2022

Yes, it's all good now. Thanks again for your contribution!

@sgugger sgugger merged commit 09a272b into huggingface:main Apr 7, 2022
@lmvasque
Copy link
Contributor Author

lmvasque commented Apr 7, 2022

Awesome! :) Thanks for leading this effort, good job :D

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants