Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed the MultilabelTrainer document, which would cause a potential bug when executing the code originally documented. #13414

Merged
merged 1 commit into from
Sep 8, 2021

Conversation

Mohan-Zhang-u
Copy link
Contributor

@Mohan-Zhang-u Mohan-Zhang-u commented Sep 4, 2021

if train with the MultilabelTrainer documented in the original document

from torch import nn
from transformers import Trainer

class MultilabelTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = nn.BCEWithLogitsLoss()
        loss = loss_fct(logits.view(-1, self.model.config.num_labels),
                        labels.float().view(-1, self.model.config.num_labels))
        return (loss, outputs) if return_outputs else loss

a bug that looks like this would appear:
File "~/anaconda3/lib/python3.8/site-packages/transformers/trainer.py", in train self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) File "~/anaconda3/lib/python3.8/site-packages/transformers/trainer.py", in _maybe_log_save_evaluate metrics = self.evaluate() File "~/anaconda3/lib/python3.8/site-packages/transformers/trainer.py", in evaluate output = self.prediction_loop( File "~/anaconda3/lib/python3.8/site-packages/transformers/trainer.py", in prediction_loop loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) File "~/anaconda3/lib/python3.8/site-packages/transformers/trainer.py", in prediction_step labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) File "~/anaconda3/lib/python3.8/site-packages/transformers/trainer_pt_utils.py", in nested_detach return type(tensors)(nested_detach(t) for t in tensors) File "~/anaconda3/lib/python3.8/site-packages/transformers/trainer_pt_utils.py", in <genexpr> return type(tensors)(nested_detach(t) for t in tensors) File "~/anaconda3/lib/python3.8/site-packages/transformers/trainer_pt_utils.py", in nested_detach return tensors.detach() AttributeError: 'NoneType' object has no attribute 'detach'

change the original code to below would effectively avoid this bug.

from torch import nn
from transformers import Trainer

class MultilabelTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get('logits')
        loss_fct = nn.BCEWithLogitsLoss()
        loss = loss_fct(logits.view(-1, self.model.config.num_labels),
                        labels.float().view(-1, self.model.config.num_labels))
        return (loss, outputs) if return_outputs else loss

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@sgugger sgugger merged commit 41cd52a into huggingface:master Sep 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants