Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-class, multi-label and regression to transformers #11012

Merged
merged 23 commits into from
May 4, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ class PretrainedConfig(object):
typically for a classification task.
- **task_specific_params** (:obj:`Dict[str, Any]`, `optional`) -- Additional keyword arguments to store for the
current task.
- **problem_type** (:obj:`str`, `optional`) -- Problem type for ForSequenceClassification tasks. It can be one
of (None, "regression", "single_label_classification", "multi_label_classification"). Default is None.
abhishekkrthakur marked this conversation as resolved.
Show resolved Hide resolved

Parameters linked to the tokenizer

Expand Down Expand Up @@ -249,6 +251,16 @@ def __init__(self, **kwargs):
# task specific arguments
self.task_specific_params = kwargs.pop("task_specific_params", None)

# regression / multi-label classification
self.problem_type = kwargs.pop("problem_type", None)
allowed_problem_types = (None, "regression", "single_label_classification", "multi_label_classification")
if self.problem_type not in allowed_problem_types:
abhishekkrthakur marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"""The config parameter `problem_type` not understood:
received {self.problem_type} but only [regression, single_label_classification
and multi_label_classification] are valid."""
)

# TPU arguments
if kwargs.pop("xla_device", None) is not None:
logger.warn(
Expand Down
41 changes: 20 additions & 21 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,30 +1518,29 @@ def forward(

loss = None
if labels is not None:
if self.problem_type is not None:
if self.problem_type == "single_column_regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
elif self.problem_type == "multi_column_regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels)
elif self.problem_type in ("binary_classification", "multi_class_classification"):
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.problem_type in ("multi_label_classification"):
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
else:
raise Exception("Problem type not understood")
else:
if self.problem_type is None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
self.problem_type = "regression"
elif self.num_labels > 1 and type(labels) == torch.long:
self.problem_type = "single_label_classification"
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
self.problem_type = "multi_label_classification"

if self.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels)
elif self.problem_type in ("single_label_classification"):
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.problem_type in ("multi_label_classification"):
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
else:
raise ValueError(
f"""The config parameter `problem_type` not understood:
received {self.problem_type} but only [regression, single_label_classification
and multi_label_classification] are valid."""
)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
Expand Down