In [117]:
from fastai.vision import *
path = untar_data(URLs.MNIST_TINY)
tfms = get_transforms(do_flip=False)
#path.ls()

# Extending DataBlock API

## The problem

Мне нравится текущий DataBlock API, так как он позволяет сделать довольно много из коробки, сэкономить время на написании однообразоного кода по получению данных и превращению их в DataLoader, и обладает довольно 
приятным, читаемым API.

Попользовавшись fastai и DataBlock API несколько раз на kaggle соревнованиях и в других задачах я обнаружил для себя несколько неудобных моментов, которые все обобщаются под "императивностью текущего API":

1. Создание DataBunch монолитно, неудобно (и в текущем флоу не нужно) разбивать создание DataBunch при помощи DataBlock API на несколько этапов, чтоб в различных экспериментах или шагах экспериментов переопределяя лишь пару параметров можно было получить новый DataBunch без излишних вычислений

In [4]:
data = (ImageList.from_folder(path) #Where to find the data? -> in path and its subfolders
    .split_by_folder()              #How to split in train/valid? -> use the folders
    .label_from_folder()            #How to label? -> depending on the folder of the filenames
    .add_test_folder()              #Optionally add a test set (here default name is test)
    .transform(tfms, size=64)       #Data augmentation? -> use tfms with a size of 64
    .databunch())                   #Finally? -> use the defaults for conversion to ImageDataBunch

2. Связано с пунктом 1 - каждый шаг DataBlock API возвращает промежуточную сущность с набором определенных методов, которые можно использовать дальше. То есть нельзя нарушить порядок настройки итоговой DataBunch и нужно держать в голове ветвление сущностей/методов для перехода с шага на шаг. 

In [23]:
def info(obj):
    attrs = len([mn for mn in dir(obj)])
    print('attrs:', attrs, type(obj))

step = ImageList.from_folder(path); info(step)   #Where to find the data? -> in path and its subfolders
step = step.split_by_folder(); info(step)        #How to split in train/valid? -> use the folders
step = step.label_from_folder(); info(step)      #How to label? -> depending on the folder of the filenames
step = step.add_test_folder(); info(step)        #Optionally add a test set (here default name is test)
step = step.transform(tfms, size=64); info(step) #Data augmentation? -> use tfms with a size of 64
step = step.databunch(); info(step)              #Finally? -> use the defaults for conversion to ImageDataBunch

attrs: 92 <class 'fastai.vision.data.ImageList'>
attrs: 44 <class 'fastai.data_block.ItemLists'>
attrs: 53 <class 'fastai.data_block.LabelLists'>
attrs: 53 <class 'fastai.data_block.LabelLists'>
attrs: 53 <class 'fastai.data_block.LabelLists'>
attrs: 73 <class 'fastai.vision.data.ImageDataBunch'>


3. Вычисления выполняются сразу, по вызову метода. То есть если вы ошиблись в конфигурации, или вам нужен другой метод, или еще что - вам надо будет начинать с начала если вы не сохранили промежуточный результат.

In [26]:
step = (ImageList.from_folder(path)
    .split_by_rand_pct(valid_pct=0.2))
# ... some expertiments, you decided you need different pct split
# Uncomment below to see error
# step.split_by_rand_pct(valid_pct=0.4)

4. Нет возможности посмотреть весь статус датабанча - какие параметры использовались в момент split, откуда вообще данные, как они размечены и тп. Все остается на промежуточных шагах и лишь некоторые вещи долетают до самой DataBunch. Как итог для переиспользования объекта нужен код, который его создавал и параметры, которые были в тот момент использованы в коде.

In [28]:
VAL_PCT = 0.2
data = (ImageList.from_folder(path)
        .split_by_rand_pct(valid_pct=VAL_PCT)
        .label_from_folder()
        .add_test_folder()
        .transform(tfms, size=64)
        .databunch())
# ... some lot of experimenting, changing VAL_PCT value blabla
# And here we can't figure out from `data` what settings were used, what val_pct for example was set

## DataBlock API extension prototype as solution

В качестве улучшения API я предлагаю добавить декларативности. Это можно сделать малой кровью даже с использованием текущего API - достаточно лишь добавить оберток.

In [394]:
# Imports
import inspect
from collections import OrderedDict
from fastai.data_block import PreProcessors # I don't know why it doesn't import within fastai.vision.*

# Util functions
def pp(d, indent=4, ljust=12, skip_none=True):
    res = []
    for key, value in d.items():
        if skip_none and value is None: continue
        val = value.__name__ if inspect.isclass(value) else value
        val = f'DataFrame {val.shape}' if isinstance(val, DataFrame) else val
        res.append(f'{" " * indent}{(str(key) + ":").ljust(ljust)}\t{val}')
    return "\n".join(res)

Базовый абстрактный класс блока. Из этих блоков будет собираться мета объект DataFactory. В каждом блоке будет хранится состояние, при помощи которого можно "собрать" этот блок. Если состояние меняется - можно "пересобрать" блок. Можно встроить автоматическую проверку измененности состояния - следить за аттрибутом settings и сбрасывать "сборку" если что-то изменилось. Но я пока решил этого не встраивать чтоб не усложнять код.

In [406]:
class Block():
    """
    An abstract data block class that all block classes should extend from
    
    As I see - all small blocks should be part of private API and users should not change them directly.
    There will be meta object which will create those blocks, set them up, sort in right order
    and when it is needed - call `assemble` on each of them and use result for next blocks
    
    Every block should have prev block it will based on.
    """
    def __init__(self, prev_block=None, assemble_fn=None, **kwargs):
        self.prev_block = prev_block
        self.assemble_fn = assemble_fn
        self.settings = kwargs
        self.assembly = None

    def _short_repr(self):
        assembled = self.assembly is not None
        return f'{self.__class__.__name__}{" (Assembled ✔)" if assembled else " (Assembled ✘)"}'

    def __repr__(self):
        "Standard method which we'll be using for status representation in metaobject"
        res = []
        res.append(f'{self._short_repr()}')
        if isinstance(self.prev_block, Block):
            res.append(f'{pp({"prev_block": self.prev_block._short_repr()})}')
        res.append(f'{pp({"assemble_fn": self.assemble_fn})}')
        res.append(f'{pp(self.settings)}')
        res = list(filter(None, res))
        return "\n".join(res)
    
    def assemble(self):
        "Will be called when we need actual result of block logic. It will cache its results"
        if self.assembly is not None: return self.assembly
        self.validate()
        return self._assemble()
    
    def _assemble(self):
        "Real implementation of assemble method which will be called from `assemble` if it's not already assembled"
        self.assembly = getattr(self.prev_block.assembly, self.assemble_fn)(**self.settings)
        return self.assembly

    def reassemble(self):
        "Will be called when something above in blocks chain was changed and we need to regenerate result"
        self.assembly = None
        return self.assemble()
    
    def validate(self):
        "Checks that every setting needed for assembly is present"
        assert self.prev_block is not None, 'Every block in chain should have prev block. If it is first block - provide specific InputBlock'
        assembly = self.prev_block.assembly
        assert assembly is not None, 'Prev block should be assembled before we assemble this one'
        afn = self.assemble_fn
        assert afn is not None, f'You need to provide `assemble_fn`, that way block will know how to assemble from prev_block'
        assert hasattr(assembly, afn), f'Class {assembly} don\'t have method `{afn}`'

In [407]:
class InputBlock(Block):
    """
    Special kind of block which actually can not be assmebled - it contains only base item list
    class to start work with.
    """
    def __init__(self, item_list):
        self.prev_block, self.assemble_fn, self.settings = None, None, dict()
        self.assembly = item_list
        
    def _assemble(self): return self.assembly
    def reassemble(self): return self.assembly

    def validate(self):
        assert issubclass(self.assembly, ItemList), 'For now item list class for input block can be only subclass of ItemList'

In [416]:
class IdentityBlock(Block):
    """
    Special kind of block which will do nothing. It will propagate prev_block.assembly to self.assembly.
    """
    def _assemble(self):
        self.assembly = self.prev_block.assembly
        return self.assembly

    def validate(self):
        "Checks that every setting needed for assembly is present"
        assert self.prev_block is not None, 'Every block in chain should have prev block. If it is first block - provide specific InputBlock'
        assembly = self.prev_block.assembly
        assert assembly is not None, 'Prev block should be assembled before we assemble this one'

In [408]:
input_block = InputBlock(ItemList)
input_block.validate()
input_block

InputBlock (Assembled ✔)

In [417]:
identity_block = IdentityBlock(prev_block=input_block)
identity_block.assemble()
identity_block

IdentityBlock (Assembled ✔)
    prev_block: 	InputBlock (Assembled ✔)

### Source blocks

Первый шаг создания DataBunch - указать откуда мы берем данные. Сейчас доступно три возможности - из папки, из df, из csv. Ниже я определяю базовый класс для SourceBlock, и наследники от него - обертки над методами ItemList from_folder, from_df, from_csv

In [409]:
class SourceBlock(Block):
    """
    Source block
    """
    
    ASSEMBLE_FNS = ['from_folder', 'from_df', 'from_csv']
    def validate(self):
        super(SourceBlock, self).validate()
        assert self.assemble_fn in self.ASSEMBLE_FNS, f'You need to provide `assemble_fn` one of {self.ASSEMBLE_FNS}'
        if self.assemble_fn == 'from_folder':
            assert isinstance(self.settings.get('path', None), PathOrStr.__args__), f'To create item list from folder you should provide `path` arg'
        elif self.assemble_fn == 'from_df':
            assert isinstance(self.settings.get('df', None), DataFrame), f'To create item list from df you should provide `df` arg'
        else:
            assert isinstance(self.settings.get('path', None), PathOrStr.__args__), f'To create item list from csv you should provide `path` arg'
            assert isinstance(self.settings.get('csv_name', None), str), f'To create item list from csv you should provide `csv_name` arg'

In [410]:
block = SourceBlock(prev_block=input_block, assemble_fn='from_df', df=pd.DataFrame([1]), path='/path')
#block.prev_block = ItemList
block.validate()
print(block)
block.assemble()
print(block)
block.assemble()
print(block)
block.settings['df'] = pd.DataFrame([[1,2,3],[4,5,6]])
block.reassemble()
print(block)

SourceBlock (Assembled ✘)
    prev_block: 	InputBlock (Assembled ✔)
    assemble_fn:	from_df
    df:         	DataFrame (1, 1)
    path:       	/path
SourceBlock (Assembled ✔)
    prev_block: 	InputBlock (Assembled ✔)
    assemble_fn:	from_df
    df:         	DataFrame (1, 1)
    path:       	/path
SourceBlock (Assembled ✔)
    prev_block: 	InputBlock (Assembled ✔)
    assemble_fn:	from_df
    df:         	DataFrame (1, 1)
    path:       	/path
SourceBlock (Assembled ✔)
    prev_block: 	InputBlock (Assembled ✔)
    assemble_fn:	from_df
    df:         	DataFrame (2, 3)
    path:       	/path


### Filter blocks

Второй опциональный шаг - фильтрация

In [411]:
class FilterBlock(Block):
    """
    Filter block is optional kind of blocks. Will be assembled after source block and before
    splitting block.
    """
    
    ASSEMBLE_FNS = ['filter_by_func', 'filter_by_folder', 'filter_by_rand']
    def validate(self):
        super(SourceBlock, self).validate()
        assert self.assemble_fn in self.ASSEMBLE_FNS, f'You need to provide `assemble_fn` one of {self.ASSEMBLE_FNS}'
        if self.assemble_fn == 'filter_by_func':
            assert isinstance(self.settings.get('func', None), Callable.__args__), f'To filter item list by func you should provide `func` arg'
        elif self.assemble_fn == 'filter_by_rand':
            assert isinstance(self.settings.get('p', None), (int, float)), f'To filter item list randomly you should provide `p` arg'

### Split blocks

Третий обязательный шаг - выбор сплита. По умолчанию думаю стоит впилить split_by_rand_pct - так как это пожалуй самый популярный сплит

In [421]:
class SplitBlock(Block):
    """
    Split block is required block. Will be assembled after filter block and before
    label block.
    """

    ASSEMBLE_FNS = ["split_none", "split_by_list", "split_by_idxs", "split_by_idx", "split_by_folder",
                    "split_by_rand_pct", "split_subsets", "split_by_valid_func", "split_by_files",
                    "split_by_fname_file", "split_from_df"]
    def validate(self):
        super(SourceBlock, self).validate()
        assert self.assemble_fn in self.ASSEMBLE_FNS, f'You need to provide `assemble_fn` one of `SplitBlock.ASSEMBLE_FNS`'
        if self.assemble_fn == 'split_by_list':
            assert self.settings.get('train', None) is not None, f'To split item list by list you should provide `train` list arg'
            assert self.settings.get('valid', None) is not None, f'To split item list by list you should provide `valid` list arg'
        elif self.assemble_fn == 'split_by_idxs':
            assert self.settings.get('train_idx', None) is not None, f'To split item list by idxs you should provide `train_idx` list arg'
            assert self.settings.get('valid_idx', None) is not None, f'To split item list by idxs you should provide `valid_idx` list arg'
        elif self.assemble_fn == 'split_by_idx':
            assert isinstance(self.settings.get('valid_idx', None), Collection[int].__args__), f'To split item list by idxs you should provide `valid_idx` list of ints arg'
        elif self.assemble_fn == 'split_subsets':
            assert isinstance(self.settings.get('train_size', None), float), f'To split item list by subsets you should provide `train_size`float arg'         
            assert isinstance(self.settings.get('valid_size', None), float), f'To split item list by subsets you should provide `train_size`float arg'         
        elif self.assemble_fn == 'split_by_valid_func':
            assert isinstance(self.settings.get('func', None), Callable.__args__), f'To split item list by valid func you should provide `func` arg'
        elif self.assemble_fn == 'split_by_files':
            assert isinstance(self.settings.get('valid_names', None), ItemList), f'To split item list by files you should provide `valid_names` item list arg'
        elif self.assemble_fn == 'split_by_fname_file':
            assert isinstance(self.settings.get('fname', None), PathOrStr.__args__), f'To split item list by fname file you should provide `fname` arg'

In [425]:
class DataChain():
    """
    Meta object for data blocks.
    
    Usage:
    bunch = DataFactory()
    bunch.from_folder(path)
    # => will create FromBlock and set FromBlock.path to path
    bunch.split_by_rand_pct(valid_pct: 0.3)
    # => will create SplitBlock and set #.type to 'rand_pct' and #.pct to 0.3
    ...
    data.info()
    # => Will sort all existed blocks and gather their __repr__ info.
    # => If there are some requried blocks which weren't declared yet - it will create them with default
    # => configuration, or if it's impossible - will make notice about requirement of creation specific blocks
    data = bunch.to_databunch()
    # => will 
    """
    
    DEFAULTS = [
        ('input',      None),
        ('source',     None),
        ('filter',     None),
        ('split',      None),
        ('label',      None),
        ('preprocess', None),
        ('transforms', None),
        ('test_set',   None),
        ('databunch',  None)
    ]
    def __init__(self, item_list=None):
        self.blocks = OrderedDict(self.DEFAULTS)
        if item_list is not None: self.input(item_list)
        
    def __repr__(self):
        "Gather all the blocks including default and show their representations"
        return f'{self.__class__.__name__}\n{pp(self.blocks, indent=2, skip_none=False)}'
    
    ## InputBlock methods
    def input(self, item_list_class:ItemList):
        "Create input block from item list subclass"
        assert issubclass(item_list_class, ItemList), f'DataFactofy#input supports only ItemList subclasses'
        self.blocks['input'] = InputBlock(item_list_class)
        return self
    
    ## SourceBlock methods
    def source(self, assemble_fn:str, **kwargs):
        "Base method for source block. Useful when you want to use config instead of writing every parameter by hand"
        self.blocks['source'] = SourceBlock(prev_block=self.blocks['input'], assemble_fn=assemble_fn, **kwargs)
        return self

    # Methods below are just wrappers of self.source method. They have identical argument lists
    # as fastai original methods. I've added them to make this api totaly compatible with current fastai api
    def from_folder(self, path:PathOrStr, extensions:Collection[str]=None, recurse:bool=True,
                    include:Optional[Collection[str]]=None, processor:PreProcessors=None, **kwargs)->'ItemList':
        self.source(assemble_fn='from_folder', path=path, extensions=extensions, recurse=recurse,
                    include=include, processor=processor, **kwargs)
        return self    
    def from_df(self, df:DataFrame, path:PathOrStr='.', cols:IntsOrStrs=0, processor:PreProcessors=None, **kwargs)->'ItemList':
        self.source(assemble_fn='from_df', df=df, path=path, cols=cols, processor=processor, **kwargs)
        return self    
    def from_csv(self, path:PathOrStr, csv_name:str, cols:IntsOrStrs=0, delimiter:str=None, header:str='infer', 
                 processor:PreProcessors=None, **kwargs)->'ItemList':
        self.source(assemble_fn='from_csv', path=path, csv_name=csv_name, cols=cols, delimiter=delimiter,
                    header=header, processor=processor, **kwargs)
        return self
    
    ## FilterBlock methods
    def filter(self, assemble_fn:str, **kwargs):
        "Base method for filter block"
        self.blocks['filter'] = FilterBlock(prev_block=self.blocks['source'], assemble_fn=assemble_fn, **kwargs)
        return self
    
    # Methods below are just wrappers of self.source method. They have identical argument lists
    # as fastai original methods. I've added them to make this api totaly compatible with current fastai api   
    def filter_by_func(self, func:Callable)->'ItemList':
        self.filter(assemble_fn='filter_by_func', func=func)
        return self
    def filter_by_folder(self, include=None, exclude=None):
        self.filter(assemble_fn='filter_by_folder', include=include, exclude=exclude)
        return self
    def filter_by_rand(self, p:float, seed:int=None):
        self.filter(assemble_fn='filter_by_rand', p=p, seed=seed)
        return self
    
    ## SplitBlock methods
    def split(self, assemble_fn:str, **kwargs):
        "Base method for split block"
        self.blocks['split'] = SplitBlock(prev_block=self.blocks['filter'], assemble_fn=assemble_fn, **kwargs)
        return self

    # Methods below are just wrappers of self.source method. They have identical argument lists
    # as fastai original methods. I've added them to make this api totaly compatible with current fastai api    
    def split_none(self):
        self.split(assemble_fn='split_none')
        return self
    def split_by_list(self, train, valid):
        self.split(assemble_fn='split_by_list', train=train, valid=valid)
        return self
    def split_by_idxs(self, train_idx, valid_idx):
        self.split(assemble_fn='split_by_idxs', train_idx=train_idx, valid_idx=valid_idx)
        return self
    def split_by_idx(self, valid_idx:Collection[int])->'ItemLists':
        self.split(assemble_fn='split_by_idx', valid_idx=valid_idx)
        return self
    def split_by_folder(self, train:str='train', valid:str='valid')->'ItemLists':
        self.split(assemble_fn='split_by_folder', train=train, valid=valid)
        return self
    def split_by_rand_pct(self, valid_pct:float=0.2, seed:int=None)->'ItemLists':
        self.split(assemble_fn='split_by_rand_pct', valid_pct=valid_pct, seed=seed)
        return self
    def split_subsets(self, train_size:float, valid_size:float, seed=None) -> 'ItemLists':
        self.split(assemble_fn='split_subsets', train_size=train_size, valid_size=valid_size, seed=seed)
        return self
    def split_by_valid_func(self, func:Callable)->'ItemLists':
        self.split(assemble_fn='split_by_valid_func', func=func)
        return self
    def split_by_files(self, valid_names:'ItemList')->'ItemLists':
        self.split(assemble_fn='split_by_files', valid_names=valid_names)
        return self
    def split_by_fname_file(self, fname:PathOrStr, path:PathOrStr=None)->'ItemLists':
        self.split(assemble_fn='split_by_fname_file', fname=fname, path=path)
        return self
    def split_from_df(self, col:IntsOrStrs=2):
        self.split(assemble_fn='split_from_df', col=col)
        return self

In [426]:
df = DataChain()
df.input(ItemList)
df.from_folder(path=Path('~/home/data/train'))
df.filter_by_rand(0.3)
df.split_by_rand_pct()
df

DataChain
  input:      	InputBlock (Assembled ✔)
  source:     	SourceBlock (Assembled ✘)
    prev_block: 	InputBlock (Assembled ✔)
    assemble_fn:	from_folder
    path:       	~/home/data/train
    recurse:    	True
  filter:     	FilterBlock (Assembled ✘)
    prev_block: 	SourceBlock (Assembled ✘)
    assemble_fn:	filter_by_rand
    p:          	0.3
  split:      	SplitBlock (Assembled ✘)
    prev_block: 	FilterBlock (Assembled ✘)
    assemble_fn:	split_by_rand_pct
    valid_pct:  	0.2
  label:      	None
  preprocess: 	None
  transforms: 	None
  test_set:   	None
  databunch:  	None