From b06c721fa2aa301bb144036ed28ba730493b02ee Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Thu, 1 Apr 2021 11:05:48 +0200 Subject: [PATCH 01/20] add to bert --- src/transformers/models/bert/modeling_bert.py | 33 ++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 370af8b47f472a..e66de430a0242c 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -25,7 +25,7 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import ( @@ -1381,7 +1381,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, - **kwargs + **kwargs, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): @@ -1463,6 +1463,7 @@ class BertForSequenceClassification(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.problem_type = config.problem_type self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -1517,13 +1518,29 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression - loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) + if self.problem_type is not None: + if self.problem_type == "single_column_regression": + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + elif self.problem_type == "multi_column_regression": + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.problem_type in ("binary_classification", "multi_class_classification"): + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.problem_type in ("multi_label_classification"): + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise Exception("Problem type not understood") else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] From 9fa64860f39b52bcfeeb52ad313107baacd541b8 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Fri, 2 Apr 2021 13:12:52 +0200 Subject: [PATCH 02/20] review comments --- src/transformers/configuration_utils.py | 12 ++++++ src/transformers/models/bert/modeling_bert.py | 41 +++++++++---------- 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 621f855a126f44..b9aab848f4be0d 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -152,6 +152,8 @@ class PretrainedConfig(object): typically for a classification task. - **task_specific_params** (:obj:`Dict[str, Any]`, `optional`) -- Additional keyword arguments to store for the current task. + - **problem_type** (:obj:`str`, `optional`) -- Problem type for ForSequenceClassification tasks. It can be one + of (None, "regression", "single_label_classification", "multi_label_classification"). Default is None. Parameters linked to the tokenizer @@ -249,6 +251,16 @@ def __init__(self, **kwargs): # task specific arguments self.task_specific_params = kwargs.pop("task_specific_params", None) + # regression / multi-label classification + self.problem_type = kwargs.pop("problem_type", None) + allowed_problem_types = (None, "regression", "single_label_classification", "multi_label_classification") + if self.problem_type not in allowed_problem_types: + raise ValueError( + f"""The config parameter `problem_type` not understood: + received {self.problem_type} but only [regression, single_label_classification + and multi_label_classification] are valid.""" + ) + # TPU arguments if kwargs.pop("xla_device", None) is not None: logger.warn( diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index e66de430a0242c..40971be05d5677 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1518,30 +1518,29 @@ def forward( loss = None if labels is not None: - if self.problem_type is not None: - if self.problem_type == "single_column_regression": - loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - elif self.problem_type == "multi_column_regression": - loss_fct = MSELoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels) - elif self.problem_type in ("binary_classification", "multi_class_classification"): - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.problem_type in ("multi_label_classification"): - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - else: - raise Exception("Problem type not understood") - else: + if self.problem_type is None: if self.num_labels == 1: - # We are doing regression - loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) + self.problem_type = "regression" + elif self.num_labels > 1 and type(labels) == torch.long: + self.problem_type = "single_label_classification" else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + self.problem_type = "multi_label_classification" + if self.problem_type == "regression": + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.problem_type in ("single_label_classification"): + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.problem_type in ("multi_label_classification"): + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError( + f"""The config parameter `problem_type` not understood: + received {self.problem_type} but only [regression, single_label_classification + and multi_label_classification] are valid.""" + ) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output From 0a4c885d9a21464f44f1f3d296659557a5f9c4d1 Mon Sep 17 00:00:00 2001 From: abhishek thakur Date: Mon, 5 Apr 2021 15:02:13 +0200 Subject: [PATCH 03/20] Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/configuration_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index b9aab848f4be0d..b311133e03f180 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -152,8 +152,8 @@ class PretrainedConfig(object): typically for a classification task. - **task_specific_params** (:obj:`Dict[str, Any]`, `optional`) -- Additional keyword arguments to store for the current task. - - **problem_type** (:obj:`str`, `optional`) -- Problem type for ForSequenceClassification tasks. It can be one - of (None, "regression", "single_label_classification", "multi_label_classification"). Default is None. + - **problem_type** (:obj:`str`, `optional`) -- Problem type for :obj:`XxxForSequenceClassification` models. Can be one + of (:obj:`"regression"`, :obj:`"single_label_classification"`, :obj:`"multi_label_classification"`). Parameters linked to the tokenizer From d5108de9ce7055d84a5e5c2692cadb62c8ba3df7 Mon Sep 17 00:00:00 2001 From: abhishek thakur Date: Mon, 5 Apr 2021 15:02:36 +0200 Subject: [PATCH 04/20] Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/configuration_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index b311133e03f180..099134d4bc156e 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -253,8 +253,8 @@ def __init__(self, **kwargs): # regression / multi-label classification self.problem_type = kwargs.pop("problem_type", None) - allowed_problem_types = (None, "regression", "single_label_classification", "multi_label_classification") - if self.problem_type not in allowed_problem_types: + allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification") + if self.problem_type is not None and self.problem_type not in allowed_problem_types: raise ValueError( f"""The config parameter `problem_type` not understood: received {self.problem_type} but only [regression, single_label_classification From fbb61ee37d91a889ada13ca0f7207b429aa0223d Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Tue, 6 Apr 2021 14:00:05 +0200 Subject: [PATCH 05/20] self.config.problem_type --- src/transformers/models/bert/modeling_bert.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 40971be05d5677..4548f0e304c1b5 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1463,7 +1463,7 @@ class BertForSequenceClassification(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.problem_type = config.problem_type + self.config = config self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -1518,27 +1518,27 @@ def forward( loss = None if labels is not None: - if self.problem_type is None: + if self.config.problem_type is None: if self.num_labels == 1: - self.problem_type = "regression" + self.config.problem_type = "regression" elif self.num_labels > 1 and type(labels) == torch.long: - self.problem_type = "single_label_classification" + self.config.problem_type = "single_label_classification" else: - self.problem_type = "multi_label_classification" + self.config.problem_type = "multi_label_classification" - if self.problem_type == "regression": + if self.config.problem_type == "regression": loss_fct = MSELoss() loss = loss_fct(logits.view(-1, self.num_labels), labels) - elif self.problem_type in ("single_label_classification"): + elif self.config.problem_type in ("single_label_classification"): loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.problem_type in ("multi_label_classification"): + elif self.config.problem_type in ("multi_label_classification"): loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) else: raise ValueError( f"""The config parameter `problem_type` not understood: - received {self.problem_type} but only [regression, single_label_classification + received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid.""" ) if not return_dict: From 5a9303afb85c7c438c8d8a4d2cb0e66945847473 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Tue, 6 Apr 2021 14:57:15 +0200 Subject: [PATCH 06/20] fix style --- src/transformers/configuration_utils.py | 12 +++++---- src/transformers/models/bert/modeling_bert.py | 4 +-- .../models/mobilebert/modeling_mobilebert.py | 26 ++++++++++++++----- 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index a105196773ef24..93350efa5cfe95 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -163,8 +163,8 @@ class PretrainedConfig(object): typically for a classification task. - **task_specific_params** (:obj:`Dict[str, Any]`, `optional`) -- Additional keyword arguments to store for the current task. - - **problem_type** (:obj:`str`, `optional`) -- Problem type for :obj:`XxxForSequenceClassification` models. Can be one - of (:obj:`"regression"`, :obj:`"single_label_classification"`, :obj:`"multi_label_classification"`). + - **problem_type** (:obj:`str`, `optional`) -- Problem type for :obj:`XxxForSequenceClassification` models. Can + be one of (:obj:`"regression"`, :obj:`"single_label_classification"`, :obj:`"multi_label_classification"`). Parameters linked to the tokenizer @@ -267,9 +267,11 @@ def __init__(self, **kwargs): allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification") if self.problem_type is not None and self.problem_type not in allowed_problem_types: raise ValueError( - f"""The config parameter `problem_type` not understood: - received {self.problem_type} but only [regression, single_label_classification - and multi_label_classification] are valid.""" + f""" +The config parameter `problem_type` not understood: + received {self.problem_type} but only [regression, single_label_classification and + multi_label_classification] are valid. + """ ) # TPU arguments diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 4548f0e304c1b5..e11a9c72d547a1 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1537,9 +1537,7 @@ def forward( loss = loss_fct(logits, labels) else: raise ValueError( - f"""The config parameter `problem_type` not understood: - received {self.config.problem_type} but only [regression, single_label_classification - and multi_label_classification] are valid.""" + f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." ) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index bd3f86d21e123e..69253f9254e012 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -29,7 +29,7 @@ import torch import torch.nn.functional as F from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import ( @@ -1214,6 +1214,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.mobilebert = MobileBertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -1268,14 +1269,27 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and type(labels) == torch.long: + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type in ("single_label_classification"): loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - + elif self.config.problem_type in ("multi_label_classification"): + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError( + f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." + ) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output From b26a4d19257536d8b7b7e9337eb4b90033ad564d Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Tue, 6 Apr 2021 15:17:53 +0200 Subject: [PATCH 07/20] fix --- src/transformers/models/bert/modeling_bert.py | 6 +++--- src/transformers/models/mobilebert/modeling_mobilebert.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index e11a9c72d547a1..797505113121e8 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1521,7 +1521,7 @@ def forward( if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" - elif self.num_labels > 1 and type(labels) == torch.long: + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" @@ -1529,10 +1529,10 @@ def forward( if self.config.problem_type == "regression": loss_fct = MSELoss() loss = loss_fct(logits.view(-1, self.num_labels), labels) - elif self.config.problem_type in ("single_label_classification"): + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type in ("multi_label_classification"): + elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) else: diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 69253f9254e012..adb09a74dbb8d6 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -1272,7 +1272,7 @@ def forward( if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" - elif self.num_labels > 1 and type(labels) == torch.long: + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" @@ -1280,10 +1280,10 @@ def forward( if self.config.problem_type == "regression": loss_fct = MSELoss() loss = loss_fct(logits.view(-1, self.num_labels), labels) - elif self.config.problem_type in ("single_label_classification"): + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type in ("multi_label_classification"): + elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) else: From f7ef5500287936e9c2e6383842c584cab7d4dca8 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Tue, 6 Apr 2021 15:46:29 +0200 Subject: [PATCH 08/20] fin --- .../models/albert/modeling_albert.py | 25 ++++++++++++---- .../models/big_bird/modeling_big_bird.py | 25 ++++++++++++---- .../models/convbert/modeling_convbert.py | 25 ++++++++++++---- .../models/distilbert/modeling_distilbert.py | 28 ++++++++++++++---- .../models/electra/modeling_electra.py | 25 ++++++++++++---- .../models/funnel/modeling_funnel.py | 25 ++++++++++++---- .../models/longformer/modeling_longformer.py | 25 ++++++++++++---- .../models/openai/modeling_openai.py | 27 +++++++++++++---- .../models/reformer/modeling_reformer.py | 29 ++++++++++++++----- .../models/roberta/modeling_roberta.py | 25 ++++++++++++---- .../squeezebert/modeling_squeezebert.py | 25 ++++++++++++---- src/transformers/models/xlm/modeling_xlm.py | 25 ++++++++++++---- .../models/xlnet/modeling_xlnet.py | 25 ++++++++++++---- 13 files changed, 265 insertions(+), 69 deletions(-) diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 21da03fd7a3ba1..7819063d7da846 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -21,7 +21,7 @@ import torch import torch.nn as nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import ( @@ -970,6 +970,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.albert = AlbertModel(config) self.dropout = nn.Dropout(config.classifier_dropout_prob) @@ -1024,13 +1025,27 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError( + f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." + ) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index f7fd54b9468d97..0a0bed738d4bc4 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -25,7 +25,7 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import ( @@ -2575,6 +2575,7 @@ class BigBirdForSequenceClassification(BigBirdPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.bert = BigBirdModel(config) self.classifier = BigBirdClassificationHead(config) @@ -2625,13 +2626,27 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError( + f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." + ) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index 0ededdc83f3fb7..ecbcd016eceb74 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -22,7 +22,7 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward @@ -962,6 +962,7 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.convbert = ConvBertModel(config) self.classifier = ConvBertClassificationHead(config) @@ -1012,13 +1013,27 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError( + f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 911fba8088481b..ce7dda04615e8d 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -24,7 +24,7 @@ import numpy as np import torch import torch.nn as nn -from torch.nn import CrossEntropyLoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu from ...file_utils import ( @@ -579,6 +579,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.distilbert = DistilBertModel(config) self.pre_classifier = nn.Linear(config.dim, config.dim) @@ -631,12 +632,27 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - loss_fct = nn.MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: - loss_fct = nn.CrossEntropyLoss() + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError( + f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." + ) if not return_dict: output = (logits,) + distilbert_output[1:] diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 913d269ad5063c..513fce78be8e07 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -22,7 +22,7 @@ import torch import torch.nn as nn import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation from ...file_utils import ( @@ -903,6 +903,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.electra = ElectraModel(config) self.classifier = ElectraClassificationHead(config) @@ -953,13 +954,27 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError( + f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." + ) if not return_dict: output = (logits,) + discriminator_hidden_states[1:] diff --git a/src/transformers/models/funnel/modeling_funnel.py b/src/transformers/models/funnel/modeling_funnel.py index 1f277498d124ae..cd3b45cb73144f 100644 --- a/src/transformers/models/funnel/modeling_funnel.py +++ b/src/transformers/models/funnel/modeling_funnel.py @@ -21,7 +21,7 @@ import numpy as np import torch from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F from ...activations import ACT2FN @@ -1240,6 +1240,7 @@ class FunnelForSequenceClassification(FunnelPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.funnel = FunnelBaseModel(config) self.classifier = FunnelClassificationHead(config, config.num_labels) @@ -1287,13 +1288,27 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError( + f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 65634ca314d393..f8061407d21b06 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -21,7 +21,7 @@ import torch import torch.nn as nn import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F from ...activations import ACT2FN, gelu @@ -1803,6 +1803,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.longformer = LongformerModel(config, add_pooling_layer=False) self.classifier = LongformerClassificationHead(config) @@ -1861,13 +1862,27 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError( + f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." + ) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 6564a8fa42cfdb..1571d6005b6442 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -24,7 +24,7 @@ import torch import torch.nn as nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu_new, silu from ...file_utils import ( @@ -749,6 +749,7 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.transformer = OpenAIGPTModel(config) self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) @@ -821,13 +822,27 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError( + f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." + ) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 516fff8f91e3f3..8240036c3f9d04 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -26,7 +26,7 @@ import torch from torch import nn from torch.autograd.function import Function -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import ( @@ -366,7 +366,7 @@ def forward( past_buckets_states=None, use_cache=False, output_attentions=False, - **kwargs + **kwargs, ): sequence_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] @@ -1045,7 +1045,7 @@ def forward( past_buckets_states=None, use_cache=False, output_attentions=False, - **kwargs + **kwargs, ): sequence_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] @@ -2377,6 +2377,7 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.reformer = ReformerModel(config) self.classifier = ReformerClassificationHead(config) @@ -2430,13 +2431,27 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError( + f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." + ) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 88155f76de29f2..bb82706808306a 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu from ...file_utils import ( @@ -1117,6 +1117,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.roberta = RobertaModel(config, add_pooling_layer=False) self.classifier = RobertaClassificationHead(config) @@ -1167,13 +1168,27 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError( + f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." + ) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/squeezebert/modeling_squeezebert.py b/src/transformers/models/squeezebert/modeling_squeezebert.py index 09dcd680bbb454..9eefcdfecdfe5f 100644 --- a/src/transformers/models/squeezebert/modeling_squeezebert.py +++ b/src/transformers/models/squeezebert/modeling_squeezebert.py @@ -19,7 +19,7 @@ import torch from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward @@ -733,6 +733,7 @@ class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.transformer = SqueezeBertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -787,13 +788,27 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError( + f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." + ) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index 3ccd63ee9781ed..b48cc3fe8fc63f 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -24,7 +24,7 @@ import numpy as np import torch from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F from ...activations import gelu @@ -779,6 +779,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.transformer = XLMModel(config) self.sequence_summary = SequenceSummary(config) @@ -836,13 +837,27 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError( + f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." + ) if not return_dict: output = (logits,) + transformer_outputs[1:] diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index 7a6a51d456ca4c..ed96d788bf1d94 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -22,7 +22,7 @@ import torch from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F from ...activations import ACT2FN @@ -1488,6 +1488,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.config = config self.transformer = XLNetModel(config) self.sequence_summary = SequenceSummary(config) @@ -1551,13 +1552,27 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError( + f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." + ) if not return_dict: output = (logits,) + transformer_outputs[1:] From 90e732e1882e883b57c4cc771623477a1253e610 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Tue, 6 Apr 2021 16:00:45 +0200 Subject: [PATCH 09/20] fix --- .../models/openai/modeling_openai.py | 27 +++++-------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 1571d6005b6442..6564a8fa42cfdb 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -24,7 +24,7 @@ import torch import torch.nn as nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import CrossEntropyLoss, MSELoss from ...activations import gelu_new, silu from ...file_utils import ( @@ -749,7 +749,6 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.config = config self.transformer = OpenAIGPTModel(config) self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) @@ -822,27 +821,13 @@ def forward( loss = None if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": + if self.num_labels == 1: + # We are doing regression loss_fct = MSELoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) + loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) else: - raise ValueError( - f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." - ) + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] From a687665431ab31bfed9c378e98981eae6b1851ca Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Mon, 12 Apr 2021 17:03:11 +0200 Subject: [PATCH 10/20] update doc --- src/transformers/configuration_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index affa382f68a0f5..783a9734de4bca 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -165,6 +165,12 @@ class PretrainedConfig(object): current task. - **problem_type** (:obj:`str`, `optional`) -- Problem type for :obj:`XxxForSequenceClassification` models. Can be one of (:obj:`"regression"`, :obj:`"single_label_classification"`, :obj:`"multi_label_classification"`). + Please note that this parameter is only available in the following models: `"AlbertForSequenceClassification"`, + `"BertForSequenceClassification"`, `"BigBirdForSequenceClassification"`, `"ConvBertForSequenceClassification"`, + `"DistilBertForSequenceClassification"`, `"ElectraForSequenceClassification"`, `"FunnelForSequenceClassification"`, + `"LongformerForSequenceClassification"`, `"MobileBertForSequenceClassification"`, `"ReformerForSequenceClassification"`, + `"RobertaForSequenceClassification"`, `"SqueezeBertForSequenceClassification"`, `"XLMForSequenceClassification"` and + `"XLNetForSequenceClassification"`. Parameters linked to the tokenizer From 557f2a11eda667dbb14f95421dc6f95abe5679e1 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Fri, 30 Apr 2021 13:29:19 +0200 Subject: [PATCH 11/20] fix --- src/transformers/configuration_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 783a9734de4bca..4af56cb14f31af 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -165,11 +165,13 @@ class PretrainedConfig(object): current task. - **problem_type** (:obj:`str`, `optional`) -- Problem type for :obj:`XxxForSequenceClassification` models. Can be one of (:obj:`"regression"`, :obj:`"single_label_classification"`, :obj:`"multi_label_classification"`). - Please note that this parameter is only available in the following models: `"AlbertForSequenceClassification"`, - `"BertForSequenceClassification"`, `"BigBirdForSequenceClassification"`, `"ConvBertForSequenceClassification"`, - `"DistilBertForSequenceClassification"`, `"ElectraForSequenceClassification"`, `"FunnelForSequenceClassification"`, - `"LongformerForSequenceClassification"`, `"MobileBertForSequenceClassification"`, `"ReformerForSequenceClassification"`, - `"RobertaForSequenceClassification"`, `"SqueezeBertForSequenceClassification"`, `"XLMForSequenceClassification"` and + Please note that this parameter is only available in the following models: + `"AlbertForSequenceClassification"`, `"BertForSequenceClassification"`, `"BigBirdForSequenceClassification"`, + `"ConvBertForSequenceClassification"`, `"DistilBertForSequenceClassification"`, + `"ElectraForSequenceClassification"`, `"FunnelForSequenceClassification"`, + `"LongformerForSequenceClassification"`, `"MobileBertForSequenceClassification"`, + `"ReformerForSequenceClassification"`, `"RobertaForSequenceClassification"`, + `"SqueezeBertForSequenceClassification"`, `"XLMForSequenceClassification"` and `"XLNetForSequenceClassification"`. Parameters linked to the tokenizer From 157d7d4c3da3a650570f5ebe8c122393c57f3ded Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Fri, 30 Apr 2021 13:54:48 +0200 Subject: [PATCH 12/20] test --- tests/test_modeling_common.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d193a9e7a47862..2f44d36c893705 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1238,6 +1238,21 @@ def cast_to_device(dictionary, device): model.parallelize() model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2) + def test_multilabel(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + config.problem_type = "multi_label_classification" + + for model_class in self.all_model_classes: + if model_class not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): + continue + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + global_rng = random.Random() From 9046ff26132b4363d7eba664849b808c50387d0f Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 30 Apr 2021 19:57:13 +0200 Subject: [PATCH 13/20] Test more problem types --- tests/test_modeling_albert.py | 2 ++ tests/test_modeling_bert.py | 1 + tests/test_modeling_big_bird.py | 1 + tests/test_modeling_common.py | 46 ++++++++++++++++++++++++------ tests/test_modeling_convbert.py | 1 + tests/test_modeling_distilbert.py | 1 + tests/test_modeling_electra.py | 1 + tests/test_modeling_funnel.py | 1 + tests/test_modeling_longformer.py | 1 + tests/test_modeling_mobilebert.py | 1 + tests/test_modeling_reformer.py | 1 + tests/test_modeling_roberta.py | 1 + tests/test_modeling_squeezebert.py | 1 + tests/test_modeling_xlm.py | 1 + tests/test_modeling_xlnet.py | 1 + 15 files changed, 53 insertions(+), 8 deletions(-) diff --git a/tests/test_modeling_albert.py b/tests/test_modeling_albert.py index 7f82c67ba088ac..81c5c48ccf1272 100644 --- a/tests/test_modeling_albert.py +++ b/tests/test_modeling_albert.py @@ -230,6 +230,8 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): else () ) + test_sequence_classification_problem_types = True + # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index 97da4350ab7c2c..acd921ce8a8dd8 100755 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -439,6 +439,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () + test_sequence_classification_problem_types = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_big_bird.py b/tests/test_modeling_big_bird.py index edef01f207a511..ba7d12fe2d336b 100644 --- a/tests/test_modeling_big_bird.py +++ b/tests/test_modeling_big_bird.py @@ -433,6 +433,7 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase): # head masking & pruning is currently not supported for big bird test_head_masking = False test_pruning = False + test_sequence_classification_problem_types = True # torchscript should be possible, but takes prohibitively long to test. # Also torchscript is not an important feature to have in the beginning. diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 2f44d36c893705..1bb23008b1a128 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -89,6 +89,7 @@ class ModelTesterMixin: test_missing_keys = True test_model_parallel = False is_encoder_decoder = False + test_sequence_classification_problem_types = False def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): inputs_dict = copy.deepcopy(inputs_dict) @@ -1238,20 +1239,49 @@ def cast_to_device(dictionary, device): model.parallelize() model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2) - def test_multilabel(self): + def test_problem_types(self): + if not self.test_sequence_classification_problem_types: + return + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.problem_type = "multi_label_classification" + problem_types = [ + {"title": "multi_label_classification", "num_labels": 2, "dtype": torch.float}, + {"title": "single_label_classification", "num_labels": 1, "dtype": torch.long}, + {"title": "regression", "num_labels": 1, "dtype": torch.float}, + ] for model_class in self.all_model_classes: if model_class not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): continue - model = model_class(config) - model.to(torch_device) - model.train() - inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - loss = model(**inputs).loss - loss.backward() + + for problem_type in problem_types: + with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"): + + config.problem_type = problem_type["title"] + config.num_labels = problem_type["num_labels"] + + model = model_class(config) + model.to(torch_device) + model.train() + + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + + if problem_type["num_labels"] > 1: + inputs["labels"] = inputs["labels"].unsqueeze(1).repeat(1, problem_type["num_labels"]) + + inputs["labels"] = inputs["labels"].to(problem_type["dtype"]) + + loss = model(**inputs).loss + loss.backward() + + with self.assertRaises(ValueError): + config.problem_type = "non_existent_problem" + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + model(**inputs) global_rng = random.Random() diff --git a/tests/test_modeling_convbert.py b/tests/test_modeling_convbert.py index 062a7f506a996f..ebe7188755133c 100644 --- a/tests/test_modeling_convbert.py +++ b/tests/test_modeling_convbert.py @@ -260,6 +260,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase): ) test_pruning = False test_head_masking = False + test_sequence_classification_problem_types = True def setUp(self): self.model_tester = ConvBertModelTester(self) diff --git a/tests/test_modeling_distilbert.py b/tests/test_modeling_distilbert.py index d6c3dc54b8d47c..0c5c4bcf68c00b 100644 --- a/tests/test_modeling_distilbert.py +++ b/tests/test_modeling_distilbert.py @@ -211,6 +211,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): test_pruning = True test_torchscript = True test_resize_embeddings = True + test_sequence_classification_problem_types = True def setUp(self): self.model_tester = DistilBertModelTester(self) diff --git a/tests/test_modeling_electra.py b/tests/test_modeling_electra.py index 5935eafee668c0..366d8f0f9079fd 100644 --- a/tests/test_modeling_electra.py +++ b/tests/test_modeling_electra.py @@ -287,6 +287,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + test_sequence_classification_problem_types = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_funnel.py b/tests/test_modeling_funnel.py index 2d59e9f4e4100d..9be00caeb734f0 100644 --- a/tests/test_modeling_funnel.py +++ b/tests/test_modeling_funnel.py @@ -360,6 +360,7 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + test_sequence_classification_problem_types = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index 96333fced11491..c5d5eee1626618 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -274,6 +274,7 @@ def prepare_config_and_inputs_for_question_answering(self): class LongformerModelTest(ModelTesterMixin, unittest.TestCase): test_pruning = False # pruning is not supported test_torchscript = False + test_sequence_classification_problem_types = True all_model_classes = ( ( diff --git a/tests/test_modeling_mobilebert.py b/tests/test_modeling_mobilebert.py index 96c974e2edc534..ce5854d16a59c0 100644 --- a/tests/test_modeling_mobilebert.py +++ b/tests/test_modeling_mobilebert.py @@ -267,6 +267,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + test_sequence_classification_problem_types = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 817d35c5b9156a..05db9599c5173a 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -590,6 +590,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod test_pruning = False test_headmasking = False test_torchscript = False + test_sequence_classification_problem_types = True def prepare_kwargs(self): return { diff --git a/tests/test_modeling_roberta.py b/tests/test_modeling_roberta.py index be675eda6d49d4..a6acdfe7b93673 100644 --- a/tests/test_modeling_roberta.py +++ b/tests/test_modeling_roberta.py @@ -351,6 +351,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas else () ) all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else () + test_sequence_classification_problem_types = True def setUp(self): self.model_tester = RobertaModelTester(self) diff --git a/tests/test_modeling_squeezebert.py b/tests/test_modeling_squeezebert.py index 493326157875c1..8f9d65fa9ac2e1 100644 --- a/tests/test_modeling_squeezebert.py +++ b/tests/test_modeling_squeezebert.py @@ -231,6 +231,7 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase): test_torchscript = True test_resize_embeddings = True test_head_masking = False + test_sequence_classification_problem_types = True def setUp(self): self.model_tester = SqueezeBertModelTester(self) diff --git a/tests/test_modeling_xlm.py b/tests/test_modeling_xlm.py index 69f76b88c981c3..691a4039ea93c2 100644 --- a/tests/test_modeling_xlm.py +++ b/tests/test_modeling_xlm.py @@ -349,6 +349,7 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_generative_model_classes = ( (XLMWithLMHeadModel,) if is_torch_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable + test_sequence_classification_problem_types = True # XLM has 2 QA models -> need to manually set the correct labels for one of them here def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index 1423ef6980f2eb..93031d03719fa7 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) (XLNetLMHeadModel,) if is_torch_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable test_pruning = False + test_sequence_classification_problem_types = True # XLNet has 2 QA models -> need to manually set the correct labels for one of them here def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): From 00d053aae371544b3b4405e48ecf72ebc5e6c4e4 Mon Sep 17 00:00:00 2001 From: abhishek thakur <1183441+abhi1thakur@users.noreply.github.com> Date: Mon, 3 May 2021 18:04:07 +0200 Subject: [PATCH 14/20] Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/configuration_utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index abec649ac32a56..b28f8e6dcd923d 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -275,11 +275,8 @@ def __init__(self, **kwargs): allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification") if self.problem_type is not None and self.problem_type not in allowed_problem_types: raise ValueError( - f""" -The config parameter `problem_type` not understood: - received {self.problem_type} but only [regression, single_label_classification and - multi_label_classification] are valid. - """ + f"The config parameter `problem_type` wasnot understood: received {self.problem_type}" + "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid." ) # TPU arguments From bb0edc4e798a53bf42d80eca68192470f328620c Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Mon, 3 May 2021 18:14:18 +0200 Subject: [PATCH 15/20] fix --- .../models/albert/modeling_albert.py | 4 - src/transformers/models/bert/modeling_bert.py | 4 - .../models/big_bird/modeling_big_bird.py | 4 - .../models/convbert/modeling_convbert.py | 4 - .../models/distilbert/modeling_distilbert.py | 4 - .../models/electra/modeling_electra.py | 4 - .../models/funnel/modeling_funnel.py | 4 - .../models/longformer/modeling_longformer.py | 4 - .../models/mobilebert/modeling_mobilebert.py | 5 +- .../models/reformer/modeling_reformer.py | 4 - .../models/roberta/modeling_roberta.py | 4 - .../models/squeezebert/Search.code-search | 73 +++++++++++++++++++ .../squeezebert/modeling_squeezebert.py | 4 - src/transformers/models/xlm/modeling_xlm.py | 4 - .../models/xlnet/modeling_xlnet.py | 4 - 15 files changed, 74 insertions(+), 56 deletions(-) create mode 100644 src/transformers/models/squeezebert/Search.code-search diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index a9648ad1b089b5..08bf9d82d0d56b 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -1042,10 +1042,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - else: - raise ValueError( - f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." - ) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 7a11baa6ba01dc..21a6eaab595265 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1535,10 +1535,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - else: - raise ValueError( - f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." - ) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 049e1edc7f258c..45da61b991389f 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -2677,10 +2677,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - else: - raise ValueError( - f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." - ) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index c5f4e64e8298e0..f5b23e46005ff5 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -1030,10 +1030,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - else: - raise ValueError( - f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." - ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index c4345a0366bf8a..b30b3db90738b7 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -649,10 +649,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - else: - raise ValueError( - f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." - ) if not return_dict: output = (logits,) + distilbert_output[1:] diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 73d37beb82bd16..5229054ff76616 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -971,10 +971,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - else: - raise ValueError( - f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." - ) if not return_dict: output = (logits,) + discriminator_hidden_states[1:] diff --git a/src/transformers/models/funnel/modeling_funnel.py b/src/transformers/models/funnel/modeling_funnel.py index c81d61da3f9d44..890a620ed41225 100644 --- a/src/transformers/models/funnel/modeling_funnel.py +++ b/src/transformers/models/funnel/modeling_funnel.py @@ -1305,10 +1305,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - else: - raise ValueError( - f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." - ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index bbbf5ce7e01711..d1ab71bb7ad724 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1879,10 +1879,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - else: - raise ValueError( - f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." - ) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 125dbb6aae6fe7..5d0f9f3119406d 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -1286,10 +1286,7 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - else: - raise ValueError( - f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." - ) + if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index d53e3550fb73f6..4beca117a6855b 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -2452,10 +2452,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - else: - raise ValueError( - f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." - ) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 7ab61e523fe5f2..cf535a719c8bdf 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -1185,10 +1185,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - else: - raise ValueError( - f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." - ) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/squeezebert/Search.code-search b/src/transformers/models/squeezebert/Search.code-search new file mode 100644 index 00000000000000..5493f00df7db6f --- /dev/null +++ b/src/transformers/models/squeezebert/Search.code-search @@ -0,0 +1,73 @@ +# Query: else:\n raise ValueError(\n f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid."\n ) +# ContextLines: 1 + +14 results - 14 files + +src/transformers/models/albert/modeling_albert.py: + 1044 loss = loss_fct(logits, labels) + 1045: + 1049 + +src/transformers/models/bert/modeling_bert.py: + 1537 loss = loss_fct(logits, labels) + 1538: + 1542 if not return_dict: + +src/transformers/models/big_bird/modeling_big_bird.py: + 2679 loss = loss_fct(logits, labels) + 2680: + 2684 + +src/transformers/models/convbert/modeling_convbert.py: + 1032 loss = loss_fct(logits, labels) + 1033: + 1037 + +src/transformers/models/distilbert/modeling_distilbert.py: + 651 loss = loss_fct(logits, labels) + 652: + 656 + +src/transformers/models/electra/modeling_electra.py: + 973 loss = loss_fct(logits, labels) + 974: + +src/transformers/models/funnel/modeling_funnel.py: + 1307 loss = loss_fct(logits, labels) + 1308: + 1312 + +src/transformers/models/longformer/modeling_longformer.py: + 1881 loss = loss_fct(logits, labels) + 1882: + 1886 + +src/transformers/models/mobilebert/modeling_mobilebert.py: + 1288 loss = loss_fct(logits, labels) + 1289: + 1293 if not return_dict: + +src/transformers/models/reformer/modeling_reformer.py: + 2454 loss = loss_fct(logits, labels) + 2455: + 2459 + +src/transformers/models/roberta/modeling_roberta.py: + 1187 loss = loss_fct(logits, labels) + 1188: + 1192 + +src/transformers/models/squeezebert/modeling_squeezebert.py: + 807 loss = loss_fct(logits, labels) + 808: + 812 + +src/transformers/models/xlm/modeling_xlm.py: + 856 loss = loss_fct(logits, labels) + 857: + 861 + +src/transformers/models/xlnet/modeling_xlnet.py: + 1571 loss = loss_fct(logits, labels) + 1572: + 1576 diff --git a/src/transformers/models/squeezebert/modeling_squeezebert.py b/src/transformers/models/squeezebert/modeling_squeezebert.py index 9d745d76975e83..462c8fb376261b 100644 --- a/src/transformers/models/squeezebert/modeling_squeezebert.py +++ b/src/transformers/models/squeezebert/modeling_squeezebert.py @@ -805,10 +805,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - else: - raise ValueError( - f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." - ) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index 5ef63a85883e00..8dc0d208d16097 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -854,10 +854,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - else: - raise ValueError( - f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." - ) if not return_dict: output = (logits,) + transformer_outputs[1:] diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index ed96d788bf1d94..fa562c5f344991 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -1569,10 +1569,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - else: - raise ValueError( - f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid." - ) if not return_dict: output = (logits,) + transformer_outputs[1:] From b63cc8e3e0a40b114b9e790129d6c2d1d2afc824 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Mon, 3 May 2021 18:14:35 +0200 Subject: [PATCH 16/20] remove --- .../models/squeezebert/Search.code-search | 73 ------------------- 1 file changed, 73 deletions(-) delete mode 100644 src/transformers/models/squeezebert/Search.code-search diff --git a/src/transformers/models/squeezebert/Search.code-search b/src/transformers/models/squeezebert/Search.code-search deleted file mode 100644 index 5493f00df7db6f..00000000000000 --- a/src/transformers/models/squeezebert/Search.code-search +++ /dev/null @@ -1,73 +0,0 @@ -# Query: else:\n raise ValueError(\n f"The config parameter `problem_type` not understood: received {self.config.problem_type} but only [regression, single_label_classification and multi_label_classification] are valid."\n ) -# ContextLines: 1 - -14 results - 14 files - -src/transformers/models/albert/modeling_albert.py: - 1044 loss = loss_fct(logits, labels) - 1045: - 1049 - -src/transformers/models/bert/modeling_bert.py: - 1537 loss = loss_fct(logits, labels) - 1538: - 1542 if not return_dict: - -src/transformers/models/big_bird/modeling_big_bird.py: - 2679 loss = loss_fct(logits, labels) - 2680: - 2684 - -src/transformers/models/convbert/modeling_convbert.py: - 1032 loss = loss_fct(logits, labels) - 1033: - 1037 - -src/transformers/models/distilbert/modeling_distilbert.py: - 651 loss = loss_fct(logits, labels) - 652: - 656 - -src/transformers/models/electra/modeling_electra.py: - 973 loss = loss_fct(logits, labels) - 974: - -src/transformers/models/funnel/modeling_funnel.py: - 1307 loss = loss_fct(logits, labels) - 1308: - 1312 - -src/transformers/models/longformer/modeling_longformer.py: - 1881 loss = loss_fct(logits, labels) - 1882: - 1886 - -src/transformers/models/mobilebert/modeling_mobilebert.py: - 1288 loss = loss_fct(logits, labels) - 1289: - 1293 if not return_dict: - -src/transformers/models/reformer/modeling_reformer.py: - 2454 loss = loss_fct(logits, labels) - 2455: - 2459 - -src/transformers/models/roberta/modeling_roberta.py: - 1187 loss = loss_fct(logits, labels) - 1188: - 1192 - -src/transformers/models/squeezebert/modeling_squeezebert.py: - 807 loss = loss_fct(logits, labels) - 808: - 812 - -src/transformers/models/xlm/modeling_xlm.py: - 856 loss = loss_fct(logits, labels) - 857: - 861 - -src/transformers/models/xlnet/modeling_xlnet.py: - 1571 loss = loss_fct(logits, labels) - 1572: - 1576 From 9e6a8b4809f5ec25da5ef5c9b38a3c23580419fb Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Mon, 3 May 2021 18:16:00 +0200 Subject: [PATCH 17/20] fix --- src/transformers/configuration_utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index b28f8e6dcd923d..24bdace9e44661 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -166,13 +166,13 @@ class PretrainedConfig(PushToHubMixin): - **problem_type** (:obj:`str`, `optional`) -- Problem type for :obj:`XxxForSequenceClassification` models. Can be one of (:obj:`"regression"`, :obj:`"single_label_classification"`, :obj:`"multi_label_classification"`). Please note that this parameter is only available in the following models: - `"AlbertForSequenceClassification"`, `"BertForSequenceClassification"`, `"BigBirdForSequenceClassification"`, - `"ConvBertForSequenceClassification"`, `"DistilBertForSequenceClassification"`, - `"ElectraForSequenceClassification"`, `"FunnelForSequenceClassification"`, - `"LongformerForSequenceClassification"`, `"MobileBertForSequenceClassification"`, - `"ReformerForSequenceClassification"`, `"RobertaForSequenceClassification"`, - `"SqueezeBertForSequenceClassification"`, `"XLMForSequenceClassification"` and - `"XLNetForSequenceClassification"`. + `AlbertForSequenceClassification`, `BertForSequenceClassification`, `BigBirdForSequenceClassification`, + `ConvBertForSequenceClassification`, `DistilBertForSequenceClassification`, + `ElectraForSequenceClassification`, `FunnelForSequenceClassification`, + `LongformerForSequenceClassification`, `MobileBertForSequenceClassification`, + `ReformerForSequenceClassification`, `RobertaForSequenceClassification`, + `SqueezeBertForSequenceClassification`, `XLMForSequenceClassification` and + `XLNetForSequenceClassification`. Parameters linked to the tokenizer From 4729da470efdd1cbbe966d6b5eb49c1a77bb0e9f Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Mon, 3 May 2021 18:22:37 +0200 Subject: [PATCH 18/20] quality --- src/transformers/configuration_utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 24bdace9e44661..6553d3f42ee38e 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -165,14 +165,12 @@ class PretrainedConfig(PushToHubMixin): current task. - **problem_type** (:obj:`str`, `optional`) -- Problem type for :obj:`XxxForSequenceClassification` models. Can be one of (:obj:`"regression"`, :obj:`"single_label_classification"`, :obj:`"multi_label_classification"`). - Please note that this parameter is only available in the following models: - `AlbertForSequenceClassification`, `BertForSequenceClassification`, `BigBirdForSequenceClassification`, - `ConvBertForSequenceClassification`, `DistilBertForSequenceClassification`, - `ElectraForSequenceClassification`, `FunnelForSequenceClassification`, + Please note that this parameter is only available in the following models: `AlbertForSequenceClassification`, + `BertForSequenceClassification`, `BigBirdForSequenceClassification`, `ConvBertForSequenceClassification`, + `DistilBertForSequenceClassification`, `ElectraForSequenceClassification`, `FunnelForSequenceClassification`, `LongformerForSequenceClassification`, `MobileBertForSequenceClassification`, `ReformerForSequenceClassification`, `RobertaForSequenceClassification`, - `SqueezeBertForSequenceClassification`, `XLMForSequenceClassification` and - `XLNetForSequenceClassification`. + `SqueezeBertForSequenceClassification`, `XLMForSequenceClassification` and `XLNetForSequenceClassification`. Parameters linked to the tokenizer From 9f674397463e5beadcc3a761101bfd739c533fd0 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Mon, 3 May 2021 18:26:54 +0200 Subject: [PATCH 19/20] make fix-copies --- src/transformers/models/mobilebert/modeling_mobilebert.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 5d0f9f3119406d..8f50c6d6f0f905 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -1286,7 +1286,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output From fd6787a1861f2de6bf72c5618b82dd8c2c71aff4 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Mon, 3 May 2021 18:42:38 +0200 Subject: [PATCH 20/20] remove test --- tests/test_modeling_common.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1bb23008b1a128..f83d65b51a7d3c 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1275,14 +1275,6 @@ def test_problem_types(self): loss = model(**inputs).loss loss.backward() - with self.assertRaises(ValueError): - config.problem_type = "non_existent_problem" - model = model_class(config) - model.to(torch_device) - model.train() - inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - model(**inputs) - global_rng = random.Random()