In [1]:
#hide
#skip
%config Completer.use_jedi = False
%config IPCompleter.greedy=True
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [3]:
# default_exp fastai.data.block

In [4]:
# export
# Python native modules
import os
from typing import Any,Callable
from inspect import isfunction,ismethod
# Third party libs
from fastcore.all import *
from torch.utils.data.dataloader_experimental import DataLoader2
from fastai.torch_core import *
from fastai.data.transforms import *
import torchdata.datapipes as dp
from collections import deque
from fastai.imports import *
# Local modules
from fastrl.fastai.data.loop.core import *
from fastrl.fastai.data.load import *

# Data Block
> High level API to quickly get your data in a `DataLoader`s

## Transform Block

> Note: We will first validate the lower level API on a dqn before making the data block. This is going to be a naive implimentation.

In [5]:
#|export
def _merge_grouper(o):
    if isinstance(o, LambdaType): return id(o)
    elif isinstance(o, type): return o
    elif (isfunction(o) or ismethod(o)): return o.__qualname__
    return o.__class__

def _merge_tfms(*tfms):
    "Group the `tfms` in a single list, removing duplicates (from the same class) and instantiating"
    g = groupby(concat(*tfms), _merge_grouper)
    return L(v[-1] for k,v in g.items()).map(instantiate)

def _zip(x): return L(x).zip()

In [6]:
#|export
class TransformBlock():
    "A basic wrapper that links defaults transforms for the data block API"
    def __init__(self, 
        type_tfms:list=None, # One or more `Transform`s for converting types. These will be re-called if workers!=0 for the dataloader.
        item_tfms:list=None, # `ItemTransform`s, applied on an item
        batch_tfms:list=None, # `Transform`s or `RandTransform`s, applied by batch
        cbs:list=None, # `Callback`s for use in dataloaders
        dl_type:DataLoader2=None, # Task specific `TfmdDL`, defaults to `TfmdDL`
        dls_kwargs:dict=None, # Additional arguments to be passed to `DataLoaders`
    ):
        self.type_tfms  =            L(type_tfms)
        self.item_tfms  =            L(item_tfms)
        self.batch_tfms =            L(batch_tfms)
        self.cbs        =            L(cbs)
        self.dl_type,self.dls_kwargs = dl_type,({} if dls_kwargs is None else dls_kwargs)

In [7]:
TransformBlock().__dict__

{'type_tfms': (#0) [],
 'item_tfms': (#0) [],
 'batch_tfms': (#0) [],
 'cbs': (#0) [],
 'dl_type': None,
 'dls_kwargs': {}}

In [8]:
# export
def simple_iter_loader_loop(
    items:Iterable,
    cbs:Optional[List[Callback]]=None,
    type_tfms:Optional[Transform]=None,
    item_tfms:Optional[Transform]=None,
    batch_tfms:Optional[Transform]=None,
    bs:int=2,
):
    type_tfms = ifnone(type_tfms,L())
    pipe = dp.map.SequenceWrapper(items)
    pipe = TypeTransformLoop(pipe, type_tfms=type_tfms)
    pipe = dp.iter.MapToIterConverter(pipe) # Will intialize the gym object, which will be an issue when doing multiproc
    pipe = ItemTransformLoop(pipe, item_tfms=ifnone(item_tfms,L()))
    pipe = pipe.batch(bs)
    pipe = BatchTransformLoop(pipe, batch_tfms=ifnone(batch_tfms,L()))
    return pipe

In [9]:
# export
class DataBlock(object):
    def __init__(
        self,
        blocks:List[TransformBlock]=None, # Transform blocks to use 
        loader_loop:Callable=None,
        dl_type=None
    ):
        store_attr(but='loader_loop')
        self.loader_loop = ifnone(loader_loop,default_loader_loop)
        blocks = L(self.blocks if blocks is None else blocks)
        blocks = L(b() if callable(b) else b for b in blocks)
        self.type_tfms = blocks.attrgot('type_tfms', L())

        self.cbs = blocks.attrgot('cbs', L())
        self.item_tfms  = _merge_tfms(*blocks.attrgot('item_tfms',  L()))
        self.batch_tfms = _merge_tfms(*blocks.attrgot('batch_tfms', L()))
        for b in blocks:
            if getattr(b, 'dl_type', None) is not None: self.dl_type = b.dl_type
        if dl_type is not None: self.dl_type = dl_type
        self.dataloaders = delegates(self.dl_type.__init__)(self.dataloaders)
        self.dls_kwargs = merge(*blocks.attrgot('dls_kwargs', {}))

    def datapipes(
        self,
        source:Any,
        bs=2,
        n=1,
        **kwargs,
    ):
        return L(self.loader_loop(
            source,
            cbs=cbs,
            type_tfms=type_tfms,
            item_tfms=self.item_tfms,
            batch_tfms=self.batch_tfms,
            bs=bs,
            n=n,
            **kwargs
        ) for type_tfms,cbs in zip(self.type_tfms,self.cbs))
        
    def dataloaders(
        self,
        source:Any,
        n_workers=0,
        **kwargs
    ):
        pipes = self.datapipes(source,**kwargs)
        return L(pipes).map(DataLoader2,num_workers=n_workers,**self.dls_kwargs)

## Example

In [10]:
import gym

In [11]:
# export
def simple_iter_loader_loop(
    items:Iterable,
    cbs:Optional[List[Callback]]=None,
    type_tfms:Optional[Transform]=None,
    item_tfms:Optional[Transform]=None,
    batch_tfms:Optional[Transform]=None,
    bs:int=2,
    n:int=1,
):
    pipe = dp.map.SequenceWrapper(items)
    pipe = TypeTransformLoop(pipe, type_tfms=ifnone(type_tfms,L()))
    pipe = dp.map.InMemoryCacheHolder(pipe)
    pipe = dp.iter.MapToIterConverter(pipe) # Will intialize the gym object, which will be an issue when doing multiproc
    pipe = pipe.cycle(count=n)
    pipe = ItemTransformLoop(pipe, item_tfms=ifnone(item_tfms,L())).attach_callbacks(cbs)
    pipe = pipe.batch(bs)
    pipe = BatchTransformLoop(pipe, batch_tfms=ifnone(batch_tfms,L()))
    
    if len(cbs)!=0: default_constructor(pipe,cbs)
    
    return pipe

In [88]:
# export
class Flatten(dp.iter.IterDataPipe):
    
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs
    
    @callback_iter
    def __iter__(self):
        for o in self.source_datapipe:
            if not is_listy(o):
                raise Exception(f'Expected listy object got {type(o)}\n{o}')
            for oo in o: yield oo
            
class NStepPipe(dp.iter.IterDataPipe):
    
    def __init__(self, source_datapipe, n=1, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.n = n
        self.kwargs = kwargs
    
    @callback_iter
    def __iter__(self):
        buffer = []
        for o in self.source_datapipe:
            if not type(o)==dict:
                raise Exception(f'Expected dict object generated from `make_step` got {type(o)}\n{o}')
        
            buffer.append(o)
            if not o['done'] and len(buffer)<self.n: continue
            
            while o['done'] and len(buffer)!=0:
                yield tuple(buffer)
                buffer.pop(0)
                
            if not o['done']:
                yield tuple(buffer)
                buffer.pop(0)
                
class NSkipPipe(dp.iter.IterDataPipe):
    
    def __init__(self, source_datapipe, n=1, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.n = n
        self.kwargs = kwargs
    
    @callback_iter
    def __iter__(self):
        skip_idx = 0
        for o in self.source_datapipe:
            if not type(o)==dict:
                raise Exception(f'Expected dict object generated from `make_step` got {type(o)}\n{o}')
            
            skip_idx += 1 # Be aware of the ordering here. we want to always show the first step when we can. 
            if skip_idx%self.n==0 or o['done']: 
                yield o
                if o['done']: skip_idx = 0

In [89]:
# export
class NStepCallback(Callback):
    "A list of data pipes that have an associated job."
    call_on=L(ItemTransformLoop)
    exclude_under=L()
    
    def __init__(self,nsteps=1,nskip=1):
        store_attr()
        self.pipes = L(partial(NSkipPipe,n=nskip),
                       partial(NStepPipe, n=nsteps),
                       Flatten
                       )

In [90]:
# export
def make_step(
    state,
    next_state,
    done,
    action,
    env_id
):
    return dict(state=state,next_state=next_state,done=done,action=action,env_id=env_id)

class GymTypeTransform(Transform):
    def encodes(self,o): return gym.make(o)
    
class GymStepTransform(Transform):
    def encodes(self,o:gym.Env):
        if getattr(o,'is_done',True):
            state = o.reset(seed=getattr(self,'seed',0))
            o.is_done = False
        else:
            state = o.state
        next_state,action,done,_ = o.step(0)
        
        if done: o.is_done = True
        o.state = next_state
        
        return make_step(state,next_state,done,action,env_id=id(o))
    
GymTransformBlock = TransformBlock(
    type_tfms = GymTypeTransform,
    item_tfms = (GymStepTransform,ToTensor),
    cbs = NStepCallback(nsteps=3,nskip=2)
)

In [91]:

block = DataBlock(
    blocks = GymTransformBlock,
    loader_loop=simple_iter_loader_loop
)
pipes = block.datapipes(['CartPole-v1']*1,n=60)

[<__main__.NStepCallback object at 0x7f0469d6bd50>]


In [92]:

from torch.utils.data.graph import traverse

In [93]:
traverse(pipes[0])

{BatchTransformLoop: {BatcherIterDataPipe: {Flatten: {NStepPipe: {NSkipPipe: {ItemTransformLoop[<__main__.NStepCallback object at 0x7f0469d6bd50>]: {CyclerIterDataPipe: {MapToIterConverterIterDataPipe: {InMemoryCacheHolderMapDataPipe: {TypeTransformLoop: {SequenceWrapperMapDataPipe: {}}}}}}}}}}}}

In [94]:
for i,o in enumerate(pipes[0]):
    print(o)
    if i>10:break

[{'state': array([ 0.01323574, -0.21745604, -0.04686959,  0.22950698], dtype=float32), 'next_state': array([ 0.00888662, -0.411878  , -0.04227945,  0.50704503], dtype=float32), 'done': False, 'action': 1.0, 'env_id': 139656932269648}, {'state': array([ 6.4906175e-04, -6.0637945e-01, -3.2138553e-02,  7.8611010e-01],
      dtype=float32), 'next_state': array([-0.01147853, -0.8010455 , -0.01641635,  1.0685112 ], dtype=float32), 'done': False, 'action': 1.0, 'env_id': 139656932269648}]
[{'state': array([-0.02749944, -0.9959465 ,  0.00495387,  1.3559971 ], dtype=float32), 'next_state': array([-0.04741837, -1.1911303 ,  0.03207381,  1.6502256 ], dtype=float32), 'done': False, 'action': 1.0, 'env_id': 139656932269648}, {'state': array([ 6.4906175e-04, -6.0637945e-01, -3.2138553e-02,  7.8611010e-01],
      dtype=float32), 'next_state': array([-0.01147853, -0.8010455 , -0.01641635,  1.0685112 ], dtype=float32), 'done': False, 'action': 1.0, 'env_id': 139656932269648}]
[{'state': array([-0.02749

In [95]:
dls = block.dataloaders(['CartPole-v1']*1,n=10,n_workers=0)

[<__main__.NStepCallback object at 0x7f0469d6bd50>]


In [96]:
for o in dls[0]:
    print(o)

[{'state': tensor([[ 0.0132, -0.2175, -0.0469,  0.2295]]), 'next_state': tensor([[ 0.0089, -0.4119, -0.0423,  0.5070]]), 'done': tensor([False]), 'action': tensor([1.], dtype=torch.float64), 'env_id': tensor([139656926889296])}, {'state': tensor([[ 6.4906e-04, -6.0638e-01, -3.2139e-02,  7.8611e-01]]), 'next_state': tensor([[-0.0115, -0.8010, -0.0164,  1.0685]]), 'done': tensor([False]), 'action': tensor([1.], dtype=torch.float64), 'env_id': tensor([139656926889296])}]
[{'state': tensor([[-0.0275, -0.9959,  0.0050,  1.3560]]), 'next_state': tensor([[-0.0474, -1.1911,  0.0321,  1.6502]]), 'done': tensor([False]), 'action': tensor([1.], dtype=torch.float64), 'env_id': tensor([139656926889296])}, {'state': tensor([[ 6.4906e-04, -6.0638e-01, -3.2139e-02,  7.8611e-01]]), 'next_state': tensor([[-0.0115, -0.8010, -0.0164,  1.0685]]), 'done': tensor([False]), 'action': tensor([1.], dtype=torch.float64), 'env_id': tensor([139656926889296])}]
[{'state': tensor([[-0.0275, -0.9959,  0.0050,  1.3560

In [97]:
# #|export
# @docs
# @funcs_kwargs
# class DataBlock():
#     "Generic container to quickly build `Datasets` and `DataLoaders`"
#     blocks,dl_type = (TransformBlock,TransformBlock),TfmdDL
#     _methods = 'get_items splitter get_y get_x'.split()
#     _msg = "If you wanted to compose several transforms in your getter don't forget to wrap them in a `Pipeline`."
#     def __init__(self, 
#         blocks:list=None, # One or more `TransformBlock`s
#         dl_type:DataLoader2=None, # Task specific `TfmdDL`, defaults to `block`'s dl_type or`TfmdDL`
#         n_inp:int=None, # Number of inputs
#         item_tfms:list=None, # `ItemTransform`s, applied on an item 
#         batch_tfms:list=None, # `Transform`s or `RandTransform`s, applied by batch
#         **kwargs, 
#     ):
#         blocks = L(self.blocks if blocks is None else blocks)
#         blocks = L(b() if callable(b) else b for b in blocks)
#         self.type_tfms = blocks.attrgot('type_tfms', L())
#         self.default_item_tfms  = _merge_tfms(*blocks.attrgot('item_tfms',  L()))
#         self.default_batch_tfms = _merge_tfms(*blocks.attrgot('batch_tfms', L()))
#         for b in blocks:
#             if getattr(b, 'dl_type', None) is not None: self.dl_type = b.dl_type
#         if dl_type is not None: self.dl_type = dl_type
#         self.dataloaders = delegates(self.dl_type.__init__)(self.dataloaders)
#         self.dls_kwargs = merge(*blocks.attrgot('dls_kwargs', {}))

#         self.n_inp = ifnone(n_inp, max(1, len(blocks)-1))

#         if kwargs: raise TypeError(f'invalid keyword arguments: {", ".join(kwargs.keys())}')
#         self.new(item_tfms, batch_tfms)

#     def _combine_type_tfms(self): return L([self.getters, self.type_tfms]).map_zip(
#         lambda g,tt: (g.fs if isinstance(g, Pipeline) else L(g)) + tt)

#     def new(self, 
#         item_tfms:list=None, # `ItemTransform`s, applied on an item
#         batch_tfms:list=None, # `Transform`s or `RandTransform`s, applied by batch 
#     ):
#         self.item_tfms  = _merge_tfms(self.default_item_tfms,  item_tfms)
#         self.batch_tfms = _merge_tfms(self.default_batch_tfms, batch_tfms)
#         return self

#     @classmethod
#     def from_columns(cls, 
#         blocks:list =None, # One or more `TransformBlock`s
#         getters:list =None, # Getter functions applied to results of `get_items`
#         get_items:callable=None, # A function to get items
#         **kwargs,
#     ):
#         if getters is None: getters = L(ItemGetter(i) for i in range(2 if blocks is None else len(L(blocks))))
#         get_items = _zip if get_items is None else compose(get_items, _zip)
#         return cls(blocks=blocks, getters=getters, get_items=get_items, **kwargs)

#     def datasets(self, 
#         source, # The data source
#         verbose:bool=False, # Show verbose messages
#     ) -> Datasets:
#         self.source = source                     ; pv(f"Collecting items from {source}", verbose)
#         items = (self.get_items or noop)(source) ; pv(f"Found {len(items)} items", verbose)
#         splits = (self.splitter or RandomSplitter())(items)
#         pv(f"{len(splits)} datasets of sizes {','.join([str(len(s)) for s in splits])}", verbose)
#         return default_loader_loop(items=source)

#     def dataloaders(self, 
#         source, # The data source
#         path:str='.', # Data source and default `Learner` path 
#         verbose:bool=False, # Show verbose messages
#         **kwargs
#     ) -> DataLoaders:
#         dsets = self.datasets(source, verbose=verbose)
#         kwargs = {**self.dls_kwargs, **kwargs, 'verbose': verbose}
#         return dsets.dataloaders(path=path, after_item=self.item_tfms, after_batch=self.batch_tfms, **kwargs)

#     _docs = dict(new="Create a new `DataBlock` with other `item_tfms` and `batch_tfms`",
#                  datasets="Create a `Datasets` object from `source`",
#                  dataloaders="Create a `DataLoaders` object from `source`")

In [98]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbdev.cli import *
    make_readme()
    notebook2script(silent=True)

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
