In [None]:
#default_exp data.load

In [None]:
#export
from local.imports import *
from local.test import *
from local.core import *
from local.notebook.showdoc import show_doc

In [None]:
#export
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind
_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)

In [None]:
def twoepochs(d): return ' '.join(''.join(o) for _ in range(2) for o in d)
bs = 4
letters = list(string.ascii_lowercase)

## DataLoader

In [None]:
#export
def _wif(worker_id):
    info = get_worker_info()
    ds = info.dataset.d
    ds.nw,ds.offs,ds.seed = info.num_workers,info.id,info.seed
    ds.wif()

class _FakeLoader(GetAttr):
    _auto_collation,collate_fn,drop_last,dataset_kind,_index_sampler = False,noops,False,_DatasetKind.Iterable,Inf.count
    def __init__(self, d, pin_memory, num_workers, timeout):
        self.dataset,self.default,self.worker_init_fn = self,d,_wif
        store_attr(self, 'd,pin_memory,num_workers,timeout')
    def __iter__(self): return iter(self.d._iter())

In [None]:
#export
class DataLoader():
    reset = wif = noop   
    _methods = 'collate_fn batches reset wif sampler item'.split()
    @kwargs(_methods)
    def __init__(self, items=None, bs=None, drop_last=False, shuffle=False, indexed=None,
                 num_workers=0, pin_memory=False, timeout=0, tfm=noop, **kwargs):
        replace_methods(self, kwargs)
        if indexed is None: indexed = items is not None and hasattr(items,'__getitem__')
        store_attr(self, 'items,tfm,bs,drop_last,shuffle,indexed')
        self.fake_l = _FakeLoader(self, pin_memory, num_workers, timeout)
        self.lock,self.seed,self.rng,self.nw,self.offs = Lock(),None,random.Random(),1,0
        try: self.n = len(self.items)
        except TypeError: self.n = None
        assert not kwargs or not (bs is None and drop_last)

    def __iter__(self): return _loaders[self.fake_l.num_workers==0](self.fake_l)
    def _iter(self):
        if self.seed is not None: set_seed(self.seed)
        self.it = iter(self.items) if self.items else None
        self.reset()
        idxs = (b for i,b in enumerate(self.sampler()) if i%self.nw==self.offs)
        return map(self.tfm, map(self.collate_fn, self.batches(idxs)))
    
    def __len__(self):
        if self.n is None: raise TypeError
        if self.bs is None: return self.n
        return self.n//self.bs + (0 if self.drop_last or self.n%self.bs==0 else 1)
    
    def batches(self, idxs):
        res = map(self.item, idxs)
        return res if self.bs is None else chunked(res, self.bs, self.drop_last)

    def sampler(self):
        res = Inf.count if self.indexed else Inf.nones
        if self.n is None: return res
        res = list(itertools.islice(res, self.n))
        return self.rng.sample(res,len(res)) if self.shuffle else res

    def collate_fn(self, b): return (default_collate,default_convert)[self.bs is None](b)
    def item(self, s): return next(self.it) if s is None else self.items[s]

In [None]:
ds1 = DataLoader(letters,4,drop_last=True)
test_eq(twoepochs(ds1), 'abcd efgh ijkl mnop qrst uvwx abcd efgh ijkl mnop qrst uvwx')

ds1 = DataLoader(range(12), bs=4)
test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])))

ds1 = DataLoader([str(i) for i in range(11)], bs=4)
test_eq_type(L(ds1), L(['0','1','2','3'],['4','5','6','7'],['8','9','10']))

it = iter(DataLoader(map(noop,range(20)), bs=4))
test_eq_type([next(it) for _ in range(3)], [tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])])

Override `batches` to return some specific finite iterable.

In [None]:
class LettersDS(DataLoader):
    def batches(self, idxs): return (string.ascii_lowercase[i:i+4] for i in range(0,26,4))

test_eq(L(LettersDS()), 'abcd,efgh,ijkl,mnop,qrst,uvwx,yz'.split(','))

Use `idxs` to get indexes of samples of this batch, if needed. 

In [None]:
class RandDS(DataLoader):
    def batches(self, idxs): return gen(lambda o:random.random(), idxs, lt(0.95))

L(RandDS())

(#2) [0.7488540808888661,0.16333032484275523]

In [None]:
def _batches(self, idxs): return gen(lambda o:random.random(), idxs, lt(0.95))
L(DataLoader(batches=_batches))

(#1) [0.6297980998854832]

Override `batch` and use the default infinite sampler to get a stream of unknown length (`raise StopIteration` when you want to stop the stream).

In [None]:
class RandDS(DataLoader):
    def item(self, s):
        r = random.random()
        return r if r<0.95 else stop()

L(RandDS())

(#8) [0.47506196625440245,0.527832059224568,0.09633094519959395,0.746975721525937,0.5657155395086615,0.5240405515186721,0.6958857697534353,0.7077600665601746]

`items` is assumed to have a `__next__` that returns a batch.

In [None]:
ds1 = DataLoader(letters)
test_eq(ds1, letters)
test_eq(len(ds1), 26)

test_shuffled(L(DataLoader(letters, shuffle=True)), letters)

ds1 = DataLoader(letters, indexed=False)
test_eq(ds1, letters)
test_eq(len(ds1), 26)

t2 = L(tensor([0,1,2]),tensor([3,4,5]))
ds2 = DataLoader(t2)
test_eq_type(L(ds2), t2)

t3 = L(array([0,1,2]),array([3,4,5]))
ds3 = DataLoader(t3)
test_eq_type(L(ds3), t2)

ds4 = DataLoader(t3, collate_fn=noops)
test_eq_type(L(ds4), t3)

In [None]:
ds1 = DataLoader(letters,4,drop_last=True)
test_eq(twoepochs(ds1), 'abcd efgh ijkl mnop qrst uvwx abcd efgh ijkl mnop qrst uvwx')

ds1 = DataLoader(range(12), bs=4)
test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])))

ds1 = DataLoader([str(i) for i in range(11)], bs=4)
test_eq_type(L(ds1), L(['0','1','2','3'],['4','5','6','7'],['8','9','10']))

it = iter(DataLoader(map(noop,range(20)), bs=4))
test_eq_type([next(it) for _ in range(3)], [tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])])

In [None]:
class RandBatchDS(DataLoader):
    def item(self, s):
        r = random.random()
        if r>0.9: raise StopIteration
        return r

L(RandBatchDS(bs=4)).mapped(len)

(#1) [2]

In [None]:
L(RandBatchDS(bs=4, num_workers=4, drop_last=True)).mapped(len)

(#8) [4,4,4,4,4,4,4,4]

In [None]:
class SleepyDS(list):
    def __getitem__(self,i):
        time.sleep(random.random()/50)
        return super().__getitem__(i)

t = SleepyDS(letters)

%time test_eq(DataLoader(t, num_workers=0), letters)
%time test_eq(DataLoader(t, num_workers=2), letters)
%time test_eq(DataLoader(t, num_workers=4), letters)
test_shuffled(L(DataLoader(t, shuffle=True, num_workers=4)), letters)

CPU times: user 3.29 ms, sys: 0 ns, total: 3.29 ms
Wall time: 239 ms
CPU times: user 13.4 ms, sys: 12.4 ms, total: 25.8 ms
Wall time: 164 ms
CPU times: user 7.48 ms, sys: 26.8 ms, total: 34.3 ms
Wall time: 116 ms


In [None]:
class SleepyQueue():
    "Simulate a queue with varying latency"
    def __init__(self, q): self.q=q
    def __iter__(self):
        while True:
            time.sleep(random.random()/100)
            try: yield self.q.get_nowait()
            except queues.Empty: return

q = Queue()
for o in range(10): q.put(o)
it = SleepyQueue(q)

%time L(DataLoader(it, num_workers=4))

CPU times: user 10.3 ms, sys: 20.8 ms, total: 31.1 ms
Wall time: 56.4 ms


(#10) [3,0,2,1,6,5,4,9,8,7]

## Export -

In [None]:
#hide
from local.notebook.export import notebook2script
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_core.ipynb.
Converted 01a_dataloader.ipynb.
Converted 01a_script.ipynb.
Converted 02_transforms.ipynb.
Converted 03_pipeline.ipynb.
Converted 04_data_external.ipynb.
Converted 05_data_core.ipynb.
Converted 06_data_source.ipynb.
Converted 07_vision_core.ipynb.
Converted 08_pets_tutorial.ipynb.
Converted 09_vision_augment.ipynb.
Converted 09a_rect_augment.ipynb.
Converted 10_data_block.ipynb.
Converted 11_layers.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_learner.ipynb.
Converted 14_callback_schedule.ipynb.
Converted 15_callback_hook.ipynb.
Converted 16_callback_progress.ipynb.
Converted 17_callback_tracker.ipynb.
Converted 18_callback_fp16.ipynb.
Converted 19_callback_mixup.ipynb.
Converted 20_metrics.ipynb.
Converted 21_tutorial_imagenette.ipynb.
Converted 30_text_core.ipynb.
Converted 31_text_data.ipynb.
Converted 32_text_models_awdlstm.ipynb.
Converted 33_test_models_core.ipynb.
Converted 34_callback_rnn.ipynb.
Converted 35_tutorial_wikitex