tutorial from https://medium.com/speechmatics/how-to-build-a-streaming-dataloader-with-pytorch-a66dd891d9dd

In [4]:
from torch.utils.data import Dataset, IterableDataset, DataLoader
from itertools import cycle, islice, chain

## Stream from one datasource

In [19]:
class MyIterableDataset(IterableDataset):
    
    def __init__(self, data):
        self.data = data
    
    def process_data(self, data):
        for x in data:
            yield x
        
    def get_stream(self, data):
        return cycle(self.process_data(data))
    
    def __iter__(self):
        return self.get_stream(self.data)

data1 = ['d1:{}'.format(x) for x in list(range(10))]

iterable_dataset = MyIterableDataset(data1)
loader = DataLoader(iterable_dataset, batch_size=4)

for batch in islice(loader, 8):
    print(batch)

['d1:0', 'd1:1', 'd1:2', 'd1:3']
['d1:4', 'd1:5', 'd1:6', 'd1:7']
['d1:8', 'd1:9', 'd1:0', 'd1:1']
['d1:2', 'd1:3', 'd1:4', 'd1:5']
['d1:6', 'd1:7', 'd1:8', 'd1:9']
['d1:0', 'd1:1', 'd1:2', 'd1:3']
['d1:4', 'd1:5', 'd1:6', 'd1:7']
['d1:8', 'd1:9', 'd1:0', 'd1:1']


## Stream from multiple datasources

In [26]:
class MyIterableDataset(IterableDataset):
    
    def __init__(self, data_list):
        self.data_list = data_list
    
    def process_data(self, data):
        for x in data:
            yield x
        
    def get_stream(self, data_list):
        return chain.from_iterable(map(self.process_data, cycle(data_list)))
    
    def __iter__(self):
        return self.get_stream(self.data_list)

data1 = ['d1:{}'.format(x) for x in list(range(10))]
data2 = ['d2:{}'.format(x) for x in list(range(6))]
data3 = ['d3:{}'.format(x) for x in list(range(8))]
data_list = [data1, data2, data3]

iterable_dataset = MyIterableDataset(data_list)
loader = DataLoader(iterable_dataset, batch_size=4)

for batch in islice(loader, 8):
    print(batch)

['d1:0', 'd1:1', 'd1:2', 'd1:3']
['d1:4', 'd1:5', 'd1:6', 'd1:7']
['d1:8', 'd1:9', 'd2:0', 'd2:1']
['d2:2', 'd2:3', 'd2:4', 'd2:5']
['d3:0', 'd3:1', 'd3:2', 'd3:3']
['d3:4', 'd3:5', 'd3:6', 'd3:7']
['d1:0', 'd1:1', 'd1:2', 'd1:3']
['d1:4', 'd1:5', 'd1:6', 'd1:7']
