In [None]:
#default_exp data.core

In [None]:
#export
from fastai_local.imports import *
from fastai_local.test import *
from fastai_local.core import *
from fastai_local.data.pipeline import *

# Helper functions for processing data

> Functions for getting, splitting, and labeling data, etc

## Get, split, and label

NB: functions here that are named with a verb and end with *er* return a function.

### Get

In [None]:
# export
def _get_files(p, fs, extensions=None):
    p = Path(p)
    res = [p/f for f in fs if not f.startswith('.')
           and ((not extensions) or f'.{f.split(".")[-1].lower()}' in extensions)]
    return res

def get_files(path, extensions=None, recurse=False, include=None):
    "Get all the files in `path` with optional `extensions`."
    path = Path(path)
    extensions = setify(extensions)
    extensions = {e.lower() for e in extensions}
    if recurse:
        res = []
        for i,(p,d,f) in enumerate(os.walk(path)): # returns (dirpath, dirnames, filenames)
            if include is not None and i==0: d[:] = [o for o in d if o in include]
            else:                            d[:] = [o for o in d if not o.startswith('.')]
            res += _get_files(p, f, extensions)
    else:
        f = [o.name for o in os.scandir(path) if o.is_file()]
        res = _get_files(path, f, extensions)
    return res

In [None]:
path = untar_data(URLs.MNIST_TINY)

In [None]:
test_eq(len(get_files(path/'train'/'3')),346)
test_eq(len(get_files(path/'train'/'3', extensions='.png')),346)
test_eq(len(get_files(path/'train'/'3', extensions='.jpg')),0)
test_eq(len(get_files(path/'train', extensions='.png')),0)
test_eq(len(get_files(path/'train', extensions='.png', recurse=True)),709)
test_eq(len(get_files(path, extensions='.png', recurse=True, include=['train'])),709)
test_eq(len(get_files(path, extensions='.png', recurse=True, include=['train', 'test'])),729)

In [None]:
#export
image_extensions = set(k for k,v in mimetypes.types_map.items() if v.startswith('image/'))

def get_image_files(path, include=None, **kwargs):
    "Get image files in `path` recursively."
    return get_files(path, extensions=image_extensions, recurse=True, include=include)

def image_getter(suf='', **kwargs):
    def _inner(o, **kw): return get_image_files(o/suf, **{**kwargs,**kw})
    return _inner

In [None]:
test_eq(len(get_image_files(path)),1428)
test_eq(len(get_image_files(path/'train')),709)
test_eq(len(get_image_files(path, include='train')),709)
test_eq(len(get_image_files(path, include=['train','valid'])),1408)

In [None]:
test_eq(len(image_getter()(path)),1428)
test_eq(len(image_getter('train')(path)),709)

In [None]:
#export
def show_image(im, ax=None, figsize=None, title=None, **kwargs):
    "Show a PIL image on `ax`."
    if ax is None: _,ax = plt.subplots(figsize=figsize)
    if isinstance(im,Tensor) and im.shape[0]<5: im=im.permute(1,2,0)
    ax.imshow(im, **kwargs)
    if title is not None: ax.set_title(title)
    ax.axis('off')
    return ax

def show_title(o, ax=None):
    if ax is None: print(o)
    else: ax.set_title(o)

#### Split

In [None]:
# export
def random_splitter(valid_pct=0.2, seed=None, **kwargs):
    "Split `items` between train/val with `valid_pct` randomly."
    def _inner(o, **kwargs):
        if seed is not None: torch.manual_seed(seed)
        rand_idx = torch.randperm(len(o))
        cut = int(valid_pct * len(o))
        return rand_idx[cut:],rand_idx[:cut]
    return _inner

In [None]:
trn,val = random_splitter(seed=42)([0,1,2,3,4,5])
test_equal(trn, tensor([3, 2, 4, 1, 5]))
test_equal(val, tensor([0]))

In [None]:
# export
def _grandparent_mask(items, name):
    return [(o.parent.parent.name if isinstance(o, Path) else o.split(os.path.sep)[-2]) == name for o in items]

def grandparent_splitter(train_name='train', valid_name='valid'):
    "Split `items` from the grand parent folder names (`train_name` and `valid_name`)."
    def _inner(o, **kwargs):
        return _grandparent_mask(o, train_name),_grandparent_mask(o, valid_name)
    return _inner

In [None]:
# With string filenames
items = [path/'train'/'3'/'9932.png', path/'valid'/'7'/'7189.png', 
         path/'valid'/'7'/'7320.png', path/'train'/'7'/'9833.png',  
         path/'train'/'3'/'7666.png', path/'valid'/'3'/'925.png',
         path/'train'/'7'/'724.png', path/'valid'/'3'/'93055.png']
trn,val = grandparent_splitter()(items)
test_eq(trn,[True,False,False,True,True,False,True,False])
test_eq(val,[False,True,True,False,False,True,False,True])

#### Label

In [None]:
# export
def parent_label(o, **kwargs):
    "Label `item` with the parent folder name."
    return o.parent.name if isinstance(o, Path) else o.split(os.path.sep)[-1]

def re_labeller(pat):
    "Label `item` with regex `pat`."
    pat = re.compile(pat)
    def _inner(o, **kwargs):
        res = pat.search(str(o))
        assert res,f'Failed to find "{pat}" in "{o}"'
        return res.group(1)
    return _inner