In [None]:
# export
from fastai.datasets import URLs, untar_data
from pathlib import Path
import pandas as pd, numpy as np, torch, re, PIL, os, mimetypes, csv, itertools
import matplotlib.pyplot as plt
from collections import OrderedDict
from typing import *
from enum import Enum
from functools import partial,reduce
from torch import tensor
from IPython.core.debugger import set_trace

## Data block API from config class

### ItemList

In [None]:
# export
def noop(x, *args, **kwargs): return x
def range_of(x): return list(range(len(x)))
torch.Tensor.ndim = property(lambda x: x.dim())

import operator

def test(a,b,cmp,cname=None):
    if cname is None: cname=cmp.__name__
    assert cmp(a,b),f"{cname}:\n{a}\n{b}"

def test_eq(a,b):    test(a,b,operator.eq,'==')
def test_ne(a,b):    test(a,b,operator.ne,'!=')
def test_equal(a,b): test(a,b,torch.equal,'==')

def compose(*funcs): return reduce(lambda f,g: lambda x: f(g(x)), reversed(funcs), noop)

In [None]:
# test
test_eq(noop(1),1)

In [None]:
# export
def listify(o):
    "Make `o` a list."
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if not isinstance(o, Iterable): return [o]
    #Rank 0 tensors in PyTorch are Iterable but don't have a length.
    try: a = len(o)
    except: return [o]
    return list(o)

In [None]:
# test
test_eq(listify(None),[])
test_eq(listify([1,2,3]),[1,2,3])
test_ne(listify([1,2,3]),[1,2,])
test_eq(listify('abc'),['abc'])
test_eq(listify(range(0,3)),[0,1,2])
test_eq(listify(tensor(0)),[tensor(0)])
test_eq(listify([tensor(0),tensor(1)]),[tensor(0),tensor(1)])
test_eq(listify(tensor([0.,1.1])),[0,1.1])

In [None]:
# export
def double_listify(o):
    "Make `o` a list of lists."
    o = listify(o)
    return o if len(o) == 0 or isinstance(o[0], list) else [o]

In [None]:
# test
test_eq(double_listify(None),[])
test_eq(double_listify([1,2,3]),[[1,2,3]])
test_eq(double_listify([[1,2,3]]),[[1,2,3]])
test_eq(double_listify([[1],[2],[3]]),[[1],[2],[3]])

In [None]:
# export
def order_sorted(funcs, order_key='_order'):
    key = lambda o: getattr(o, order_key, 0)
    return sorted(listify(funcs), key=key)

def apply_all(x, funcs, *args, order_key='_order', **kwargs):
    "Apply all `funcs` to `x` in order, pass along `args` and `kwargs`."
    for f in order_sorted(funcs, order_key=order_key): x = f(x, *args, **kwargs)
    return x

In [None]:
# test
# basic behavior
def _test_f1(x, a=2): return x**a
def _test_f2(x, a=2): return a*x
test_eq(apply_all(2, [_test_f1, _test_f2]),8)
# order
_test_f1._order = 1
test_eq(apply_all(2, [_test_f1, _test_f2]),16)
#args
test_eq(apply_all(2, [_test_f1, _test_f2], 3),216)
#kwargs
test_eq(apply_all(2, [_test_f1, _test_f2], a=3),216)

In [None]:
# export
class DataSource():
    def __init__(self, items, tfms=None, filters=None):
        if filters is None: filters = [range(len(items))]
        self.items,self.filters,self.tfms = listify(items),listify(filters),None
        tfms = double_listify(tfms) #We want a list of lists of transforms.
        for tfm in sum(tfms, []): getattr(tfm, 'setup', noop)(self)
        tfms = [order_sorted(tfm) for tfm in tfms]
        self.tfms = tfms
        if len(self.tfms)==1: self.tfms = self.tfms*len(self)
        
    def transformed(self, tfms):
        tfms = double_listify(tfms)
        if len(tfms)==1: tfms = tfms*len(self)
        tfms = [a + b for a,b in zip(self.tfms,tfms)]
        return self.__class__(items, tfms, self.filters)
        
    def __len__(self): return len(self.filters)
    def len(self, filt=0): return len(self.filters[filt])
    def __getitem__(self, i): return FilteredList(self, i)

    def sublist(self, filt):
        return [self.get(j,filt) for j in range(self.len(filt))]

    def get(self, idx, filt=0):
        if hasattr(idx,'__len__') and getattr(idx,'ndim',1):
            # rank>0 collection
            if isinstance(idx[0],bool):
                assert len(idx)==self.len(filt) # bool mask
                return [self.get(i,filt) for i,m in enumerate(idx) if m]
            return [self.get(i,filt) for i in idx]  # index list
        if self.filters: idx = self.filters[filt][idx]
        res = self.items[idx]
        if self.tfms: res = apply_all(res, self.tfms[filt])
        return res
    
    def decode(self, o, filt=0):
        if self.tfms: 
            return apply_all(o, [getattr(f, 'decode', noop) for f in reversed(self.tfms[filt])])

    def __iter__(self):
        for i in range_of(self.filters):
            yield (self.get(j,i) for j in range(self.len(i)))
            
    def __eq__(self,b):
        if not isinstance(b,DataSource): b = DataSource(b)
        if len(b) != len(self): return False
        for i in range_of(self.filters):
            if b.len(i) != self.len(i): return False
            return all(self.get(j,i)==b.get(j,i) for j in range_of(self.filters[i]))

    def __repr__(self):
        res = f'{self.__class__.__name__}\n'
        for i,o in enumerate(self):
            l = self.len(i)
            res += f'{i}: ({l} items) ['
            res += ','.join(itertools.islice(map(str,o), 10))
            if l>10: res += '...'
            res += ']\n'
        return res
    
    @property
    def train(self): return self[0]
    @property
    def valid(self): return self[1]

In [None]:
#export
class FilteredList:
    def __init__(self, il, filt): self.il,self.filt = il,filt
    def __getitem__(self,i): return self.il.get(i,self.filt)
    def __len__(self): return self.il.len(self.filt)
    
    def __iter__(self):
        return (self.il.get(j,self.filt) for j in range_of(self))
            
    def __repr__(self):
        res = f'({len(self)} items) ['
        res += ','.join(itertools.islice(map(str,self), 10))
        if len(self)>10: res += '...'
        res += ']\n'
        return res
    
    def decode(self, o): return self.il.decode(o, self.filt)

#### Tests

In [None]:
# test
il = DataSource(range(5))
test_eq(il,[0,1,2,3,4])
test_eq(il.sublist(0),[0,1,2,3,4])
test_ne(il,[0,1,2,3,5])
test_eq(il.get(2),2)
test_eq(il.get([1,2]),[1,2])
test_eq(il.get([True,False,False,True,False]),[0,3])

In [None]:
# test
il = DataSource(range(5), lambda x:x*2)
test_eq(il,[0,2,4,6,8])
test_eq(il.sublist(0),[0,2,4,6,8])
test_ne(il,[1,2,4,6,8])
test_eq(il.get(2), 4)
test_eq(il.get([1,2]), [2,4])
test_eq(il.get([True,False,False,True,False]), [0,6])

In [None]:
# test
il = DataSource(range(5), [[noop],[lambda x:x*2]], [[1,2],[0,3,4]])
test_eq(il.sublist(0),[1,2])
test_eq(il.sublist(1),[0,6,8])
test_eq(il.get(2,1), 8)
test_eq(il.get([1,2], 1), [6,8])
test_eq(il.get([False,True], 0), [2])

In [None]:
il

In [None]:
# test
fl = il[1]
test_eq(list(fl),[0,6,8])
test_eq(fl[2], 8)
test_eq(fl[[1,2]], [6,8])
test_eq(fl[[False,True,True]], [6,8])

In [None]:
fl

### Core helper functions

In [None]:
# export
def uniqueify(x, sort=False, bidir=False):
    "Return the unique elements in `x`, optionally `sort`-ed."
    res = list(OrderedDict.fromkeys(x).keys())
    if sort: res.sort()
    if bidir: return res, {v:k for k,v in enumerate(res)}
    return res

In [None]:
# test
test_eq(set(uniqueify([1,1,0,5,0,3])),{0,1,3,5})
test_eq(uniqueify([1,1,0,5,0,3], sort=True),[0,1,3,5])

In [None]:
# export
def setify(o): return o if isinstance(o,set) else set(listify(o))

In [None]:
# test
test_eq(setify(None),set())
test_eq(setify('abc'),{'abc'})
test_eq(setify([1,2,2]),{1,2})
test_eq(setify(range(0,3)),{0,1,2})
test_eq(setify({1,2}),{1,2})

In [None]:
# export
def onehot(x, c, a=1.):
    "Return the `a`-hot encoded tensor for `x` with `c` classes."
    res = torch.zeros(c)
    if a<1: res += (1-a)/(c-1)
    res[x] = a
    return res

In [None]:
# test
test_equal(onehot(1,5), tensor([0.,1.,0.,0.,0.]))
test_equal(onehot([1,3],5), tensor([0.,1.,0.,1.,0.]))
test_equal(onehot(tensor([1,3]),5), tensor([0.,1.,0.,1.,0.]))
test_equal(onehot([True,False,True,True,False],5), tensor([1.,0.,1.,1.,0.]))
test_equal(onehot([],5), tensor([0.,0.,0.,0.,0.]))

test_equal(onehot(1,5,0.9), tensor([0.025,0.9,0.025,0.025,0.025]))

In [None]:
# export
def _get_files(p, fs, extensions=None):
    p = Path(p)
    res = [p/f for f in fs if not f.startswith('.')
           and ((not extensions) or f'.{f.split(".")[-1].lower()}' in extensions)]
    return res

def get_files(path, extensions=None, recurse=False, include=None):
    "Get all the files in `path` with optional `extensions`."
    path = Path(path)
    extensions = setify(extensions)
    extensions = {e.lower() for e in extensions}
    if recurse:
        res = []
        for i,(p,d,f) in enumerate(os.walk(path)): # returns (dirpath, dirnames, filenames)
            if include is not None and i==0: d[:] = [o for o in d if o in include]
            else:                            d[:] = [o for o in d if not o.startswith('.')]
            res += _get_files(p, f, extensions)
    else:
        f = [o.name for o in os.scandir(path) if o.is_file()]
        res = _get_files(path, f, extensions)
    return res

In [None]:
# test
path = untar_data(URLs.MNIST_TINY)
test_eq(len(get_files(path/'train'/'3')),346)
test_eq(len(get_files(path/'train'/'3', extensions='.png')),346)
test_eq(len(get_files(path/'train'/'3', extensions='.jpg')),0)
test_eq(len(get_files(path/'train', extensions='.png')),0)
test_eq(len(get_files(path/'train', extensions='.png', recurse=True)),709)
test_eq(len(get_files(path, extensions='.png', recurse=True, include=['train'])),709)
test_eq(len(get_files(path, extensions='.png', recurse=True, include=['train', 'test'])),729)

### Helper functions

#### Get image files

In [None]:
#export
def show_image(im, ax=None, figsize=None, **kwargs):
    if ax is None: _,ax = plt.subplots(figsize=figsize)
    #ax.imshow(im[0] if im.shape[0] == 1 else im.permute(1,2,0), **kwargs)
    ax.imshow(im)
    ax.axis('off')
    return ax

image_extensions = set(k for k,v in mimetypes.types_map.items() if v.startswith('image/'))

def get_image_files(path, include=None):
    "Get image files in `path` recursively."
    return get_files(path, extensions=image_extensions, recurse=True, include=include)

In [None]:
# test
path = untar_data(URLs.MNIST_TINY)
test_eq(len(get_image_files(path)),1428)
test_eq(len(get_image_files(path/'train')),709)
test_eq(len(get_image_files(path, include='train')),709)
test_eq(len(get_image_files(path, include=['train','valid'])),1408)

#### Split

In [None]:
# export
def random_splitter(items, valid_pct=0.2, seed=None):
    "Split `items` between train/val with `valid_pct` randomly."
    if seed is not None: torch.manual_seed(seed)
    rand_idx = torch.randperm(len(items))
    cut = int(valid_pct * len(items))
    return rand_idx[cut:],rand_idx[:cut]

In [None]:
#test
trn,val = random_splitter([0,1,2,3,4,5], seed=42)
test_equal(trn, tensor([3, 2, 4, 1, 5]))
test_equal(val, tensor([0]))

In [None]:
# export
def _grandparent_mask(items, name):
    return [(o.parent.parent.name if isinstance(o, Path) else o.split(os.path.sep)[-2]) == name for o in items]

def grandparent_splitter(items, train_name='train', valid_name='valid'):
    "Split `items` from the grand parent folder names (`train_name` and `valid_name`)."
    return _grandparent_mask(items, train_name),_grandparent_mask(items, valid_name)

In [None]:
path = untar_data(URLs.MNIST_TINY)

In [None]:
#test
#With string filenames
path = untar_data(URLs.MNIST_TINY)
items = [path/'train'/'3'/'9932.png', path/'valid'/'7'/'7189.png', 
         path/'valid'/'7'/'7320.png', path/'train'/'7'/'9833.png',  
         path/'train'/'3'/'7666.png', path/'valid'/'3'/'925.png',
         path/'train'/'7'/'724.png', path/'valid'/'3'/'93055.png']
trn,val = grandparent_splitter(items)
test_eq(trn,[True,False,False,True,True,False,True,False])
test_eq(val,[False,True,True,False,False,True,False,True])

#### Label

In [None]:
# export
def parent_label(o):
    "Label `item` with the parent folder name."
    return o.parent.name if isinstance(o, Path) else o.split(os.path.sep)[-1]

def re_labeller(pat):
    "Label `item` with regex `pat`."
    pat = re.compile(pat)
    def _inner(o):
        res = pat.search(str(o))
        assert res,f'Failed to find "{pat}" in "{o}"'
        return res.group(1)
    return _inner

### Pets DataSource

In [None]:
#export
class TupleTransform():
    def __init__(self, *tfms): 
        self.tfms = [order_sorted(listify(tfm)) for tfm in tfms]
    def __call__(self,o): return [apply_all(o, tfm) for tfm in self.tfms]
    def decode(self, o): 
        return [apply_all(x, [getattr(f, 'decode', noop) for f in reversed(tfm)]) for x,tfm in zip(o,self.tfms)]
    
    def setup(self,items):
        for tfm in sum(self.tfms, []): getattr(tfm, 'setup', noop)(items)
            
class Transform():
    def setup(self, items): return  # 1-time setup
    def __call__(self,o): return o  # transform
    def decode(self,o): return o    # reverse transform

class Categorize(Transform):
    def __init__(self, tfm=noop): self.tfm,self.vocab = tfm,None
    def __call__(self,o): return self.o2i[self.tfm(o)]
    def decode(self, o): return self.vocab[o]
    def show(self, o, ax): ax.set_title(o)
        
    def setup(self, items):
        if self.vocab is not None: return
        vals = [self.tfm(o) for o in items.train]
        self.vocab,self.o2i = uniqueify(vals, sort=True, bidir=True)

In [None]:
source = untar_data(URLs.PETS)/"images"
items = get_image_files(source)
split_idx = random_splitter(items)
xt = PIL.Image.open
yt = Categorize(re_labeller(pat = r'/([^/]+)_\d+.jpg$'))

In [None]:
pets = DataSource(items, TupleTransform(xt,yt), split_idx)

In [None]:
x,y = pets.get(0,0)
x,y

In [None]:
(x,y) = pets.decode((x,y), 0)

In [None]:
ax = show_image(x)
yt.show(y, ax)

## Transforms

In [None]:
# export
TfmY = Enum('TfmY', 'Mask Image Point Bbox No')

In [None]:
# export
class ImageTransform():
    "Basic class for image transforms."
    _order=0
    _tfm_y_func={TfmY.Image: 'apply_img',   TfmY.Mask: 'apply_mask', TfmY.No: 'noop',
                 TfmY.Point: 'apply_point', TfmY.Bbox: 'apply_bbox'}
    _decode_y_func={TfmY.Image: 'unapply_img',   TfmY.Mask: 'unapply_mask', TfmY.No: 'noop',
                   TfmY.Point: 'unapply_point', TfmY.Bbox: 'unapply_bbox'}
    
    def randomize(self): pass
    
    def __call__(self, o, **kwargs):
        x,y = o
        self.randomize() # Ensures we have the same state for x and y
        self.x = x # Saves the x in case it's needed in the apply for y
        return self.apply(x),self.apply_y(y, **kwargs)
    
    def decode(self, o, **kwargs):
        (x,y) = o
        return self.unapply(x),self.unapply_y(y, **kwargs)

    def noop(self,x):         return x
    def apply_img(self, y):   return self.apply(y)
    def apply_mask(self, y):  return self.apply_img(y)
    def apply_point(self, y): return y
    def apply_bbox(self, y):  return self.apply_point(y)

    def apply(self, x): return x
    def apply_y(self, y, tfm_y=TfmY.No):
        return getattr(self, self._tfm_y_func[tfm_y])(y)
    
    def unapply_img(self, y):   return self.unapply(y)
    def unapply_mask(self, y):  return self.unapply_img(y)
    def unapply_point(self, y): return y
    def unapply_bbox(self, y):  return self.unapply_point(y)
    
    def unapply(self, x): return x
    def unapply_y(self, y, tfm_y=TfmY.No):
        return getattr(self, self._decode_y_func[tfm_y])(y)

In [None]:
# test
import random
class FakeTransform(ImageTransform):
    def randomize(self): self.a = random.randint(1,10)
    def apply(self, x): return x + self.a
    def apply_mask(self, x): return x + 5
    def apply_point(self, x): return x + 2

tfm = FakeTransform()
(x,y) = (5,10)
#Basic behavior: x has changed, not y
t1 = tfm((x,y))
assert t1[0]!=x and t1[1]==y, t1
#Check the same random integer was used for x and y when transforming y
t1 = tfm((x,y), tfm_y=TfmY.Image)
test_eq(t1[0] - 5,t1[1] - 10)
#Check mask, point,bbox implementations
test_eq(tfm((x,y), tfm_y=TfmY.Mask) [1],15)
test_eq(tfm((x,y), tfm_y=TfmY.Point)[1],12)
test_eq(tfm((x,y), tfm_y=TfmY.Bbox) [1],12)

In [None]:
#export
def ifnone(a,b): return b if a is None else a

class DecodeImg(ImageTransform):
    "Convert regular image to RGB, masks to L mode."
    def __init__(self, mode_x='RGB', mode_y=None): self.mode_x,self.mode_y = mode_x,mode_y
    def apply(self, x):       return x.convert(self.mode_x)
    def apply_image(self, y): return y.convert(ifnone(self.mode_y,self.mode_x))
    def apply_mask(self, y):  return y.convert(ifnone(self.mode_y,'L'))

In [None]:
class ResizeFixed(ImageTransform):
    "Resize image to `size` using `mode_x` (and `mode_y` on targets)."
    _order=10
    def __init__(self, size, mode_x=PIL.Image.BILINEAR, mode_y=None):
        if isinstance(size,int): size=(size,size)
        size = (size[1],size[0]) #PIL takes size in the otherway round
        self.size,self.mode_x,self.mode_y = size,mode_x,mode_y
        
    def apply(self, x):       return x.resize(self.size, self.mode_x)
    def apply_image(self, y): return y.resize(self.size, ifnone(self.mode_y,self.mode_x))
    def apply_mask(self, y):  return y.resize(self.size, ifnone(self.mode_y,PIL.Image.NEAREST))

In [None]:
class ToByteTensor(ImageTransform):
    "Transform our items to byte tensors."
    _order=20
    def apply(self, x):
        res = torch.ByteTensor(torch.ByteStorage.from_buffer(x.tobytes()))
        w,h = x.size
        return res.view(h,w,-1).permute(2,0,1)
    
    def unapply(self, x): return x[0] if x.shape[0] == 1 else x.permute(1,2,0)

In [None]:
class ToFloatTensor(ImageTransform):
    "Transform our items to float tensors (int in the case of mask)."
    _order=20
    def __init__(self, div_x=255., div_y=None): self.div_x,self.div_y = div_x,div_y
    def apply(self, x): return x.float().div_(self.div_x)
    def apply_mask(self, x): 
        return x.long() if self.div_y is None else x.long().div_(self.div_y)

## DataBunch

In [None]:
tfms = [DecodeImg(), ResizeFixed(128), ToByteTensor(), ToFloatTensor()]

In [None]:
pets_t = pets.transformed(tfms)

In [None]:
x,y = pets_t.get(1)

In [None]:
(x,y) = pets_t.decode((x,y), 0)

In [None]:
ax = show_image(x)
yt.show(y, ax)

In [None]:
# export
from torch.utils.data.dataloader import DataLoader

def get_dls(il, bs=64):
    return [DataLoader(il[i], bs, shuffle=i==0) for i in range_of(il)]

In [None]:
dls = get_dls(pets_t, 9)

In [None]:
# export
class DataBunch():
    "Basic wrapper around several `DataLoader`s."
    def __init__(self, *dls, fns): self.dls,self.fns = dls,fns
    def one_batch(self, i): return next(iter(self.dls[i]))
    
    @property
    def train_dl(self): return self.dls[0]
    @property
    def valid_dl(self): return self.dls[1]
    @property
    def train_ds(self): return self.train_dl.dataset
    @property
    def valid_ds(self): return self.valid_dl.dataset

    def show_batch(self, i, cols=3, figsize=None):
        b = list(zip(*self.one_batch(i)))
        rows = (len(b)+1) // cols
        if figsize is None: figsize = (cols*3, rows*3)
        fig,axs = plt.subplots(rows, cols, figsize=figsize)
        for it,ax in zip(b,axs.flatten()):
            it = self.dls[i].dataset.decode(it)
            for o,fn in zip(it, self.fns): fn(o, ax)

In [None]:
data = DataBunch(*dls, fns=(show_image,yt.show))

In [None]:
x,y = data.one_batch(0)

In [None]:
x.shape,x.type(),y.shape,y.type()

In [None]:
data.show_batch(0)

## Try different data

### MNIST

In [None]:
class MnistData(DataBlock):
    get_x_cls = ImageGetter
    get_y_cls = CategoryGetter
    
    def get_source(self):        return untar_data(URLs.MNIST)
    def get_items(self, source): return get_image_files(source)
    def split(self, items):      return grandparent_splitter(items, train_name='training', valid_name='testing')
    def label(self, items):      return parent_labeller(items)

In [None]:
data = MnistData(tfms=[ToByteTensor(), ToFloatTensor()]).databunch()

In [None]:
data.show_batch()

cmap is specified in the `item_get` for inputs.

In [None]:
data.train_ds.x.item_get.cmap='gray'

In [None]:
data.show_batch()

### Planet

In [None]:
path = untar_data(URLs.PLANET_SAMPLE)

In [None]:
path.ls()

In [None]:
df = pd.read_csv(path/'labels.csv')

In [None]:
df.head()

In [None]:
# export
class MultiCategoryProcessor(CategoryProcessor):
    "A Processor for multi-labeled categories."
    def proc1(self, item):  return [self.otoi[o] for o in item if o in self.otoi]
    
    def deproc1(self, idx): return [self.vocab[i] for i in idx]
    
    def create_vocab(self, items):
        vocab = set()
        for c in items: vocab = vocab.union(set(c))
        self.vocab = list(vocab)
        self.vocab.sort()
        self.otoi  = {v:k for k,v in enumerate(self.vocab)}

class MultiCategoryGetter(ItemGetter):
    "An `ItemGetter` suitable for multi-label classification targets."
    default_proc = MultiCategoryProcessor
    
    def get(self, o): return onehot(o, len(self.procs[0].vocab))
    def raw(self, o): return [i for i,x in enumerate(o) if x == 1]
    def show(self, x, ax): ax.set_title(';'.join(x))

In [None]:
# test
proc = MultiCategoryProcessor()
#Even if 'c' is the first class, vocab is sorted for reproducibility
test_eq(proc([['c','a'], ['a','b'], ['b'], []]),[[2,0], [0,1], [1], []])
test_eq(proc([['a','c','b'], ['a']]),[[0,2,1],[0]])
test_eq(proc.vocab,['a','b','c'])
test_eq(proc.deprocess([[1,0], [2]]),[['b','a'], ['c']])
test_eq(proc.proc1(['b','a']),[1,0])
test_eq(proc.deproc1([2,0]),['c','a'])

In [None]:
# export
def get_str_column(df, col_name, prefix='', suffix='', delim=None):
    "Read `col_name` in `df`, optionnally adding `prefix` or `suffix`."
    values = df[col_name].values.astype(str)
    values = np.char.add(np.char.add(prefix, values), suffix)
    if delim is not None:
        values = np.array(list(csv.reader(values, delimiter=delim)))
    return values

In [None]:
# test
df = pd.DataFrame({'a': ['cat', 'dog', 'car'], 'b': ['a b', 'c d', 'a e']})
test_equal(get_str_column(df, 'a'), np.array(['cat', 'dog', 'car']))
test_equal(get_str_column(df, 'a', prefix='o'), np.array(['ocat', 'odog', 'ocar']))
test_equal(get_str_column(df, 'a', suffix='.png'), np.array(['cat.png', 'dog.png', 'car.png']))
test_equal(get_str_column(df, 'b', delim=' '), np.array([['a','b'], ['c','d'], ['a','e']]))

In [None]:
class PlanetData(DataBlock):
    get_x_cls = ImageGetter
    get_y_cls = MultiCategoryGetter
    
    def get_source(self):        
        self.path = untar_data(URLs.PLANET_SAMPLE)
        return pd.read_csv(path/'labels.csv')
    def get_items(self, source): return get_str_column(source, 'image_name', prefix=f'{self.path}/train/', suffix='.jpg')
    def split(self, items):      return random_splitter(items)
    def label(self, items):      return get_str_column(self.source, 'tags', delim=' ')

In [None]:
data = PlanetData(tfms=tfms).databunch()

In [None]:
data.show_batch()

### Camvid

In [None]:
# export
class SegmentMaskGetter(ImageGetter):
    "An `ItemGetter` for segmentation mask targets."
    default_tfm = TfmY.Mask
    def __init__(self, procs=None, cmap='tab20', alpha=0.5): 
        super().__init__(procs, cmap=cmap, alpha=alpha)

In [None]:
class CamvidData(DataBlock):
    get_x_cls = ImageGetter
    get_y_cls = SegmentMaskGetter
    
    def get_source(self):        return untar_data(URLs.CAMVID_TINY)      
    def get_items(self, source): return get_image_files(source/'images')
    def split(self, items):      return random_splitter(items)
    def label(self, items):      
        path_lbl = self.source/'labels'
        codes = np.loadtxt(self.source/'codes.txt', dtype=str)
        return func_labeller(items, lambda x: path_lbl/f'{x.stem}_P{x.suffix}')

In [None]:
data = CamvidData(tfms=tfms).databunch(bs=16)

In [None]:
data.show_batch()

### Biwii

In [None]:
import pickle

In [None]:
# export
class PointsGetter(ItemGetter):
    "An `ItemGetter` for points."
    default_tfm = TfmY.Point
    def __init__(self, procs=None, do_scale=True, y_first=False): 
        super().__init__(procs)
        self.do_scale,self.y_first = do_scale,y_first
    
    def get(self, o):
        if not isinstance(o, torch.Tensor): o = tensor(o)
        o = o.view(-1, 2).float()
        if not self.y_first: o = o.flip(1)
        if self.do_scale and hasattr(self, '_x') and self._x is not None: 
            sz = tensor(list(self._x.size)).float()
            o = o * 2/sz - 1
        return o
    
    def raw(self, o):
        o = o.flip(1)
        if hasattr(self, '_x') and self._x is not None: 
            sz = tensor([self._x.shape[1:]]).float()
            o = (o + 1) * sz/2
        return o
    
    def show(self, x, ax):
        params = {'s': 10, 'marker': '.', 'c': 'r'}
        ax.scatter(x[:, 1], x[:, 0], **params)

In [None]:
class BiwiData(DataBlock):
    get_x_cls = ImageGetter
    get_y_cls = PointsGetter
    
    def get_source(self):        return untar_data(URLs.BIWI_SAMPLE)      
    def get_items(self, source): return get_image_files(source/'images')
    def split(self, items):      return random_splitter(items)
    def label(self, items):      
        fn2ctr = pickle.load(open(self.source/'centers.pkl', 'rb'))
        return func_labeller(items, lambda o:fn2ctr[o.name])

In [None]:
data = BiwiData(tfms=tfms).databunch(bs=16)

In [None]:
data.show_batch()

### Coco

In [None]:
from fastai.vision.data import get_annotations

In [None]:
#export
class BBoxProcessor(MultiCategoryProcessor):
    def __call__(self, items): 
        if self.vocab is None:
            vocab = set()
            for c in items: vocab = vocab.union(set(c[1]))
            self.vocab = ['background'] + list(vocab)
            self.vocab.sort()
            self.otoi  = {v:k for k,v in enumerate(self.vocab)}
        return [self.proc1(o) for o in items]
    def proc1(self, item):  return item[0],super().proc1(item[1])
    def deproc1(self, idx): return idx[0],super().deproc1(idx[1])

In [None]:
#export 
from matplotlib import patches, patheffects

def _draw_outline(o, lw):
    o.set_path_effects([patheffects.Stroke(linewidth=lw, foreground='black'), patheffects.Normal()])

def _draw_rect(ax, b, color='white', text=None, text_size=14):
    patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2))
    _draw_outline(patch, 4)
    if text is not None:
        patch = ax.text(*b[:2], text, verticalalignment='top', color=color, fontsize=text_size, weight='bold')
        _draw_outline(patch,1)

In [None]:
# export
class BBoxGetter(PointsGetter):
    default_proc = BBoxProcessor
    default_tfm = TfmY.Bbox
     
    def get(self, o): return super().get(o[0]).view(-1,4),o[1]
    def raw(self, o): return super().raw(o[0].view(-1,2)).view(-1,4),o[1]
    
    def show(self, x, ax):
        bbox,label = x
        for b,l in zip(bbox, label): 
            if l != 'background': _draw_rect(ax, [b[1],b[0],b[3]-b[1],b[2]-b[0]], text=l)

In [None]:
def bb_pad_collate(samples, pad_idx=0):
    max_len = max([len(s[1][1]) for s in samples])
    bboxes = torch.zeros(len(samples), max_len, 4)
    labels = torch.zeros(len(samples), max_len).long() + pad_idx
    imgs = []
    for i,s in enumerate(samples):
        imgs.append(s[0][None])
        bbs, lbls = s[1]
        if not (bbs.nelement() == 0):
            bboxes[i,-len(lbls):] = bbs
            labels[i,-len(lbls):] = tensor(lbls)
    return torch.cat(imgs,0), (bboxes,labels)

In [None]:
class CocoData(DataBlock):
    get_x_cls = ImageGetter
    get_y_cls = BBoxGetter
    
    def get_source(self):        return untar_data(URLs.COCO_TINY)      
    def get_items(self, source): return get_image_files(source/'train')
    def split(self, items):      return random_splitter(items)
    def label(self, items):      
        images, lbl_bbox = get_annotations(self.source/'train.json')
        img2bbox = dict(zip(images, lbl_bbox))
        return func_labeller(items, lambda o:img2bbox[o.name])
    
    def databunch(self, bs=64, **kwargs):
        kwargs['collate_fn'] = bb_pad_collate
        return super().databunch(bs=bs, **kwargs)

In [None]:
data = CocoData(tfms=tfms).databunch(bs=16)

In [None]:
data.show_batch()

## Export

In [None]:
! python notebook2script.py "200_datablock_config.ipynb"