In [None]:
#default_exp data.load

In [None]:
#export
from local.imports import *
from local.test import *
from local.core import *
from local.notebook.showdoc import show_doc

In [None]:
#export
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind
_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)

In [None]:
def twoepochs(d): return ' '.join(''.join(o) for _ in range(2) for o in d)
bs = 4

## DataLoader

In [None]:
#export
class Dataset():
    _methods = 'collate_fn indexes batches reset wif sampler'.split()
    @kwargs(_methods)
    def __init__(self, items=None, bs=None, drop_last=False, shuffle=False, indexed=None, **kwargs):
        if indexed is None: indexed = items is not None and hasattr(items,'__getitem__')
        self.items,self.bs,self.drop_last,self.shuffle,self.indexed = items,bs,drop_last,shuffle,indexed
        try: self.items.dataset = self
        except: pass
        self.lock,self.seed,self.rng,self.nw,self.offs = Lock(),None,random.Random(),1,0
        replace_methods(self, kwargs)
        try: self.n = len(self.items)
        except TypeError: self.n = None
        assert not kwargs or not (bs is None and drop_last)

    def __iter__(self):
        if self.seed is not None: set_seed(self.seed)
        self.it = iter(self.items) if self.items else None
        idxs = (b for i,b in enumerate(self.sampler()) if i%self.nw==self.offs)
        self.reset()
        return map(self.collate_fn, self.batches(iter(idxs)))
    
    def __len__(self):
        if self.n is None: raise TypeError
        if self.bs is None: return self.n
        return self.n//self.bs + (0 if self.drop_last or self.n%self.bs==0 else 1)
    
    def batches(self, idxs):
        res = map(self.item, idxs)
        return res if self.bs is None else chunked(res, self.bs, self.drop_last)

    def sampler(self):
        res = Inf.count if self.indexed else Inf.nones
        if self.n is None: return res
        res = list(itertools.islice(res, self.n))
        return self.rng.sample(res,len(res)) if self.shuffle else res

    reset = wif = noop   
    def collate_fn(self, b): return (default_collate,default_convert)[self.bs is None](b)
    def item(self, s): return next(self.it) if s is None else self.items[s]

Override `batches` to return some specific finite iterable.

In [None]:
class LettersDS(Dataset):
    def batches(self, idxs): return (string.ascii_lowercase[i:i+4] for i in range(0,26,4))

test_eq(L(LettersDS()), 'abcd,efgh,ijkl,mnop,qrst,uvwx,yz'.split(','))

Use `idxs` to get indexes of samples of this batch, if needed. 

In [None]:
class RandDS(Dataset):
    def batches(self, idxs): return gen(lambda o:random.random(), idxs, lt(0.95))

L(RandDS())

(#1) [0.2192423459964089]

In [None]:
def _batches(self, idxs): return gen(lambda o:random.random(), idxs, lt(0.95))
L(Dataset(batches=_batches))

(#36) [0.9106970425376758,0.35600176946891315,0.8238095646167429,0.8796711639590988,0.7131970026709382,0.819739166663936,0.3696885689216126,0.2785539368256724,0.2563947158799942,0.26743777640004407...]

Override `batch` and use the default infinite sampler to get a stream of unknown length (`raise StopIteration` when you want to stop the stream).

In [None]:
class RandDS(Dataset):
    def item(self, s):
        r = random.random()
        return r if r<0.95 else stop()

L(RandDS())

(#11) [0.7713139674771592,0.42880737781046263,0.8834340033055438,0.589390980317075,0.17985707124870853,0.4100714318094145,0.458700785956627,0.4427079004702963,0.5566877117139156,0.677099402208...]

`items` is assumed to have a `__next__` that returns a batch.

In [None]:
letters = list(string.ascii_lowercase)

In [None]:
ds1 = Dataset(letters)
test_eq(ds1, letters)
test_eq(len(ds1), 26)

test_shuffled(L(Dataset(letters, shuffle=True)), letters)

ds1 = Dataset(letters, indexed=False)
test_eq(ds1, letters)
test_eq(len(ds1), 26)

t2 = L(tensor([0,1,2]),tensor([3,4,5]))
ds2 = Dataset(t2)
test_eq_type(L(ds2), t2)

t3 = L(array([0,1,2]),array([3,4,5]))
ds3 = Dataset(t3)
test_eq_type(L(ds3), t2)

ds4 = Dataset(t3, collate_fn=noops)
test_eq_type(L(ds4), t3)

In [None]:
ds1 = Dataset(letters,4,drop_last=True)
test_eq(twoepochs(ds1), 'abcd efgh ijkl mnop qrst uvwx abcd efgh ijkl mnop qrst uvwx')

ds1 = Dataset(range(12), bs=4)
test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])))

ds1 = Dataset([str(i) for i in range(11)], bs=4)
test_eq_type(L(ds1), L(['0','1','2','3'],['4','5','6','7'],['8','9','10']))

it = iter(Dataset(map(noop,range(20)), bs=4))
test_eq_type([next(it) for _ in range(3)], [tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])])

In [None]:
class RandBatchDS(Dataset):
    def item(self, s):
        r = random.random()
        if r>0.9: raise StopIteration
        return r

ds = RandBatchDS(bs=4)
L(ds)

(#5) [tensor([0.7370, 0.4806, 0.1867, 0.6093], dtype=torch.float64),tensor([0.1975, 0.5111, 0.4837, 0.1394], dtype=torch.float64),tensor([0.5290, 0.2578, 0.2953, 0.5098], dtype=torch.float64),tensor([0.4525, 0.2465, 0.2318, 0.3966], dtype=torch.float64),tensor([0.2007, 0.4606, 0.7058, 0.8673], dtype=torch.float64)]

In [None]:
#export
def _wif(worker_id):
    info = get_worker_info()
    ds = info.dataset
    ds.nw,ds.offs,ds.seed = info.num_workers,info.id,info.seed
    ds.wif()

In [None]:
#export
class DataLoader(GetAttr):
    _auto_collation,collate_fn,drop_last,dataset_kind,_index_sampler = False,noops,False,_DatasetKind.Iterable,Inf.count
    @delegates(Dataset.__init__)
    def __init__(self, items, num_workers=0, pin_memory=False, timeout=0, tfm=noop, **kwargs):
        self.default = self.dataset = items if isinstance(items, Dataset) else Dataset(items, **kwargs) 
        self.pin_memory,self.tfm,self.worker_init_fn = pin_memory,tfm,_wif
        self.num_workers = 0 if num_workers < 0 else num_workers
        self.timeout = 0 if timeout < 0 else timeout
        self.dataset.lock = Lock()

    def __iter__(self):  return map(self.tfm, _loaders[self.num_workers==0](self))
    def __len__(self): return len(self.dataset)

In [None]:
[len(L(DataLoader(ds))) for _ in range(4)]

[0, 1, 0, 20]

In [None]:
[len(L(DataLoader(ds, num_workers=4))) for _ in range(4)]

[31, 57, 61, 38]

In [None]:
class SleepyDS(list):
    def __getitem__(self,i):
        time.sleep(random.random()/50)
        return super().__getitem__(i)

In [None]:
t = SleepyDS(letters)

%time test_eq(DataLoader(t, num_workers=0), letters)
%time test_eq(DataLoader(t, num_workers=2), letters)
%time test_eq(DataLoader(t, num_workers=4), letters)
test_shuffled(L(DataLoader(t, shuffle=True, num_workers=4)), letters)

CPU times: user 3.65 ms, sys: 0 ns, total: 3.65 ms
Wall time: 235 ms
CPU times: user 19 ms, sys: 5.08 ms, total: 24.1 ms
Wall time: 155 ms
CPU times: user 16.5 ms, sys: 18.1 ms, total: 34.6 ms
Wall time: 116 ms


In [None]:
class SleepyQueue():
    "Simulate a queue with varying latency"
    def __init__(self, q): self.q=q
    def __iter__(self):
        while True:
            time.sleep(random.random()/100)
            try: yield self.q.get_nowait()
            except queues.Empty: return

In [None]:
q = Queue()
for o in range(100): q.put(o)
it = SleepyQueue(q)

%time L(DataLoader(it, num_workers=4))

## Export -

In [None]:
#hide
from local.notebook.export import notebook2script
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_core.ipynb.
Converted 01a_dataloader.ipynb.
Converted 01a_script.ipynb.
Converted 02_transforms.ipynb.
Converted 03_pipeline.ipynb.
Converted 04_data_external.ipynb.
Converted 05_data_core.ipynb.
Converted 06_data_source.ipynb.
Converted 07_vision_core.ipynb.
Converted 08_pets_tutorial.ipynb.
Converted 09_vision_augment.ipynb.
Converted 09a_rect_augment.ipynb.
Converted 10_data_block.ipynb.
Converted 11_layers.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_learner.ipynb.
Converted 14_callback_schedule.ipynb.
Converted 15_callback_hook.ipynb.
Converted 16_callback_progress.ipynb.
Converted 17_callback_tracker.ipynb.
Converted 18_callback_fp16.ipynb.
Converted 19_callback_mixup.ipynb.
Converted 20_metrics.ipynb.
Converted 21_tutorial_imagenette.ipynb.
Converted 30_text_core.ipynb.
Converted 31_text_data.ipynb.
Converted 32_text_models_awdlstm.ipynb.
Converted 33_test_models_core.ipynb.
Converted 34_callback_rnn.ipynb.
Converted 35_tutorial_wikitex