In [1]:
# default_exp data

In [2]:
#export
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from dataclasses import field,asdict
from typing import List,Any,Dict,Callable
from collections import deque
import gym

if IN_NOTEBOOK:
    from IPython import display
    import PIL.Image

In [3]:
# hide
from nbdev.showdoc import *
from nbdev.imports import *
from nbdev.export2html import *
if not os.environ.get("IN_TEST", None):
    assert IN_NOTEBOOK
    assert not IN_COLAB
    assert IN_IPYTHON

# ExperienceSourceDatasets

> Iterable datasets for returning environment outputs

We need `TfmdSourceDL` to trigger some cleanup before doing an iteration. 

TODO: (Josiah): Is there a way to override the `before_iter` in the DataBlock instead? The main issue is that we need to be able to reference `self` which isn't possible when passing methods through the `DataLoader` params.

In [4]:
# export
class TfmdSourceDL(TfmdDL):
    def before_iter(self):
        super().before_iter()
        self.dataset.reset_src()

A `TfmdSource` has an adjustable `__len__`. Unlike the `TfmdLists`, `TfmdSource` iters on a single item until the item raises a `SourceExhausted` exception. This means that the soruces `items` are being tracked by a separate index.

In [438]:
# export
class SourceExhausted(Exception):pass

@delegates(TfmdLists)
class TfmdSource(TfmdLists):
    "A `Pipeline` of `tfms` applied to a collection of sources called `items`. Only swtches between them if they get exhausted."
    def __init__(self,items, tfms,n:int=None,cycle_srcs=True,verbose=False,**kwargs):
        self.n=n;self.cycle_srcs=cycle_srcs;self.source_idx=0;self.verbose=verbose;self.res_buffer=deque([]);self.extra_len=0
        super().__init__(items,tfms,**kwargs)
#         store_attr('n,cycle_srcs', self) TODO (Josiah): Does not seem to work?
        
    def __repr__(self): return f"{self.__class__.__name__}: Cycling sources: {self.cycle_srcs}\n{self.items}\ntfms - {self.tfms.fs}"
    def reset_src(self): 
        [t.reset(self) for t in self.tfms if hasattr(t,'reset')]
        self.res_buffer.clear()
        
    def setup(self,train_setup=True):super().setup(train_setup);self.reset_src()
     
    def __len__(self):
#         return ifnone(self.n,super().__len__()) TODO (Josiah): self.n is not settable in DataBlock, and since TfmdLists gets reinit, this will not persist
        if len(self.items)!=0 and isinstance(self.items[0],gym.Env) and self.cycle_srcs:
            self.reset_src()
            return self.items[0].spec.max_episode_steps+self.extra_len # TODO(Josiah): This is the only OpenAI dependent code. How do we have htis set in setup?
        return super().__len__()
    
    def __getitem__(self,idx):
        if len(self.res_buffer)!=0:
#             print('\nBuffer Not Empty:',self.res_buffer)
            return self.res_buffer.popleft()
        
        try:res=super().__getitem__(self.source_idx if self.cycle_srcs else idx)
        except (IndexError,SourceExhausted) as e:
            if not self.cycle_srcs:raise
            if type(e)==SourceExhausted: 
                self.source_idx+=1;  pv(f'SourceExhausted, incrementing to idx {self.source_idx}',verbose=self.verbose) 
                if len(self.items)<=self.source_idx:e=IndexError(f'Index {self.source_idx} from SourceExhausted except is out of bounds.')
            if type(e)==IndexError:      
                self.source_idx=0;   pv(f'IndexError, setting idx to {self.source_idx}',verbose=self.verbose)
                self.reset_src()
            res=self.__getitem__(self.source_idx)
        
        if is_listy(res):
            self.res_buffer=deque(res)
#             print('\nBuffer Empty:',self.res_buffer)
            return self.res_buffer.popleft()
        return res

In [398]:
# export
class IterableDataBlock(DataBlock):
    def datasets(self, source, verbose=True):
        self.source = source                     ; pv(f"Collecting items from {source}", verbose)
        items = (self.get_items or noop)(source) ; pv(f"Found {len(items)} items", verbose)
        splits = (self.splitter or RandomSplitter())(items)
        pv(f"{len(splits)} datasets of sizes {','.join([str(len(s)) for s in splits])}", verbose)
        tls=L([TfmdSource(items, t,verbose=verbose) for t in L(ifnone(self._combine_type_tfms(),[None]))])
        return Datasets(items,tls=tls,splits=splits, dl_type=self.dl_type, n_inp=self.n_inp, verbose=verbose)

In [399]:
# export
class MakeTfm(Transform):
    def setup(self,items:TfmdSource,train_setup=False):
        items.cycle_srcs=False
        for i in range(len(items)):items[i]=gym.make(items[i])
        return super().setup(items,train_setup)

In [400]:
%matplotlib inline

In [401]:
# export    
def env_display(env:gym.Env):
    img=env.render('rgb_array')
    try:display.clear_output(wait=True)
    except AttributeError:pass
    new_im=PIL.Image.fromarray(img)
    display.display(new_im)

In [491]:
# export
@dataclass
class Experience:d:bool;s:np.ndarray;sp:np.ndarray;r:float;a:Any;eid:int=0
def envlen(o:gym.Env):return o.spec.max_episode_steps

@dataclass
class ResetAndStepTfm(Transform):
    def __init__(self,seed:int=None,agent:object=None,n_steps:int=1,steps_delta:int=1,a:Any=None,history:deque=None,
                 s:dict=None,steps:dict=None,maxsteps:int=None,display:bool=False,hist2dict:bool=True):
        self.seed=seed;self.agent=agent;self.n_steps=n_steps;self.steps_delta=steps_delta;self.a=a;self.history=history;self.hist2dict=hist2dict
        self.maxsteps=maxsteps;self.display=display
        self.s=ifnone(s,{})
        self.steps=ifnone(steps,{})
        # store_attr('n,cycle_srcs', self) TODO (Josiah): Does not seem to work?
            
    def setup(self,items:TfmdSource,train_setup=False):
        self.reset(items)
        self.history=deque(maxlen=self.n_steps)
        return super().setup(items,train_setup)
    
    def reset(self,items):
        if items.extra_len==0:
            items.extra_len=items.items[0].spec.max_episode_steps*(self.n_steps-1) # Extra steps to unwrap done
        items.cycle_srcs=False
        self.s={id(o):o.reset() for o in items.items if o.seed(self.seed) or True}
        self.steps={id(o):0 for o in items.items}
        self.maxsteps=ifnone(self.maxsteps,envlen(items.items[0]))
        if self.history is not None:self.history.clear()
        items.cycle_srcs=True
        
    def queue2dict(self,q:deque):return [asdict(hist) for hist in tuple(copy(q))]
    def encodes(self,o:gym.Env):
        # If history has done, then instead we try emptying the environment
        if self.history and self.history[-1].d:
            self.history.popleft()
            if len(self.history)==0:raise SourceExhausted
            return self.queue2dict(self.history) if self.hist2dict else copy(self.history)
        
        while True:
            a=ifnone(self.a,o.action_space.sample()) if self.agent is None else self.agent(self.s[id(o)])
            sp,r,d,_=o.env.step(a)
            if self.display:env_display(o)
                
            self.steps[id(o)]+=1
            d=self.steps[id(o)]>=self.maxsteps if not d else d
    
            self.history.append(Experience(d=d,s=self.s[id(o)].copy(),sp=sp.copy(),r=r,a=a,eid=id(o)))
            self.s[id(o)]=sp.copy()
            
            if self.steps[id(o)]%self.steps_delta!=0: continue # TODO(Josiah): if `steps_delta`!=1, it may skip the first state. Is this ok?
            if len(self.history)!=self.n_steps:       continue
            break
        
        return self.queue2dict(self.history) if self.hist2dict else copy(self.history)

In [500]:
# export
@delegates(ResetAndStepTfm)
def ExperienceBlock(**kwargs):
    return TransformBlock(type_tfms=[MakeTfm(),ResetAndStepTfm(**kwargs)],dl_type=TfmdSourceDL)

In [558]:
import pytest
@delegates(ResetAndStepTfm)
def test_block(n_steps,steps_delta,block=ExperienceBlock,**kwargs):
    for env in ['MountainCar-v0','CartPole-v1']:
        blk=IterableDataBlock(blocks=(block(n_steps=n_steps,steps_delta=steps_delta,**kwargs)),
                    splitter=FuncSplitter(lambda x:False))

        dls=blk.dataloaders([env]*5,bs=1,num_workers=0,verbose=False,
                              indexed=True,shuffle_train=False)


        states={
            'MountainCar-v0':['tensor([[-0.5891,  0.0000]], dtype=torch.float64)','tensor([[-0.7105,  0.0043]], dtype=torch.float64)'],
            'CartPole-v1':['tensor([[-0.0446,  0.0465,  0.0133, -0.0210]], dtype=torch.float64)',
                           'tensor([[-0.1770, -1.7150,  0.2274,  2.7892]], dtype=torch.float64)']
        }

        print('Starting Iteration')
        counter=0
        counters=[]
        for epoch in range(3):
            dones=0
            for x in dls[0]:
                print('\nResult',x)
                if counter==0: test_eq(str(x[0]['s']),states[env][0])
                if x[0]['d']:
#                     print(len(dls[0]),dones)
                    dones+=1
                    if dones==n_steps:
                        test_eq(str(x[0]['sp']),states[env][1])
                        print(counter)
                        # TODO(Josiah): Figure out why this differs between runs.
                        if env=='MountainCar-v0':test_eq(counter,pytest.approx(200*n_steps-n_steps*(0+1),10))
                        counters.append(counter)
                        counter=0
                else:
                    counter+=1
            test_ne(counters[-1],0)

Check that all defaults work as expected. That an episode completes, and that the expected final state gets reached.

In [504]:
test_block(a=0,seed=0,n_steps=1)

Collecting items from ['CartPole-v1', 'CartPole-v1', 'CartPole-v1', 'CartPole-v1', 'CartPole-v1']
Found 5 items
2 datasets of sizes 5,0
Starting Iteration
8
445
445
Collecting items from ['MountainCar-v0', 'MountainCar-v0', 'MountainCar-v0', 'MountainCar-v0', 'MountainCar-v0']
Found 5 items
2 datasets of sizes 5,0
Starting Iteration
199
199
199


In [505]:
test_block(a=0,seed=0,n_steps=3)

Collecting items from ['CartPole-v1', 'CartPole-v1', 'CartPole-v1', 'CartPole-v1', 'CartPole-v1']
Found 5 items
2 datasets of sizes 5,0
Starting Iteration
21
1314
1314
Collecting items from ['MountainCar-v0', 'MountainCar-v0', 'MountainCar-v0', 'MountainCar-v0', 'MountainCar-v0']
Found 5 items
2 datasets of sizes 5,0
Starting Iteration
594
597
597


In [559]:
test_block(a=0,seed=0,n_steps=3,steps_delta=3)

Collecting items from ['MountainCar-v0', 'MountainCar-v0', 'MountainCar-v0', 'MountainCar-v0', 'MountainCar-v0']
Found 5 items
2 datasets of sizes 5,0
Starting Iteration

Result ({'d': tensor([False]), 's': tensor([[-0.5891,  0.0000]], dtype=torch.float64), 'sp': tensor([[-5.8964e-01, -5.1169e-04]], dtype=torch.float64), 'r': tensor([-1.], dtype=torch.float64), 'a': tensor([0]), 'eid': tensor([140317729972496])},)

Result ({'d': tensor([False]), 's': tensor([[-5.8964e-01, -5.1169e-04]], dtype=torch.float64), 'sp': tensor([[-0.5907, -0.0010]], dtype=torch.float64), 'r': tensor([-1.], dtype=torch.float64), 'a': tensor([0]), 'eid': tensor([140317729972496])},)

Result ({'d': tensor([False]), 's': tensor([[-0.5907, -0.0010]], dtype=torch.float64), 'sp': tensor([[-0.5922, -0.0015]], dtype=torch.float64), 'r': tensor([-1.], dtype=torch.float64), 'a': tensor([0]), 'eid': tensor([140317729972496])},)

Result ({'d': tensor([False]), 's': tensor([[-0.5922, -0.0015]], dtype=torch.float64), 'sp': 

AssertionError: ==:
tensor([[-0.7105,  0.0043]], dtype=torch.float64)
tensor([[-0.5891,  0.0000]], dtype=torch.float64)

In [541]:
# export
class FirstLastTfm(Transform):
    def __init__(self,discount=0.99):self.discount=discount
    
    def reset(self,items):
        if items.extra_len!=0:items.extra_len=0
    
    def encodes(self,o):
        total_reward=0.0
        first_o=o[0]
        first_o.sp=o[-1].sp
        for exp in reversed(list(o)):
            total_reward*=self.discount
            total_reward+=exp.r
        first_o.r=total_reward
        return asdict(first_o)


@delegates(ResetAndStepTfm)
def FirstLastExperienceBlock(**kwargs):
    return TransformBlock(type_tfms=[MakeTfm(),ResetAndStepTfm(hist2dict=False,**kwargs),FirstLastTfm],dl_type=TfmdSourceDL)

In [546]:
import pytest
@delegates(ResetAndStepTfm)
def fl_test_block(n_steps,block=FirstLastExperienceBlock,**kwargs):
    for env in ['MountainCar-v0','CartPole-v1']:
        blk=IterableDataBlock(blocks=(block(n_steps=n_steps,**kwargs)),
                    splitter=FuncSplitter(lambda x:False))

        dls=blk.dataloaders([env]*5,bs=1,num_workers=0,verbose=False,
                              indexed=True,shuffle_train=False)


        states={
            'MountainCar-v0':['tensor([[-0.5891,  0.0000]], dtype=torch.float64)','tensor([[-0.7105,  0.0043]], dtype=torch.float64)'],
            'CartPole-v1':['tensor([[-0.0446,  0.0465,  0.0133, -0.0210]], dtype=torch.float64)',
                           'tensor([[-0.1770, -1.7150,  0.2274,  2.7892]], dtype=torch.float64)']
        }

        print('Starting Iteration')
        counter=0
        counters=[]
        for epoch in range(3):
            dones=0
            for x in dls[0]:
#                 print('\nResult',x)
                if counter==0: test_eq(str(x[0]['s']),states[env][0])
                if x[0]['d']:
#                     print(len(dls[0]),dones)
                    dones+=1
                    test_eq(str(x[0]['sp']),states[env][1])
#                     print(counter)
                    # TODO(Josiah): Figure out why this differs between runs.
                    if env=='MountainCar-v0':test_eq(counter,pytest.approx(200*n_steps-n_steps*(0+1),1))
                    counters.append(counter)
                    counter=0
                else:
                    counter+=1
            test_ne(counters[-1],0)
fl_test_block(a=0,seed=0,n_steps=3)

Collecting items from ['MountainCar-v0', 'MountainCar-v0', 'MountainCar-v0', 'MountainCar-v0', 'MountainCar-v0']
Found 5 items
2 datasets of sizes 5,0
Starting Iteration
Collecting items from ['CartPole-v1', 'CartPole-v1', 'CartPole-v1', 'CartPole-v1', 'CartPole-v1']
Found 5 items
2 datasets of sizes 5,0
Starting Iteration


# ExperienceSourceDatasets OLD

> Iterable datasets for returning environment outputs.

## EnvLists

> Iterable lists for returning environment outputs.

In [None]:
# export
@delegates(TfmdLists.__init__)
class EnvLists(TfmdLists):
    def __init__(self,items,tfms,**kwargs):
        self.is_set=False
        super().__init__(items,tfms,**kwargs)
    
    def __len__(self):
        if (len(self.items)!=0 and not issubclass(type(self.items[0]),gym.Env)) or not self.is_set: return len(self.items) 
        else:                                                                                       return self.items[0].spec.max_episode_steps

    def _get(self,i): return i if self.is_set else super()._get(i) 
        
    def setup(self,train_setup=True):
        super().setup(train_setup)
        for f in self.fs:
            if hasattr(f,'reset'):f.reset()
        self.is_set=True

### EnvMakeTfm

> Make environments into their openai gym versions.

In [None]:
# export
class EnvMakeTfm(Transform):    
    def setup(self,items=None,train_setup=False):
        for i,o in enumerate(items):
            print(o)
            items[i]=gym.make(o)
        return super().setup(items,train_setup)

Check that the correct environment was created.

In [None]:
tl=EnvLists(['CartPole-v1' for _ in range(5)],tfms=[EnvMakeTfm])
test_eq(len(tl),500)
for o in tl:test_stdout(lambda:print(tl.items[o%5]),'<TimeLimit<CartPoleEnv<CartPole-v1>>>')

### EnvResetTfms

> Handles environment resetting.

In [None]:
# export
@dataclass
class DoneStateEnv:d:bool;s:np.array;env:object
DoneStateEnv.__repr__=lambda self:str((self.d,self.s,self.env))

@dataclass
class EnvResetTfm(Transform):
    env_idx:int=0;seed:Optional[int]=None;s:Optional[np.ndarray]=None;d:bool=False;was_setup:bool=False
    callback:DoneStateEnv=field(default_factory=lambda:DoneStateEnv(True,None,None))
    items:List[gym.Env]=field(default_factory=list)

    def setup(self,items=None,train_setup=False):
        for o in items:self.items.append(o)
        self._env_idx=len(items)
        return self 
    
    def reset(self):
        self._env_idx=0
        if self.seed is not None:[o.seed(self.seed) for o in self.items]
        self.s=[_o.reset() for _o in self.items]
        self.d=[False for _ in self.items]
        self.callback=DoneStateEnv(d=self.d[self._env_idx],s=self.s[self._env_idx],env=self.items[self._env_idx])
    
    def encodes(self,o:int):        
        if self.callback.d:
#             print('was done lol')
            self._env_idx+=1
            if self._env_idx>=len(self.items):self.reset()
            self.callback=DoneStateEnv(d=self.d[self._env_idx],s=self.s[self._env_idx],env=self.items[self._env_idx])
#             print(self.callback.s)
        return self.callback

Check that the `_env_idx` is incremented per iter. The intention is that a generating `Transform` uses the value to generate many elements.

In [None]:
n_envs=5
tl=EnvLists(['CartPole-v1' for _ in range(n_envs)],tfms=[EnvMakeTfm,EnvResetTfm(seed=0)])

next_inc=0
for ii in range(50):
    for i,o in enumerate(tl):
        if ii==0:test_eq(str(o.s),str(np.array([-0.04456399,  0.04653909,  0.01326909, -0.02099827])))
        test_eq(tl.tfms[1]._env_idx,next_inc)
        test_eq(tl.tfms[1].items.index(o.env),next_inc)
        test_eq(str(o),str((False,np.array([-0.04456399,  0.04653909,  0.01326909, -0.02099827]),tl.items[0])))
    o.d=True
    next_inc=next_inc+1 if next_inc<n_envs-1 else 0
# for o in tl:test_eq(str(o),str((False,np.array([-0.04456399,  0.04653909,  0.01326909, -0.02099827]),tl.items[0])))

### EnvGenUnwrapTfm
> Unwraps generators into 1d lists.

In [None]:
# export
@dataclass
class EnvGenUnwrapTfm(Transform):
    step:int=0
    
    def encodes(self,o:Generator): 
        result= tuple([xx for x in o for xx in x]) 
        self.step+=1
        return result

### EnvStepTfm

> Handles stepping through environments. 

In [None]:
%matplotlib inline

In [None]:
# export    
def env_display(env:gym.Env):
    img=env.render('rgb_array')
    try:display.clear_output(wait=True)
    except AttributeError:pass
    new_im=PIL.Image.fromarray(img)
    display.display(new_im)

In [None]:
# export
@dataclass
class Experience:d:bool;s:np.ndarray;sp:np.ndarray;r:float;a:Any

@dataclass
class EnvStepTfm(Transform):
    agent:Optional[object]=None;constant_action:Optional[int]=None;n_steps:int=1;steps_delta:int=1;step:int=-1
    _reset_step:bool=False;history:deque=None;display:bool=False;enc_set:bool=False
    
    def __post_init__(self):     self.history=deque(maxlen=self.n_steps)
    def bump_step_count(self,d): self.step=0 if d else self.step+1
    def reset(self):             self.step=0; self.history.clear()
        
    def encodes(self,o:DoneStateEnv):
        while True:
            if self.agent is None: a=ifnone(self.constant_action,o.env.action_space.sample()) 
            else:                  a=self.agent(o.s)
                
            sp,r,o.d,_=o.env.step(a)
            if self.display:env_display(o.env)
            self.history.append(Experience(d=o.d,s=o.s.copy(),sp=sp.copy(),r=r,a=a))
            o.s=sp.copy()
            self.bump_step_count(o.d)
            if o.d:
                while len(self.history)>1: # We allow the a single element left to be yielded by the broken while
                    yield tuple(self.history)
                    self.history.popleft()
                break
#             print(len(self.history),((self.step-1)%self.steps_delta),self.step)
            if len(self.history)==self.n_steps and ((self.step-1)%self.steps_delta)==0 and not o.d or \
                 (len(self.history)==self.step and len(self.history)==self.n_steps and not o.d):break
            
        history=tuple(copy(self.history))
        if o.d:self.history.clear();self.reset()
        yield tuple(history)
        return None

In [None]:
def validate_env(env_name,a,n_steps,steps_delta,n_envs,env_steps,dones,max_iter=800,n_episodes_break=-1,
                 initial_s=str(np.array([-0.58912799 , 0.        ])),
                 final_s=str(np.array([-0.71048047,  0.00427297]))):
    tl=EnvLists([env_name for _ in range(n_envs)],tfms=[EnvMakeTfm,EnvResetTfm(seed=0),
                                                        EnvStepTfm(constant_action=a,n_steps=n_steps,steps_delta=steps_delta,display=False),EnvGenUnwrapTfm])
    print('Starting loop')
    count=0
    for k in range(n_envs):
        print('\n\n')
        for i,o in enumerate(tl):
#             print(count,o)
            count+=1
#             print(o,i)
            if i==0:test_eq(str(o[0].s),initial_s)
            for exp in o:test_eq(exp.a,a)
            
            if any(_.d for _ in o):break
        if any(_.d for _ in o):
            test_eq(str(o[-1].sp),final_s)
            dones+=1
            if count>round((env_steps*n_envs)/steps_delta)-10:break
                
    
#     test_eq(count-(steps_delta!=1)*n_envs+((n_steps-1)*n_envs),round((env_steps*n_envs)/steps_delta))
    # If both are changed, the env will loop extra times at the start
    extra_steps_on_start=0 if n_steps==1 or steps_delta==1 else min((steps_delta,n_steps))-1
    
    
    test_eq(count-(steps_delta!=1)*n_envs+((n_steps-1)*n_envs),round((env_steps*n_envs)/steps_delta)+extra_steps_on_start)
    test_eq(n_envs,dones)

Check that `MountainCar-v0` with `steps_delta=2` has single episode lengths cut in half. Verify that each episode keeps its `done` signal.

In [None]:
validate_env(
    env_name='MountainCar-v0',
    a=0,
    n_steps=1,
    steps_delta=2,
    n_envs=1,
    env_steps=200,
    dones=0,
    n_episodes_break=1
)

Check that `MountainCar-v0` at defaults, runs the full 200 iterations.
Check that all actions are `a`

In [None]:
validate_env(
    env_name='MountainCar-v0',
    a=0,
    n_steps=1,
    steps_delta=1,
    n_envs=1,
    env_steps=200,
    dones=0
)

Check that `MountainCar-v0` at defaults, with 2 envs, fully resets between episodes, and that the starting and ending states are always output.

In [None]:
validate_env(
    env_name='MountainCar-v0',
    a=0,
    n_steps=1,
    steps_delta=1,
    n_envs=2,
    env_steps=200,
    dones=0
)

Check that `MountainCar-v0` with `n_steps=2` has single episode iterations is cut in half. Verify that each episode keeps its `done` signal.

In [None]:
validate_env(
    env_name='MountainCar-v0',
    a=0,
    n_steps=2,
    steps_delta=1,
    n_envs=1,
    env_steps=200,
    dones=0
)

Check that `CartPole-v1` iterates correctly also.

In [None]:
count=0
dones=0
env_steps=10
steps_delta=1
n_steps=2
n_envs=1

tl=EnvLists(['CartPole-v0' for _ in range(n_envs)],tfms=[EnvMakeTfm,EnvResetTfm(seed=0),
                                                    EnvStepTfm(constant_action=0,n_steps=n_steps,steps_delta=steps_delta,display=True),EnvGenUnwrapTfm])
print('Starting loop')
for k in range(n_envs):
    print('\n\n')
    for i,o in enumerate(tl):
#             print(count,o)
        count+=1
#             print(o,i)
        if i==0:test_eq(str(o[0].s),'[-0.04456399  0.04653909  0.01326909 -0.02099827]')
        for exp in o:test_eq(exp.a,0)

        if any(_.d for _ in o):break
    if any(_.d for _ in o):
        test_eq(str(o[-1].sp),'[-0.17695373 -1.71499924  0.22743892  2.78917835]')
        dones+=1
        if count>round((env_steps*n_envs)/steps_delta)-10:break

print(count)

In [None]:
@dataclass
class ExperienceToDictTfm(Transform):
    def encodes(self,o:Experience):
        print('out pooting ',o)
        return asdict(o)

# DataBlock 

> Generates DataBlock appropriate for running OpenAI envs.

In [None]:
def doer(o):print('Itemm dorer',o)
def NoopSplitter(o):return [o]
    
class TestDataBlock(DataBlock):

    def datasets(self, source, verbose=False):
        self.source = source                     ; pv(f"Collecting items from {source}", verbose)
        items = (self.get_items or noop)(source) ; pv(f"Found {len(items)} items", verbose)
        splits = (self.splitter or RandomSplitter())(items)
        print('splits',splits,items)
        pv(f"{len(splits)} datasets of sizes {','.join([str(len(s)) for s in splits])}", verbose)
        return Datasets(None, tls=[EnvLists(items,tfms=self.item_tfms,splits=[[0]])], 
                        splits=splits, dl_type=self.dl_type, n_inp=self.n_inp, verbose=verbose)
    
    def dataloaders(self, source, path='.', verbose=False, **kwargs):
        pv(source,verbose)
        dsets = self.datasets(source)
        for dset in dests:
        print('Splits: ',dsets.splits[0])
        kwargs = {**self.dls_kwargs, **kwargs, 'verbose': verbose}
        return dsets.dataloaders(path=path,after_batch=self.batch_tfms,**kwargs)

In [None]:
ds=block.datasets(['MountainCar-v0' for _ in range(1)]);ds.tls[0].splits

In [None]:
block=TestDataBlock(splitter=NoopSplitter,item_tfms=[EnvMakeTfm,EnvResetTfm(seed=0),
                                                 EnvStepTfm(constant_action=0,n_steps=n_steps,steps_delta=steps_delta,display=False),
                                                 EnvGenUnwrapTfm,
                                                 ExperienceToDictTfm])
dls=DataLoaders.from_dblock(block,['MountainCar-v0' for _ in range(1)],bs=20,verbose=True,shuffle_train=False,num_workers=0)

In [None]:
dls[0].do_batch([dls[0].do_item(0)])

In [None]:
dls[0].dataset.tls

In [None]:
len(dls[0].dataset.items)

In [None]:
len(dls[0])

In [None]:
dls[0].dataset.tls[0]

In [None]:
len(dls[0].dataset.tls[0])

In [None]:
for x in dls[0]:
    print(x)
    break

## Export

In [None]:
# hide
from nbdev.export import *
notebook2script()
notebook2html(n_workers=0)