diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index bbe10a5ee30e9d..62200d5976d54c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1333,6 +1333,12 @@ def predict(self, test_dataset: Dataset) -> PredictionOutput: Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__` + .. note:: + + If your predictions or labels have different sequence length (for instance because you're doing dynamic + padding in a token classification task) the predictions will be padded (on the right) to allow for + concatenation into one array. The padding index is -100. + Returns: `NamedTuple` A namedtuple with the following keys: - predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`. @@ -1412,9 +1418,9 @@ def prediction_loop( losses = loss.repeat(batch_size) losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) if logits is not None: - preds_host = logits if preds_host is None else nested_concat(preds_host, logits, dim=0) + preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) if labels is not None: - labels_host = labels if labels_host is None else nested_concat(labels_host, labels, dim=0) + labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index f19edba609ab7e..cb3d4a5bfe5b7b 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -42,17 +42,50 @@ logger = logging.get_logger(__name__) -def nested_concat(tensors, new_tensors, dim=0): - "Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors." +def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100): + """Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary.""" + if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]: + return torch.cat((tensor1, tensor2), dim=0) + + # Let's figure out the new shape + new_shape = (tensor1.shape[0] + tensor2.shape[0], max(tensor1.shape[1], tensor2.shape[1])) + tensor1.shape[2:] + + # Now let's fill the result tensor + result = tensor1.new_full(new_shape, padding_index) + result[: tensor1.shape[0], : tensor1.shape[1]] = tensor1 + result[tensor1.shape[0] :, : tensor2.shape[1]] = tensor2 + return result + + +def numpy_pad_and_concatenate(array1, array2, padding_index=-100): + """Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary.""" + if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]: + return np.concatenate((array1, array2), dim=0) + + # Let's figure out the new shape + new_shape = (array1.shape[0] + array2.shape[0], max(array1.shape[1], array2.shape[1])) + array1.shape[2:] + + # Now let's fill the result tensor + result = np.full_like(array1, padding_index, shape=new_shape) + result[: array1.shape[0], : array1.shape[1]] = array1 + result[array1.shape[0] :, : array2.shape[1]] = array2 + return result + + +def nested_concat(tensors, new_tensors, padding_index=-100): + """ + Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or + nested list/tuples of tensors. + """ assert type(tensors) == type( new_tensors ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}." if isinstance(tensors, (list, tuple)): - return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors)) + return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors)) elif isinstance(tensors, torch.Tensor): - return torch.cat((tensors, new_tensors), dim=dim) + return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index) elif isinstance(tensors, np.ndarray): - return np.concatenate((tensors, new_tensors), axis=dim) + return numpy_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index) else: raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}") @@ -190,11 +223,21 @@ def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset): return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) -def nested_new_like(arrays, num_samples): +def nested_new_like(arrays, num_samples, padding_index=-100): """ Create the same nested structure as `arrays` with a first dimension always at `num_samples`.""" if isinstance(arrays, (list, tuple)): return type(arrays)(nested_new_like(x, num_samples) for x in arrays) - return np.zeros((num_samples, *arrays.shape[1:]), dtype=arrays.dtype) + return np.full_like(arrays, padding_index, shape=(num_samples, *arrays.shape[1:])) + + +def nested_expand_like(arrays, new_seq_length, padding_index=-100): + """ Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `padding_index` for padding.""" + if isinstance(arrays, (list, tuple)): + return type(arrays)(nested_expand_like(x, new_seq_length, padding_index=padding_index) for x in arrays) + + result = np.full_like(arrays, padding_index, shape=(arrays.shape[0], new_seq_length) + arrays.shape[2:]) + result[:, : arrays.shape[1]] = arrays + return result def nested_truncate(tensors, limit): @@ -204,6 +247,13 @@ def nested_truncate(tensors, limit): return tensors[:limit] +def _get_first_shape(arrays): + """Return the shape of the first array found in the nested struct `arrays`.""" + if isinstance(arrays, (list, tuple)): + return _get_first_shape(arrays[0]) + return arrays.shape + + class DistributedTensorGatherer: """ A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks. @@ -247,9 +297,11 @@ class DistributedTensorGatherer: make_multiple_of (:obj:`int`, `optional`): If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument (by adding samples). + padding_index (:obj:`int`, `optional`, defaults to -100): + The padding index to use if the arrays don't all have the same sequence length. """ - def __init__(self, world_size, num_samples, make_multiple_of=None): + def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100): self.world_size = world_size self.num_samples = num_samples total_size = world_size if make_multiple_of is None else world_size * make_multiple_of @@ -257,6 +309,7 @@ def __init__(self, world_size, num_samples, make_multiple_of=None): self.process_length = self.total_samples // world_size self._storage = None self._offsets = None + self.padding_index = padding_index def add_arrays(self, arrays): """ @@ -266,8 +319,14 @@ def add_arrays(self, arrays): if arrays is None: return if self._storage is None: - self._storage = nested_new_like(arrays, self.total_samples) + self._storage = nested_new_like(arrays, self.total_samples, padding_index=self.padding_index) self._offsets = list(range(0, self.total_samples, self.process_length)) + else: + storage_shape = _get_first_shape(self._storage) + arrays_shape = _get_first_shape(arrays) + if len(storage_shape) > 1 and storage_shape[1] < arrays_shape[1]: + # If we get new arrays that are too big too fit, we expand the shape fo the storage + self._storage = nested_expand_like(self._storage, arrays_shape[1], padding_index=self.padding_index) slice_len = self._nested_set_tensors(self._storage, arrays) for i in range(self.world_size): self._offsets[i] += slice_len @@ -283,7 +342,12 @@ def _nested_set_tensors(self, storage, arrays): slice_len = arrays.shape[0] // self.world_size for i in range(self.world_size): - storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len] + if len(arrays.shape) == 1: + storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len] + else: + storage[self._offsets[i] : self._offsets[i] + slice_len, : arrays.shape[1]] = arrays[ + i * slice_len : (i + 1) * slice_len + ] return slice_len def finalize(self): diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 497c6c8b4daa73..a040d1cb16940c 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -73,6 +73,22 @@ def __getitem__(self, i): return result +class DynamicShapesDataset: + def __init__(self, length=64, seed=42, batch_size=8): + self.length = length + np.random.seed(seed) + sizes = np.random.randint(1, 20, (length // batch_size,)) + # For easy batching, we make every batch_size consecutive samples the same size. + self.xs = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)] + self.ys = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)] + + def __len__(self): + return self.length + + def __getitem__(self, i): + return {"input_x": self.xs[i], "labels": self.ys[i]} + + class AlmostAccuracy: def __init__(self, thresh=0.25): self.thresh = thresh @@ -282,7 +298,7 @@ def test_train_and_eval_dataloaders(self): self.assertEqual(len(trainer.get_train_dataloader()), 66 // (16 * n_gpu)) self.assertEqual(len(trainer.get_eval_dataloader()), 74 // (32 * n_gpu)) - # Check passing a new dataset for evaluation wors + # Check passing a new dataset for evaluation works new_eval_dataset = RegressionDataset(length=128) self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu)) @@ -340,6 +356,42 @@ def test_predict(self): self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0])) self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1])) + def test_dynamic_shapes(self): + eval_dataset = DynamicShapesDataset(batch_size=self.batch_size) + model = RegressionModel(a=2, b=1) + args = TrainingArguments("./regression") + trainer = Trainer(model, args, eval_dataset=eval_dataset) + + # Check evaluation can run to completion + _ = trainer.evaluate() + + # Check predictions + preds = trainer.predict(eval_dataset) + for expected, seen in zip(eval_dataset.ys, preds.label_ids): + self.assertTrue(np.array_equal(expected, seen[: expected.shape[0]])) + self.assertTrue(np.all(seen[expected.shape[0] :] == -100)) + + for expected, seen in zip(eval_dataset.xs, preds.predictions): + self.assertTrue(np.array_equal(2 * expected + 1, seen[: expected.shape[0]])) + self.assertTrue(np.all(seen[expected.shape[0] :] == -100)) + + # Same tests with eval accumulation + args = TrainingArguments("./regression", eval_accumulation_steps=2) + trainer = Trainer(model, args, eval_dataset=eval_dataset) + + # Check evaluation can run to completion + _ = trainer.evaluate() + + # Check predictions + preds = trainer.predict(eval_dataset) + for expected, seen in zip(eval_dataset.ys, preds.label_ids): + self.assertTrue(np.array_equal(expected, seen[: expected.shape[0]])) + self.assertTrue(np.all(seen[expected.shape[0] :] == -100)) + + for expected, seen in zip(eval_dataset.xs, preds.predictions): + self.assertTrue(np.array_equal(2 * expected + 1, seen[: expected.shape[0]])) + self.assertTrue(np.all(seen[expected.shape[0] :] == -100)) + @require_datasets def test_trainer_with_datasets(self): import datasets