Skip to content

Commit

Permalink
Switch from return_tuple to return_dict (#6138)
Browse files Browse the repository at this point in the history
* Switch from return_tuple to return_dict

* Fix test

* [WIP] Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleC… (#5614)

* Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleChoice} models and tests

* AutoModels


Tiny tweaks

* Style

* Final changes before merge

* Re-order for simpler review

* Final fixes

* Addressing @sgugger's comments

* Test MultipleChoice

* Rework TF trainer (#6038)

* Fully rework training/prediction loops

* fix method name

* Fix variable name

* Fix property name

* Fix scope

* Fix method name

* Fix tuple index

* Fix tuple index

* Fix indentation

* Fix variable name

* fix eval before log

* Add drop remainder for test dataset

* Fix step number + fix logging datetime

* fix eval loss value

* use global step instead of step + fix logging at step 0

* Fix logging datetime

* Fix global_step usage

* Fix breaking loop + logging datetime

* Fix step in prediction loop

* Fix step breaking

* Fix train/test loops

* Force TF at least 2.2 for the trainer

* Use assert_cardinality to facilitate the dataset size computation

* Log steps per epoch

* Make tfds compliant with TPU

* Make tfds compliant with TPU

* Use TF dataset enumerate instead of the Python one

* revert previous commit

* Fix data_dir

* Apply style

* rebase on master

* Address Sylvain's comments

* Address Sylvain's and Lysandre comments

* Trigger CI

* Remove unused import

* Switch from return_tuple to return_dict

* Fix test

* Add recent model

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Julien Plu <plu.julien@gmail.com>
  • Loading branch information
3 people committed Jul 30, 2020
1 parent 562b636 commit 91cb954
Show file tree
Hide file tree
Showing 35 changed files with 675 additions and 633 deletions.
17 changes: 7 additions & 10 deletions docs/source/quicktour.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,19 +230,16 @@ final activations of the model.
>>> ## PYTORCH CODE
>>> print(pt_outputs)
SequenceClassifierOutput(loss=None, logits=tensor([[-4.0833, 4.3364],
[ 0.0818, -0.0418]], grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)
(tensor([[-4.0833, 4.3364],
[ 0.0818, -0.0418]], grad_fn=<AddmmBackward>),)
>>> ## TENSORFLOW CODE
>>> print(tf_outputs)
(<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-4.0832963 , 4.336414 ],
[ 0.08181786, -0.04179301]], dtype=float32)>,)
The model can return more than just the final activations, which is why the PyTorch output is a special class and the
TensorFlow output is a tuple. Here we only asked for the final activations, so we get a tuple with one element on the
TensorFlow side and a :class:`~transformers.modeling_outputs.SequenceClassifierOutput` with just the ``logits`` field
filled on the PyTorch side.

The model can return more than just the final activations, which is why the output is a tuple. Here we only asked for
the final activations, so we get a tuple with one element.
.. note::

All 🤗 Transformers models (PyTorch or TensorFlow) return the activations of the model *before* the final
Expand All @@ -254,7 +251,7 @@ Let's apply the SoftMax activation to get predictions.
>>> ## PYTORCH CODE
>>> import torch.nn.functional as F
>>> pt_predictions = F.softmax(pt_outputs.logits, dim=-1)
>>> pt_predictions = F.softmax(pt_outputs[0], dim=-1)
>>> ## TENSORFLOW CODE
>>> import tensorflow as tf
>>> tf_predictions = tf.nn.softmax(tf_outputs[0], axis=-1)
Expand Down Expand Up @@ -341,8 +338,8 @@ code is easy to access and tweak if you need to.

In our previous example, the model was called "distilbert-base-uncased-finetuned-sst-2-english", which means it's
using the :doc:`DistilBERT </model_doc/distilbert>` architecture. As
:class:`~transformers.AutoModelForSequenceClassification` (or :class:`~transformers.TFAutoModelForSequenceClassification`
if you are using TensorFlow)` was used, the model automatically created is then a
:class:`~transformers.AutoModelForSequenceClassification` (or :class:`~transformers.TFAutoModelForSequenceClassification`
if you are using TensorFlow) was used, the model automatically created is then a
:class:`~transformers.DistilBertForSequenceClassification`. You can look at its documentation for all details relevant
to that specific model, or browse the source code. This is how you would directly instantiate model and tokenizer
without the auto magic:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ put it in train mode.
.. code-block:: python
from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', return_dict=True)
model.train()
This is useful because it allows us to make use of the pre-trained BERT
Expand Down
5 changes: 0 additions & 5 deletions examples/question-answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,6 @@ def train(args, train_dataset, model, tokenizer):
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
)

if isinstance(model, torch.nn.DataParallel):
inputs["return_tuple"] = True

outputs = model(**inputs)
# model outputs are always tuple in transformers (see doc)
loss = outputs[0]
Expand Down Expand Up @@ -316,8 +313,6 @@ def evaluate(args, model, tokenizer, prefix=""):
inputs.update(
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
)
if isinstance(model, torch.nn.DataParallel):
inputs["return_tuple"] = True
outputs = model(**inputs)

for i, feature_index in enumerate(feature_indices):
Expand Down
2 changes: 1 addition & 1 deletion examples/seq2seq/test_seq2seq_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_distill_checkpointing_with_teacher(self):
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))

def test_loss_fn(self):
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY)
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY, return_dict=True)
input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs["attention_mask"]
target_ids = torch.tensor([[0, 4, 8, 2], [0, 8, 2, 1]], dtype=torch.long, device=model.device)
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ class PretrainedConfig(object):
Whether or not the model should returns all attentions.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should return tuples instead of :obj:`ModelOutput` objects.
return_dict (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should return a :class:`~transformers.file_utils.ModelOutput` instead of a
plain tuple.
is_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether the model is used as an encoder/decoder or not.
is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
Expand Down Expand Up @@ -133,7 +134,7 @@ class PretrainedConfig(object):

def __init__(self, **kwargs):
# Attributes with defaults
self.return_tuple = kwargs.pop("return_tuple", False)
self.return_dict = kwargs.pop("return_dict", False)
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_attentions = kwargs.pop("output_attentions", False)
self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
Expand Down Expand Up @@ -194,12 +195,12 @@ def __init__(self, **kwargs):
raise err

@property
def use_return_tuple(self) -> bool:
def use_return_dict(self) -> bool:
"""
:obj:`bool`: Whether or not the model should return a tuple.
:obj:`bool`: Whether or not return :class:`~transformers.file_utils.ModelOutput` instead of tuples.
"""
# If torchscript is set, force return_tuple to avoid jit errors
return self.return_tuple or self.torchscript
# If torchscript is set, force `return_dict=False` to avoid jit errors
return self.return_dict and not self.torchscript

@property
def num_labels(self) -> int:
Expand Down
120 changes: 92 additions & 28 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
import sys
import tarfile
import tempfile
from collections import OrderedDict
from contextlib import contextmanager
from dataclasses import fields
from functools import partial, wraps
from hashlib import sha256
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile

import numpy as np
import requests
from filelock import FileLock
from tqdm.auto import tqdm
Expand Down Expand Up @@ -190,8 +193,8 @@ def docstring_decorator(fn):
RETURN_INTRODUCTION = r"""
Returns:
:class:`~{full_output_type}` or :obj:`tuple(torch.FloatTensor)`:
A :class:`~{full_output_type}` or a tuple of :obj:`torch.FloatTensor` (if ``return_tuple=True`` is passed or
when ``config.return_tuple=True``) comprising various elements depending on the configuration
A :class:`~{full_output_type}` (if ``return_dict=True`` is passed or when ``config.return_dict=True``) or a
tuple of :obj:`torch.FloatTensor` comprising various elements depending on the configuration
(:class:`~transformers.{config_class}`) and inputs.
"""
Expand Down Expand Up @@ -257,7 +260,7 @@ def _prepare_output_docstrings(output_type, config_class):
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1
Expand All @@ -274,7 +277,7 @@ def _prepare_output_docstrings(output_type, config_class):
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> start_positions = torch.tensor([1])
Expand All @@ -293,7 +296,7 @@ def _prepare_output_docstrings(output_type, config_class):
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
Expand All @@ -309,7 +312,7 @@ def _prepare_output_docstrings(output_type, config_class):
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True)
>>> input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"]
Expand All @@ -325,7 +328,7 @@ def _prepare_output_docstrings(output_type, config_class):
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
Expand All @@ -340,7 +343,7 @@ def _prepare_output_docstrings(output_type, config_class):
>>> import torch
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True)
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> choice0 = "It is eaten with a fork and a knife."
Expand All @@ -362,7 +365,7 @@ def _prepare_output_docstrings(output_type, config_class):
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs, labels=inputs["input_ids"])
Expand Down Expand Up @@ -900,30 +903,91 @@ def wrapper(*args, **kwargs):
return wrapper


class ModelOutput:
def is_tensor(x):
""" Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`. """
if is_torch_available():
import torch

if isinstance(x, torch.Tensor):
return True
if is_tf_available():
import tensorflow as tf

if isinstance(x, tf.Tensor):
return True
return isinstance(x, np.ndarray)


class ModelOutput(OrderedDict):
"""
Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows indexing by integer or slice (like
a tuple) or strings (like a dictionnary) that will ignore the ``None`` attributes.
a tuple) or strings (like a dictionnary) that will ignore the ``None`` attributes. Otherwise behaves like a
regular python dictionary.
.. warning::
You can't unpack a :obj:`ModelOutput` directly. Use the :meth:`~transformers.file_utils.ModelOutput.to_tuple`
method to convert it to a tuple before.
"""

def to_tuple(self):
"""
Converts :obj:`self` to a tuple.
def __post_init__(self):
class_fields = fields(self)

# Safety and consistency checks
assert len(class_fields), f"{self.__class__.__name__} has no fields."
assert all(
field.default is None for field in class_fields[1:]
), f"{self.__class__.__name__} should not have more than one required field."

first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])

if other_fields_are_none and not is_tensor(first_field):
try:
iterator = iter(first_field)
first_field_iterator = True
except TypeError:
first_field_iterator = False

# if we provided an iterator as first field and the iterator is a (key, value) iterator
# set the associated fields
if first_field_iterator:
for element in iterator:
if (
not isinstance(element, (list, tuple))
or not len(element) == 2
or not isinstance(element[0], str)
):
break
setattr(self, element[0], element[1])
if element[1] is not None:
self[element[0]] = element[1]
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v

Return: A tuple containing all non-:obj:`None` attributes of the :obj:`self`.
"""
return tuple(getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None)
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")

def to_dict(self):
"""
Converts :obj:`self` to a Python dictionary.
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")

Return: A dictionary containing all non-:obj:`None` attributes of the :obj:`self`.
"""
return {f: getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None}
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")

def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")

def __getitem__(self, i):
return self.to_dict()[i] if isinstance(i, str) else self.to_tuple()[i]
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = {k: v for (k, v) in self.items()}
return inner_dict[k]
else:
return self.to_tuple()[k]

def __len__(self):
return len(self.to_tuple())
def to_tuple(self) -> Tuple[Any]:
"""
Convert self to a tuple containing all the attributes/keys that are not ``None``.
"""
return tuple(self[k] for k in self.keys())
Loading

0 comments on commit 91cb954

Please sign in to comment.