In [9]:
# default_exp data

In [10]:
#export
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from dataclasses import field,asdict
from typing import List,Any,Dict,Callable
from collections import deque
import gym

if IN_NOTEBOOK:
    from IPython import display
    import PIL.Image

In [11]:
# hide
from nbdev.showdoc import *
from nbdev.imports import *
from nbdev.export2html import *
if not os.environ.get("IN_TEST", None):
    assert IN_NOTEBOOK
    assert not IN_COLAB
    assert IN_IPYTHON

# Experience Blocks

> Iterable datasets for returning environment outputs

We need `TfmdSourceDL` to trigger some cleanup before doing an iteration. 

TODO: (Josiah): Is there a way to override the `before_iter` in the DataBlock instead? The main issue is that we need to be able to reference `self` which isn't possible when passing methods through the `DataLoader` params.

In [12]:
# export
def is_single_nested_tuple(b):return isinstance(b,tuple) and len(b)==1 and isinstance(b[0],tuple)
    
class TfmdSourceDL(TfmdDL):
    def before_iter(self):
        super().before_iter()
        self.dataset.reset_src()
        
    def create_item(self,b):
        b=super().create_item(b)
        return b[0] if is_single_nested_tuple(b) else b

    def after_iter(self):
        super().after_iter()
        self.dataset.close_src()

A `TfmdSource` has an adjustable `__len__`. Unlike the `TfmdLists`, `TfmdSource` iters on a single item until the item raises a `SourceExhausted` exception. This means that the soruces `items` are being tracked by a separate index.

In [13]:
# export
class SourceExhausted(Exception):pass

@delegates(TfmdLists)
class TfmdSource(TfmdLists):
    "A `Pipeline` of `tfms` applied to a collection of sources called `items`. Only swtches between them if they get exhausted."
    def __init__(self,items, tfms,n:int=None,cycle_srcs=True,verbose=False,**kwargs):
        self.n=n;self.cycle_srcs=cycle_srcs;self.source_idx=0;self.verbose=verbose;self.res_buffer=deque([]);self.extra_len=0
        super().__init__(items,tfms,**kwargs)
#         store_attr('n,cycle_srcs', self) TODO (Josiah): Does not seem to work?
    
    def __enter__(self):                             self.cycle_srcs=False
    def __exit__(self,exc_type,exc_value,traceback): self.cycle_srcs=True
        
    def __repr__(self): return f"{self.__class__.__name__}: Cycling sources: {self.cycle_srcs}\n{self.items}\ntfms - {self.tfms.fs}"
    def close_src(self):
        [t.close(self) for t in self.tfms if hasattr(t,'close')]
        self.res_buffer.clear()
        
    def reset_src(self): 
        [t.reset(self) for t in self.tfms if hasattr(t,'reset')]
        self.res_buffer.clear()
        
    def setup(self,train_setup=True):super().setup(train_setup);self.reset_src()
     
    def __len__(self):
#         return ifnone(self.n,super().__len__()) TODO (Josiah): self.n is not settable in DataBlock, and since TfmdLists gets reinit, this will not persist
        if len(self.items)!=0 and isinstance(self.items[0],gym.Env) and self.cycle_srcs:
            self.reset_src()
            return self.items[0].spec.max_episode_steps+self.extra_len # TODO(Josiah): This is the only OpenAI dependent code. How do we have htis set in setup?
        if self.n is not None: return self.n
        if len(self.items)!=0 and hasattr(self.items[0],'n'):
            return self.items[0].n # TODO(Josiah): Possible solution to make this more generic?
        return super().__len__()
    
    def __getitem__(self,idx):
        if len(self.res_buffer)!=0:return self.res_buffer.popleft()
        
        try:res=super().__getitem__(self.source_idx if self.cycle_srcs else idx)
        except (IndexError,SourceExhausted) as e:
            if not self.cycle_srcs:raise
            if type(e)==SourceExhausted: 
                self.source_idx+=1;  pv(f'SourceExhausted, incrementing to idx {self.source_idx}',verbose=self.verbose) 
                if len(self.items)<=self.source_idx:e=IndexError(f'Index {self.source_idx} from SourceExhausted except is out of bounds.')
            if type(e)==IndexError:      
                self.source_idx=0;   pv(f'IndexError, setting idx to {self.source_idx}',verbose=self.verbose)
                self.reset_src()
            res=self.__getitem__(self.source_idx)
        
        if is_listy(res):
            self.res_buffer=deque(res)
            return self.res_buffer.popleft()
        return res

In [14]:
# export
class IterableDataBlock(DataBlock):
    tls_type=TfmdSource
    def datasets(self, source, verbose=False):
        self.source = source                     ; pv(f"Collecting items from {source}", verbose)
        items = (self.get_items or noop)(source) ; pv(f"Found {len(items)} items", verbose)
        splits = (self.splitter or RandomSplitter())(items)
        pv(f"{len(splits)} datasets of sizes {','.join([str(len(s)) for s in splits])}", verbose)
        tls=L([self.tls_type(items, t,verbose=verbose) for t in L(ifnone(self._combine_type_tfms(),[None]))])
        return Datasets(items,tls=tls,splits=splits, dl_type=self.dl_type, n_inp=self.n_inp, verbose=verbose)

In [15]:
# export
class MakeTfm(Transform):
    def setup(self,items:TfmdSource,train_setup=False):
        with items: 
            for i in range(len(items)):items[i]=gym.make(items[i])
        return super().setup(items,train_setup)

In [16]:
%matplotlib inline

In [17]:
# export    
def env_display(env:gym.Env):
    img=env.render('rgb_array')
    try:display.clear_output(wait=True)
    except AttributeError:pass
    new_im=PIL.Image.fromarray(img)
    display.display(new_im)

In [18]:
# export
@dataclass
class Experience():
    d:bool;s:np.ndarray;sp:np.ndarray;r:float;a:Any;eid:int=0;episode_r:float=0;absolute_end:bool=False
                                    
    def __repr__(self):return f'Experience({",".join([f"{f.name}={getattr(self,f.name).numpy()}" for f in fields(self)])})'
    @classmethod
    def from_batch(cls,b):
        if isinstance(b,tuple):b=b[0]
        bs=max(len(e) for e in b.values())
        return L([cls(**{k:v[i] for k,v in b.items()}) for i in range(bs)])
                                    
def envlen(o:gym.Env):return o.spec.max_episode_steps

@dataclass
class ResetAndStepTfm(Transform):
    def __init__(self,seed:int=None,agent:object=None,n_steps:int=1,steps_delta:int=1,a:Any=None,history:deque=None,
                 s:dict=None,steps:dict=None,maxsteps:int=None,display:bool=False,hist2dict:bool=True):
        self.seed=seed;self.agent=agent;self.n_steps=n_steps;self.steps_delta=steps_delta;self.a=a;self.history=history;self.hist2dict=hist2dict
        self.maxsteps=maxsteps;self.display=display
        self.s=ifnone(s,{})
        self.steps=ifnone(steps,{})
        # store_attr('n,cycle_srcs', self) TODO (Josiah): Does not seem to work?
            
    def setup(self,items:TfmdSource,train_setup=False):
        self.reset(items)
        self.history=deque(maxlen=self.n_steps)
        return super().setup(items,train_setup)
    
    def reset(self,items):
        if len(items.items)==0:return
        if items.extra_len==0:
            items.extra_len=items.items[0].spec.max_episode_steps*(self.n_steps-1) # Extra steps to unwrap done
        with items:
            self.s={id(o):o.reset() for o in items.items if o.seed(self.seed) or True}
            self.steps={id(o):0 for o in items.items}
            self.maxsteps=ifnone(self.maxsteps,envlen(items.items[0]))
            if self.history is not None:self.history.clear()
        
    def queue2dict(self,q:deque):return [asdict(hist) for hist in tuple(copy(q))]
    def encodes(self,o:gym.Env):
        # If history has finished, then instead we try emptying the environment
        if self.history and self.history[-1].d:
            self.history.popleft()
            if len(self.history)==0:raise SourceExhausted
            if len(self.history)==1:self.history[-1].absolute_end=True
            return self.queue2dict(self.history) if self.hist2dict else copy(self.history)
        
        while True:
            a=ifnone(self.a,o.action_space.sample()) if self.agent is None else self.agent(self.s[id(o)])[0]
            sp,r,d,_=o.env.step(a)
            if self.display:env_display(o)
                
            self.steps[id(o)]+=1
            d=self.steps[id(o)]>=self.maxsteps if not d else d
    
            self.history.append(Experience(d=d,s=self.s[id(o)].copy(),sp=sp.copy(),r=r,a=a,eid=id(o),
                                episode_r=r+(self.history[-1].episode_r if self.history else 0)))
            self.s[id(o)]=sp.copy()
            
            if self.steps[id(o)]%self.steps_delta!=0: continue # TODO(Josiah): if `steps_delta`!=1, it may skip the first state. Is this ok?
            if len(self.history)!=self.n_steps:       continue
            break
        
        if self.history[-1].d and len(self.history)==1:self.history[-1].absolute_end=True
        return self.queue2dict(self.history) if self.hist2dict else copy(self.history)

In [19]:
# export
@delegates(ResetAndStepTfm)
def ExperienceBlock(dls_kwargs=None,**kwargs):
    return TransformBlock(type_tfms=[MakeTfm(),ResetAndStepTfm(**kwargs)],dl_type=TfmdSourceDL,dls_kwargs=dls_kwargs)

In [20]:
import ptan

In [155]:
# envs=[gym.make('MountainCar-v0') for _ in range(5)]
envs=[gym.make('CartPole-v1') for _ in range(5)]
for e in envs:e.seed(0)
class TestAgent(ptan.agent.BaseAgent):
    def __call__(self,s,ss):return [0]*len(s),[0]*len(s)

exp_src=ptan.experience.ExperienceSourceFirstLast(envs, TestAgent(), gamma=0.99, steps_count=1)
# exp_src=ptan.experience.ExperienceSource(envs, TestAgent(), steps_count=1)

for i,e in enumerate(exp_src):
    print(e)
    if i==38:break

ExperienceFirstLast(state=array([-0.04456399,  0.04653909,  0.01326909, -0.02099827]), action=0, reward=1.0, last_state=array([-0.04363321, -0.14877061,  0.01284913,  0.2758415 ]))
ExperienceFirstLast(state=array([-0.04456399,  0.04653909,  0.01326909, -0.02099827]), action=0, reward=1.0, last_state=array([-0.04363321, -0.14877061,  0.01284913,  0.2758415 ]))
ExperienceFirstLast(state=array([-0.04456399,  0.04653909,  0.01326909, -0.02099827]), action=0, reward=1.0, last_state=array([-0.04363321, -0.14877061,  0.01284913,  0.2758415 ]))
ExperienceFirstLast(state=array([-0.04456399,  0.04653909,  0.01326909, -0.02099827]), action=0, reward=1.0, last_state=array([-0.04363321, -0.14877061,  0.01284913,  0.2758415 ]))
ExperienceFirstLast(state=array([-0.04456399,  0.04653909,  0.01326909, -0.02099827]), action=0, reward=1.0, last_state=array([-0.04363321, -0.14877061,  0.01284913,  0.2758415 ]))
ExperienceFirstLast(state=array([-0.04363321, -0.14877061,  0.01284913,  0.2758415 ]), action=0

In [134]:
@dataclass
class Experience():
    d:bool;s:np.ndarray;sp:np.ndarray;r:float;a:Any;eid:int=0;episode_r:float=0;absolute_end:bool=False
                                    
    def __repr__(self):return f'Experience({",".join([f"{f.name}={getattr(self,f.name).numpy()}" for f in fields(self)])})'
    @classmethod
    def from_batch(cls,b):
        if isinstance(b,tuple):b=b[0]
        bs=max(len(e) for e in b.values())
        return L([cls(**{k:v[i] for k,v in b.items()}) for i in range(bs)])

In [151]:
blk=IterableDataBlock(blocks=(ExperienceBlock(n_steps=1,steps_delta=1,a=0,seed=0)),
            splitter=FuncSplitter(lambda x:False))

dls=blk.dataloaders(['CartPole-v1']*1,bs=1,num_workers=0,verbose=False,
                      indexed=True,shuffle_train=False)
for i,x in enumerate(dls[0]):
    print(Experience.from_batch(x))
    if i==38:break

(#1) [Experience(d=False,s=[-0.04456399  0.04653909  0.01326909 -0.02099827],sp=[-0.04363321 -0.14877061  0.01284913  0.2758415 ],r=1.0,a=0,eid=140362763841488,episode_r=1.0,absolute_end=False)]
(#1) [Experience(d=False,s=[-0.04363321 -0.14877061  0.01284913  0.2758415 ],sp=[-0.04660862 -0.3440735   0.01836596  0.5725492 ],r=1.0,a=0,eid=140362763841488,episode_r=2.0,absolute_end=False)]
(#1) [Experience(d=False,s=[-0.04660862 -0.3440735   0.01836596  0.5725492 ],sp=[-0.05349009 -0.5394481   0.02981694  0.87096095],r=1.0,a=0,eid=140362763841488,episode_r=3.0,absolute_end=False)]
(#1) [Experience(d=False,s=[-0.05349009 -0.5394481   0.02981694  0.87096095],sp=[-0.06427906 -0.73496263  0.04723616  1.17286728],r=1.0,a=0,eid=140362763841488,episode_r=4.0,absolute_end=False)]
(#1) [Experience(d=False,s=[-0.06427906 -0.73496263  0.04723616  1.17286728],sp=[-0.07897831 -0.93066572  0.07069351  1.47997674],r=1.0,a=0,eid=140362763841488,episode_r=5.0,absolute_end=False)]
(#1) [Experience(d=False,

In [32]:
import pytest
@delegates(ResetAndStepTfm)
def test_block(n_steps=1,steps_delta=1,block=ExperienceBlock,**kwargs):
    for env in ['MountainCar-v0','CartPole-v1']:
        blk=IterableDataBlock(blocks=(block(n_steps=n_steps,steps_delta=steps_delta,**kwargs)),
                    splitter=FuncSplitter(lambda x:False))

        dls=blk.dataloaders([env]*5,bs=1,num_workers=0,verbose=False,
                              indexed=True,shuffle_train=False)


        states={
            'MountainCar-v0':['tensor([[-0.5891,  0.0000]], dtype=torch.float64)','tensor([[-0.7105,  0.0043]], dtype=torch.float64)'],
            'CartPole-v1':['tensor([[-0.0446,  0.0465,  0.0133, -0.0210]], dtype=torch.float64)',
                           'tensor([[-0.1770, -1.7150,  0.2274,  2.7892]], dtype=torch.float64)']
        }

        print('Starting Iteration')
        counter=0
        counters=[]
        for epoch in range(3):
            dones=0
            for x in dls[0]:
#                 print('\nResult',x)
                if counter==0 and steps_delta==0: test_eq(str(x[0]['s']),states[env][0])
                if x[0]['d']:
#                     print(len(dls[0]),dones)
                    dones+=1
                    if dones==n_steps:
                        test_eq(str(x[0]['sp']),states[env][1])
                        print(counter)
                        # TODO(Josiah): Figure out why this differs between runs.
                        if env=='MountainCar-v0':test_eq(counter,pytest.approx(200*n_steps-n_steps*(0+1),10))
                        counters.append(counter)
                        counter=0
                else:
                    counter+=1
            test_ne(counters[-1],0)

Check that all defaults work as expected. That an episode completes, and that the expected final state gets reached.

In [33]:
test_block(a=0,seed=0,n_steps=1)

Starting Iteration
199
199
199
Starting Iteration
8
445
445


In [34]:
test_block(a=0,seed=0,n_steps=3)

Starting Iteration
594
597
597
Starting Iteration
21
1314
1314


In [35]:
test_block(a=0,seed=0,n_steps=3,steps_delta=3)

Starting Iteration
199
590
590
Starting Iteration
9
1125
1125


In [36]:
# export
class FirstLastTfm(Transform):
    def __init__(self,discount=0.99):self.discount=discount
    
    def reset(self,items):
        if items.extra_len!=0:items.extra_len=0
    
    def encodes(self,o):
        first_o=o[0]
        first_o.sp=o[-1].sp
        total_reward=first_o.r
        elms=list(o)[:-1]

        for exp in reversed(elms):
            total_reward*=self.discount
            total_reward+=exp.r
        first_o.r=total_reward
        return asdict(first_o)


@delegates(ResetAndStepTfm)
def FirstLastExperienceBlock(dls_kwargs=None,**kwargs):
    return TransformBlock(type_tfms=[MakeTfm(),ResetAndStepTfm(hist2dict=False,**kwargs),FirstLastTfm],dl_type=TfmdSourceDL,dls_kwargs=dls_kwargs)

In [37]:
import pytest
@delegates(ResetAndStepTfm)
def fl_test_block(n_steps,block=FirstLastExperienceBlock,**kwargs):
    for env in ['MountainCar-v0','CartPole-v1']:
        blk=IterableDataBlock(blocks=(block(n_steps=n_steps,**kwargs)),
                    splitter=FuncSplitter(lambda x:False),batch_tfms=lambda x:(x['s'],x))

        dls=blk.dataloaders([env]*5,bs=1,num_workers=0,verbose=False,
                              indexed=True,shuffle_train=False)


        states={
            'MountainCar-v0':['tensor([[-0.5891,  0.0000]], dtype=torch.float64)','tensor([[-0.7105,  0.0043]], dtype=torch.float64)'],
            'CartPole-v1':['tensor([[-0.0446,  0.0465,  0.0133, -0.0210]], dtype=torch.float64)',
                           'tensor([[-0.1770, -1.7150,  0.2274,  2.7892]], dtype=torch.float64)']
        }

        print('Starting Iteration')
        counter=0
        counters=[]
        for epoch in range(3):
            dones=0
            for x in dls[0]:
                print('\nResult',x)
                if counter==0: test_eq(str(x[0]['s']),states[env][0])
                if x[0]['d']:
#                     print(len(dls[0]),dones)
                    dones+=1
                    test_eq(str(x[0]['sp']),states[env][1])
#                     print(counter)
                    # TODO(Josiah): Figure out why this differs between runs.
                    if env=='MountainCar-v0':test_eq(counter,pytest.approx(200*n_steps-n_steps*(0+1),1))
                    counters.append(counter)
                    counter=0
                else:
                    counter+=1
            test_ne(counters[-1],0)
fl_test_block(a=0,seed=0,n_steps=3)

Starting Iteration

Result ((tensor([[-0.5891,  0.0000]], dtype=torch.float64), {'d': tensor([False]), 's': tensor([[-0.5891,  0.0000]], dtype=torch.float64), 'sp': tensor([[-0.5922, -0.0015]], dtype=torch.float64), 'r': tensor([-2.9701], dtype=torch.float64), 'a': tensor([0]), 'eid': tensor([140262333237328]), 'episode_r': tensor([-1.], dtype=torch.float64), 'absolute_end': tensor([False])}),)


TypeError: tuple indices must be integers or slices, not str

# Export

In [81]:
# hide
from nbdev.export import *
notebook2script()
notebook2html(n_workers=0)

Converted 00_core.ipynb.
Converted 01_wrappers.ipynb.
Converted 03_basic_agents.ipynb.
Converted 04_learner.ipynb.
Converted 05a_data.ipynb.
Converted 05b_async_data.ipynb.
Converted 06_basic_train.ipynb.
Converted 13_metrics.ipynb.
Converted 14_actorcritic.sac.ipynb.
Converted 15_actorcritic.a3c_data.ipynb.
Converted 16_actorcritic.a2c.ipynb.
Converted index.ipynb.


converting: /opt/project/fastrl/nbs/05a_data.ipynb
