From 0be0faa58fac98a69babb1125e73b5e759a52e7e Mon Sep 17 00:00:00 2001 From: Matt Gardner Date: Wed, 30 Sep 2020 22:24:59 -0700 Subject: [PATCH] added some tests --- .../data_loaders/multitask_data_loader.py | 10 +++- .../data/data_loaders/multitask_scheduler.py | 8 ++- .../multitask_data_loader_test.py | 55 +++++++++++++++++++ .../data_loaders/multitask_scheduler_test.py | 54 ++++++++++++++++++ 4 files changed, 124 insertions(+), 3 deletions(-) create mode 100644 tests/data/data_loaders/multitask_data_loader_test.py create mode 100644 tests/data/data_loaders/multitask_scheduler_test.py diff --git a/allennlp/data/data_loaders/multitask_data_loader.py b/allennlp/data/data_loaders/multitask_data_loader.py index b20e8b7a97b..00887abf420 100644 --- a/allennlp/data/data_loaders/multitask_data_loader.py +++ b/allennlp/data/data_loaders/multitask_data_loader.py @@ -137,7 +137,13 @@ def __init__( # iterator, and it will call the lambda function each time it runs out of instances, # which will produce a new shuffling of the dataset. key: util.cycle_iterator_function( - lambda: util.shuffle_iterable(loader.iter_instances()) + # This default argument to the lambda function is necessary to create a new scope + # for the loader variable, so a _different_ loader gets saved for every iterator. + # Dictionary comprehensions don't create new scopes in python. If you don't have + # this loader, you end up with `loader` always referring to the last loader in the + # iteration... mypy also doesn't know what to do with this, for some reason I can't + # figure out. + lambda l=loader: util.shuffle_iterable(l.iter_instances()) # type: ignore ) for key, loader in self._loaders.items() } @@ -178,6 +184,8 @@ def __iter__(self) -> Iterator[TensorDict]: yield batch.as_tensor_dict() batch_instances = [instance] current_batch_size = self._batch_size_multiplier.get(dataset, 1) + else: + batch_instances.append(instance) # Based on how we yield batches above, we are guaranteed to always have leftover instances, # so we don't need a check for that here. diff --git a/allennlp/data/data_loaders/multitask_scheduler.py b/allennlp/data/data_loaders/multitask_scheduler.py index b637849538e..1e59ac7012b 100644 --- a/allennlp/data/data_loaders/multitask_scheduler.py +++ b/allennlp/data/data_loaders/multitask_scheduler.py @@ -1,4 +1,5 @@ from collections import defaultdict +import itertools from typing import Any, Dict, Iterable, Tuple, Union import more_itertools @@ -49,7 +50,10 @@ class RoundRobinScheduler(MultiTaskScheduler): def order_epoch_instances( self, epoch_instances: Dict[str, Iterable[Instance]] ) -> Iterable[Tuple[str, Instance]]: - iterators = list(epoch_instances.values()) + iterators = [ + zip(itertools.cycle([dataset]), iterator) + for dataset, iterator in epoch_instances.items() + ] return more_itertools.roundrobin(*iterators) @@ -91,7 +95,7 @@ def order_epoch_instances( self, epoch_instances: Dict[str, Iterable[Instance]] ) -> Iterable[Tuple[str, Instance]]: grouped_iterators = [ - util.lazy_groups_of(iterator, self.batch_size[dataset]) + util.lazy_groups_of(zip(itertools.cycle([dataset]), iterator), self.batch_size[dataset]) for dataset, iterator in epoch_instances.items() ] batch_iterator = more_itertools.roundrobin(*grouped_iterators) diff --git a/tests/data/data_loaders/multitask_data_loader_test.py b/tests/data/data_loaders/multitask_data_loader_test.py new file mode 100644 index 00000000000..36cf662fd8c --- /dev/null +++ b/tests/data/data_loaders/multitask_data_loader_test.py @@ -0,0 +1,55 @@ +import torch + +from allennlp.data import DatasetReader, Instance, Vocabulary +from allennlp.data.fields import LabelField +from allennlp.data.dataset_readers import MultiTaskDatasetReader +from allennlp.data.data_loaders.multitask_data_loader import MultiTaskDataLoader +from allennlp.data.data_loaders.multitask_scheduler import RoundRobinScheduler +from allennlp.data.data_loaders.multitask_epoch_sampler import UniformSampler + + +class FakeDatasetReaderA(DatasetReader): + def _read(self, file_path: str): + while True: + yield Instance({"label": LabelField("A")}) + + +class FakeDatasetReaderB(DatasetReader): + def _read(self, file_path: str): + while True: + yield Instance({"label": LabelField("B")}) + + +class MultiTaskDataLoaderTest: + def test_loading(self): + reader = MultiTaskDatasetReader( + readers={"a": FakeDatasetReaderA(), "b": FakeDatasetReaderB()} + ) + data_path = {"a": "ignored", "b": "ignored"} + batch_size = 4 + scheduler = RoundRobinScheduler() + sampler = UniformSampler() + instances_per_epoch = 8 + batch_size_multiplier = {"a": 1, "b": 2} + loader = MultiTaskDataLoader( + reader=reader, + data_path=data_path, + batch_size=batch_size, + scheduler=scheduler, + sampler=sampler, + instances_per_epoch=instances_per_epoch, + batch_size_multiplier=batch_size_multiplier, + max_instances_in_memory={"a": 10, "b": 10}, + ) + vocab = Vocabulary() + vocab.add_tokens_to_namespace(["A", "B"], "labels") + loader.index_with(vocab) + iterator = iter(loader) + batch = next(iterator) + assert torch.all(batch["label"] == torch.IntTensor([0, 1, 0])) + batch = next(iterator) + assert torch.all(batch["label"] == torch.IntTensor([1, 0])) + batch = next(iterator) + assert torch.all(batch["label"] == torch.IntTensor([1, 0])) + batch = next(iterator) + assert torch.all(batch["label"] == torch.IntTensor([1])) diff --git a/tests/data/data_loaders/multitask_scheduler_test.py b/tests/data/data_loaders/multitask_scheduler_test.py new file mode 100644 index 00000000000..b8af32ae2f8 --- /dev/null +++ b/tests/data/data_loaders/multitask_scheduler_test.py @@ -0,0 +1,54 @@ +from allennlp.data.data_loaders.multitask_scheduler import ( + RoundRobinScheduler, + HomogeneousRoundRobinScheduler, +) + + +class RoundRobinSchedulerTest: + def test_order_instances(self): + scheduler = RoundRobinScheduler() + epoch_instances = { + "a": [1] * 5, + "b": [2] * 3, + } + flattened = scheduler.order_epoch_instances(epoch_instances) + assert list(flattened) == [ + ("a", 1), + ("b", 2), + ("a", 1), + ("b", 2), + ("a", 1), + ("b", 2), + ("a", 1), + ("a", 1), + ] + + +class HomogeneousRoundRobinSchedulerTest: + def test_order_instances(self): + scheduler = HomogeneousRoundRobinScheduler({"a": 2, "b": 3}) + epoch_instances = { + "a": [1] * 9, + "b": [2] * 9, + } + flattened = scheduler.order_epoch_instances(epoch_instances) + assert list(flattened) == [ + ("a", 1), + ("a", 1), + ("b", 2), + ("b", 2), + ("b", 2), + ("a", 1), + ("a", 1), + ("b", 2), + ("b", 2), + ("b", 2), + ("a", 1), + ("a", 1), + ("b", 2), + ("b", 2), + ("b", 2), + ("a", 1), + ("a", 1), + ("a", 1), + ]