In [1]:
#hide
#skip
%config Completer.use_jedi = False
%config IPCompleter.greedy=True
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbverbose.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [3]:
# default_exp fastai.data.block

In [43]:
# export
# Python native modules
import os
from inspect import isfunction,ismethod
from typing import *
# Third party libs
from fastcore.all import *
from fastai.torch_basics import *
# from torch.utils.data.dataloader import DataLoader as OrgDataLoader
from torchdata.datapipes.iter import *
from torch.utils.data.dataloader_experimental import DataLoader2
from fastai.data.transforms import *
# Local modules
from fastrl.fastai.loop import *
from fastrl.fastai.data.load import *

# DataBlock
> High level API to quickly get your data in a DataLoaders

In [9]:
bs = 4
letters = list(string.ascii_lowercase)

In [27]:
#export
class TransformBlock():
    "A basic wrapper that links defaults transforms for the data block API"
    def __init__(self, 
                 type_tfms:Transform=None, # Executed when the DataPipe is 
                 # initialized / wif is run. Intended as a 1 time transform.
                 item_tfms:Transform=None, # Executed on individual elements.
                 batch_tfms:Transform=None, # Executed over a batch.
                 dl_type:MinimumDataLoader=None, # Its recommended not to set this, 
                 # all custom behaviors should be done via callbacks. 
                 dls_kwargs:dict=None
                ):
        self.type_tfms  =            L(type_tfms)
        self.item_tfms  = ToTensor + L(item_tfms)
        self.batch_tfms =            L(batch_tfms)
        self.dl_type,self.dls_kwargs = dl_type,ifnone(dls_kwargs,{})

In [28]:
#export
def _merge_grouper(o):
    if isinstance(o, LambdaType): return id(o)
    elif isinstance(o, type): return o
    elif (isfunction(o) or ismethod(o)): return o.__qualname__
    return o.__class__

def _merge_tfms(*tfms):
    "Group the `tfms` in a single list, removing duplicates (from the same class) and instantiating"
    g = groupby(concat(*tfms), _merge_grouper)
    return L(v[-1] for k,v in g.items()).map(instantiate)

def _zip(x): return L(x).zip()

In [44]:
#export
class DataBlock():
    "Generic container to quickly build `Datasets` and `DataLoaders`"
    _msg = """If you wanted to compose several transforms in your getter don't 
    forget to wrap them in a `Pipeline`."""
    def __init__(self, 
                 blocks=TransformBlock, 
                 dl_type=MinimumDataLoader, 
                 getters=None,
                 item_tfms=None, 
                 batch_tfms=None,
                 get_items=None,
                 splitter=None):
        blocks = L(self.blocks if blocks is None else blocks)
        blocks = L(b() if callable(b) else b for b in blocks)
        self.default_type_tfms  = _merge_tfms(*blocks.attrgot('type_tfms',  L()))
        self.default_item_tfms  = _merge_tfms(*blocks.attrgot('item_tfms',  L()))
        self.default_batch_tfms = _merge_tfms(*blocks.attrgot('batch_tfms', L()))
        for b in blocks:
            if getattr(b, 'dl_type', None) is not None: self.dl_type = b.dl_type
        if dl_type is not None: self.dl_type = dl_type
        self.dls_kwargs = merge(*blocks.attrgot('dls_kwargs', {}))
        self.new(item_tfms, batch_tfms)

    def _combine_type_tfms(self): return L([self.getters, self.type_tfms]).map_zip(
        lambda g,tt: (g.fs if isinstance(g, Pipeline) else L(g)) + tt)

    def new(self, item_tfms=None, batch_tfms=None):
        self.item_tfms  = _merge_tfms(self.default_item_tfms,  item_tfms)
        self.batch_tfms = _merge_tfms(self.default_batch_tfms, batch_tfms)
        return self
    
    def datapipes(self, 
                  source:Union[L,Callable] # Absolute initial items for create the `IterDataPipe`s from.
                  # These should be picklable/probably uninitialized. 
                 )->List[IterDataPipe]:
        # if source is callable: items=source()
        # else                   items=source
        
        # dps=IterPipe(items)
        # if splitter: 
            # if splitter is callable
            # dps = dp.fork(splitter.n_outputs) 
        
        
        
        
        return dp
        
        
    def dataloaders(self, source, verbose=False, **kwargs)->List[DataLoader2]:
        dsets = self.datasets(source, verbose=verbose)
        kwargs = {**self.dls_kwargs, **kwargs, 'verbose': verbose}
        return dsets.dataloaders(path=path, after_item=self.item_tfms, after_batch=self.batch_tfms, **kwargs)

In [48]:
from torchdata.datapipes.iter import Filter

In [51]:
Filter??

[0;31mInit signature:[0m [0mFilter[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwds[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m        
[0;32mclass[0m [0mFilterIterDataPipe[0m[0;34m([0m[0mIterDataPipe[0m[0;34m[[0m[0mT_co[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;34mr"""[0m
[0;34m    Filters out elements from the source datapipe according to input ``filter_fn`` (functional name: ``filter``).[0m
[0;34m[0m
[0;34m    Args:[0m
[0;34m        datapipe: Iterable DataPipe being filtered[0m
[0;34m        filter_fn: Customized function mapping an element to a boolean.[0m
[0;34m        drop_empty_batches: By default, drops a batch if it is empty after filtering instead of keeping an empty list[0m
[0;34m[0m
[0;34m    Example:[0m
[0;34m        >>> from torchdata.datapipes.iter import IterableWrapper[0m
[0;34m        >>> def is_even(n):[0m
[0;34m        ...     return n % 2 == 0[0m
[0;34m      

In [None]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbverbose.cli import *
    make_readme()
    notebook2script()
    notebook2html()