In [1]:
#default_exp data.load

# Data load
> Core modifications for making the dataloader work

## TODOs:
* Should not modify main dataloader, but instead subclass it.

In [2]:
#export
import fastai2.imports
import fastcore.imports

In [3]:
#export
# def is_iter(o):
#     "Test whether `o` can be used in a `for` loop"
#     res = True
#     try: iter(o)
#     except TypeError: res = False
#     #Rank 0 tensors in PyTorch are not really iterable
#     return res and getattr(o,'ndim',1)
# fastai2.imports.is_iter = is_iter
# fastcore.imports.is_iter = is_iter

In [4]:
#export
from fastai2.vision.all import *

In [5]:
#export
class Bucket:
    def __init__(self, items): self.items = items
    def __repr__(self): return f'<{self.__class__.__name__}: {self.items.__repr__()}>'
#     def __getitem__(self, i): return L.__getitem__(self, i)
    def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else Bucket(self._get(idx))
    def _get(self, i):
        if is_indexer(i) or isinstance(i,slice): return getattr(self.items,'iloc',self.items)[i]
        i = mask2idxs(i)
        return (self.items.iloc[list(i)] if hasattr(self.items,'iloc')
                else self.items.__array__()[(i,)] if hasattr(self.items,'__array__')
                else [self.items[i_] for i_ in i])
#     def __iter__(self): raise TypeError(f"'{self.__class__.__name__}' object is not iterable")
    def __eq__(self, other): return self.items == other
    def __len__(self): return len(self.items)
    def tolist(self): return list(self.items)
    def map(self, f): return type(self)(type(self.items)(f(o) for o in self.items))
    def to_device(self, device): return type(self)(to_device(self.items))
    @property
    def shape(self): return (len(self.items),) # Needed for find_bs

In [6]:
#export
def create_bucket(items):
#     name = items[0].__class__.__name__ + 'Bucket'
#     return type(name, (Bucket,), {})
    return Bucket

In [7]:
#export
class _Buckets(dict):
    def __getitem__(self, k):
        try: return super().__getitem__(type(k[0]))(k)
        except KeyError:
            v = self[type(k[0])] = create_bucket(k)
            return v(k)

In [8]:
buckets = _Buckets()
b1 = buckets[['a', 'b']]
b2 = buckets[['c', 'd']]
test_eq(type(b1), type(b2))

In [9]:
test_eq(b1[0], 'a')
test_eq(b2[1], 'd')

In [32]:
test_eq(b1[0,1], Bucket(['a','b']))

In [10]:
# test_fail(lambda: [o for o in b1])

In [11]:
test_eq(is_iter(b1), False)

In [12]:
test_eq(b1.map(lambda o: o+'foo'), Bucket(['afoo', 'bfoo']))

In [13]:
test_eq(b1.shape, (2,))

In [14]:
#export
_buckets = _Buckets()

In [15]:
#export
def bucketify(items): return _buckets[items]

In [16]:
#export
def _bucket_collate(t): return Tuple(bucketify(o) for o in zip(*t))
def _bucket_convert(t): raise NotImplementedError

In [17]:
#export
def detect_batch_to_samples(b, max_n=10):
    zipped = []
    for i in range(min(len(b[0]), max_n)):
        zipped.append(Tuple([o[i] for o in b]))
    return zipped
#     return L(b).zip()[:max_n]

In [18]:
#export
class DetectDataLoader(TfmdDL):
    def create_batch(self, b): return (_bucket_collate,_bucket_convert)[self.prebatched](b)
    
    def _decode_batch(self, b, max_n=9, full=True):
        f = self.after_item.decode
        f = compose(f, partial(getattr(self.dataset,'decode',noop), full = full))
        return L(detect_batch_to_samples(b, max_n=max_n)).map(f)
    
#     def _one_pass(self):
#         res = super()._one_pass()
#         self._types = {Tuple: [tuple, tuple]} # HACK
#         return res
    
    def show_batch(self, b=None, max_n=9, ctxs=None, show=True, unique=False, **kwargs):
        if unique:
            old_get_idxs = self.get_idxs
            self.get_idxs = lambda: Inf.zeros
        if b is None: b = self.one_batch()
        if not show: return self._pre_show_batch(b, max_n=max_n)
        show = show_batch[type(b[0][0]), type(b[1][0])]
        pb = self._pre_show_batch(b, max_n=max_n)
        show(*pb, ctxs=ctxs, max_n=max_n, **kwargs)
        if unique: self.get_idxs = old_get_idxs

In [19]:
dset = np.arange(20).reshape(10,2)

In [20]:
class Letter:
    def __init__(self, v): self.v = v
    def __repr__(self): return f'<Letter: {self.v.__repr__()}>'

In [21]:
letters = L(string.ascii_lowercase, use_list=True).map(Letter)
numbers = L(range_of(letters))
dset = list(zip(letters,numbers)) 

In [22]:
dls = DataLoaders.from_dsets(dset, bs=3, dl_type=DetectDataLoader, shuffle=False, num_workers=0)

In [23]:
test_eq(first(dls.train), (bucketify(Tuple(letters[:3])), bucketify(Tuple(numbers[:3]))))

In [24]:
b = dls.one_batch()

In [25]:
detect_batch_to_samples(b, 2)

[(<Letter: 'a'>, 0), (<Letter: 'b'>, 1)]

In [26]:
dset = [[tensor(1), tensor(2)], [tensor(1), tensor(2)]]
dls = DataLoaders.from_dsets(dset, bs=2, dl_type=DetectDataLoader, device='cuda')

In [27]:
xb,yb = dls.one_batch()
xb[0].is_cuda

True

In [28]:
for b in progress_bar(dls.train): pass

In [29]:
#export
old_do_call = Transform._do_call
def _do_call(self, f, x, **kwargs):
    if isinstance(x, Bucket):
        _f = lambda o: retain_type(f(o, **kwargs), o, f.returns_none(o))
        return x if f is None else x.map(_f)
    return old_do_call(self, f, x, **kwargs)
Transform._do_call = _do_call

## Export -

In [33]:
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 03_data.load.ipynb.
Converted 04_data.core.ipynb.
Converted index.ipynb.
