Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Trainer evaluation handle dynamic seq_length #8336

Merged
merged 7 commits into from
Nov 5, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,12 @@ def predict(self, test_dataset: Dataset) -> PredictionOutput:
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`

.. note::

If your predictions or labels have different sequence length (for instance because you're doing dynamic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If your predictions or labels have different sequence length (for instance because you're doing dynamic
If your predictions or labels have different sequence lengths (for instance because you're doing dynamic

padding in a token classification task) the predictions will be padded (on the right) to allow for
concatenation into one array. The padding index is -100.

Returns: `NamedTuple` A namedtuple with the following keys:

- predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
Expand Down Expand Up @@ -1412,9 +1418,9 @@ def prediction_loop(
losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if logits is not None:
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, dim=0)
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, pad_idx=-100)
if labels is not None:
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, dim=0)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, pad_idx=-100)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)

# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
Expand Down
84 changes: 74 additions & 10 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,50 @@
logger = logging.get_logger(__name__)


def nested_concat(tensors, new_tensors, dim=0):
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
def torch_pad_and_concatenate(tensor1, tensor2, pad_idx=-100):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the argument be padding_id, as that's the wording used elsewhere? (this applies to all methods used here)

"""Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:
return torch.cat((tensor1, tensor2), dim=0)

# Let's figure out the new shape
new_shape = (tensor1.shape[0] + tensor2.shape[0], max(tensor1.shape[1], tensor2.shape[1])) + tensor1.shape[2:]

# Now let's fill the result tensor
result = tensor1.new_full(new_shape, pad_idx)
result[: tensor1.shape[0], : tensor1.shape[1]] = tensor1
result[tensor1.shape[0] :, : tensor2.shape[1]] = tensor2
return result


def numpy_pad_and_concatenate(array1, array2, pad_idx=-100):
"""Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
return np.concatenate((array1, array2), dim=0)

# Let's figure out the new shape
new_shape = (array1.shape[0] + array2.shape[0], max(array1.shape[1], array2.shape[1])) + array1.shape[2:]

# Now let's fill the result tensor
result = np.full_like(array1, pad_idx, shape=new_shape)
result[: array1.shape[0], : array1.shape[1]] = array1
result[array1.shape[0] :, : array2.shape[1]] = array2
return result


def nested_concat(tensors, new_tensors, pad_idx=-100):
"""
Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
nested list/tuples of tensors.
"""
assert type(tensors) == type(
new_tensors
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
return type(tensors)(nested_concat(t, n, pad_idx=pad_idx) for t, n in zip(tensors, new_tensors))
elif isinstance(tensors, torch.Tensor):
return torch.cat((tensors, new_tensors), dim=dim)
return torch_pad_and_concatenate(tensors, new_tensors, pad_idx=pad_idx)
elif isinstance(tensors, np.ndarray):
return np.concatenate((tensors, new_tensors), axis=dim)
return numpy_pad_and_concatenate(tensors, new_tensors, pad_idx=pad_idx)
else:
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")

Expand Down Expand Up @@ -190,11 +223,21 @@ def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())


def nested_new_like(arrays, num_samples):
def nested_new_like(arrays, num_samples, pad_idx=-100):
""" Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
if isinstance(arrays, (list, tuple)):
return type(arrays)(nested_new_like(x, num_samples) for x in arrays)
return np.zeros((num_samples, *arrays.shape[1:]), dtype=arrays.dtype)
return np.full_like(arrays, pad_idx, shape=(num_samples, *arrays.shape[1:]))


def nested_expand_like(arrays, new_seq_length, pad_idx=-100):
""" Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `pad_idx` for padding."""
if isinstance(arrays, (list, tuple)):
return type(arrays)(nested_expand_like(x, new_seq_length, pad_idx=pad_idx) for x in arrays)

result = np.full_like(arrays, pad_idx, shape=(arrays.shape[0], new_seq_length) + arrays.shape[2:])
result[:, :new_seq_length] = arrays
return result


def nested_truncate(tensors, limit):
Expand All @@ -204,6 +247,13 @@ def nested_truncate(tensors, limit):
return tensors[:limit]


def _get_first_shape(arrays):
"""Return the shape of the first array found in the nested struct `arrays`."""
if isinstance(arrays, (list, tuple)):
return _get_first_shape(arrays[0])
return arrays.shape


class DistributedTensorGatherer:
"""
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
Expand Down Expand Up @@ -247,16 +297,19 @@ class DistributedTensorGatherer:
make_multiple_of (:obj:`int`, `optional`):
If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument
(by adding samples).
pad_idx (:obj:`int`, `optional`, defaults to -100):
The padding index to use if the arrays don't all have the same sequence length.
"""

def __init__(self, world_size, num_samples, make_multiple_of=None):
def __init__(self, world_size, num_samples, make_multiple_of=None, pad_idx=-100):
self.world_size = world_size
self.num_samples = num_samples
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
self.total_samples = int(np.ceil(num_samples / total_size)) * total_size
self.process_length = self.total_samples // world_size
self._storage = None
self._offsets = None
self.pad_idx = pad_idx

def add_arrays(self, arrays):
"""
Expand All @@ -266,8 +319,14 @@ def add_arrays(self, arrays):
if arrays is None:
return
if self._storage is None:
self._storage = nested_new_like(arrays, self.total_samples)
self._storage = nested_new_like(arrays, self.total_samples, pad_idx=self.pad_idx)
self._offsets = list(range(0, self.total_samples, self.process_length))
else:
storage_shape = _get_first_shape(self._storage)
arrays_shape = _get_first_shape(arrays)
if len(storage_shape) > 1 and storage_shape[1] < arrays_shape[1]:
# If we get new arrays that are too big too fit, we expand the shape fo the storage
self._storage = nested_expand_like(self._storage, arrays_shape[1], pad_idx=self.pad_idx)
slice_len = self._nested_set_tensors(self._storage, arrays)
for i in range(self.world_size):
self._offsets[i] += slice_len
Expand All @@ -283,7 +342,12 @@ def _nested_set_tensors(self, storage, arrays):

slice_len = arrays.shape[0] // self.world_size
for i in range(self.world_size):
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
if len(arrays.shape) == 1:
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
else:
storage[self._offsets[i] : self._offsets[i] + slice_len, : arrays.shape[1]] = arrays[
i * slice_len : (i + 1) * slice_len
]
return slice_len

def finalize(self):
Expand Down
54 changes: 53 additions & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,22 @@ def __getitem__(self, i):
return result


class DynamicShapesDataset:
def __init__(self, length=64, seed=42, batch_size=8):
self.length = length
np.random.seed(seed)
sizes = np.random.randint(1, 20, (length // batch_size,))
# For easy batching, we make every batch_size consecutive samples the same size.
self.xs = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)]
self.ys = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)]

def __len__(self):
return self.length

def __getitem__(self, i):
return {"input_x": self.xs[i], "labels": self.ys[i]}


class AlmostAccuracy:
def __init__(self, thresh=0.25):
self.thresh = thresh
Expand Down Expand Up @@ -282,7 +298,7 @@ def test_train_and_eval_dataloaders(self):
self.assertEqual(len(trainer.get_train_dataloader()), 66 // (16 * n_gpu))
self.assertEqual(len(trainer.get_eval_dataloader()), 74 // (32 * n_gpu))

# Check passing a new dataset for evaluation wors
# Check passing a new dataset for evaluation works
new_eval_dataset = RegressionDataset(length=128)
self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu))

Expand Down Expand Up @@ -340,6 +356,42 @@ def test_predict(self):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))

def test_dynamic_shapes(self):
eval_dataset = DynamicShapesDataset(batch_size=self.batch_size)
model = RegressionModel(a=2, b=1)
args = TrainingArguments("./regression")
trainer = Trainer(model, args, eval_dataset=eval_dataset)

# Check evaluation can run to completion
_ = trainer.evaluate()

# Check predictions
preds = trainer.predict(eval_dataset)
for expected, seen in zip(eval_dataset.ys, preds.label_ids):
self.assertTrue(np.array_equal(expected, seen[: expected.shape[0]]))
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))

for expected, seen in zip(eval_dataset.xs, preds.predictions):
self.assertTrue(np.array_equal(2 * expected + 1, seen[: expected.shape[0]]))
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))

# Same tests with eval accumulation
args = TrainingArguments("./regression", eval_accumulation_steps=2)
trainer = Trainer(model, args, eval_dataset=eval_dataset)

# Check evaluation can run to completion
_ = trainer.evaluate()

# Check predictions
preds = trainer.predict(eval_dataset)
for expected, seen in zip(eval_dataset.ys, preds.label_ids):
self.assertTrue(np.array_equal(expected, seen[: expected.shape[0]]))
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))

for expected, seen in zip(eval_dataset.xs, preds.predictions):
self.assertTrue(np.array_equal(2 * expected + 1, seen[: expected.shape[0]]))
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))

@require_datasets
def test_trainer_with_datasets(self):
import datasets
Expand Down