# Quickstart: Using the Trajectory Buffer with Flashbax

This guide demonstrates how to use the Trajectory Buffer for experience replay in reinforcement learning tasks. The trajectory buffer in Flashbax is a versatile tool for managing and utilizing sequences of experiences in reinforcement learning. It efficiently stores batches of trajectories while preserving their temporal ordering, making it particularly useful for scenarios involving TD-lambda errors and multi-step learning.

## Prerequisites

If running locally as a dev then uncomment the below 2 lines. 

In [1]:
# import sys
# sys.path.insert(0, "../")

In [2]:
import chex
import jax
import jax.numpy as jnp

# Setup fake devices - we use this later with `jax.pmap`.
DEVICE_COUNT_MOCK = 2
chex.set_n_cpu_devices(DEVICE_COUNT_MOCK)

In [3]:
%%capture
try:
    import flashbax as fbx
except ModuleNotFoundError:
    print('installing flashbax')
    %pip install -q flashbax
    import flashbax as fbx

## Key Functionality of the Trajectory Buffer

The trajectory buffer receives batches of trajectories, saves them while maintaining their temporal ordering, which allows sampling to return trajectories also. Similarly to the flat buffer, trajectories are saved in a first-in-first-out (FIFO) circular manner, and sampling is performed uniformly according to a desired period i.e. control over the overlap of sampled sequences. A common use case for this buffer would be if our loss uses td-lambda errors or n-step returns instead of simply the 1-step td error. Additionally, trajectories are useful for RL with any form of recurrent network.

The trajectory buffer has the following key functions:
1. `init`: Initialize the state of the buffer.
2. `add`: Add a new batch of experience data to the buffer.
3. `can_sample`: Check if the buffer is ready to be sampled.
4. `sample`: Sample a batch from the buffer.

Below we will go through how each of these can be used. We note the buffer is compatible with `jax.pmap` - we show how to use flashbax buffers with `jax.pmap` in this `examples.quickstart_n_step_buffer.py` tutorial. 

Firstly, we provide the function `trajectory_buffer.make_trajectory_buffer` which returns an instance of the `TrajectoryBuffer`. This is a `NamedTuple` containing the aforementioned `init`, `add`, `can_sample` and `sample` pure functions. We instantiate this below.

In [4]:
# First define hyper-parameters of the buffer.
max_length_time_axis = 32 # Maximum length of the buffer along the time axis. 
min_length_time_axis = 16 # Minimum length across the time axis before we can sample.
sample_batch_size = 4 # Batch size of trajectories sampled from the buffer.
add_batch_size = 6 # Batch size of trajectories added to the buffer.
sample_sequence_length = 8 # Sequence length of trajectories sampled from the buffer.
add_sequence_length = 10 # Sequence length of trajectories added to the buffer.
period = 1 # Period at which we sample trajectories from the buffer.

# Instantiate the trajectory buffer, which is a NamedTuple of pure functions.
buffer = fbx.make_trajectory_buffer(
    max_length_time_axis=max_length_time_axis,
    min_length_time_axis=min_length_time_axis,
    sample_batch_size=sample_batch_size,
    add_batch_size=add_batch_size,
    sample_sequence_length=sample_sequence_length,
    period=period
)

Now, we show how each function within the `buffer` can be used:

## Initialize the Buffer State

To demonstrate how to use the buffer, we'll start by initializing its state using the `init` function. This requires a unit of experience data, which is used to infer the structure of the experience that will be added later. For this example, we create a fake timestep:

In [5]:
fake_timestep = {"obs": jnp.array([5, 4]), "reward": jnp.array(1.0)} 
state = buffer.init(fake_timestep)

I0000 00:00:1695739304.514052 1494773 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/e.toledo/miniconda3/envs/flashbax/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1695739304.514136 1494773 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1695739304.514140 1494773 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
I0000 00:00:1695739308.103153 1494773 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


## Adding Experience to the Buffer
To fill the buffer above its minimum length, we use the `add` function. The function expects batches of sequences of experience.

In [6]:
rng_key = jax.random.PRNGKey(0)
# Now fill the buffer above its minimum length using the `add` function.
# The add function expects batches of trajectories.
# Thus, we create a fake batch of trajectories by broadcasting the `fake_timestep`.
broadcast_fn = lambda x: jnp.broadcast_to(x, (add_batch_size, add_sequence_length, *x.shape))
fake_batch_sequence = jax.tree_map(broadcast_fn, fake_timestep)
state = buffer.add(state, fake_batch_sequence)
assert buffer.can_sample(state) == False  # After one batch the buffer is not yet full.
state = buffer.add(state, fake_batch_sequence)
assert buffer.can_sample(state)  # Now the buffer is ready to be sampled. i.e it is filled above min_length_time_axis.

## Sampling from the Buffer
To sample from the buffer, we use the `sample` function:

In [7]:
rng_key = jax.random.PRNGKey(0)  # Setup source of randomness
# Sample from the buffer. This returns a batch of **sequences** of data with the same structure as 
# `fake_timestep`.
batch = buffer.sample(state, rng_key)

By inspecting the batch object, you can see that it matches the structure of the `fake_timestep` but with an extra leading batch dimension and a sequence dimension of `sample_sequence_length` representing a sequence of timesteps.

In [8]:
print(batch.experience.keys()) # prints dict_keys(['obs', 'reward'])
print(batch.experience['reward'].shape) # prints (4,8) = (sample_batch_size, sample_sequence_length, *fake_transition['reward'].shape)

dict_keys(['obs', 'reward'])
(4, 8)


## Buffer State and Structure

By inspecting the buffer state we see that it contains 

- `experience` which is a pytree matching the structure of `fake_timestep` but with extra axes of size `add_batch_size` and `max_length`
- a current index that keeps track of where along the time dimension in the buffer experience should be added
- a `is_full` boolean array which notes if the buffer has been filled above the `max_length_time_axis`, after which new added experience starts overwriting old experience.

In [9]:
print(state.__dict__.keys())
print(state.experience.keys())
# prints (6,32,2) = (add_batch_size, max_length_time_axis, *fake_timestep['obs'].shape)
print(state.experience['obs'].shape) 

dict_keys(['experience', 'current_index', 'is_full'])
dict_keys(['obs', 'reward'])
(6, 32, 2)


To understand the specifics of how the trajectory buffer works we recommend inspecting the above objects returned from each of the key buffer functions, while looking at the code and documentation in the `flashbax.buffer.trajectory_buffer.py` file. For example, by inspecting the `batch` object we see that it is a batch of **sequences** of experience data (i.e. a batch of trajectories) with the same structure as `fake_timestep`.

In [10]:
print(batch.experience.keys()) # prints dict_keys(['obs', 'reward'])
# prints (4, 8) = (sample_batch_size, sample_sequence_length, *fake_timestep['reward'].shape)
print(batch.experience['reward'].shape) 

dict_keys(['obs', 'reward'])
(4, 8)
