In [108]:
#hide
#skip
%config Completer.use_jedi = False
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [109]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbverbose.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [110]:
# default_exp memory.experience_replay

In [111]:
# export
# Python native modules
import os
from typing import *
from warnings import warn
# Third party libs
from fastcore.all import *
from fastai.learner import *
from fastai.torch_basics import *
from fastai.torch_core import *
from fastai.callback.all import *
from torch.utils.tensorboard import SummaryWriter
# Local modules
from fastrl.core import *
from fastrl.callback.core import *
from fastrl.data.block import *

# Experience Replay
> Experience Replay is likely the simplest form of memory used by RL agents. 

In [112]:
# export
class ExperienceReplayException(Exception): pass

class ExperienceReplay(object):
    def __init__(self,
                 bs:int=16,         # Number of entries to query from memory
                 max_sz:int=200,    # Maximum number of entries to hold. Will start overwriting after.
                 warmup_sz:int=100,  # Minimum number of entries needed to continue with a batch
                 # Used for testing. Once the memory has reached max size, it will not
                 # Add any more data. This is useful for checking whether a model is training correctly.
                 freeze_at_max:bool=False, 
                 memory:Optional[BD]=None # Optionally, you can initialize a new `ExperienceReplay` with an existing dictionary
                 ):
        "Stores `BD`s in a rotating list `self.memory`"
        store_attr()
        test_lt(warmup_sz-1,max_sz)
        self.memory=memory
        self.pointer=0
    
    def __add__(self,other:BD):
        "In-place add `other` to memory, overwriting if len(self.memory)>self.max_sz"
        if isinstance(other,tuple) and len(other)==1: other=other[0]
        elif isinstance(other,tuple):                 raise ExperienceReplayException('records need to be `BD`s or 1 element tuples')
        if isinstance(other,dict):                    other=BD(other)
        elif isinstance(other,list):                  other=sum(other)
        
        if 'td_error' not in other: other['td_error']=TensorBatch(torch.zeros((other.bs(),1)))
        
        if self.memory is None: 
            if other.bs()>self.max_sz: 
                self.memory=other[:self.max_sz]
                self.pointer=0           # Keep the pointer 0 since we have basically replaced the memory
                self+other[self.max_sz:] # Recursively add the rest of the batch
            else:
                self.memory=other
                self.pointer=self.memory.bs() # remember that pointer is not an index but number of elements
        else:
            if self.freeze_at_max and self.memory.bs()>=self.max_sz: return self
            n_over=(other.bs()+self.pointer)-self.max_sz
            if n_over>0: # e.g.: max_sz 200, pointer 195, other is 5.
                self.memory=self.memory[:self.pointer]+other[:-n_over]
                self.pointer=0
                self+other[other.bs()-n_over:]
            else:
                # If the number of elements is not over
                next_pointer=self.pointer+other.bs()
                self.memory=self.memory[:self.pointer]+other+self.memory[next_pointer:]
                self.pointer=next_pointer
        return self
    
    def __getitem__(self,i):
        return ExperienceReplay(bs=self.bs,max_sz=self.max_sz,
                                warmup_sz=self.warmup_sz,memory=self.memory[i])
    
    def __radd__(self,other:BD): raise ExperienceReplayException('You can only do experience_reply+[some other element]')
    
    def __len__(self): return self.memory.bs() if self.memory is not None else 0
        
    def sample(self)->BD:
        "Returns a sample of size `self.bs`"
        with torch.no_grad():
            idxs=np.random.randint(0,self.memory.bs(),self.bs).tolist()
            samples=self.memory[idxs].mapv(to_device)
        
        if self.memory.bs()<self.warmup_sz: raise CancelBatchException
        return samples,idxs
    
    def update_td(self,td_errors:Tensor,idxs:Tensor):
        if not isinstance(idxs,list):
            test_len(idxs.shape,1)
        test_len(td_errors.shape,2)
        self.memory['td_error'][idxs]=to_detach(td_errors)

lets generate some batches to test with...

In [113]:
from fastrl.data.gym import *
source=Source(
    cbs=[GymLoop(env_name='CartPole-v1',steps_delta=1,steps_count=1,seed=0),FirstLast]
)
source=Source(cbs=[GymLoop(env_name='CartPole-v1',steps_delta=1,steps_count=1,seed=0),FirstLast])
learn=fake_gym_learner(source,n=1000,bs=5)
batches=[BD(b[0]) for b in learn.dls[0]]

Could not do one pass in your dataloader, there is something wrong in it


In [114]:
experience_replay=ExperienceReplay(max_sz=20,warmup_sz=19)
test_len(experience_replay,0)

**what if we fill up ER?**
Lets add the batches, this process will happen inplace...

In [115]:
experience_replay+batches[0]
test_eq(experience_replay.pointer,5)
test_len(experience_replay,5)

If we add again, the total size should be 10...

In [116]:
experience_replay+batches[1]
test_eq(experience_replay.pointer,10)
test_len(experience_replay,10)
test_eq(experience_replay.memory['step'],(batches[0]+batches[1])['step'])

In [117]:
experience_replay+batches[2]
test_len(experience_replay,15)
test_eq(experience_replay.pointer,15)
test_eq(experience_replay.memory['step'],(batches[0]+batches[1]+batches[2])['step'])

In [118]:
experience_replay+batches[3]
test_len(experience_replay,20)
test_eq(experience_replay.pointer,20)
test_eq(experience_replay.memory['step'],(batches[0]+batches[1]+batches[2]+batches[3])['step'])

Let's verify that the steps are what we expect...

**What if ER is full and we add batches? ** We are at the maximum memory size, we expect that the next batch added should completely
overwrite the first 5 entries...

In [119]:
experience_replay+batches[4]
test_len(experience_replay,20)
test_eq(experience_replay.pointer,5)
test_eq(experience_replay.memory['step'],(batches[4]+batches[1]+batches[2]+batches[3])['step'])

This overwrite should properly overwrite the rest of the entries...

In [120]:
experience_replay+batches[5]+batches[6]+batches[7]
test_eq(experience_replay.memory['step'],(batches[4]+batches[5]+batches[6]+batches[7])['step'])
test_eq(experience_replay.pointer,20)

so we have fully overwritten the memory twice, and so far we can prove that the memory overwritting works. Let's 
see what happens when we append add numbered dictionaries...

In [121]:
experience_replay+batches[8]+batches[9]+batches[10]
test_eq(experience_replay.pointer,15)
test_eq(experience_replay.memory['step'],(batches[8]+batches[9]+batches[10]+batches[7])['step'])

**What if we need to split a batch to fit at the end and beginnging of the memory?** This is a possibly scary part where some of the dictionary needs to be split. Some needs to be allocated to the end of the memory, and
some of it need to be allocated at the start.

In [122]:
single_large_batch=batches[11]+batches[12]
experience_replay+single_large_batch;

In [123]:
test_eq(experience_replay.pointer,5)
test_eq(experience_replay.memory['step'],(batches[12]+batches[9]+batches[10]+batches[11])['step'])

**What if we sample the experience?**

In [124]:
full_memory=(batches[12]+batches[9]+batches[10]+batches[11])
entry_ids=[str(o) for o in torch.hstack((full_memory['step'],full_memory['episode_id']))]
memory_hits=[False]*len(entry_ids)

We should be able to sample enough times that we have sampled **everything**. 
So we test this by sampling, check if that sample has been seen before, and then record that.

In [126]:
for i in range(8):
    res,idxs=experience_replay.sample()
    for o in torch.hstack((res['step'],res['episode_id'])):
        memory_hits[entry_ids.index(str(o))]=True
test_eq(all(memory_hits),True)

**What happens when we index the experience replay?**

In [127]:
test_eq(experience_replay[5:10].memory['step'],batches[9]['step'])

**What happens when we try to update the td_errors?**

In [128]:
TensorBatch(torch.full((5,1),1.0))

TensorBatch([[1.],
        [1.],
        [1.],
        [1.],
        [1.]])

In [129]:
experience_replay.update_td(TensorBatch(torch.full((5,1),1.0)),torch.arange(5,10))

In [130]:
test_eq(experience_replay.memory['td_error'].sum(),5)
test_eq(experience_replay.memory['td_error'][torch.arange(5,10)].sum(),5)
test_eq(experience_replay.memory['td_error'][torch.arange(6,11)].sum(),4)

**What happens when we freeze the memory?**

We should expect that we can fill up the memory, then once it is at its max, it will not accept anything else.

In [131]:
experience_replay=ExperienceReplay(max_sz=20,warmup_sz=5,freeze_at_max=True)
experience_replay+batches[0:4]
test_eq(experience_replay.memory['step'],sum(batches[0:4])['step'])
experience_replay+batches[4]
test_eq(experience_replay.memory['step'],sum(batches[0:4])['step'])

In [132]:
# export
class ExperienceReplayCallback(Callback):
    @delegates(ExperienceReplay)
    def __init__(self,
                 verbose=False, # Will show warnings for recommended behavior.
                 **kwargs):
        "Stores `BD`s in a rotating list `self.memory`"
        store_attr()
        self._kwargs=kwargs
        
    def before_fit(self):
        if not hasattr(self.learn,'experience_replay'):
            self.learn.experience_replay=ExperienceReplay(**self._kwargs)
    
    def after_pred(self):
        "Adds `learn.xb` to memory, then sets `learn.xb=experience_replay.sample()`"
        xb=BD(self.learn.xb[0]).mapv(to_detach)
        self.learn.experience_replay+xb
        
        self.learn.xb,self.learn.sample_indexes=self.experience_replay.sample()
    
    def after_batch(self):
        if hasattr(self.learn,'td_error'):
            self.experience_replay.update_td(
                self.td_error,
                self.sample_indexes
            )
        elif self.verbose:
            warn("""The learner does not have a `td_error` field. Produced logs
                    will not be useful unless `td_error` exists.""")

In [133]:
from fastrl.data.gym import *
source=Source(cbs=[GymLoop(env_name='CartPole-v1',steps_delta=1,steps_count=1,seed=0,mode='rgb_array'),
                   ResReduce(reduce_by=4),
                   FirstLast])
learn=fake_gym_learner(source,n=30,bs=10)

Could not do one pass in your dataloader, there is something wrong in it


In [134]:
experience_replay=ExperienceReplayCallback(bs=5,max_sz=20,warmup_sz=11,verbose=True)
experience_replay.learn=learn

In [135]:
experience_replay.before_fit()
for b in learn.dls[0]:
    learn.xb=b
    
    try:
        experience_replay.after_pred()
        print('memory sampled')
    except CancelBatchException:
        print('memory is not full yet!')

memory is not full yet!
memory sampled
memory sampled


In [136]:
experience_replay.experience_replay.update_td(TensorBatch(torch.rand((5,1))),torch.arange(5,10))

## Memory Exploration

In [137]:
# export
def snapshot_memory(writer:SummaryWriter,epoch:int,experience_replay,prefix='experience_replay'):
    if 'image' not in experience_replay.memory: 
        warn('image is missing from the experience replay. This is needed to produce understandble logs.')
        return
        
    for i,frame in enumerate(experience_replay.memory['image'].permute(0,3, 1, 2)):
        writer.add_video(f'{prefix}/{epoch}/video',frame.unsqueeze(0).unsqueeze(0),global_step=i)
        
    for i,v in enumerate(experience_replay.memory['td_error'].numpy().reshape(-1)):
        writer.add_scalar(f'{prefix}/{epoch}/td_error',v,i)

In [138]:
# export
class ExperienceReplayTensorboard(Callback):
    def __init__(self,writer=None,every_epoch=1):
        store_attr()
        self.writer=ifnone(writer,SummaryWriter())
    
    def before_fit(self):
        if not hasattr(self.learn,'experience_replay'):
            warn('Learner does not have `experience_replay`, nothing will be logged.')
            
    def after_epoch(self):
        if self.epoch%self.every_epoch==0:
            snapshot_memory(self.writer,epoch=self.epoch,
                            experience_replay=self.learn.experience_replay)

In [139]:
%%bash 
# hide
rm -r runs/*

In [140]:
experience_replay_logger=ExperienceReplayTensorboard()
experience_replay_logger.learn=learn
learn.epoch=0
experience_replay_logger.after_epoch()

TENSOR_BOARD_STARTED=False

In [141]:
# export
def run_tensorboard(port=6006, # The port to run tensorboard on/connect on
                    start_tag=None, # Starting regex e.g.: experience_replay/1
                    samples_per_plugin=None, # Sampling freq such as  images=0 (keep all)
                    extra_args=None, # Any additional arguments in the `--arg value` format
                    rm_glob=None # Remove old logs via a parttern e.g.: '*' will remove all files: runs/* 
                   ):
    if rm_glob is not None:
        for p in Path('runs').glob(rm_glob): p.delete()
    import socket
    from tensorboard import notebook
    a_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    if not a_socket.connect_ex(('127.0.0.1',6006)):
        
        notebook.display(port=port, height=1000)
    else:
        cmd=f'--logdir runs --port {port} --host=0.0.0.0'
        if samples_per_plugin is not None: cmd+=f' --samples_per_plugin {samples_per_plugin}'
        if start_tag is not None:          cmd+=f' --tag {start_tag}'
        if extra_args is not None:         cmd+=f' {extra_args}'
        notebook.start(cmd)

In [142]:
# hide
SHOW_TENSOR_BOARD=False
if not os.environ.get("IN_TEST", None) and SHOW_TENSOR_BOARD:
    run_tensorboard()

In [143]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbverbose.cli import *
    make_readme()
    notebook2script()
    notebook2html()

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
Converted 00_core.ipynb.
Converted 00_nbdev_extension.ipynb.
Converted 03_callback.core.ipynb.
Converted 04_agent.ipynb.
Converted 05_data.test_async.ipynb.
Converted 05a_data.block.ipynb.
Converted 05b_data.gym.ipynb.
Converted 06a_memory.experience_replay.ipynb.
Converted 10a_agents.dqn.core.ipynb.
Converted 10b_agents.dqn.targets.ipynb.
Converted 10c_agents.dqn.double.ipynb.
Converted 10d_agents.dqn.dueling.ipynb.
Converted 10e_agents.dqn.categorical.ipynb.
Converted 11a_agents.policy_gradient.ppo.ipynb.
Converted 20_test_utils.ipynb.
Converted index.ipynb.
Converted nbdev_template.ipynb.
converting: /home/fastrl_user/fastrl/nbs/06a_memory.experience_replay.ipynb
