Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix problem_type to match with the applied loss function for distillbert sequence classification #12015

Closed
wants to merge 2 commits into from

Conversation

sidhantls
Copy link

@sidhantls sidhantls commented Jun 3, 2021

What does this PR do?

The problem_type in config is not correct with the loss function applied. Can be seen here.

This PR fixes this so that the applied loss is consistent with the problem type: BCEWithLogitsLoss is applied for the problem_type of single_label_classification, and CrossEntropyLoss is applied for the problem type multi_label_classification

Fixes #12014

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@LysandreJik

@LysandreJik
Copy link
Member

cc @abhi1thakur @sgugger

@sgugger
Copy link
Collaborator

sgugger commented Jun 7, 2021

No, this is incorrect. What single_label_classification means each sample can only have one label (but there could be multiple classes) so the loss to use is cross entropy. multi_label_classification means each sample can have zero or several labels, so in this case we use bce (because there can't be a softmax).

@sgugger sgugger closed this Jun 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

MIssmatch between problem_type and loss functions in DistillBert for sequence classification
3 participants