Description
Describe the bug
I have an IterableDataset (created using streaming=True) and I am trying to create batches using Torch DataLoader class by passing this IterableDataset to it. This throws error which is pasted below. I can do the same by using Torch IterableDataset. One thing I noticed is that in the former case when I look at the dataloader.sampler class I get torch.utils.data.sampler.SequentialSampler while the latter one gives torch.utils.data.dataloader._InfiniteConstantSampler.
I am not sure if this is how it is meant to be used, but that's what seemed reasonable to me.
Steps to reproduce the bug
- Does not work.
>>> from datasets import load_dataset
>>> dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
>>> dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)
>>> dataloader.sampler
<torch.utils.data.sampler.SequentialSampler object at 0x7f245a510208>
>>> for batch in dataloader:
... print(batch)
- Works.
import torch
from torch.utils.data import Dataset, IterableDataset, DataLoader
class CustomIterableDataset(IterableDataset):
'Characterizes a dataset for PyTorch'
def __init__(self, data):
'Initialization'
self.data = data
def __iter__(self):
return iter(self.data)
data = list(range(12))
dataset = CustomIterableDataset(data)
dataloader = DataLoader(dataset, batch_size=4)
print("dataloader: ", dataloader.sampler)
for batch in dataloader:
print(batch)
Expected results
To get batches of data with the batch size as 4. Output from the latter one (2) though Datasource is different here so actual data is different.
dataloader: <torch.utils.data.dataloader._InfiniteConstantSampler object at 0x7f1cc29e2c50>
tensor([0, 1, 2, 3])
tensor([4, 5, 6, 7])
tensor([ 8, 9, 10, 11])
Actual results
<torch.utils.data.sampler.SequentialSampler object at 0x7f245a510208>
...
Traceback (most recent call last):
File "", line 1, in
File "/data/leshekha/lib/HFDatasets/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 435, in next
data = self._next_data()
File "/data/leshekha/lib/HFDatasets/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 474, in _next_data
index = self._next_index() # may raise StopIteration
File "/data/leshekha/lib/HFDatasets/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 427, in _next_index
return next(self._sampler_iter) # may raise StopIteration
File "/data/leshekha/lib/HFDatasets/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 227, in iter
for idx in self.sampler:
File "/data/leshekha/lib/HFDatasets/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 67, in iter
return iter(range(len(self.data_source)))
TypeError: object of type 'IterableDataset' has no len()
Environment info
datasets
version: '1.8.1.dev0'- Platform: Linux
- Python version: Python 3.6.8
- PyArrow version: '3.0.0'