Skip to content

Commit

Permalink
Adding an option to toggle drop_remainder=True in datasets during bat…
Browse files Browse the repository at this point in the history
…ching

PiperOrigin-RevId: 337989180
  • Loading branch information
shreyaspadhy authored and Copybara-Service committed Oct 20, 2020
1 parent a8b6113 commit 6e96526
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions uncertainty_baselines/datasets/base.py
Expand Up @@ -136,7 +136,9 @@ def _create_process_example_fn(self, split: Split) -> Optional[PreProcessFn]:
raise NotImplementedError(
'Must override dataset _create_process_example_fn!')

def _batch(self, split: Split, dataset: tf.data.Dataset) -> tf.data.Dataset:
def _batch(self, split: Split,
dataset: tf.data.Dataset,
drop_remainder: bool = True) -> tf.data.Dataset:
"""Get the batched version of `dataset`."""
# `uneven_datasets` is a list of datasets with a number of validation and/or
# test examples that is not evenly divisible by commonly used batch sizes.
Expand All @@ -163,12 +165,13 @@ def _batch(self, split: Split, dataset: tf.data.Dataset) -> tf.data.Dataset:
'examples: %d', batch_size, self._num_test_examples)
# Note that we always drop the last batch when the batch size does not
# evenly divide the number of examples.
return dataset.batch(batch_size, drop_remainder=True)
return dataset.batch(batch_size, drop_remainder=drop_remainder)

def build(
self,
split: Union[str, Split],
as_tuple: bool = False,
drop_remainder: bool = True,
ood_split: Optional[Union[str, OodSplit]] = None) -> tf.data.Dataset:
"""Transforms the dataset from self._read_examples() to batch, repeat, etc.
Expand All @@ -183,6 +186,9 @@ def build(
with at least the keys ['features', 'labels'], or a tuple of
(feature, label). If there are keys besides 'features' and 'labels' in
the Dict then this ignore them.
drop_remainder: whether or not to drop the last batch of data if the
number of points is not exactly equal to the batch size. This option
needs to be True for running on TPUs.
ood_split: an optional OOD split, either one of the OodSplit enum or
their associated strings.
Expand Down Expand Up @@ -226,7 +232,7 @@ def build(
dataset = dataset.shuffle(self._shuffle_buffer_size)
dataset = dataset.repeat()

dataset = self._batch(split, dataset)
dataset = self._batch(split, dataset, drop_remainder=drop_remainder)

dataset = dataset.prefetch(-1)

Expand Down

0 comments on commit 6e96526

Please sign in to comment.