In [None]:
from fastai.datasets import URLs, untar_data
from pathlib import Path
import torch, re, PIL
import matplotlib.pyplot as plt

In [None]:
path = untar_data(URLs.PETS)

In [None]:
from typing import *

def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]

In [None]:
def compose(x, funcs, *args, order_key='_order', **kwargs):
    key = lambda o: getattr(o, order_key, 0)
    for f in sorted(listify(funcs), key=key): x = f(x, **kwargs)
    return x

In [None]:
class ItemList():
    def __init__(self, items, tfms=None): self.items,self.tfms = listify(items),tfms
    def  get(self, i): return i
    def _get(self, i): return compose(self.get(i), self.tfms)
    def __getitem__(self, idx):
        try: return self._get(self.items[idx])
        except TypeError:
            if isinstance(idx[0],bool):
                assert len(idx)==len(self) # bool mask
                return [self._get(o) for m,o in zip(idx,self.items) if m]
            return [self._get(self.items[i]) for i in idx]
    def __len__(self): return len(self.items)
    def __iter__(self): return iter(self.items)
    def __setitem__(self, i, o): self.items[i] = o
    def __delitem__(self, i): del(self.items[i])
    def __repr__(self):
        res = f'{self.__class__.__name__} ({len(self)} items)\n{self.items[:10]}'
        if len(self)>10: res = res[:-1]+ '...]'
        return res
    def new(self, items, cls=None):
        if cls is None: cls=self.__class__
        return cls(items, tfms=self.tfms)

In [None]:
class LabeledData():
    def process(self, il, proc): return il.new(compose(il.items, proc))

    def __init__(self, x, y, proc_x=None, proc_y=None):
        self.x,self.y = self.process(x, proc_x),self.process(y, proc_y)
        self.proc_x,self.proc_y = proc_x,proc_y
        
    def __repr__(self): return f'{self.__class__.__name__}\nx: {self.x}\ny: {self.y}\n'
    def __getitem__(self,idx): return self.x[idx],self.y[idx]
    def __len__(self): return len(self.x)
    
    def x_obj(self, idx): return self.obj(self.x, idx, self.proc_x)
    def y_obj(self, idx): return self.obj(self.y, idx, self.proc_y)
    
    def obj(self, items, idx, procs):
        isint = isinstance(idx, int) or (isinstance(idx,torch.LongTensor) and not idx.ndim)
        item = items[idx]
        for proc in reversed(listify(procs)):
            item = proc.deproc1(item) if isint else proc.deprocess(item)
        return item

In [None]:
class DataBlock():
    _input_cls = ItemList
    _label_cls = ItemList
    def download(self):         raise NotImplementedError
    def get_items(self, path):  raise NotImplementedError
    def split(self, items):     raise NotImplementedError
    def label(self, items):     raise NotImplementedError
        
    def __init__(self, path=None, tfms=None, proc_x=None, proc_y=None):
        self.path = self.download()
        items = ItemList(self.get_items(path or self.path))
        split_idx = self.split(items)
        labels = ItemList(self.label(items))
        x_train,x_valid = map(lambda o: self._input_cls(items[o], tfms=tfms), split_idx)
        y_train,y_valid = map(lambda o: self._label_cls(labels[o]), split_idx)
        self.train = LabeledData(x_train, y_train, proc_x=proc_x, proc_y=proc_y)
        self.valid = LabeledData(x_valid, y_valid, proc_x=proc_x, proc_y=proc_y)

In [None]:
class ImageList(ItemList):
    def get(self, fn): return PIL.Image.open(fn)

In [None]:
import os 

def setify(o): return o if isinstance(o,set) else set(listify(o))

def _get_files(p, fs, extensions=None):
    p = Path(p)
    res = [p/f for f in fs if not f.startswith('.')
           and ((not extensions) or f'.{f.split(".")[-1].lower()}' in extensions)]
    return res

def get_files(path, extensions=None, recurse=False, include=None):
    path = Path(path)
    extensions = setify(extensions)
    extensions = {e.lower() for e in extensions}
    if recurse:
        res = []
        for i,(p,d,f) in enumerate(os.walk(path)): # returns (dirpath, dirnames, filenames)
            if include is not None and i==0: d[:] = [o for o in d if o in include]
            else:                            d[:] = [o for o in d if not o.startswith('.')]
            res += _get_files(p, f, extensions)
        return res
    else:
        f = [o.name for o in os.scandir(path) if o.is_file()]
        return _get_files(path, f, extensions)

In [None]:
import mimetypes
image_extensions = set(k for k,v in mimetypes.types_map.items() if v.startswith('image/'))

In [None]:
def get_image_files(path, include=None):
    return get_files(path, extensions=image_extensions, recurse=True, include=include)

In [None]:
def random_splitter(items, valid_pct=0.2, seed=None): 
    if seed is not None: torch.manual_seed(seed)
    rand_idx = torch.randperm(len(items))
    cut = int(valid_pct * len(items))
    return rand_idx[:cut],rand_idx[cut:] 

In [None]:
def label_from_func(items, func):
    return [func(o) for o in items]

In [None]:
def label_from_re(items, pat):
    pat = re.compile(pat)
    def _inner(o):
        res = pat.search(str(o))
        assert res,f'Failed to find "{pat}" in "{s}"'
        return res.group(1)
    return label_from_func(items, _inner)

In [None]:
class PetsData(DataBlock):
    _input_cls = ImageList
    
    def download(self):        return untar_data(URLs.PETS)
    def get_items(self, path): return get_image_files(path/"images")
    def split(self, items):    return random_splitter(items)
    def label(self, items):    return label_from_re(items, pat = r'/([^/]+)_\d+.jpg$')

In [None]:
data = PetsData()

In [None]:
img,cls = data.train[0]

In [None]:
img

In [None]:
def make_rgb(item): return item.convert('RGB')

In [None]:
class Transform(): _order=0

class ResizeFixed(Transform):
    _order=10
    def __init__(self,size):
        if isinstance(size,int): size=(size,size)
        self.size = size
        
    def __call__(self, item): return item.resize(self.size, PIL.Image.BILINEAR)

def to_byte_tensor(item):
    res = torch.ByteTensor(torch.ByteStorage.from_buffer(item.tobytes()))
    w,h = item.size
    return res.view(h,w,-1).permute(2,0,1)
to_byte_tensor._order=20

def to_float_tensor(item): return item.float().div_(255.)
to_float_tensor._order=30

In [None]:
tfms = [make_rgb, ResizeFixed(128), to_byte_tensor, to_float_tensor]

In [None]:
data = PetsData(tfms=tfms)

In [None]:
img,cls = data.train[0]

In [None]:
def show_image(im, figsize=(3,3)):
    plt.figure(figsize=figsize)
    plt.axis('off')
    plt.imshow(im.permute(1,2,0))

In [None]:
show_image(img)

In [None]:
class Processor(): 
    def process(self, items): return items

class CategoryProcessor(Processor):
    def __init__(self): self.vocab=None
    
    def __call__(self, items):
        #The vocab is defined on the first use.
        if self.vocab is None:
            self.vocab = uniqueify(items)
            self.otoi  = {v:k for k,v in enumerate(self.vocab)}
        return [self.proc1(o) for o in items]
    def proc1(self, item):  return self.otoi[item]
    
    def deprocess(self, idxs):
        assert self.vocab is not None
        return [self.deproc1(idx) for idx in idxs]
    def deproc1(self, idx): return self.vocab[idx]