From e91fe238f81490f0d8e9da4ef67157193a7d72cd Mon Sep 17 00:00:00 2001 From: erjia Date: Wed, 29 Sep 2021 22:18:32 +0000 Subject: [PATCH] Convert generator in Sampler back to lazy construction ghstack-source-id: bf4b2badb725d51e23b570f8c43e903fc6d9bb71 Pull Request resolved: https://github.com/pytorch/pytorch/pull/63646 --- test/test_dataloader.py | 22 ++++++++++++++++++++++ torch/utils/data/sampler.py | 18 +++++++++++------- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 9b18cfe73bf..a813b870ec2 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -1524,6 +1524,28 @@ def test_sampler_reproducibility(self): ): self.assertEqual(list(fn()), list(fn())) + for sampler in ( + RandomSampler(self.dataset, num_samples=5, replacement=True), + RandomSampler(self.dataset, replacement=False), + WeightedRandomSampler(weights, num_samples=5, replacement=True), + WeightedRandomSampler(weights, num_samples=5, replacement=False), + SubsetRandomSampler(range(10)), + ): + torch.manual_seed(0) + l1 = list(sampler) + list(sampler) + + torch.manual_seed(0) + l2 = list(sampler) + list(sampler) + self.assertEqual(l1, l2) + + its = (iter(sampler), iter(sampler)) + ls = ([], []) + for idx in range(len(sampler)): + for i in range(2): + if idx == 0: + torch.manual_seed(0) + ls[i].append(next(its[i])) + self.assertEqual(ls[0], ls[1]) def _test_sampler(self, **kwargs): indices = range(2, 12) # using a regular iterable diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 79033470995..232302e53c5 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -112,15 +112,18 @@ def num_samples(self) -> int: def __iter__(self) -> Iterator[int]: n = len(self.data_source) if self.generator is None: - self.generator = torch.Generator() - self.generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = self.generator if self.replacement: for _ in range(self.num_samples // 32): - yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=self.generator).tolist() - yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=self.generator).tolist() + yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() + yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() else: - yield from torch.randperm(n, generator=self.generator).tolist() + yield from torch.randperm(n, generator=generator).tolist() def __len__(self) -> int: return self.num_samples @@ -140,7 +143,8 @@ def __init__(self, indices: Sequence[int], generator=None) -> None: self.generator = generator def __iter__(self) -> Iterator[int]: - return (self.indices[i] for i in torch.randperm(len(self.indices), generator=self.generator)) + for i in torch.randperm(len(self.indices), generator=self.generator): + yield self.indices[i] def __len__(self) -> int: return len(self.indices) @@ -183,7 +187,7 @@ def __init__(self, weights: Sequence[float], num_samples: int, def __iter__(self) -> Iterator[int]: rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator) - return iter(rand_tensor.tolist()) + yield from iter(rand_tensor.tolist()) def __len__(self) -> int: return self.num_samples