Skip to content

Commit

Permalink
Fixes to make life easier with the nlp library (huggingface#6423)
Browse files Browse the repository at this point in the history
* allow using tokenizer.pad as a collate_fn in pytorch

* allow using tokenizer.pad as a collate_fn in pytorch

* Add documentation and tests

* Make attention mask the right shape

* Better test

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
  • Loading branch information
2 people authored and fabiocapsouza committed Nov 15, 2020
1 parent 1d17ff7 commit 79dd29d
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 43 deletions.
2 changes: 1 addition & 1 deletion src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,7 +2318,7 @@ def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Option
max_len = max([len(item) for item in outputs])
outputs = [output + [self.pad_token_id] * (max_len - len(output)) for output in outputs]
outputs = BatchEncoding(
{"input_ids": outputs, "attention_mask": [1] * len(outputs)}, tensor_type=self.framework
{"input_ids": outputs, "attention_mask": [[1] * len(outputs)]}, tensor_type=self.framework,
)
return outputs

Expand Down
92 changes: 50 additions & 42 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,11 +553,12 @@ def convert_to_tensors(

tensor = as_tensor(value)

# at-least2d
if tensor.ndim > 2:
tensor = tensor.squeeze(0)
elif tensor.ndim < 2:
tensor = tensor[None, :]
# Removing this for now in favor of controling the shape with `prepend_batch_axis`
# # at-least2d
# if tensor.ndim > 2:
# tensor = tensor.squeeze(0)
# elif tensor.ndim < 2:
# tensor = tensor[None, :]

self[key] = tensor
except: # noqa E722
Expand Down Expand Up @@ -589,43 +590,6 @@ def to(self, device: str) -> "BatchEncoding":
return self


# class AddedToken(UserString):
# """ AddedToken represents a token to be added to a Tokenizer

# An AddedToken can have special options defining the way it should behave.

# Args:
# content: str:
# The content of the token

# single_word: bool
# Whether this token should only match against single word. If True,
# this token will never match inside of a word.

# lstrip: bool
# Whether this token should strip all potential whitespaces on the left side.
# If True, this token will greedily match any whitespace on the left and then strip
# them out.

# rstrip: bool
# Whether this token should strip all potential whitespaces on the right side.
# If True, this token will greedily match any whitespace on the right and then strip
# them out.
# """

# def __init__(
# self, data: str, single_word: bool = False, lstrip: bool = False, rstrip: bool = False,
# ):
# super().__init__(data)

# self._single_word = single_word
# self._lstrip = lstrip
# self._rstrip = rstrip

# def lower(self):
# return AddedToken(self.data.lower(), self._single_word, self._lstrip, self._rstrip)


class SpecialTokensMixin:
"""
A mixin derived by :class:`~transformers.PreTrainedTokenizer` and :class:`~transformers.PreTrainedTokenizerFast`
Expand Down Expand Up @@ -2225,12 +2189,21 @@ def pad(
Padding side (left/right) padding token ids are defined at the tokenizer level
(with ``self.padding_side``, ``self.pad_token_id`` and ``self.pad_token_type_id``)
.. note::
If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
case of PyTorch tensors, you will lose the specific device of your tensors however.
Args:
encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or
:obj:`Dict[str, List[int]]`) or a batch of tokenized inputs (list of
:class:`~transformers.BatchEncoding`, `Dict[str, List[List[int]]]` or `List[Dict[str, List[int]]]`) so
you can use this method during preprocessing as well as in a PyTorch Dataloader collate function.
Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
see the note above for the return type.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
Expand Down Expand Up @@ -2263,6 +2236,7 @@ def pad(
Whether or not to print informations and warnings.
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}

Expand All @@ -2277,6 +2251,40 @@ def pad(
encoded_inputs["attention_mask"] = []
return encoded_inputs

# If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
# and rebuild them afterwards if no return_tensors is specified
# Note that we lose the specific device the tensor may be on for PyTorch
first_element = encoded_inputs["input_ids"][0]
if isinstance(first_element, (list, tuple)) and first_element:
first_element = first_element[0]
if not isinstance(first_element, int):
if is_tf_available() and isinstance(first_element, tf.Tensor):
return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and isinstance(first_element, torch.Tensor):
return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors
else:
raise ValueError(
f"type of {first_element} unknown: {type(first_element)}. "
f"Should be one of a python, numpy, pytorch or tensorflow object."
)

def to_py_obj(obj):
if isinstance(obj, (list, tuple)):
return [to_py_obj(o) for o in obj]
elif is_tf_available() and isinstance(obj, tf.Tensor):
return obj.numpy().tolist()
elif is_torch_available() and isinstance(obj, torch.Tensor):
return obj.cpu().tolist()
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return obj

for key, value in encoded_inputs.items():
encoded_inputs[key] = to_py_obj(value)

# Convert padding_strategy in PaddingStrategy
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
padding=padding, max_length=max_length, verbose=verbose
Expand Down
76 changes: 76 additions & 0 deletions tests/test_tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import unittest
from typing import Callable, Optional

import numpy as np

from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType
from transformers.testing_utils import require_tf, require_torch, slow
from transformers.tokenization_gpt2 import GPT2Tokenizer
Expand Down Expand Up @@ -135,3 +137,77 @@ def test_batch_encoding_is_fast(self):

with self.subTest("Rust Tokenizer"):
self.assertTrue(tokenizer_r("Small example to_encode").is_fast)

def test_batch_encoding_with_labels(self):
batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
tensor_batch = batch.convert_to_tensors(tensor_type="np")
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
self.assertEqual(tensor_batch["labels"].shape, (2,))

batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
tensor_batch = batch.convert_to_tensors(tensor_type="np", prepend_batch_axis=True)
self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
self.assertEqual(tensor_batch["labels"].shape, (1,))

@require_torch
def test_batch_encoding_with_labels_pt(self):
batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
tensor_batch = batch.convert_to_tensors(tensor_type="pt")
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
self.assertEqual(tensor_batch["labels"].shape, (2,))

batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
tensor_batch = batch.convert_to_tensors(tensor_type="pt", prepend_batch_axis=True)
self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
self.assertEqual(tensor_batch["labels"].shape, (1,))

@require_tf
def test_batch_encoding_with_labels_tf(self):
batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
tensor_batch = batch.convert_to_tensors(tensor_type="tf")
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
self.assertEqual(tensor_batch["labels"].shape, (2,))

batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
tensor_batch = batch.convert_to_tensors(tensor_type="tf", prepend_batch_axis=True)
self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
self.assertEqual(tensor_batch["labels"].shape, (1,))

def test_padding_accepts_tensors(self):
features = [{"input_ids": np.array([0, 1, 2])}, {"input_ids": np.array([0, 1, 2, 3])}]
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

batch = tokenizer.pad(features, padding=True)
self.assertTrue(isinstance(batch["input_ids"], np.ndarray))
self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
batch = tokenizer.pad(features, padding=True, return_tensors="np")
self.assertTrue(isinstance(batch["input_ids"], np.ndarray))
self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])

@require_torch
def test_padding_accepts_tensors_pt(self):
import torch

features = [{"input_ids": torch.tensor([0, 1, 2])}, {"input_ids": torch.tensor([0, 1, 2, 3])}]
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

batch = tokenizer.pad(features, padding=True)
self.assertTrue(isinstance(batch["input_ids"], torch.Tensor))
self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
batch = tokenizer.pad(features, padding=True, return_tensors="pt")
self.assertTrue(isinstance(batch["input_ids"], torch.Tensor))
self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])

@require_tf
def test_padding_accepts_tensors_tf(self):
import tensorflow as tf

features = [{"input_ids": tf.constant([0, 1, 2])}, {"input_ids": tf.constant([0, 1, 2, 3])}]
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

batch = tokenizer.pad(features, padding=True)
self.assertTrue(isinstance(batch["input_ids"], tf.Tensor))
self.assertEqual(batch["input_ids"].numpy().tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
batch = tokenizer.pad(features, padding=True, return_tensors="tf")
self.assertTrue(isinstance(batch["input_ids"], tf.Tensor))
self.assertEqual(batch["input_ids"].numpy().tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])

0 comments on commit 79dd29d

Please sign in to comment.