-
Notifications
You must be signed in to change notification settings - Fork 25.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Trainer evaluation handle dynamic seq_length #8336
Merged
Merged
Changes from 4 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
f77a25f
Make Trainer evaluation handle dynamic seq_length
sgugger a2b9dba
Document behavior.
sgugger 7729a2b
Fix test
sgugger ab3ca4c
Better fix
sgugger 666ed2e
Fixes for realsies this time
sgugger ee87672
Address review comments
sgugger e94a1f9
Without forgetting to save...
sgugger File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,17 +42,50 @@ | |
logger = logging.get_logger(__name__) | ||
|
||
|
||
def nested_concat(tensors, new_tensors, dim=0): | ||
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors." | ||
def torch_pad_and_concatenate(tensor1, tensor2, pad_idx=-100): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can the argument be |
||
"""Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary.""" | ||
if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]: | ||
return torch.cat((tensor1, tensor2), dim=0) | ||
|
||
# Let's figure out the new shape | ||
new_shape = (tensor1.shape[0] + tensor2.shape[0], max(tensor1.shape[1], tensor2.shape[1])) + tensor1.shape[2:] | ||
|
||
# Now let's fill the result tensor | ||
result = tensor1.new_full(new_shape, pad_idx) | ||
result[: tensor1.shape[0], : tensor1.shape[1]] = tensor1 | ||
result[tensor1.shape[0] :, : tensor2.shape[1]] = tensor2 | ||
return result | ||
|
||
|
||
def numpy_pad_and_concatenate(array1, array2, pad_idx=-100): | ||
"""Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary.""" | ||
if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]: | ||
return np.concatenate((array1, array2), dim=0) | ||
|
||
# Let's figure out the new shape | ||
new_shape = (array1.shape[0] + array2.shape[0], max(array1.shape[1], array2.shape[1])) + array1.shape[2:] | ||
|
||
# Now let's fill the result tensor | ||
result = np.full_like(array1, pad_idx, shape=new_shape) | ||
result[: array1.shape[0], : array1.shape[1]] = array1 | ||
result[array1.shape[0] :, : array2.shape[1]] = array2 | ||
return result | ||
|
||
|
||
def nested_concat(tensors, new_tensors, pad_idx=-100): | ||
""" | ||
Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or | ||
nested list/tuples of tensors. | ||
""" | ||
assert type(tensors) == type( | ||
new_tensors | ||
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}." | ||
if isinstance(tensors, (list, tuple)): | ||
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors)) | ||
return type(tensors)(nested_concat(t, n, pad_idx=pad_idx) for t, n in zip(tensors, new_tensors)) | ||
elif isinstance(tensors, torch.Tensor): | ||
return torch.cat((tensors, new_tensors), dim=dim) | ||
return torch_pad_and_concatenate(tensors, new_tensors, pad_idx=pad_idx) | ||
elif isinstance(tensors, np.ndarray): | ||
return np.concatenate((tensors, new_tensors), axis=dim) | ||
return numpy_pad_and_concatenate(tensors, new_tensors, pad_idx=pad_idx) | ||
else: | ||
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}") | ||
|
||
|
@@ -190,11 +223,21 @@ def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset): | |
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) | ||
|
||
|
||
def nested_new_like(arrays, num_samples): | ||
def nested_new_like(arrays, num_samples, pad_idx=-100): | ||
""" Create the same nested structure as `arrays` with a first dimension always at `num_samples`.""" | ||
if isinstance(arrays, (list, tuple)): | ||
return type(arrays)(nested_new_like(x, num_samples) for x in arrays) | ||
return np.zeros((num_samples, *arrays.shape[1:]), dtype=arrays.dtype) | ||
return np.full_like(arrays, pad_idx, shape=(num_samples, *arrays.shape[1:])) | ||
|
||
|
||
def nested_expand_like(arrays, new_seq_length, pad_idx=-100): | ||
""" Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `pad_idx` for padding.""" | ||
if isinstance(arrays, (list, tuple)): | ||
return type(arrays)(nested_expand_like(x, new_seq_length, pad_idx=pad_idx) for x in arrays) | ||
|
||
result = np.full_like(arrays, pad_idx, shape=(arrays.shape[0], new_seq_length) + arrays.shape[2:]) | ||
result[:, :new_seq_length] = arrays | ||
return result | ||
|
||
|
||
def nested_truncate(tensors, limit): | ||
|
@@ -204,6 +247,13 @@ def nested_truncate(tensors, limit): | |
return tensors[:limit] | ||
|
||
|
||
def _get_first_shape(arrays): | ||
"""Return the shape of the first array found in the nested struct `arrays`.""" | ||
if isinstance(arrays, (list, tuple)): | ||
return _get_first_shape(arrays[0]) | ||
return arrays.shape | ||
|
||
|
||
class DistributedTensorGatherer: | ||
""" | ||
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks. | ||
|
@@ -247,16 +297,19 @@ class DistributedTensorGatherer: | |
make_multiple_of (:obj:`int`, `optional`): | ||
If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument | ||
(by adding samples). | ||
pad_idx (:obj:`int`, `optional`, defaults to -100): | ||
The padding index to use if the arrays don't all have the same sequence length. | ||
""" | ||
|
||
def __init__(self, world_size, num_samples, make_multiple_of=None): | ||
def __init__(self, world_size, num_samples, make_multiple_of=None, pad_idx=-100): | ||
self.world_size = world_size | ||
self.num_samples = num_samples | ||
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of | ||
self.total_samples = int(np.ceil(num_samples / total_size)) * total_size | ||
self.process_length = self.total_samples // world_size | ||
self._storage = None | ||
self._offsets = None | ||
self.pad_idx = pad_idx | ||
|
||
def add_arrays(self, arrays): | ||
""" | ||
|
@@ -266,8 +319,14 @@ def add_arrays(self, arrays): | |
if arrays is None: | ||
return | ||
if self._storage is None: | ||
self._storage = nested_new_like(arrays, self.total_samples) | ||
self._storage = nested_new_like(arrays, self.total_samples, pad_idx=self.pad_idx) | ||
self._offsets = list(range(0, self.total_samples, self.process_length)) | ||
else: | ||
storage_shape = _get_first_shape(self._storage) | ||
arrays_shape = _get_first_shape(arrays) | ||
if len(storage_shape) > 1 and storage_shape[1] < arrays_shape[1]: | ||
# If we get new arrays that are too big too fit, we expand the shape fo the storage | ||
self._storage = nested_expand_like(self._storage, arrays_shape[1], pad_idx=self.pad_idx) | ||
slice_len = self._nested_set_tensors(self._storage, arrays) | ||
for i in range(self.world_size): | ||
self._offsets[i] += slice_len | ||
|
@@ -283,7 +342,12 @@ def _nested_set_tensors(self, storage, arrays): | |
|
||
slice_len = arrays.shape[0] // self.world_size | ||
for i in range(self.world_size): | ||
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len] | ||
if len(arrays.shape) == 1: | ||
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len] | ||
else: | ||
storage[self._offsets[i] : self._offsets[i] + slice_len, : arrays.shape[1]] = arrays[ | ||
i * slice_len : (i + 1) * slice_len | ||
] | ||
return slice_len | ||
|
||
def finalize(self): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.