diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 2a1ad215158b4..6553d3f42ee38 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -163,6 +163,14 @@ class PretrainedConfig(PushToHubMixin): typically for a classification task. - **task_specific_params** (:obj:`Dict[str, Any]`, `optional`) -- Additional keyword arguments to store for the current task. + - **problem_type** (:obj:`str`, `optional`) -- Problem type for :obj:`XxxForSequenceClassification` models. Can + be one of (:obj:`"regression"`, :obj:`"single_label_classification"`, :obj:`"multi_label_classification"`). + Please note that this parameter is only available in the following models: `AlbertForSequenceClassification`, + `BertForSequenceClassification`, `BigBirdForSequenceClassification`, `ConvBertForSequenceClassification`, + `DistilBertForSequenceClassification`, `ElectraForSequenceClassification`, `FunnelForSequenceClassification`, + `LongformerForSequenceClassification`, `MobileBertForSequenceClassification`, + `ReformerForSequenceClassification`, `RobertaForSequenceClassification`, + `SqueezeBertForSequenceClassification`, `XLMForSequenceClassification` and `XLNetForSequenceClassification`. Parameters linked to the tokenizer @@ -260,6 +268,15 @@ def __init__(self, **kwargs): # task specific arguments self.task_specific_params = kwargs.pop("task_specific_params", None) + # regression / multi-label classification + self.problem_type = kwargs.pop("problem_type", None) + allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification") + if self.problem_type is not None and self.problem_type not in allowed_problem_types: + raise ValueError( + f"The config parameter `problem_type` wasnot understood: received {self.problem_type}" + "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid." + ) + # TPU arguments if kwargs.pop("xla_device", None) is not None: logger.warning( diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index a753c580888f8..08bf9d82d0d56 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -21,7 +21,7 @@ import torch import torch.nn as nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import ( @@ -970,6 +970,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.albert = AlbertModel(config) self.dropout = nn.Dropout(config.classifier_dropout_prob) @@ -1024,13 +1025,23 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 34dd5329bebab..21a6eaab59526 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -25,7 +25,7 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import ( @@ -1381,7 +1381,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, - **kwargs + **kwargs, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): @@ -1463,6 +1463,7 @@ class BertForSequenceClassification(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -1517,14 +1518,23 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index a886e57d1f95c..45da61b991389 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -25,7 +25,7 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import ( @@ -2609,6 +2609,7 @@ class BigBirdForSequenceClassification(BigBirdPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.bert = BigBirdModel(config) self.classifier = BigBirdClassificationHead(config) @@ -2659,13 +2660,23 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index f597ff1789125..f5b23e46005ff 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -22,7 +22,7 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward @@ -962,6 +962,7 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.convbert = ConvBertModel(config) self.classifier = ConvBertClassificationHead(config) @@ -1012,13 +1013,23 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index ca4b42987b41a..b30b3db90738b 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -24,7 +24,7 @@ import numpy as np import torch import torch.nn as nn -from torch.nn import CrossEntropyLoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu from ...file_utils import ( @@ -579,6 +579,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.distilbert = DistilBertModel(config) self.pre_classifier = nn.Linear(config.dim, config.dim) @@ -631,12 +632,23 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - loss_fct = nn.MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: - loss_fct = nn.CrossEntropyLoss() + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + distilbert_output[1:] diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 006a22f4c7741..5229054ff7661 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -22,7 +22,7 @@ import torch import torch.nn as nn import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation from ...file_utils import ( @@ -903,6 +903,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.electra = ElectraModel(config) self.classifier = ElectraClassificationHead(config) @@ -953,13 +954,23 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + discriminator_hidden_states[1:] diff --git a/src/transformers/models/funnel/modeling_funnel.py b/src/transformers/models/funnel/modeling_funnel.py index 1db4ab87f3018..890a620ed4122 100644 --- a/src/transformers/models/funnel/modeling_funnel.py +++ b/src/transformers/models/funnel/modeling_funnel.py @@ -21,7 +21,7 @@ import numpy as np import torch from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F from ...activations import ACT2FN @@ -1240,6 +1240,7 @@ class FunnelForSequenceClassification(FunnelPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.funnel = FunnelBaseModel(config) self.classifier = FunnelClassificationHead(config, config.num_labels) @@ -1287,13 +1288,23 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index aea9f4a902501..d1ab71bb7ad72 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -21,7 +21,7 @@ import torch import torch.nn as nn import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F from ...activations import ACT2FN, gelu @@ -1803,6 +1803,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.longformer = LongformerModel(config, add_pooling_layer=False) self.classifier = LongformerClassificationHead(config) @@ -1861,13 +1862,23 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 53c721a306e48..8f50c6d6f0f90 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -29,7 +29,7 @@ import torch import torch.nn.functional as F from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import ( @@ -1214,6 +1214,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.mobilebert = MobileBertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -1268,14 +1269,23 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 0420a1c8ee4f2..4beca117a6855 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -26,7 +26,7 @@ import torch from torch import nn from torch.autograd.function import Function -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import ( @@ -366,7 +366,7 @@ def forward( past_buckets_states=None, use_cache=False, output_attentions=False, - **kwargs + **kwargs, ): sequence_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] @@ -1045,7 +1045,7 @@ def forward( past_buckets_states=None, use_cache=False, output_attentions=False, - **kwargs + **kwargs, ): sequence_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] @@ -2381,6 +2381,7 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.reformer = ReformerModel(config) self.classifier = ReformerClassificationHead(config) @@ -2434,13 +2435,23 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 274833d0507d0..cf535a719c8bd 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu from ...file_utils import ( @@ -1117,6 +1117,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.roberta = RobertaModel(config, add_pooling_layer=False) self.classifier = RobertaClassificationHead(config) @@ -1167,13 +1168,23 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/squeezebert/modeling_squeezebert.py b/src/transformers/models/squeezebert/modeling_squeezebert.py index ce7d18808dc38..462c8fb376261 100644 --- a/src/transformers/models/squeezebert/modeling_squeezebert.py +++ b/src/transformers/models/squeezebert/modeling_squeezebert.py @@ -19,7 +19,7 @@ import torch from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward @@ -733,6 +733,7 @@ class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.transformer = SqueezeBertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -787,13 +788,23 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index 9ce0e5558ec0c..8dc0d208d1609 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -24,7 +24,7 @@ import numpy as np import torch from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F from ...activations import gelu @@ -779,6 +779,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.transformer = XLMModel(config) self.sequence_summary = SequenceSummary(config) @@ -836,13 +837,23 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + transformer_outputs[1:] diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index 7a6a51d456ca4..fa562c5f34499 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -22,7 +22,7 @@ import torch from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F from ...activations import ACT2FN @@ -1488,6 +1488,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.transformer = XLNetModel(config) self.sequence_summary = SequenceSummary(config) @@ -1551,13 +1552,23 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + transformer_outputs[1:] diff --git a/tests/test_modeling_albert.py b/tests/test_modeling_albert.py index 7f82c67ba088a..81c5c48ccf127 100644 --- a/tests/test_modeling_albert.py +++ b/tests/test_modeling_albert.py @@ -230,6 +230,8 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): else () ) + test_sequence_classification_problem_types = True + # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index 97da4350ab7c2..acd921ce8a8dd 100755 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -439,6 +439,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () + test_sequence_classification_problem_types = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_big_bird.py b/tests/test_modeling_big_bird.py index edef01f207a51..ba7d12fe2d336 100644 --- a/tests/test_modeling_big_bird.py +++ b/tests/test_modeling_big_bird.py @@ -433,6 +433,7 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase): # head masking & pruning is currently not supported for big bird test_head_masking = False test_pruning = False + test_sequence_classification_problem_types = True # torchscript should be possible, but takes prohibitively long to test. # Also torchscript is not an important feature to have in the beginning. diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d193a9e7a4786..f83d65b51a7d3 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -89,6 +89,7 @@ class ModelTesterMixin: test_missing_keys = True test_model_parallel = False is_encoder_decoder = False + test_sequence_classification_problem_types = False def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): inputs_dict = copy.deepcopy(inputs_dict) @@ -1238,6 +1239,42 @@ def cast_to_device(dictionary, device): model.parallelize() model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2) + def test_problem_types(self): + if not self.test_sequence_classification_problem_types: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + problem_types = [ + {"title": "multi_label_classification", "num_labels": 2, "dtype": torch.float}, + {"title": "single_label_classification", "num_labels": 1, "dtype": torch.long}, + {"title": "regression", "num_labels": 1, "dtype": torch.float}, + ] + + for model_class in self.all_model_classes: + if model_class not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): + continue + + for problem_type in problem_types: + with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"): + + config.problem_type = problem_type["title"] + config.num_labels = problem_type["num_labels"] + + model = model_class(config) + model.to(torch_device) + model.train() + + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + + if problem_type["num_labels"] > 1: + inputs["labels"] = inputs["labels"].unsqueeze(1).repeat(1, problem_type["num_labels"]) + + inputs["labels"] = inputs["labels"].to(problem_type["dtype"]) + + loss = model(**inputs).loss + loss.backward() + global_rng = random.Random() diff --git a/tests/test_modeling_convbert.py b/tests/test_modeling_convbert.py index 062a7f506a996..ebe7188755133 100644 --- a/tests/test_modeling_convbert.py +++ b/tests/test_modeling_convbert.py @@ -260,6 +260,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase): ) test_pruning = False test_head_masking = False + test_sequence_classification_problem_types = True def setUp(self): self.model_tester = ConvBertModelTester(self) diff --git a/tests/test_modeling_distilbert.py b/tests/test_modeling_distilbert.py index d6c3dc54b8d47..0c5c4bcf68c00 100644 --- a/tests/test_modeling_distilbert.py +++ b/tests/test_modeling_distilbert.py @@ -211,6 +211,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): test_pruning = True test_torchscript = True test_resize_embeddings = True + test_sequence_classification_problem_types = True def setUp(self): self.model_tester = DistilBertModelTester(self) diff --git a/tests/test_modeling_electra.py b/tests/test_modeling_electra.py index 5935eafee668c..366d8f0f9079f 100644 --- a/tests/test_modeling_electra.py +++ b/tests/test_modeling_electra.py @@ -287,6 +287,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + test_sequence_classification_problem_types = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_funnel.py b/tests/test_modeling_funnel.py index 2d59e9f4e4100..9be00caeb734f 100644 --- a/tests/test_modeling_funnel.py +++ b/tests/test_modeling_funnel.py @@ -360,6 +360,7 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + test_sequence_classification_problem_types = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index 96333fced1149..c5d5eee162661 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -274,6 +274,7 @@ def prepare_config_and_inputs_for_question_answering(self): class LongformerModelTest(ModelTesterMixin, unittest.TestCase): test_pruning = False # pruning is not supported test_torchscript = False + test_sequence_classification_problem_types = True all_model_classes = ( ( diff --git a/tests/test_modeling_mobilebert.py b/tests/test_modeling_mobilebert.py index 96c974e2edc53..ce5854d16a59c 100644 --- a/tests/test_modeling_mobilebert.py +++ b/tests/test_modeling_mobilebert.py @@ -267,6 +267,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + test_sequence_classification_problem_types = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 817d35c5b9156..05db9599c5173 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -590,6 +590,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod test_pruning = False test_headmasking = False test_torchscript = False + test_sequence_classification_problem_types = True def prepare_kwargs(self): return { diff --git a/tests/test_modeling_roberta.py b/tests/test_modeling_roberta.py index be675eda6d49d..a6acdfe7b9367 100644 --- a/tests/test_modeling_roberta.py +++ b/tests/test_modeling_roberta.py @@ -351,6 +351,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas else () ) all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else () + test_sequence_classification_problem_types = True def setUp(self): self.model_tester = RobertaModelTester(self) diff --git a/tests/test_modeling_squeezebert.py b/tests/test_modeling_squeezebert.py index 493326157875c..8f9d65fa9ac2e 100644 --- a/tests/test_modeling_squeezebert.py +++ b/tests/test_modeling_squeezebert.py @@ -231,6 +231,7 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase): test_torchscript = True test_resize_embeddings = True test_head_masking = False + test_sequence_classification_problem_types = True def setUp(self): self.model_tester = SqueezeBertModelTester(self) diff --git a/tests/test_modeling_xlm.py b/tests/test_modeling_xlm.py index 69f76b88c981c..691a4039ea93c 100644 --- a/tests/test_modeling_xlm.py +++ b/tests/test_modeling_xlm.py @@ -349,6 +349,7 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_generative_model_classes = ( (XLMWithLMHeadModel,) if is_torch_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable + test_sequence_classification_problem_types = True # XLM has 2 QA models -> need to manually set the correct labels for one of them here def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index 1423ef6980f2e..93031d03719fa 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) (XLNetLMHeadModel,) if is_torch_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable test_pruning = False + test_sequence_classification_problem_types = True # XLNet has 2 QA models -> need to manually set the correct labels for one of them here def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):