Skip to content

Error iteration over IterableDataset using Torch DataLoader #2583

Closed
@LeenaShekhar

Description

@LeenaShekhar

Describe the bug

I have an IterableDataset (created using streaming=True) and I am trying to create batches using Torch DataLoader class by passing this IterableDataset to it. This throws error which is pasted below. I can do the same by using Torch IterableDataset. One thing I noticed is that in the former case when I look at the dataloader.sampler class I get torch.utils.data.sampler.SequentialSampler while the latter one gives torch.utils.data.dataloader._InfiniteConstantSampler.

I am not sure if this is how it is meant to be used, but that's what seemed reasonable to me.

Steps to reproduce the bug

  1. Does not work.
>>> from datasets import load_dataset
>>> dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
>>> dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)
>>> dataloader.sampler
<torch.utils.data.sampler.SequentialSampler object at 0x7f245a510208>
>>> for batch in dataloader:
...     print(batch)
  1. Works.
import torch
from torch.utils.data import Dataset, IterableDataset, DataLoader
class CustomIterableDataset(IterableDataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, data):
        'Initialization'
        self.data = data


  def __iter__(self):
        return iter(self.data)


data = list(range(12))
dataset = CustomIterableDataset(data)
dataloader = DataLoader(dataset, batch_size=4)
print("dataloader: ", dataloader.sampler)
for batch in dataloader:
    print(batch)

Expected results

To get batches of data with the batch size as 4. Output from the latter one (2) though Datasource is different here so actual data is different.
dataloader: <torch.utils.data.dataloader._InfiniteConstantSampler object at 0x7f1cc29e2c50>
tensor([0, 1, 2, 3])
tensor([4, 5, 6, 7])
tensor([ 8, 9, 10, 11])

Actual results

<torch.utils.data.sampler.SequentialSampler object at 0x7f245a510208>

...
Traceback (most recent call last):
File "", line 1, in
File "/data/leshekha/lib/HFDatasets/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 435, in next
data = self._next_data()
File "/data/leshekha/lib/HFDatasets/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 474, in _next_data
index = self._next_index() # may raise StopIteration
File "/data/leshekha/lib/HFDatasets/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 427, in _next_index
return next(self._sampler_iter) # may raise StopIteration
File "/data/leshekha/lib/HFDatasets/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 227, in iter
for idx in self.sampler:
File "/data/leshekha/lib/HFDatasets/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 67, in iter
return iter(range(len(self.data_source)))
TypeError: object of type 'IterableDataset' has no len()

Environment info

  • datasets version: '1.8.1.dev0'
  • Platform: Linux
  • Python version: Python 3.6.8
  • PyArrow version: '3.0.0'

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions