Skip to content

Commit

Permalink
added some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-gardner committed Oct 1, 2020
1 parent b88b876 commit 0be0faa
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 3 deletions.
10 changes: 9 additions & 1 deletion allennlp/data/data_loaders/multitask_data_loader.py
Expand Up @@ -137,7 +137,13 @@ def __init__(
# iterator, and it will call the lambda function each time it runs out of instances,
# which will produce a new shuffling of the dataset.
key: util.cycle_iterator_function(
lambda: util.shuffle_iterable(loader.iter_instances())
# This default argument to the lambda function is necessary to create a new scope
# for the loader variable, so a _different_ loader gets saved for every iterator.
# Dictionary comprehensions don't create new scopes in python. If you don't have
# this loader, you end up with `loader` always referring to the last loader in the
# iteration... mypy also doesn't know what to do with this, for some reason I can't
# figure out.
lambda l=loader: util.shuffle_iterable(l.iter_instances()) # type: ignore
)
for key, loader in self._loaders.items()
}
Expand Down Expand Up @@ -178,6 +184,8 @@ def __iter__(self) -> Iterator[TensorDict]:
yield batch.as_tensor_dict()
batch_instances = [instance]
current_batch_size = self._batch_size_multiplier.get(dataset, 1)
else:
batch_instances.append(instance)

# Based on how we yield batches above, we are guaranteed to always have leftover instances,
# so we don't need a check for that here.
Expand Down
8 changes: 6 additions & 2 deletions allennlp/data/data_loaders/multitask_scheduler.py
@@ -1,4 +1,5 @@
from collections import defaultdict
import itertools
from typing import Any, Dict, Iterable, Tuple, Union

import more_itertools
Expand Down Expand Up @@ -49,7 +50,10 @@ class RoundRobinScheduler(MultiTaskScheduler):
def order_epoch_instances(
self, epoch_instances: Dict[str, Iterable[Instance]]
) -> Iterable[Tuple[str, Instance]]:
iterators = list(epoch_instances.values())
iterators = [
zip(itertools.cycle([dataset]), iterator)
for dataset, iterator in epoch_instances.items()
]
return more_itertools.roundrobin(*iterators)


Expand Down Expand Up @@ -91,7 +95,7 @@ def order_epoch_instances(
self, epoch_instances: Dict[str, Iterable[Instance]]
) -> Iterable[Tuple[str, Instance]]:
grouped_iterators = [
util.lazy_groups_of(iterator, self.batch_size[dataset])
util.lazy_groups_of(zip(itertools.cycle([dataset]), iterator), self.batch_size[dataset])
for dataset, iterator in epoch_instances.items()
]
batch_iterator = more_itertools.roundrobin(*grouped_iterators)
Expand Down
55 changes: 55 additions & 0 deletions tests/data/data_loaders/multitask_data_loader_test.py
@@ -0,0 +1,55 @@
import torch

from allennlp.data import DatasetReader, Instance, Vocabulary
from allennlp.data.fields import LabelField
from allennlp.data.dataset_readers import MultiTaskDatasetReader
from allennlp.data.data_loaders.multitask_data_loader import MultiTaskDataLoader
from allennlp.data.data_loaders.multitask_scheduler import RoundRobinScheduler
from allennlp.data.data_loaders.multitask_epoch_sampler import UniformSampler


class FakeDatasetReaderA(DatasetReader):
def _read(self, file_path: str):
while True:
yield Instance({"label": LabelField("A")})


class FakeDatasetReaderB(DatasetReader):
def _read(self, file_path: str):
while True:
yield Instance({"label": LabelField("B")})


class MultiTaskDataLoaderTest:
def test_loading(self):
reader = MultiTaskDatasetReader(
readers={"a": FakeDatasetReaderA(), "b": FakeDatasetReaderB()}
)
data_path = {"a": "ignored", "b": "ignored"}
batch_size = 4
scheduler = RoundRobinScheduler()
sampler = UniformSampler()
instances_per_epoch = 8
batch_size_multiplier = {"a": 1, "b": 2}
loader = MultiTaskDataLoader(
reader=reader,
data_path=data_path,
batch_size=batch_size,
scheduler=scheduler,
sampler=sampler,
instances_per_epoch=instances_per_epoch,
batch_size_multiplier=batch_size_multiplier,
max_instances_in_memory={"a": 10, "b": 10},
)
vocab = Vocabulary()
vocab.add_tokens_to_namespace(["A", "B"], "labels")
loader.index_with(vocab)
iterator = iter(loader)
batch = next(iterator)
assert torch.all(batch["label"] == torch.IntTensor([0, 1, 0]))
batch = next(iterator)
assert torch.all(batch["label"] == torch.IntTensor([1, 0]))
batch = next(iterator)
assert torch.all(batch["label"] == torch.IntTensor([1, 0]))
batch = next(iterator)
assert torch.all(batch["label"] == torch.IntTensor([1]))
54 changes: 54 additions & 0 deletions tests/data/data_loaders/multitask_scheduler_test.py
@@ -0,0 +1,54 @@
from allennlp.data.data_loaders.multitask_scheduler import (
RoundRobinScheduler,
HomogeneousRoundRobinScheduler,
)


class RoundRobinSchedulerTest:
def test_order_instances(self):
scheduler = RoundRobinScheduler()
epoch_instances = {
"a": [1] * 5,
"b": [2] * 3,
}
flattened = scheduler.order_epoch_instances(epoch_instances)
assert list(flattened) == [
("a", 1),
("b", 2),
("a", 1),
("b", 2),
("a", 1),
("b", 2),
("a", 1),
("a", 1),
]


class HomogeneousRoundRobinSchedulerTest:
def test_order_instances(self):
scheduler = HomogeneousRoundRobinScheduler({"a": 2, "b": 3})
epoch_instances = {
"a": [1] * 9,
"b": [2] * 9,
}
flattened = scheduler.order_epoch_instances(epoch_instances)
assert list(flattened) == [
("a", 1),
("a", 1),
("b", 2),
("b", 2),
("b", 2),
("a", 1),
("a", 1),
("b", 2),
("b", 2),
("b", 2),
("a", 1),
("a", 1),
("b", 2),
("b", 2),
("b", 2),
("a", 1),
("a", 1),
("a", 1),
]

0 comments on commit 0be0faa

Please sign in to comment.