In [1]:
#default_exp data.load

In [2]:
#export
from fastai2.vision.all import *

# Data load
> Core modifications for making the dataloader work

## Bucket

In [3]:
#export
class Bucket:
    def __init__(self, items): self.items = items
        
    def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else Bucket(self._get(idx))
    def _get(self, i):
        if is_indexer(i) or isinstance(i,slice): return getattr(self.items,'iloc',self.items)[i]
        i = mask2idxs(i)
        return (self.items.iloc[list(i)] if hasattr(self.items,'iloc')
                else self.items.__array__()[(i,)] if hasattr(self.items,'__array__')
                else [self.items[i_] for i_ in i])
    
    def tolist(self): return list(self.items)
    def map(self, f): return type(self)(type(self.items)(f(o) for o in self.items))
    def to_device(self, device): return type(self)(to_device(self.items))
    @property
    def shape(self): return (len(self.items),) # Needed for find_bs
    
    def __eq__(self, other): return self.items == other
    def __len__(self): return len(self.items)
    def __repr__(self): return f'<{self.__class__.__name__}: {self.items.__repr__()}>'

In [4]:
b1 = Bucket(['a', 'b'])
b2 = Bucket(['c', 'd'])
test_eq(type(b1), type(b2))

In [5]:
test_eq(b1[0], 'a')
test_eq(b2[1], 'd')

In [6]:
test_eq_type(b1[0,1], Bucket(['a','b']))

In [7]:
test_eq(is_iter(b1), False)

In [8]:
test_eq(b1.map(lambda o: o+'foo'), Bucket(['afoo', 'bfoo']))

In [9]:
test_eq(b1.shape, (2,))

Transforms need to be dispatched over Buckets

In [10]:
#export
old_do_call = Transform._do_call
def _do_call(self, f, x, **kwargs):
    if isinstance(x, Bucket):
        _f = lambda o: retain_type(f(o, **kwargs), o, f.returns_none(o))
        return x if f is None else x.map(_f)
    return old_do_call(self, f, x, **kwargs)
Transform._do_call = _do_call

In [11]:
@Transform
def Neg(x:int): return -x

In [12]:
test_eq(Neg(Bucket([1,2,4.2])), Bucket([-1,-2,4.2]))

## DataLoader

In [13]:
#export
def _bucket_collate(t): return Tuple(Bucket(o) for o in zip(*t))
def _bucket_convert(t): raise NotImplementedError

In [14]:
#export
def detect_batch_to_samples(b, max_n=10):
    zipped = []
    for i in range(min(len(b[0]), max_n)):
        zipped.append(Tuple([o[i] for o in b]))
    return zipped

In [15]:
#export
class DetectDataLoader(TfmdDL):
    def create_batch(self, b): return (_bucket_collate,_bucket_convert)[self.prebatched](b)
    
    def _decode_batch(self, b, max_n=9, full=True):
        f = self.after_item.decode
        f = compose(f, partial(getattr(self.dataset,'decode',noop), full = full))
        return L(detect_batch_to_samples(b, max_n=max_n)).map(f)
    
    def show_batch(self, b=None, max_n=9, ctxs=None, show=True, unique=False, **kwargs):
        if unique:
            old_get_idxs = self.get_idxs
            self.get_idxs = lambda: Inf.zeros
        if b is None: b = self.one_batch()
        if not show: return self._pre_show_batch(b, max_n=max_n)
        show = show_batch[type(b[0][0]), type(b[1][0])]
        pb = self._pre_show_batch(b, max_n=max_n)
        show(*pb, ctxs=ctxs, max_n=max_n, **kwargs)
        if unique: self.get_idxs = old_get_idxs

In [16]:
dset = np.arange(20).reshape(10,2)

In [17]:
class Letter:
    def __init__(self, v): self.v = v
    def __repr__(self): return f'<Letter: {self.v.__repr__()}>'

In [18]:
letters = L(string.ascii_lowercase, use_list=True).map(Letter)
numbers = L(range_of(letters))
dset = list(zip(letters,numbers)) 

In [19]:
dls = DataLoaders.from_dsets(dset, bs=3, dl_type=DetectDataLoader, shuffle=False, num_workers=0)

In [20]:
test_eq(first(dls.train), (Bucket(Tuple(letters[:3])), Bucket(Tuple(numbers[:3]))))

In [21]:
b = dls.one_batch()
test_eq(detect_batch_to_samples(b, 2), [(letters[0], 0), (letters[1], 1)])

In [22]:
# cannot be tested where cuda is not available
# dset = [[tensor(1), tensor(2)], [tensor(1), tensor(2)]]
# dls = DataLoaders.from_dsets(dset, bs=2, dl_type=DetectDataLoader, device='cuda')
# xb,yb = dls.one_batch()
# xb[0].is_cuda

In [23]:
#hide
for b in progress_bar(dls.train): pass

## Export -

In [24]:
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 03_data.load.ipynb.
Converted 04_data.core.ipynb.
Converted index.ipynb.
