In [1]:
from torch.utils.data import IterableDataset
import torch
import math

In [10]:
# From PyTorch doc
class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        print(f"Worker information:: {worker_info}\n")
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start = self.start
            iter_end = self.end
        else:  # in a worker process
            # split workload equally
            per_worker = int(
                math.ceil((self.end - self.start) / float(worker_info.num_workers))
            )
            print(f"PER WORKER:: {per_worker}\n")
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
            print(
                f"WORKER ID: {worker_id} ITER START: {iter_start} ITER END: {iter_end}\n"
            )

        print("=" * 50)
        return iter(range(iter_start, iter_end))

In [42]:
# should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
ds = MyIterableDataset(start=3, end=7)

# Single-process loading
print(list(torch.utils.data.DataLoader(ds, num_workers=0)))

# Mult-process loading with two worker processes
# Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
print(list(torch.utils.data.DataLoader(ds, num_workers=2)))

print(list(torch.utils.data.DataLoader(ds, num_workers=4)))
# With even more workers
# print(list(torch.utils.data.DataLoader(ds, num_workers=20)))

Worker information:: None

[tensor([3]), tensor([4]), tensor([5]), tensor([6])]
Worker information:: WorkerInfo(id=1, num_workers=2, seed=7518864737868168320, dataset=<__main__.MyIterableDataset object at 0x7ff720349c70>)
Worker information:: WorkerInfo(id=0, num_workers=2, seed=7518864737868168319, dataset=<__main__.MyIterableDataset object at 0x7ff720349c70>)


PER WORKER:: 2
PER WORKER:: 2


WORKER ID: 1 ITER START: 5 ITER END: 7
WORKER ID: 0 ITER START: 3 ITER END: 5



[tensor([3]), tensor([5]), tensor([4]), tensor([6])]
Worker information:: WorkerInfo(id=0, num_workers=4, seed=6570111619759283533, dataset=<__main__.MyIterableDataset object at 0x7ff720349c70>)

PER WORKER:: 1

Worker information:: WorkerInfo(id=1, num_workers=4, seed=6570111619759283534, dataset=<__main__.MyIterableDataset object at 0x7ff720349c70>)
WORKER ID: 0 ITER START: 3 ITER END: 4
Worker information:: WorkerInfo(id=2, num_workers=4, seed=6570111619759283535, dataset=<__main__.MyIterableDataset object at 0x7

### Doubt
- how does split work for odd numbers?
    - Example: 
should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].    
```bash
Worker information:: WorkerInfo(id=1, num_workers=3, seed=955517934870473897, dataset=<__main__.MyIterableDataset object at 0x7ff728525190>)
Worker information:: WorkerInfo(id=2, num_workers=3, seed=955517934870473898, dataset=<__main__.MyIterableDataset object at 0x7ff728525190>)
Worker information:: WorkerInfo(id=0, num_workers=3, seed=955517934870473896, dataset=<__main__.MyIterableDataset object at 0x7ff728525190>)

PER WORKER:: 2
PER WORKER:: 2
PER WORKER:: 2

WORKER ID: 1 ITER START: 5 ITER END: 7
WORKER ID: 0 ITER START: 3 ITER END: 5
WORKER ID: 2 ITER START: 7 ITER END: 7
```

In [40]:
for i in iter(ds):
    print(i)

Worker information:: None

3
4
5
6


In [27]:
from torch.utils.data import DataLoader

for x in DataLoader(ds, batch_size=2, num_workers=2):
    print(x)

Worker information:: WorkerInfo(id=0, num_workers=2, seed=3559251123838318224, dataset=<__main__.MyIterableDataset object at 0x7ff728525970>)
Worker information:: WorkerInfo(id=1, num_workers=2, seed=3559251123838318225, dataset=<__main__.MyIterableDataset object at 0x7ff728525970>)


PER WORKER:: 2
PER WORKER:: 2


WORKER ID: 0 ITER START: 3 ITER END: 5
WORKER ID: 1 ITER START: 5 ITER END: 7



tensor([3, 4])
tensor([5, 6])
