In [1]:
#hide
#skip
%config Completer.use_jedi = False
%config IPCompleter.greedy=True
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbverbose.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [3]:
# default_exp fastai.data.block

In [127]:
# export
# Python native modules
import os
from inspect import isfunction,ismethod
from typing import *
# Third party libs
from fastcore.all import *
from fastai.torch_basics import *
# from torch.utils.data.dataloader import DataLoader as OrgDataLoader
from torchdata.datapipes.iter import *
from torch.utils.data.dataloader_experimental import DataLoader2
from fastai.data.transforms import *
# Local modules
from fastrl.fastai.loop import *
from fastrl.fastai.data.load import *

# DataBlock
> High level API to quickly get your data in a DataLoaders

In [5]:
bs = 4
letters = list(string.ascii_lowercase)

In [6]:
#export
class TransformBlock():
    "A basic wrapper that links defaults transforms for the data block API"
    def __init__(self, 
                 type_tfms:Transform=None, # Executed when the DataPipe is 
                 # initialized / wif is run. Intended as a 1 time transform.
                 item_tfms:Transform=None, # Executed on individual elements.
                 batch_tfms:Transform=None, # Executed over a batch.
                 dl_type:MinimumDataLoader=None, # Its recommended not to set this, 
                 # all custom behaviors should be done via callbacks. 
                 dls_kwargs:dict=None
                ):
        self.type_tfms  =            L(type_tfms)
        self.item_tfms  = ToTensor + L(item_tfms)
        self.batch_tfms =            L(batch_tfms)
        self.dl_type,self.dls_kwargs = dl_type,ifnone(dls_kwargs,{})

In [26]:
#export
def CategoryBlock(vocab=None, sort=True, add_na=False):
    "`TransformBlock` for single-label categorical targets"
    return TransformBlock(type_tfms=Categorize(vocab=vocab, sort=sort, add_na=add_na))

In [7]:
#export
def _merge_grouper(o):
    if isinstance(o, LambdaType): return id(o)
    elif isinstance(o, type): return o
    elif (isfunction(o) or ismethod(o)): return o.__qualname__
    return o.__class__

def _merge_tfms(*tfms):
    "Group the `tfms` in a single list, removing duplicates (from the same class) and instantiating"
    g = groupby(concat(*tfms), _merge_grouper)
    return L(v[-1] for k,v in g.items()).map(instantiate)

def _zip(x): return L(x).zip()

In [491]:
#export
class DataBlock():
    "Generic container to quickly build `Datasets` and `DataLoaders`"
    _msg = """If you wanted to compose several transforms in your getter don't 
    forget to wrap them in a `Pipeline`."""
    def __init__(self, 
                 blocks=TransformBlock, 
                 dl_type=MinimumDataLoader, 
                 get_items=None,
                 type_tfms=None, 
                 item_tfms=None, 
                 batch_tfms=None,
                 bs=1,
                 splitter:Optional[Union[IterDataPipe,Callable]]=None, # If a callable, it is 
                 # assumed to split the datapipe into 2. If you want more than 2, 
                 # create a custom IterDataPipe with `__len__` for the number of splits.
                 shuffle:bool=False
                ):
        blocks = L(self.blocks if blocks is None else blocks)
        blocks = L(b() if callable(b) else b for b in blocks)
        self.default_type_tfms  = _merge_tfms(*blocks.attrgot('type_tfms',  L()))
        self.default_item_tfms  = _merge_tfms(*blocks.attrgot('item_tfms',  L()))
        self.default_batch_tfms = _merge_tfms(*blocks.attrgot('batch_tfms', L()))
        for b in blocks:
            if getattr(b, 'dl_type', None) is not None: self.dl_type = b.dl_type
        if dl_type is not None: self.dl_type = dl_type
        self.get_items = get_items
        self.splitter = splitter
        self.shuffle = shuffle
        self.bs = bs
        self.dls_kwargs = merge(*blocks.attrgot('dls_kwargs', {}))
        self.new(item_tfms, batch_tfms, type_tfms)

    def _combine_type_tfms(self): return L([self.getters, self.type_tfms]).map_zip(
        lambda g,tt: (g.fs if isinstance(g, Pipeline) else L(g)) + tt)

    def new(self, item_tfms=None, batch_tfms=None, type_tfms=None):
        self.type_tfms  = _merge_tfms(self.default_type_tfms,  type_tfms)
        self.item_tfms  = _merge_tfms(self.default_item_tfms,  item_tfms)
        self.batch_tfms = _merge_tfms(self.default_batch_tfms, batch_tfms)
        return self
    
    def datapipes(self, 
                  source:Union[L,Callable] # Absolute initial items for create the `IterDataPipe`s from.
                  # These should be picklable/probably uninitialized. 
                 )->List[IterDataPipe]:
        items = source() if source is callable else source
        items = ifnone(Pipeline(self.get_items),noop)(items)
        
        dps = IterableWrapper(items)
        dps = dps.map(Pipeline(self.type_tfms))
        
        if callable(self.splitter):     dps = dps.demux(2,self.splitter)
        elif self.splitter is not None: dps = self.splitter(dps)
            
        # Regardless of the splitter or not, we will assume it to be a list to
        # standardize the following code.
        dps = L(dps)

        dps = dps.map(Cacher)
        dps = dps.map(Self.map(Pipeline(self.item_tfms)))
        if self.shuffle: 
            for i in range(len(dps)): dps[i] = dps[i].shuffle()
            
        for i in range(len(dps)): dps[i] = dps[i].batch(self.bs)
        dps = dps.map(Self.map(Pipeline(self.batch_tfms)))
  
        return dps
        
        
    def dataloaders(self, source, verbose=False, **kwargs)->List[DataLoader2]:
        dsets = self.datasets(source, verbose=verbose)
        kwargs = {**self.dls_kwargs, **kwargs, 'verbose': verbose}
        return dsets.dataloaders(path=path, after_item=self.item_tfms, after_batch=self.batch_tfms, **kwargs)

In [492]:
# export
T_co = TypeVar("T_co", covariant=True)

class Cacher(IterDataPipe[T_co]):
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs
    
    def __iter__(self) -> Iterator[T_co]:
        cached_entries=[]
        use_cache=False
        while True:
            if use_cache:
                yield from cached_entries
            else:
                try:
                    for v in self.source_datapipe:
                        cached_entries.append(v)
                        yield v
                except StopIteration:
                    use_cache=True
                    cached_entries=cycle(cached_entries)

In [493]:
#For example, so not exported
from fastai.vision.core import *
from fastai.vision.data import *
from fastai.data.external import *

In [494]:
untar_data(URLs.MNIST_TINY)

Path('/home/fastrl_user/.fastai/data/mnist_tiny')

In [495]:
#export
def GrandparentSplitter(train_name='train', valid_name='valid'):
    "Split `items` to indexes 0 (train) and 1 (valid)."
    def _inner(o,negate=False):
        return o.parent.parent.name==(train_name if negate else valid_name)
    return _inner

In [496]:
from torchdata.datapipes.iter import Batcher

In [497]:
mnist = DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),
                  shuffle=True
                   # get_y=parent_label
                 )
dsets = mnist.datapipes(untar_data(URLs.MNIST_TINY))

In [502]:
mnist.type_tfms

(#2) [<bound method PILBase.create of <class 'fastai.vision.core.PILImage'>>,Categorize -- {'vocab': None, 'sort': True, 'add_na': False}:
encodes: (object,object) -> encodes
decodes: (object,object) -> decodes
]

In [498]:
dsets=iter(dsets[1])

In [499]:

next(dsets)

AttributeError: 'NoneType' object has no attribute 'o2i'

In [457]:

# test_eq(mnist.n_inp, 2)
test_eq(len(dsets.train[0]), 3)

AttributeError: 'generator' object has no attribute 'train'

In [503]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbverbose.cli import *
    make_readme()
    notebook2script()
    notebook2html()

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
Converted 00_core.ipynb.
Converted 00_nbdev_extension.ipynb.
Converted 02_fastai.exception_test.ipynb.
Converted 02a_fastai.loop.ipynb.
Converted 02b_fastai.data.load.ipynb.
Converted 02c_fastai.data.block.ipynb.
Converted 03_callback.core.ipynb.
Converted 04_agent.ipynb.
Converted 05_data.test_async.ipynb.
Converted 05a_data.block.ipynb.
Converted 05b_data.gym.ipynb.
Converted 06a_memory.experience_replay.ipynb.
Converted 06f_memory.tensorboard.ipynb.
Converted 10a_agents.dqn.core.ipynb.
Converted 10b_agents.dqn.targets.ipynb.
Converted 10c_agents.dqn.double.ipynb.
Converted 10d_agents.dqn.dueling.ipynb.
Converted 10e_agents.dqn.categorical.ipynb.
Converted 11a_agents.policy_gradient.ppo.ipynb.
Converted 20_test_utils.ipynb.
Converted index.ipynb.
Converted nbdev_template.ipynb.
converting: /home/fastrl_user/fastrl/nbs/10c_agents.dqn.double.ipynb
converting: /home/fastrl_user/fastrl/nbs/10e_agents.dqn.categorical.ipynb
c