Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to make life easier with the nlp library #6423

Merged
merged 6 commits into from
Aug 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,7 +2318,7 @@ def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Option
max_len = max([len(item) for item in outputs])
outputs = [output + [self.pad_token_id] * (max_len - len(output)) for output in outputs]
outputs = BatchEncoding(
{"input_ids": outputs, "attention_mask": [1] * len(outputs)}, tensor_type=self.framework
{"input_ids": outputs, "attention_mask": [[1] * len(outputs)]}, tensor_type=self.framework,
Copy link
Collaborator Author

@sgugger sgugger Aug 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only thing were the change of dim in BatchEncoding.convert_to_tensors is breaking something, but in this case, it was a bit magical that the dimension was automatically added, so I don't think this is a serious failure.

)
return outputs

Expand Down
92 changes: 50 additions & 42 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,11 +553,12 @@ def convert_to_tensors(

tensor = as_tensor(value)

# at-least2d
if tensor.ndim > 2:
tensor = tensor.squeeze(0)
elif tensor.ndim < 2:
tensor = tensor[None, :]
# Removing this for now in favor of controling the shape with `prepend_batch_axis`
# # at-least2d
# if tensor.ndim > 2:
# tensor = tensor.squeeze(0)
# elif tensor.ndim < 2:
# tensor = tensor[None, :]

self[key] = tensor
except: # noqa E722
Expand Down Expand Up @@ -589,43 +590,6 @@ def to(self, device: str) -> "BatchEncoding":
return self


# class AddedToken(UserString):
# """ AddedToken represents a token to be added to a Tokenizer

# An AddedToken can have special options defining the way it should behave.

# Args:
# content: str:
# The content of the token

# single_word: bool
# Whether this token should only match against single word. If True,
# this token will never match inside of a word.

# lstrip: bool
# Whether this token should strip all potential whitespaces on the left side.
# If True, this token will greedily match any whitespace on the left and then strip
# them out.

# rstrip: bool
# Whether this token should strip all potential whitespaces on the right side.
# If True, this token will greedily match any whitespace on the right and then strip
# them out.
# """

# def __init__(
# self, data: str, single_word: bool = False, lstrip: bool = False, rstrip: bool = False,
# ):
# super().__init__(data)

# self._single_word = single_word
# self._lstrip = lstrip
# self._rstrip = rstrip

# def lower(self):
# return AddedToken(self.data.lower(), self._single_word, self._lstrip, self._rstrip)


class SpecialTokensMixin:
"""
A mixin derived by :class:`~transformers.PreTrainedTokenizer` and :class:`~transformers.PreTrainedTokenizerFast`
Expand Down Expand Up @@ -2164,12 +2128,21 @@ def pad(
Padding side (left/right) padding token ids are defined at the tokenizer level
(with ``self.padding_side``, ``self.pad_token_id`` and ``self.pad_token_type_id``)

.. note::

If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
case of PyTorch tensors, you will lose the specific device of your tensors however.

Args:
encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or
:obj:`Dict[str, List[int]]`) or a batch of tokenized inputs (list of
:class:`~transformers.BatchEncoding`, `Dict[str, List[List[int]]]` or `List[Dict[str, List[int]]]`) so
you can use this method during preprocessing as well as in a PyTorch Dataloader collate function.

Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
see the note above for the return type.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
Expand Down Expand Up @@ -2202,6 +2175,7 @@ def pad(
Whether or not to print informations and warnings.
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}

Expand All @@ -2216,6 +2190,40 @@ def pad(
encoded_inputs["attention_mask"] = []
return encoded_inputs

# If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
# and rebuild them afterwards if no return_tensors is specified
# Note that we lose the specific device the tensor may be on for PyTorch
first_element = encoded_inputs["input_ids"][0]
if isinstance(first_element, (list, tuple)) and first_element:
first_element = first_element[0]
if not isinstance(first_element, int):
if is_tf_available() and isinstance(first_element, tf.Tensor):
return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and isinstance(first_element, torch.Tensor):
return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors
else:
raise ValueError(
f"type of {first_element} unknown: {type(first_element)}. "
f"Should be one of a python, numpy, pytorch or tensorflow object."
)

def to_py_obj(obj):
if isinstance(obj, (list, tuple)):
return [to_py_obj(o) for o in obj]
elif is_tf_available() and isinstance(obj, tf.Tensor):
return obj.numpy().tolist()
elif is_torch_available() and isinstance(obj, torch.Tensor):
return obj.cpu().tolist()
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return obj

for key, value in encoded_inputs.items():
encoded_inputs[key] = to_py_obj(value)

# Convert padding_strategy in PaddingStrategy
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
padding=padding, max_length=max_length, verbose=verbose
Expand Down
76 changes: 76 additions & 0 deletions tests/test_tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import unittest
from typing import Callable, Optional

import numpy as np

from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType
from transformers.testing_utils import require_tf, require_torch, slow
from transformers.tokenization_gpt2 import GPT2Tokenizer
Expand Down Expand Up @@ -135,3 +137,77 @@ def test_batch_encoding_is_fast(self):

with self.subTest("Rust Tokenizer"):
self.assertTrue(tokenizer_r("Small example to_encode").is_fast)

def test_batch_encoding_with_labels(self):
batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
tensor_batch = batch.convert_to_tensors(tensor_type="np")
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
self.assertEqual(tensor_batch["labels"].shape, (2,))

batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
tensor_batch = batch.convert_to_tensors(tensor_type="np", prepend_batch_axis=True)
self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
self.assertEqual(tensor_batch["labels"].shape, (1,))

@require_torch
def test_batch_encoding_with_labels_pt(self):
batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
tensor_batch = batch.convert_to_tensors(tensor_type="pt")
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
self.assertEqual(tensor_batch["labels"].shape, (2,))

batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
tensor_batch = batch.convert_to_tensors(tensor_type="pt", prepend_batch_axis=True)
self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
self.assertEqual(tensor_batch["labels"].shape, (1,))

@require_tf
def test_batch_encoding_with_labels_tf(self):
batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
tensor_batch = batch.convert_to_tensors(tensor_type="tf")
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
self.assertEqual(tensor_batch["labels"].shape, (2,))

batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
tensor_batch = batch.convert_to_tensors(tensor_type="tf", prepend_batch_axis=True)
self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
self.assertEqual(tensor_batch["labels"].shape, (1,))

def test_padding_accepts_tensors(self):
features = [{"input_ids": np.array([0, 1, 2])}, {"input_ids": np.array([0, 1, 2, 3])}]
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

batch = tokenizer.pad(features, padding=True)
self.assertTrue(isinstance(batch["input_ids"], np.ndarray))
self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
batch = tokenizer.pad(features, padding=True, return_tensors="np")
self.assertTrue(isinstance(batch["input_ids"], np.ndarray))
self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])

@require_torch
def test_padding_accepts_tensors_pt(self):
import torch

features = [{"input_ids": torch.tensor([0, 1, 2])}, {"input_ids": torch.tensor([0, 1, 2, 3])}]
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

batch = tokenizer.pad(features, padding=True)
self.assertTrue(isinstance(batch["input_ids"], torch.Tensor))
self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
batch = tokenizer.pad(features, padding=True, return_tensors="pt")
self.assertTrue(isinstance(batch["input_ids"], torch.Tensor))
self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])

@require_tf
def test_padding_accepts_tensors_tf(self):
import tensorflow as tf

features = [{"input_ids": tf.constant([0, 1, 2])}, {"input_ids": tf.constant([0, 1, 2, 3])}]
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

batch = tokenizer.pad(features, padding=True)
self.assertTrue(isinstance(batch["input_ids"], tf.Tensor))
self.assertEqual(batch["input_ids"].numpy().tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
batch = tokenizer.pad(features, padding=True, return_tensors="tf")
self.assertTrue(isinstance(batch["input_ids"], tf.Tensor))
self.assertEqual(batch["input_ids"].numpy().tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])