In [1]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [2]:
#|default_exp memory.experience_replay

In [3]:
#|export
# Python native modules
from copy import copy
# Third party libs
from fastcore.all import add_docs,ifnone
import torchdata.datapipes as dp
import numpy as np
import torch
# Local modules
from fastrl.core import StepType

# Experience Replay
> Experience Replay is likely the simplest form of memory used by RL agents. 

In [5]:
#|export
class ExperienceReplay(dp.iter.IterDataPipe):
    debug=False
    def __init__(self,
            source_datapipe,
            learner=None,
            bs=1,
            max_sz=100,
            return_idxs=False,
            # If the `self.device` is not cpu, and `store_as_cpu=True`, then
            # calls to `sample()` will dynamically move them to `self.device`, and
            # next `sample()` will move them back to cpu before producing new samples.
            # This can be slower, but can save vram.
            # If `store_as_cpu=False`, then samples stay on `self.device`
            #
            # If being run with n_workers>0, shared_memory, and fork, this MUST be true. This is needed because
            # otherwise the tensors in the memory will remain shared with the tensors created in the 
            # dataloader.
            store_as_cpu:bool=True
        ):
        self.memory = np.array([None]*max_sz)
        self.source_datapipe = source_datapipe
        self.learner = learner
        if learner is not None:
            self.learner.experience_replay = self
        self.bs = bs
        self.max_sz = max_sz
        self._sz_tracker = 0
        self._idx_tracker = 0
        self._cycle_tracker = 0
        self.return_idxs = return_idxs
        self.store_as_cpu = store_as_cpu
        self._last_idx = None
        self.device = None

    def to(self,*args,**kwargs):
        self.device = kwargs.get('device',None)

    def sample(self,bs=None): 
        idxs = np.random.choice(range(self._sz_tracker),size=(ifnone(bs,self.bs),),replace=False)
        if self.return_idxs: return self.memory[idxs],idxs
        self._last_idx = idxs
        return [o.to(device=self.device) for o in self.memory[idxs]]
    
    def __repr__(self):
        return str({k:v if k!='memory' else f'{len(self)} elements' for k,v in self.__dict__.items()})

    def __len__(self): return self._sz_tracker

    def show(self):
        from fastrl.memory.memory_visualizer import MemoryBufferViewer
        return MemoryBufferViewer(self.memory)
    
    def __iter__(self):
        for i,b in enumerate(self.source_datapipe):
            if self.debug: print('Experience Replay Adding: ',b)
            
            if not issubclass(b.__class__,(StepType,list,tuple)):
                raise Exception(f'Expected typing.NamedTuple,list,tuple object got {type(step)}\n{step}')
            
            if issubclass(b.__class__,StepType):   self.add(b)
            elif issubclass(b.__class__,(list,tuple)): 
                for step in b: self.add(step)
            else:
                raise Exception(f'This should not have occured: {self.__dict__}')
        
            if self._sz_tracker<self.bs: continue
            yield self.sample()

    def add(self,step:StepType): 
        if self.store_as_cpu: 
            step = step.clone().detach().to(device='cpu')
        
        if self._sz_tracker==0: 
            self.memory[self._idx_tracker] = step
            self._sz_tracker += 1
            self._idx_tracker = 1
        elif 0<self._sz_tracker<self.max_sz:
            self.memory[self._idx_tracker] = step
            self._sz_tracker += 1
            self._idx_tracker += 1
        elif self._sz_tracker>=self.max_sz:
            if self._idx_tracker>=self.max_sz:
                self._idx_tracker = 0
                self._cycle_tracker += 1
            self.memory[self._idx_tracker] = step
            self._idx_tracker += 1
        else:
            raise Exception(f'This should not have occured: {self.__dict__}')
            
add_docs(
ExperienceReplay,
"""Simplest form of memory. Takes steps from `source_datapipe` to stores them in `memory`. 
It outputs `bs` steps.""",
sample="Returns `bs` steps from `memory` in a uniform distribution.",
add="Adds new steps to `memory`. If `memory` reaches size `max_sz` then `step` will be added in earlier steps.",
to=torch.Tensor.to.__doc__,
show="Displays a ipywidget to look at the steps in `self.memory`"
)

lets generate some batches to test with...

In [6]:
from fastrl.envs.gym import GymDataPipe
from fastcore.all import delegates,test_eq,test_ne
from fastrl.core import test_len

In [7]:
def baseline_test(envs,total_steps,seed=0):
    pipe = GymDataPipe(envs,n=total_steps,seed=seed)
    pipe = pipe.unbatch()
    return list(pipe), pipe

@delegates(ExperienceReplay)
def exp_replay_test(envs,total_steps,seed=0,**kwargs):
    pipe = GymDataPipe(envs,n=total_steps,seed=seed)
    pipe = pipe.unbatch()
    pipe = ExperienceReplay(pipe,**kwargs)
    if total_steps is None: return None,pipe
    return list(pipe), pipe

In [8]:
steps, experience_replay = exp_replay_test(['CartPole-v1'],0,bs=1)
test_eq(len(experience_replay),0)

  state=torch.tensor(step.next_state),


**what if we fill up ER?**
Lets add the batches, this process will happen inplace...

In [9]:
steps, experience_replay = exp_replay_test(['CartPole-v1'],10,max_sz=20)
test_eq(experience_replay._sz_tracker,10)
test_eq(experience_replay._idx_tracker,10)
test_eq(experience_replay._cycle_tracker,0)
test_len(experience_replay,10)

In [10]:
experience_replay.show()

VBox(children=(Label(value='Number of Elements in Memory: 10'), HBox(children=(Button(description='Previous', …

<fastrl.memory.memory_visualizer.MemoryBufferViewer at 0x7f16b10cbbb0>

If we run 10 more times, the total size should be 20...

In [11]:
steps = [step for step,_ in zip(*(range(10),experience_replay))]
test_eq(experience_replay._sz_tracker,20)
test_eq(experience_replay._idx_tracker,20)
test_eq(experience_replay._cycle_tracker,0)
test_len(experience_replay,20)

`experience_replay` memory should contain identical steps to if we just run without it...

In [12]:
steps, pipe = baseline_test(['CartPole-v1'],20,seed=0)
_, experience_replay = exp_replay_test(['CartPole-v1'],20,max_sz=20)

for i,(baseline_step,memory_step) in enumerate(zip(steps,experience_replay.memory)):
    test_eq(baseline_step.state,memory_step.state)
    test_eq(baseline_step.next_state,memory_step.next_state)
    print('Step ',i)

Step  0
Step  1
Step  2
Step  3
Step  4
Step  5
Step  6
Step  7
Step  8
Step  9
Step  10
Step  11
Step  12
Step  13
Step  14
Step  15
Step  16
Step  17
Step  18
Step  19


Since the `max_sz` is 20, and so far we have run a total of 20 steps, if we run another 10 steps,
the `_cycle_tracker` should be 1 (since this is a new cycle),`_idx_tracker` should be 10 since it should 
have reset and stopped half way in the memory. The `_sz_tracker` should still be 20.

In [13]:
_, experience_replay = exp_replay_test(['CartPole-v1'],None,max_sz=20)
list(experience_replay.header(19))

steps = [step for step,_ in zip(*(range(10),experience_replay))]
test_eq(experience_replay._sz_tracker,20)
test_eq(experience_replay._idx_tracker,10)
test_eq(experience_replay._cycle_tracker,1)
test_len(experience_replay,20)

...and if we run the baseline, the last 10 steps in the baseline, should match the first 10 steps in memory
since it is in the middle of re-writing the memory due to being at max size.

In [14]:
steps, pipe = baseline_test(['CartPole-v1'],30)

for baseline_step,memory_step in zip(steps[20:],experience_replay.memory[:10]):
    test_eq(baseline_step.state,memory_step.state)
    test_eq(baseline_step.next_state,memory_step.next_state)

Finally we want to finish writing over the memory in its entirety. 

In [15]:
steps = [step for step,_ in zip(*(range(10),experience_replay))]
test_eq(experience_replay._sz_tracker,20)
test_eq(experience_replay._idx_tracker,20)
test_eq(experience_replay._cycle_tracker,1)
test_len(experience_replay,20)

In [16]:
steps, pipe = baseline_test(['CartPole-v1'],40)

for baseline_step,memory_step in zip(steps[20:],experience_replay.memory):
    test_eq(baseline_step.state,memory_step.state)
    test_eq(baseline_step.next_state,memory_step.next_state)

Let's verify that the steps are what we expect...

**What if we sample the experience?**

In [17]:
steps, experience_replay = exp_replay_test(['CartPole-v1'],1000,bs=300,max_sz=1000)
memory = None
for i,sample in enumerate(experience_replay):
    for s in sample:
        if memory is not None: test_ne(s,memory)
        memory = copy(s)
    if i>100:break

We should be able to sample enough times that we have sampled **everything**. 
So we test this by sampling, check if that sample has been seen before, and then record that.

In [18]:
steps, experience_replay = exp_replay_test(['CartPole-v1'],1000,bs=1,max_sz=30,return_idxs=True)
memory_hits = [False]*30
for i in range(150):
    res,idxs = experience_replay.sample()
    for idx in idxs: memory_hits[idx] = True
test_eq(all(memory_hits),True)

In [19]:
#|hide
#|eval: false
!nbdev_export

Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
