Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error iteration over IterableDataset using Torch DataLoader #2583

Closed
LeenaShekhar opened this issue Jul 2, 2021 · 2 comments
Closed

Error iteration over IterableDataset using Torch DataLoader #2583

LeenaShekhar opened this issue Jul 2, 2021 · 2 comments
Labels
bug Something isn't working

Comments

@LeenaShekhar
Copy link

LeenaShekhar commented Jul 2, 2021

Describe the bug

I have an IterableDataset (created using streaming=True) and I am trying to create batches using Torch DataLoader class by passing this IterableDataset to it. This throws error which is pasted below. I can do the same by using Torch IterableDataset. One thing I noticed is that in the former case when I look at the dataloader.sampler class I get torch.utils.data.sampler.SequentialSampler while the latter one gives torch.utils.data.dataloader._InfiniteConstantSampler.

I am not sure if this is how it is meant to be used, but that's what seemed reasonable to me.

Steps to reproduce the bug

  1. Does not work.
>>> from datasets import load_dataset
>>> dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
>>> dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)
>>> dataloader.sampler
<torch.utils.data.sampler.SequentialSampler object at 0x7f245a510208>
>>> for batch in dataloader:
...     print(batch)
  1. Works.
import torch
from torch.utils.data import Dataset, IterableDataset, DataLoader
class CustomIterableDataset(IterableDataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, data):
        'Initialization'
        self.data = data


  def __iter__(self):
        return iter(self.data)


data = list(range(12))
dataset = CustomIterableDataset(data)
dataloader = DataLoader(dataset, batch_size=4)
print("dataloader: ", dataloader.sampler)
for batch in dataloader:
    print(batch)

Expected results

To get batches of data with the batch size as 4. Output from the latter one (2) though Datasource is different here so actual data is different.
dataloader: <torch.utils.data.dataloader._InfiniteConstantSampler object at 0x7f1cc29e2c50>
tensor([0, 1, 2, 3])
tensor([4, 5, 6, 7])
tensor([ 8, 9, 10, 11])

Actual results

<torch.utils.data.sampler.SequentialSampler object at 0x7f245a510208>

...
Traceback (most recent call last):
File "", line 1, in
File "/data/leshekha/lib/HFDatasets/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 435, in next
data = self._next_data()
File "/data/leshekha/lib/HFDatasets/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 474, in _next_data
index = self._next_index() # may raise StopIteration
File "/data/leshekha/lib/HFDatasets/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 427, in _next_index
return next(self._sampler_iter) # may raise StopIteration
File "/data/leshekha/lib/HFDatasets/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 227, in iter
for idx in self.sampler:
File "/data/leshekha/lib/HFDatasets/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 67, in iter
return iter(range(len(self.data_source)))
TypeError: object of type 'IterableDataset' has no len()

Environment info

  • datasets version: '1.8.1.dev0'
  • Platform: Linux
  • Python version: Python 3.6.8
  • PyArrow version: '3.0.0'
@LeenaShekhar LeenaShekhar added the bug Something isn't working label Jul 2, 2021
@lhoestq
Copy link
Member

lhoestq commented Jul 5, 2021

Hi ! This is because you first need to format the dataset for pytorch:

>>> import torch
>>> from datasets import load_dataset
>>> dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
>>> torch_iterable_dataset = dataset.with_format("torch")
>>> assert isinstance(torch_iterable_dataset, torch.utils.data.IterableDataset)
>>> dataloader = torch.utils.data.DataLoader(torch_iterable_dataset, batch_size=4)
>>> next(iter(dataloader))
{'id': tensor([0, 1, 2, 3]), 'text': ['Mtendere Village was inspired...]}

This is because the pytorch dataloader expects a subclass of torch.utils.data.IterableDataset. Since you can't pass an arbitrary iterable to a pytorch dataloader, you first need to build an object that inherits from torch.utils.data.IterableDataset using with_format("torch") for example.

@LeenaShekhar
Copy link
Author

Thank you for that and the example!

What you said makes total sense; I just somehow missed that and assumed HF IterableDataset was a subclass of Torch IterableDataset.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants