# Quickstart: Using the Prioritised Flat Buffer with Flashbax

This guide demonstrates how to use the Prioritised Flat Buffer for experience replay in reinforcement learning tasks. The Prioritised Flat Buffer operates like a uniform flat buffer however it returns batches of sampled experience according to given priorities. This is akin to the buffer used in the [PER paper](https://arxiv.org/abs/1511.05952) by Schaul et al. (2015).

In [1]:
# If running locally as a dev then uncomment the below 2 lines. 
# import sys
# sys.path.insert(0, "../")

In [2]:
import chex
import jax
import jax.numpy as jnp

# Setup fake devices - we use this later with `jax.pmap`.
DEVICE_COUNT_MOCK = 2
chex.set_n_cpu_devices(DEVICE_COUNT_MOCK)

In [3]:
%%capture
try:
    import flashbax as fbx
except ModuleNotFoundError:
    print('installing flashbax')
    %pip install -q flashbax
    import flashbax as fbx

# Prioritised buffer

The prioritised buffer allows for experience to be saved to the buffer with a "priority" that determines how likely it is to be sampled. This is based on the paper [Prioritized Experience Replay](https://arxiv.org/abs/1511.05952) by Schaul et al. (2015).
The prioritised buffer has the following key functionality:
- **init**: initialise the state of the buffer
- **add**: add a new batch of experience data to the buffer's state
- **can_sample**: check if the buffer's state is full enough to be sampled from
- **sample**: sample a batch from the buffer's state with probability proportional to the samples priority
- **set_priorities**: update the priorities of specific experience within the buffer state

below we will go through how each of these can be used. In the below code we use these functions without `jax.pmap`, however they can be easily adapted for this. To see how this can be done we refer to the `examples/quickstart_n_step_buffer` notebook and the `test_prioritised_buffer_does_not_smoke` function notebook in `flashbax.buffers.prioritised_buffer_test.py`. 

Firstly, we provide the function `make_prioritised_n_step_buffer` which returns an instance of the `PrioritisedTrajectoryBuffer` with wrapped sample and add functionality. This is a `Dataclass` containing the aforementioned `init`, `add`, `can_sample`, `sample` and `set_prioritised` pure functions.

In [4]:
# First define hyper-parameters of the buffer.
max_length = 32 # Maximum length of buffer (max number of experiences stored within the state).
min_length = 8 # Minimum number of experiences saved in the buffer state before we can sample.
sample_batch_size = 4 # Batch size of experience data sampled from the buffer.
# The buffer will be sampled from with probability proportional to priority**priority_exponent.
priority_exponent = 0.6

add_sequences = False # Will we be adding data in sequences to the buffer?
add_batch_size = 6    # Will we be adding data in batches to the buffer? 
                      # It is possible to add data in both sequences and batches. 
                      # If adding data in batches, what is the batch size that is being added each time?

# Instantiate the prioritised buffer, which is a NamedTuple of pure functions.
buffer = fbx.make_prioritised_n_step_buffer(
    max_length, min_length, sample_batch_size, add_sequences, add_batch_size, priority_exponent
)



In [5]:
rng_key = jax.random.PRNGKey(0) # Setup source of randomness

# Initialise the buffer's state using the `init` function. 
# To do this we need a unit of experience data which is used to infer 
# the tree stucture of the experience that will be added later to the buffer state.
# We create a fake timestep for the example.
fake_timestep = {"obs": jnp.array([5, 4]), "reward": jnp.array(1.0)} 
state = buffer.init(fake_timestep)

# Now fill the buffer above its minimum length using the `add` function.
# The add function expects batches of experience - we create a fake batch by stacking
# timesteps.
# New samples to the buffer have their priority set to the maximum priority within the buffer. 
fake_batch = jax.tree_map(lambda x: jnp.stack([x + i for i in range(add_batch_size)]),
                          fake_timestep) 
state = buffer.add(state, fake_batch)
assert buffer.can_sample(state) == False  # After one batch the buffer is not yet full.
state = buffer.add(state, fake_batch)
assert buffer.can_sample(state)  # Now the buffer is full. 


# Sample from the buffer. This returns a batch of `PrioritisedTransitionSample` which is a Dataclass 
# With the fields `experience`, `indices` and `priorities`. The `experience` field contains an experience 
# pair giving the transition data which has the same structure as 
# `fake_timestep`, but with an additional leading batch dimension.
rng_key, rng_subkey = jax.random.split(rng_key)
batch = buffer.sample(state, rng_subkey)


# Adjust priorities.This would commonly be set to the abs(td_error) of the corresponding sample. 
new_priorities = jnp.ones_like(batch.priorities) + 10007 # Fake new priorities
state = buffer.set_priorities(state, batch.indices, new_priorities)

I0000 00:00:1695739455.593782 1496431 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/e.toledo/miniconda3/envs/flashbax/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1695739455.593878 1496431 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1695739455.593882 1496431 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
I0000 00:00:1695739458.615064 1496431 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
  ans = self.f(*args, **dict(self.params, **kwargs))


To understand the specifics of how the prioritised buffer works we recommend inspecting the above objects returned from each of the key buffer functions, while looking at the code and documentation in the `flashbax.buffer.prioritised_buffer.py` file. For example, by inspecting the `batch` we see that it is a NamedTuple 
with the fields `experience`, `indices` and `priorities`. The `experience` field returns a `TransitionPair` which has `first` and `second` attributes with the same structure as `fake_timestep`, but with an additional leading batch dimension.

In [6]:
print(batch.__dict__.keys())
print(f"indices: {batch.indices}")
print(f"priorities: {batch.priorities}") 
print(f"experience keys: {batch.experience.first.keys()}")
print(f"experience keys: {batch.experience.second.keys()}")
print(f"obs shape: {batch.experience.first['obs'].shape}")

dict_keys(['experience', 'indices', 'priorities'])
indices: [ 0  5 15 20]
priorities: [1. 1. 1. 1.]
experience keys: dict_keys(['obs', 'reward'])
experience keys: dict_keys(['obs', 'reward'])
obs shape: (4, 2)


The above batch was sampled before we adjusted the priorities, if we sample again we see that the samples priorities now match the adjusted priorities (the adjusted priorities we set have a very high priority so we are basically guaranteed to sample these experiences). 

In [7]:
rng_key, rng_subkey = jax.random.split(rng_key)
batch = buffer.sample(state, rng_subkey)
print(batch.__dict__.keys())
print(f"indices: {batch.indices}")
print(f"priorities: {batch.priorities} == new_priorities**priority_exponent == {new_priorities[0]**priority_exponent}")
print(f"experience keys: {batch.experience.first.keys()}")
print(f"experience keys: {batch.experience.second.keys()}")
print(f"obs shape: {batch.experience.first['obs'].shape}")

dict_keys(['experience', 'indices', 'priorities'])
indices: [ 0  5 15 20]
priorities: [251.30885 251.30885 251.30885 251.30885] == new_priorities**priority_exponent == 251.30885314941406
experience keys: dict_keys(['obs', 'reward'])
experience keys: dict_keys(['obs', 'reward'])
obs shape: (4, 2)


  ans = self.f(*args, **dict(self.params, **kwargs))


By inspecting the buffer state we see that it contains:

 -  `priority_state` which is the state of the sum-tree which is the data structure we use to store the priorities. The sum-tree allows for sampling and priority adjustments with `O(log N)` complexity where `N` is the max length of the buffer. We refer to the [Prioritized Experience Replay Paper](https://arxiv.org/abs/1511.05952), [Dopamine sum_tree.py code](https://github.com/google/dopamine/blob/master/dopamine/replay_memory/sum_tree.py) and [this blog](http://www.sefidian.com/2021/09/09/sumtree-data-structure-for-prioritized-experience-replay-per-explained-with-python-code/) as resources for understanding how the sum-tree works.

 -  `experience` which is a pytree matching the structure of `fake_timestep` but with an extra axis of `add_batch_size` and `max_length//add_batch_size`.

 -  a current index that keeps track of where in the buffer experience should be added.

 -  a `is_full` boolean array which notes if the buffer has been filled above the `max_length//add_batch_size`, after which new added experience starts overwriting old experience.

In [8]:
print(state.__dict__.keys())
print(state.experience.keys())
print(state.experience['obs'].shape) # prints (6, 5, 2) = (max_length, *fake_timestep['obs'].shape)

dict_keys(['experience', 'current_index', 'is_full', 'priority_state'])
dict_keys(['obs', 'reward'])
(6, 5, 2)
