In [None]:
# default_exp data

In [2]:
#export
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from dataclasses import field,asdict,fields
from typing import List,Any,Dict,Callable
from collections import deque
import gym

if IN_NOTEBOOK:
    from IPython import display
    import PIL.Image

In [3]:
# hide
from nbdev.showdoc import *
from nbdev.imports import *
from nbdev.export2html import *
if not os.environ.get("IN_TEST", None):
    assert IN_NOTEBOOK
    assert not IN_COLAB
    assert IN_IPYTHON

# Experience Blocks

> Iterable datasets for returning environment outputs

We need `TfmdSourceDL` to trigger some cleanup before doing an iteration. 

TODO: (Josiah): Is there a way to override the `before_iter` in the DataBlock instead? The main issue is that we need to be able to reference `self` which isn't possible when passing methods through the `DataLoader` params.

In [4]:
# export
def is_single_nested_tuple(b):return isinstance(b,tuple) and len(b)==1 and isinstance(b[0],tuple)
    
class TfmdSourceDL(TfmdDL):
    def before_iter(self):
        super().before_iter()
        self.dataset.reset_src()
        
    def create_item(self,b):
        b=super().create_item(b)
        return b[0] if is_single_nested_tuple(b) else b

    def after_iter(self):
        super().after_iter()
        self.dataset.close_src()

A `TfmdSource` has an adjustable `__len__`. Unlike the `TfmdLists`, `TfmdSource` iters on a single item until the item raises a `SourceExhausted` exception. This means that the soruces `items` are being tracked by a separate index.

In [145]:
# export
class SourceExhausted(Exception):pass

@delegates(TfmdLists)
class TfmdSource(TfmdLists):
    "A `Pipeline` of `tfms` applied to a collection of sources called `items`. Only swtches between them if they get exhausted."
    def __init__(self,items, tfms,n:int=None,cycle_srcs=True,verbose=False,**kwargs):
        self.n=n;self.cycle_srcs=cycle_srcs;self.source_idx=0;self.verbose=verbose;self.res_buffer=deque([]);self.extra_len=0;self.n_exhausted_envs=0
        super().__init__(items,tfms,**kwargs)
#         store_attr('n,cycle_srcs', self) TODO (Josiah): Does not seem to work?
    
    def __enter__(self):                             self.cycle_srcs=False
    def __exit__(self,exc_type,exc_value,traceback): self.cycle_srcs=True
        
    def __repr__(self): return f"{self.__class__.__name__}: Cycling sources: {self.cycle_srcs}\n{self.items}\ntfms - {self.tfms.fs}"
    def close_src(self):
        [t.close(self) for t in self.tfms if hasattr(t,'close')]
        self.res_buffer.clear()
        
    def reset_src(self): 
        [t.reset(self) for t in self.tfms if hasattr(t,'reset')]
        print('clearing buffer: ',self.res_buffer)
        self.res_buffer.clear()
        
    def setup(self,train_setup=True):super().setup(train_setup);self.reset_src()
     
    def __len__(self):
#         return ifnone(self.n,super().__len__()) TODO (Josiah): self.n is not settable in DataBlock, and since TfmdLists gets reinit, this will not persist
        if len(self.items)!=0 and isinstance(self.items[0],gym.Env) and self.cycle_srcs:
            self.reset_src()
            return self.items[0].spec.max_episode_steps+self.extra_len # TODO(Josiah): This is the only OpenAI dependent code. How do we have htis set in setup?
        if self.n is not None: return self.n
        if len(self.items)!=0 and hasattr(self.items[0],'n'):
            return self.items[0].n # TODO(Josiah): Possible solution to make this more generic?
        return super().__len__()
    
    def __getitem__(self,idx):
#         print(self.res_buffer)
        if len(self.res_buffer)!=0:return self.res_buffer.popleft()
        if len(self.items)<=self.n_exhausted_envs:
            self.reset_src()
            self.n_exhausted_envs=0
        
        try:
            res=super().__getitem__(self.source_idx if self.cycle_srcs else idx)
            self.source_idx+=1
        except (IndexError,SourceExhausted) as e:
            if not self.cycle_srcs:raise
            if type(e)==SourceExhausted: 
                self.source_idx+=1;  pv(f'SourceExhausted, incrementing to idx {self.source_idx}',verbose=self.verbose) 
                self.n_exhausted_envs+=1
            if type(e)==IndexError: self.source_idx=0;  
            res=self.__getitem__(self.source_idx)
        
        if is_listy(res):
#             print(res)
            self.res_buffer=deque([])
#             print('RES: ',res)
            if type(res)==tuple:
                for e in res: 
#                     print('loadin e: ',e)
                    self.res_buffer.append([tuple(e)])
            else:
                self.res_buffer.append([tuple(res)])
            return self.res_buffer.popleft()
        return res

In [146]:
# export
class IterableDataBlock(DataBlock):
    tls_type=TfmdSource
    def datasets(self, source, verbose=False):
        self.source = source                     ; pv(f"Collecting items from {source}", verbose)
        items = (self.get_items or noop)(source) ; pv(f"Found {len(items)} items", verbose)
        splits = (self.splitter or RandomSplitter())(items)
        pv(f"{len(splits)} datasets of sizes {','.join([str(len(s)) for s in splits])}", verbose)
        tls=L([self.tls_type(items, t,verbose=verbose) for t in L(ifnone(self._combine_type_tfms(),[None]))])
        return Datasets(items,tls=tls,splits=splits, dl_type=self.dl_type, n_inp=self.n_inp, verbose=verbose)

### ExperienceBlock

In [159]:
# export
class SeedZeroWrapper(gym.Wrapper):
    def reset(self,*args,**kwargs):
        self.seed(0)
        return super().reset(*args,**kwargs)

class MakeTfm(Transform):
    def setup(self,items:TfmdSource,train_setup=False):
        with items: 
            for i in range(len(items)):items[i]=SeedZeroWrapper(gym.make(items[i]))
        return super().setup(items,train_setup)

In [160]:
%matplotlib inline

In [161]:
# export    
def env_display(env:gym.Env):
    img=env.render('rgb_array')
    try:display.clear_output(wait=True)
    except AttributeError:pass
    new_im=PIL.Image.fromarray(img)
    display.display(new_im)

In [195]:
# export
import ptan


class TestAgent(ptan.agent.BaseAgent):
    def __call__(self,s,ss):return [0]*len(s),[0]*len(s)


@dataclass
class Experience():
    d:bool;s:np.ndarray;sp:np.ndarray;r:float;a:Any;eid:int=0;episode_r:float=0;absolute_end:bool=False
                                    
    def __str__(self):return self.__repr__()
    def __repr__(self):return f'Experience({",".join([f"{f.name}={getattr(self,f.name).numpy() if isinstance(getattr(self,f.name),Tensor) else getattr(self,f.name)}" for f in fields(self)])})'
    @classmethod
    def from_batch(cls,b):
        if isinstance(b,tuple):b=b[0]
        bs=max(len(e) for e in b.values())
        return L([cls(**{k:v[i] for k,v in b.items()}) for i in range(bs)])
                                    
def envlen(o:gym.Env):return o.spec.max_episode_steps

@dataclass
class ResetAndStepTfm(Transform):
    def __init__(self,seed:int=None,agent:object=None,n_steps:int=1,steps_delta:int=1,a:Any=None,histories:Dict[str,deque]=None,
                 s:dict=None,steps:dict=None,maxsteps:int=None,display:bool=False,hist2dict:bool=True):
        self.seed=seed;self.agent=agent;self.n_steps=n_steps;self.steps_delta=steps_delta;self.a=a;self.histories=histories;self.hist2dict=hist2dict
        self.maxsteps=maxsteps;self.display=display
        self.s=ifnone(s,{})
        self.steps=ifnone(steps,{})
        self._exhausted=False
        self.exp_src=None
        # store_attr('n,cycle_srcs', self) TODO (Josiah): Does not seem to work?
            
    def setup(self,items:TfmdSource,train_setup=False):
#         self.reset(items)
        self.exp_src=iter(ptan.experience.ExperienceSource(items.items, TestAgent(), steps_count=self.n_steps,steps_delta=self.steps_delta))
        return super().setup(items,train_setup)
    
    def reset(self,items):
        print('reset')
        if len(items.items)==0:return
        if items.extra_len==0:
            items.extra_len=items.items[0].spec.max_episode_steps*(self.n_steps-1) # Extra steps to unwrap done
        self.exp_src=iter(ptan.experience.ExperienceSource(items.items, TestAgent(), steps_count=self.n_steps,steps_delta=self.steps_delta))
#         with items:
#             self.s={id(o):o.reset() for o in items.items if o.seed(self.seed) or True}
#             self.steps={id(o):0 for o in items.items}
#             self.maxsteps=ifnone(self.maxsteps,envlen(items.items[0]))
#             if self.histories is not None and self.histories:
#                 for o in items.items: self.histories[id(o)].clear()
#             else:
#                  self.histories={id(o):deque(maxlen=self.n_steps) for o in items}
        
#     def queue2dict(self,q:deque):return [asdict(hist) for hist in tuple(copy(q))]
#     def encodes(self,o:gym.Env):
#         if self._exhausted:
#             self._exhausted=False
#             raise SourceExhausted()
#         # If history has finished, then instead we try emptying the environment
#         if self.histories[id(o)] and self.histories[id(o)][-1].d:
            
#             if len(self.histories[id(o)])==1:
#                 self.histories[id(o)][-1].absolute_end=True
# #                 print('returning')
#                 h=deepcopy(self.histories[id(o)])
#                 self.histories[id(o)].popleft()
#                 self._exhausted=True
#                 return self.queue2dict(h) if self.hist2dict else copy(h)
#             else:
#                 self.histories[id(o)].popleft()
# #             print('hello2',self.histories[id(o)])
#             if len(self.histories[id(o)])==0:
# #                 print('resiting')
#                 raise SourceExhausted
#             return self.queue2dict(self.histories[id(o)]) if self.hist2dict else copy(self.histories[id(o)])
        
#         while True:
#             a=ifnone(self.a,o.action_space.sample()) if self.agent is None else self.agent(self.s[id(o)])[0]
#             sp,r,d,_=o.env.step(a)
# #             print(self.s[id(o)])
#             if self.display:env_display(o)
                
#             self.steps[id(o)]+=1
#             d=self.steps[id(o)]>=self.maxsteps if not d else d
# #             if d:print('done detected trololol')
    
#             self.histories[id(o)].append(Experience(d=d,s=self.s[id(o)].copy(),sp=sp.copy(),r=r,a=a,eid=id(o),
#                                          episode_r=r+(self.histories[id(o)][-1].episode_r if self.histories[id(o)] else 0)))
#             self.s[id(o)]=sp.copy()
            
#             if self.steps[id(o)]%self.steps_delta!=0:          continue # TODO(Josiah): if `steps_delta`!=1, it may skip the first state. Is this ok?
# #             print(len(self.histories[id(o)]))
#             if len(self.histories[id(o)])!=self.n_steps:       continue
#             break
        
#         if self.histories[id(o)][-1].d and len(self.histories[id(o)])==1:self.histories[id(o)][-1].absolute_end=True
# #         print('hello1',self.histories[id(o)])
#         return self.queue2dict(self.histories[id(o)]) if self.hist2dict else copy(self.histories[id(o)])

    def try_hist2dict(self,o,n_pop=1): 
        out=self.queue2dict(self.histories[id(o)]) if self.hist2dict else copy(self.histories[id(o)])
        for _ in range(n_pop):self.histories[id(o)].popleft()
        return out
    
    
    def encodes(self,o:gym.Env):
        exps=next(self.exp_src)
#         print(exps)
        
#         print(exps)
        return exps
#         a=ifnone(self.a,o.action_space.sample()) if self.agent is None else self.agent(self.s[id(o)])[0]
#         sp,r,d,_=o.env.step(a)
#         if d:raise SourceExhausted()
#         self.histories[id(o)].append(Experience(d=d,s=self.s[id(o)].copy(),sp=sp.copy(),r=r,a=a,eid=id(o),
#                                      episode_r=r+(self.histories[id(o)][-1].episode_r if self.histories[id(o)] else 0)))
#         self.s[id(o)]=sp.copy()
        
#         return self.try_hist2dict(o)

In [196]:
# export
@delegates(ResetAndStepTfm)
def ExperienceBlock(dls_kwargs=None,**kwargs):
    return TransformBlock(type_tfms=[MakeTfm(),ResetAndStepTfm(**kwargs)],dl_type=TfmdSourceDL,dls_kwargs=dls_kwargs)

In [197]:
# !/bin/bash -c "source activate fastrl && pip install ptan --no-dependencies"

In [198]:
n_steps=1
steps_delta=1

blk=IterableDataBlock(blocks=(ExperienceBlock(n_steps=n_steps,steps_delta=steps_delta,a=0,seed=0)),
                              splitter=FuncSplitter(lambda x:False))
dls=blk.dataloaders(['CartPole-v1'],bs=3,num_workers=0,verbose=False,
                      indexed=True,shuffle_train=False,n=40)
print('starting')
for e in dls[0]:
    print(e)

reset
clearing buffer:  deque([])
reset
clearing buffer:  deque([])
starting
reset
clearing buffer:  deque([])
reset
clearing buffer:  deque([])
([((tensor([[-0.0446,  0.0465,  0.0133, -0.0210],
        [-0.0436, -0.1488,  0.0128,  0.2758],
        [-0.0466, -0.3441,  0.0184,  0.5725]], dtype=torch.float64), tensor([0, 0, 0]), tensor([1., 1., 1.], dtype=torch.float64), tensor([False, False, False])),)],)
([((tensor([[-0.0535, -0.5394,  0.0298,  0.8710],
        [-0.0643, -0.7350,  0.0472,  1.1729],
        [-0.0790, -0.9307,  0.0707,  1.4800]], dtype=torch.float64), tensor([0, 0, 0]), tensor([1., 1., 1.], dtype=torch.float64), tensor([False, False, False])),)],)
([((tensor([[-0.0976, -1.1266,  0.1003,  1.7939],
        [-0.1201, -1.3227,  0.1362,  2.1160],
        [-0.1466, -1.5189,  0.1785,  2.4474]], dtype=torch.float64), tensor([0, 0, 0]), tensor([1., 1., 1.], dtype=torch.float64), tensor([False, False,  True])),)],)
([((tensor([[-0.0446,  0.0465,  0.0133, -0.0210],
        [-0.0436

In [289]:
n_steps=1
steps_delta=1


blk=IterableDataBlock(blocks=(ExperienceBlock(n_steps=n_steps,steps_delta=steps_delta,a=0,seed=0)),
                              splitter=FuncSplitter(lambda x:False))

envs=[gym.make('CartPole-v1')]
envs=[SeedZeroWrapper(e) for e in envs]
exp_src=ptan.experience.ExperienceSource(envs, TestAgent(), steps_count=n_steps)

dls=blk.dataloaders(['CartPole-v1'],bs=n_steps,num_workers=0,verbose=False,
                      indexed=True,shuffle_train=False,n=40)

fastrl_exp=[]
ptan_exps=[]

for i,(x,e) in enumerate(zip(dls[0],exp_src)):
    un_batch_x=[]
#     print(x[0][0][0][1])
#     print(x[0][0][0][1].shape[0])
    for i in range(x[0][0][0][1].shape[0]):
#         print(i,x[0][0])
        un_batch_x.append(ptan.experience.Experience(*tuple(el[i].numpy() for el in x[0][0][0])))
#     print(tuple(un_batch_x),'\n',e)
    for fastrl_e,ptan_e in zip(un_batch_x,e):
        print(fastrl_e.state,ptan_e.state)
        test_eq(fastrl_e.state,ptan_e.state)

reset
clearing buffer:  deque([])
reset
clearing buffer:  deque([])
reset
clearing buffer:  deque([])
reset
clearing buffer:  deque([])
[-0.04456399  0.04653909  0.01326909 -0.02099827] [-0.04456399  0.04653909  0.01326909 -0.02099827]
[-0.04363321 -0.14877061  0.01284913  0.2758415 ] [-0.04363321 -0.14877061  0.01284913  0.2758415 ]
[-0.04660862 -0.3440735   0.01836596  0.5725492 ] [-0.04660862 -0.3440735   0.01836596  0.5725492 ]
[-0.05349009 -0.5394481   0.02981694  0.87096095] [-0.05349009 -0.5394481   0.02981694  0.87096095]
[-0.06427906 -0.73496263  0.04723616  1.17286728] [-0.06427906 -0.73496263  0.04723616  1.17286728]
[-0.07897831 -0.93066572  0.07069351  1.47997674] [-0.07897831 -0.93066572  0.07069351  1.47997674]
[-0.09759162 -1.12657568  0.10029304  1.79387427] [-0.09759162 -1.12657568  0.10029304  1.79387427]
[-0.12012314 -1.32266817  0.13617053  2.11597167] [-0.12012314 -1.32266817  0.13617053  2.11597167]
[-0.1465765  -1.51886144  0.17848996  2.44744788] [-0.1465765  -

# FirstLastExperienceBlock

In [None]:
# export
class FirstLastTfm(Transform):
    def __init__(self,discount=0.99):self.discount=discount
    
    def reset(self,items):
        if items.extra_len!=0:items.extra_len=0
    
    def encodes(self,o):
        first_o=o[0]
        first_o.sp=o[-1].sp
        total_reward=first_o.r
        elms=list(o)[:-1]

        for exp in reversed(elms):
            total_reward*=self.discount
            total_reward+=exp.r
        first_o.r=total_reward
        return asdict(first_o)


@delegates(ResetAndStepTfm)
def FirstLastExperienceBlock(dls_kwargs=None,**kwargs):
    return TransformBlock(type_tfms=[MakeTfm(),ResetAndStepTfm(hist2dict=False,**kwargs),FirstLastTfm],dl_type=TfmdSourceDL,dls_kwargs=dls_kwargs)

In [None]:
import pytest
@delegates(ResetAndStepTfm)
def fl_test_block(n_steps,block=FirstLastExperienceBlock,**kwargs):
    for env in ['MountainCar-v0','CartPole-v1']:
        blk=IterableDataBlock(blocks=(block(n_steps=n_steps,**kwargs)),
                    splitter=FuncSplitter(lambda x:False),batch_tfms=lambda x:(x['s'],x))

        dls=blk.dataloaders([env]*5,bs=1,num_workers=0,verbose=False,
                              indexed=True,shuffle_train=False)


        states={
            'MountainCar-v0':['tensor([[-0.5891,  0.0000]], dtype=torch.float64)','tensor([[-0.7105,  0.0043]], dtype=torch.float64)'],
            'CartPole-v1':['tensor([[-0.0446,  0.0465,  0.0133, -0.0210]], dtype=torch.float64)',
                           'tensor([[-0.1770, -1.7150,  0.2274,  2.7892]], dtype=torch.float64)']
        }

        print('Starting Iteration')
        counter=0
        counters=[]
        for epoch in range(3):
            dones=0
            for x in dls[0]:
                print('\nResult',x)
                if counter==0: test_eq(str(x[0]['s']),states[env][0])
                if x[0]['d']:
#                     print(len(dls[0]),dones)
                    dones+=1
                    test_eq(str(x[0]['sp']),states[env][1])
#                     print(counter)
                    # TODO(Josiah): Figure out why this differs between runs.
                    if env=='MountainCar-v0':test_eq(counter,pytest.approx(200*n_steps-n_steps*(0+1),1))
                    counters.append(counter)
                    counter=0
                else:
                    counter+=1
            test_ne(counters[-1],0)
fl_test_block(a=0,seed=0,n_steps=3)

# Export

In [None]:
# hide
from nbdev.export import *
notebook2script()
notebook2html(n_workers=0)