In [None]:
#default_exp data.load

In [None]:
#export
from local.imports import *
from local.test import *
from local.core import *
from local.notebook.showdoc import show_doc

In [None]:
#export
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind
_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)

## DataLoader

In [None]:
class SleepyDS():
    def __init__(self,coll): self.coll,self.rng = coll,random.Random()
    def __len__(self): return len(self.coll)
    def __getitem__(self,i):
        time.sleep(self.rng.random()/100)
        return self.coll[i]

def twoepochs(d): return ' '.join(''.join(o) for _ in range(2) for o in d)

testds = SleepyDS(string.ascii_lowercase)    
bs = 4

- set bs,drop_last,sampler after init
- collate_fn,kind,sampler,auto_collate from ds
  - auto_collate replaced by 
- figure ds type from attr, not inheritance
- transforms and reset
- define appropriate init params with subclass params, not bool chks

In [None]:
#export
def _wif(worker_id):
    info = get_worker_info()
    ds = info.dataset
    ds.nw,ds.offs = info.num_workers,info.id
    ds.wif()

In [None]:
def chunked(it, cs, drop_last=False):
    if not isinstance(it, Iterator): it = iter(it)
    while True:
        res = list(itertools.islice(it, cs))
        if not res or (len(res)<cs and drop_last): return
        yield res

In [None]:
t = L.range(10)
test_eq(chunked(t,3),      [[0,1,2], [3,4,5], [6,7,8], [9]])
test_eq(chunked(t,3,True), [[0,1,2], [3,4,5], [6,7,8],    ])

In [None]:
#export
class BaseDataset():
    _methods = 'collate_fn indexes batches reset wif'
    @kwargs(_methods, keep=True)
    def __init__(self, items=None, bs=None, drop_last=False, shuffle=False, indexed=False, sampler=None, **kwargs):
        self.items,self.bs,self.drop_last,self.shuffle = items,bs,drop_last,shuffle
        self.indexed,self.sampler,self.rng = indexed,sampler,random.Random()
        try: self.n = len(self.items)
        except TypeError: self.n = None
        for k in copy(kwargs): setattr(self, k, types.MethodType(kwargs.pop(k),self))
        assert not kwargs
        assert not (bs is None and drop_last)
        assert not self.shuffle or (self.n is not None and self.sampler is None)

    def __iter__(self):
        torch.manual_seed(self.rng.randint(0,sys.maxsize))
        self.it = iter(self.items) if self.items else None
        self.reset()
        return map(self.collate_fn, self.batches())
    
    def __len__(self):
        n = stop(TypeError) if self.n is None else self.n
        if self.bs is None: return n
        return n//self.bs + (0 if self.drop_last or n%self.bs==0 else 1)
    
    def batches(self):
        res = map(self.item, self.mk_sampler())
        return res if self.bs is None else chunked(res, self.bs, self.drop_last)

    def mk_sampler(self):
        if self.sampler: return self.sampler
        res = Inf.count if self.indexed else Inf.nones
        if self.n is not None:
            res = list(itertools.islice(res, self.n))
            if self.shuffle: random.shuffle(res)
        return iter(res)

    reset = wif = noop   
    def collate_fn(self, b): return (default_collate,default_convert)[self.bs is None](b)
    def item(self, s): return next(self.it) if s is None else self.items[s]

Override `batches` to return some specific finite iterable.

In [None]:
class LettersDS(BaseDataset):
    def batches(self): return (string.ascii_lowercase[i:i+4] for i in range(0,26,4))

test_eq(L(LettersDS()), 'abcd,efgh,ijkl,mnop,qrst,uvwx,yz'.split(','))

`mk_sampler` is also available here.

In [None]:
class RandDS(BaseDataset):
    def batches(self): return gen(lambda o:random.random(), self.mk_sampler(), lt(0.95))

L(RandDS())

(#10) [0.20387192762369488,0.3985282381412417,0.18992398700195434,0.9378714396248389,0.18563209824033744,0.4227552895863914,0.3886418127933633,0.5823938265157906,0.10078752761895038,0.526057197290042]

Override `batch` and use the default infinite sampler to get a stream of unknown length (`raise StopIteration` when you want to stop the stream).

In [None]:
class RandDS(BaseDataset):
    def item(self, s):
        r = random.random()
        return r if r<0.95 else stop()

L(RandDS())

(#17) [0.8874313343998564,0.19782911066294018,0.6102335054945686,0.5107968334575369,0.5834613827928582,0.01478423856315303,0.020841334904318942,0.35038440472148213,0.2677479971018841,0.9255407362024884...]

`items` is assumed to have a `__next__` that returns a batch.

In [None]:
ds1 = BaseDataset(testds)
test_eq(''.join(ds1), string.ascii_lowercase)
test_eq(len(ds1), 26)

t2 = L(tensor([0,1,2]),tensor([3,4,5]))
ds2 = BaseDataset(t2)
test_eq_type(L(ds2), t2)

t3 = L(array([0,1,2]),array([3,4,5]))
ds3 = BaseDataset(t3)
test_eq_type(L(ds3), t2)

ds4 = BaseDataset(t3, collate_fn=noops)
test_eq_type(L(ds4), t3)

In [None]:
ds1 = BaseDataset(testds,bs=4,drop_last=True)
test_eq(twoepochs(ds1), 'abcd efgh ijkl mnop qrst uvwx abcd efgh ijkl mnop qrst uvwx')

ds1 = BaseDataset(range(12), bs=4)
test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])))

ds1 = BaseDataset([str(i) for i in range(11)], bs=4)
test_eq_type(L(ds1), L(['0','1','2','3'],['4','5','6','7'],['8','9','10']))

it = iter(BaseDataset(map(noop,range(20)), bs=4))
test_eq_type([next(it) for _ in range(3)], [tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])])

In [None]:
class RandBatchDS(BaseDataset):
    def item(self, s):
        r = random.random()
        if r>0.9: raise StopIteration
        return r

ds = RandBatchDS(bs=4)
L(ds)

(#18) [tensor([0.2928], dtype=torch.float64),tensor([0.1940, 0.2347], dtype=torch.float64),tensor([0.2381, 0.3685, 0.2362, 0.4815], dtype=torch.float64),tensor([0.2945], dtype=torch.float64),tensor([0.5506, 0.5095, 0.8379, 0.2772], dtype=torch.float64),tensor([0.0525, 0.3875, 0.3879, 0.5744], dtype=torch.float64),tensor([0.1093, 0.2447, 0.5660, 0.3921], dtype=torch.float64),tensor([0.5968, 0.5351, 0.1871, 0.3360], dtype=torch.float64),tensor([0.5119, 0.6163, 0.5396, 0.7278], dtype=torch.float64),tensor([0.7152, 0.5996, 0.6750, 0.6917], dtype=torch.float64)...]

In [None]:
def delegate_attr(k, o, to):
    if k.startswith('_') or k==to: raise AttributeError(k)
    try: return getattr(getattr(o,to), k)
    except AttributeError: raise AttributeError(k) from None

In [None]:
#export
class DataLoader:
    _auto_collation,collate_fn,drop_last,dataset_kind = False,noops,False,_DatasetKind.Iterable
    def __init__(self, dataset, num_workers=0, pin_memory=False, timeout=0, tfm=noop, **kwargs):
        self.dataset = dataset if isinstance(dataset, BaseDataset) else BaseDataset(dataset, **kwargs) 
        self.pin_memory,self.tfm,self.worker_init_fn,self._index_sampler = pin_memory,tfm,_wif,Inf.count
        self.num_workers = 0 if num_workers < 0 else num_workers
        self.timeout = 0 if timeout < 0 else timeout

    def __iter__(self):  return map(self.tfm, _loaders[self.num_workers==0](self))
    def __getattr__(self,k): return delegate_attr(k,self,'dataset')
    def __len__(self): return len(self.dataset)

In [None]:
for _ in range(4): print(len(L(DataLoader(ds))))

7
0
17
11


In [None]:
for _ in range(4): print(len(L(DataLoader(ds, num_workers=4))))

42
17
46
36


## Incomplete below

In [None]:
test_eq(''.join(DataLoader(ds1, num_workers=0)), string.ascii_lowercase)
test_eq(L(DataLoader(ds2, num_workers=1)), t2)
# n workers means n copies of the iter, in some arbitrary order
test_eq(L(DataLoader(ds4, num_workers=2)).mapped(list).sorted(), (t3*2).mapped(list).sorted())

In [None]:
# def mk_collate_fn(auto_collation): return (default_convert,default_collate)[auto_collation]

In [None]:
#export
class IndexedDataset(Dataset):
    def __init__(self, items ,bs=1, shuffle=False, sampler=None, batch_sampler=None, drop_last=False,
                 sampler_cls=None, batch_sampler_cls=BatchSampler, collate_fn=default_collate):
        super().__init__(items,collate_fn)
        self.sampler = batch_sampler
        self.rng,self.nw,self.offs = random.Random(),1,0
        self._delegate_items("get_batches","get_batch","collate")
        if self.sampler: return
        if not sampler: sampler = ifnone(sampler_cls, (SequentialSampler,RandomSampler)[shuffle])(items)
        self.sampler = batch_sampler_cls(sampler, bs, drop_last)

    def __iter__(self):
        torch.manual_seed(self.rng.randint(0,sys.maxsize))
        samps = list(enumerate(self.sampler))
        idxs = (b for i,b in samps if i%self.nw==self.offs)
        return self.get_batches(idxs)
    
    def get_batch(self, b): return [self.items[j] for j in b]
    def get_batches(self, idxs): return map(self.get_batch, idxs)
    def wif(self) : self.sampler.sampler = copy(self.sampler.sampler)

In [None]:
def get_dl(bs=1, collate_fn=default_collate, num_workers=0, **kwargs):
    return DataLoader(testds, num_workers=num_workers, bs=bs, collate_fn=collate_fn, **kwargs)

In [None]:
dl = get_dl(bs=4, num_workers=0)
t = twoepochs(dl)
test_eq(t, 'abcd efgh ijkl mnop qrst uvwx yz abcd efgh ijkl mnop qrst uvwx yz')
test_eq(len(set(t)), 27)

In [None]:
dl = get_dl(bs=4, num_workers=4, shuffle=True)
t = twoepochs(dl)
test_ne(t, 'abcd efgh ijkl mnop qrst uvwx yz abcd efgh ijkl mnop qrst uvwx yz')
test_eq(len(set(t)), 27)
t

In [None]:
# class _NextFetcher:
#     def __init__(self, dataset): self.dataset_iter = iter(dataset)
#     def fetch(self, possibly_batched_index): return next(self.dataset_iter)
# def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): return _NextFetcher(dataset)
# _DatasetKind.create_fetcher = create_fetcher

#     def _delegate_items(self, *attrs):
#         for attr in attrs:
#             if hasattr(self.items,attr): setattr(self, attr, getattr(self.items, attr))
# 
#     def batch(self, s):
#         if self.done:
#             self.done=False
#             raise StopIteration
#         res = []
#         try:
#             for _ in range(self.bs): res.append(self.item(s))
#         except StopIteration:
#             if res==[]: raise StopIteration
#             self.done=True
#         return res

## BatchDS

In [None]:
#export
class BaseDS(GetAttr):
    _xtra = ['show', 'decode', 'show_at', 'decode_at', 'decode_batch']
    def __init__(self, ds):
        self.default = self.ds = ds
        ds.wrapper = self
        self._delegate_ds("reset")

    def _delegate_ds(self, attr):
        if hasattr(self.ds,attr): setattr(self, attr, getattr(self.ds, attr))
            
    def reset(self): pass

In [None]:
#export
class BatchDS(BaseDS, IterableDataset):
    _xtra = ['show', 'decode', 'show_at', 'decode_at', 'decode_batch']
    def __init__(self, ds ,bs=1, shuffle=False, sampler=None, batch_sampler=None, drop_last=False,
                 collate_fn=default_collate, sampler_cls=None, batch_sampler_cls=BatchSampler):
        self.default,self.ds,self.samp,self.collate_fn = ds,ds,batch_sampler,collate_fn
        self.rng,self.nw,self.offs,self.is_iterable = random.Random(),1,0,True
        for o in ("get_batches","get_batch","collate"): self._delegate_ds(o)
        if self.samp: return
        if not sampler: sampler = ifnone(sampler_cls, (SequentialSampler,RandomSampler)[shuffle])(ds)
        self.samp = batch_sampler_cls(sampler, bs, drop_last)

    def __iter__(self):
        torch.manual_seed(self.rng.randint(0,sys.maxsize))
        samps = list(enumerate(self.samp))
        idxs = (b for i,b in samps if i%self.nw==self.offs)
        return self.get_batches(idxs)
    
    def get_batch(self, b): return [self.ds[j] for j in b]
    def get_batches(self, idxs): return map(self.get_batch, idxs)
    def collate(self, idxs): return self.collate_fn(self.get_batches(idxs))
    def __len__(self): return len(self.samp)

In [None]:
#export
def _wif(worker_id):
    info = get_worker_info()
    ds = info.dataset
    ds.nw,ds.offs = info.num_workers,info.id
    ds.samp.sampler = copy(ds.samp.sampler)

In [None]:
class SleepyDS():
    def __init__(self,coll): self.coll=coll
    def __len__(self): return len(self.coll)
    def __getitem__(self,i): time.sleep(0.02); return self.coll[i]

In [None]:
letters = list(string.ascii_lowercase)
def twoepochs(d): print(' '.join(''.join(o) for _ in range(2) for o in d))
bs = 4

In [None]:
#export
def dataloader(ds, bs=1, num_workers=0, collate_fn=default_collate, **kwargs):
    if not isinstance(ds, IterableDataset): ds = BatchDS(ds, bs, **kwargs)
    return DataLoader(ds, num_workers=num_workers, batch_size=None,
                      worker_init_fn=_wif, collate_fn=noop)

In [None]:
def get_dl(bs=1, collate_fn=default_collate, num_workers=0, **kwargs):
    return dataloader(SleepyDS(string.ascii_lowercase), bs, num_workers, collate_fn=collate_fn, **kwargs)

In [None]:
dl = get_dl(bs=4, num_workers=4, shuffle=True)
twoepochs(dl)
test_eq(len(set(sum(dl,[]))), 26)

gdxl ubiy znwj ecra spkh mtfq vo gdxl ubiy znwj ecra spkh mtfq vo


In [None]:
dl = get_dl(bs=4, num_workers=4)
%time twoepochs(dl)
test_eq(len(set(sum(dl,[]))), 26)

abcd efgh ijkl mnop qrst uvwx yz abcd efgh ijkl mnop qrst uvwx yz
CPU times: user 21.1 ms, sys: 46.4 ms, total: 67.5 ms
Wall time: 396 ms


In [None]:
dl = get_dl(bs=4, num_workers=0)
%time twoepochs(dl)
len(dl)

abcd efgh ijkl mnop qrst uvwx yz abcd efgh ijkl mnop qrst uvwx yz
CPU times: user 5.71 ms, sys: 0 ns, total: 5.71 ms
Wall time: 1.05 s


7

In [None]:
dl = get_dl(bs=4, num_workers=4, drop_last=True)
twoepochs(dl)
len(dl)

abcd efgh ijkl mnop qrst uvwx abcd efgh ijkl mnop qrst uvwx


6

In [None]:
dl = get_dl(bs=4, num_workers=0, shuffle=True)
twoepochs(dl)
test_eq(len(set(sum(dl,[]))), 26)

lahp kzyn rgmb sfdt xvco ueij wq svid tckw phjz raeu gfqy mlnb xo


In [None]:
ds = SleepyDS(string.ascii_lowercase)
dl = get_dl(bs=4, num_workers=4, sampler=SequentialSampler(ds))
twoepochs(dl)

abcd efgh ijkl mnop qrst uvwx yz abcd efgh ijkl mnop qrst uvwx yz


In [None]:
dl = get_dl(num_workers=4, batch_sampler=BatchSampler(RandomSampler(ds), 8, False))
twoepochs(dl)
test_eq(len(set(sum(dl,[]))), 26)

ypojsdrz hkvwnilc xqmbfgtu ae ypojsdrz hkvwnilc xqmbfgtu ae


In [None]:
def rev_collate(s): return default_collate(list(reversed(s)))
dl = get_dl(bs=4, num_workers=4, collate_fn=rev_collate)
twoepochs(dl)

abcd efgh ijkl mnop qrst uvwx yz abcd efgh ijkl mnop qrst uvwx yz


In [None]:
class SleepyDS2(SleepyDS):
    def get_batch(self, b): return "".join([self[j] for j in b]) + '/'

dl = dataloader(SleepyDS2(string.ascii_lowercase), bs=4, num_workers=4)
twoepochs(dl)

abcd/ efgh/ ijkl/ mnop/ qrst/ uvwx/ yz/ abcd/ efgh/ ijkl/ mnop/ qrst/ uvwx/ yz/
