# Quickstart: Using the Flat Buffer with Flashbax

This guide demonstrates how to use the Flat Buffer, for experience replay in reinforcement learning tasks. The Flat Buffer operates by saving all experience data in a first-in-first-out (FIFO) queue and returns batches of uniformly sampled experience from it. This is akin to the buffer used in the [original DQN paper](https://arxiv.org/abs/1312.5602). 

## Prerequisites

If running locally as a dev then uncomment the below 2 lines. 

In [1]:

# import sys
# sys.path.insert(0, "../")

In [2]:
import chex
import jax
import jax.numpy as jnp

# Setup fake devices - we use this later with `jax.pmap`.
DEVICE_COUNT_MOCK = 2
chex.set_n_cpu_devices(DEVICE_COUNT_MOCK)

In [3]:
%%capture
try:
    import flashbax as fbx
except ModuleNotFoundError:
    print('installing flashbax')
    %pip install -q flashbax
    import flashbax as fbx

### Initialize the Flat Buffer

The following code demonstrates how to initialize the Flat Buffer:


In [4]:
# First define hyper-parameters of the buffer.
max_length = 32 # Maximum length of buffer (max number of experiences stored within the state).
min_length = 8 # Minimum number of experiences saved in the buffer state before we can sample.
sample_batch_size = 4 # Batch size of experience data sampled from the buffer.

add_sequences = False # Will we be adding data in sequences to the buffer?
add_batch_size = 6    # Will we be adding data in batches to the buffer? 
                      # It is possible to add data in both sequences and batches. 
                      # If adding data in batches, what is the batch size that is being added each time?

# Instantiate the flat buffer, which is a Dataclass of pure functions.
buffer = fbx.make_n_step_buffer(max_length, min_length, sample_batch_size, add_sequences, add_batch_size)



## Key Functionality of the Flat Buffer

The Flat Buffer provides the following key functions:

1. `init`: Initialize the state of the buffer.
2. `add`: Add a new batch of experience data to the buffer.
3. `can_sample`: Check if the buffer is ready to be sampled.
4. `sample`: Sample a batch from the buffer.

## Initialize the Buffer State

To demonstrate how to use the buffer, we'll start by initializing its state using the `init` function. This requires a unit of experience data, which is used to infer the structure of the experience that will be added later. For this example, we create a fake timestep:

In [5]:
fake_timestep = {"obs": jnp.array([5, 4]), "reward": jnp.array(1.0)} 
state = buffer.init(fake_timestep)

I0000 00:00:1695315515.489560  758745 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/e.toledo/miniconda3/envs/flashbax/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1695315515.489634  758745 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1695315515.489637  758745 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
I0000 00:00:1695315518.435092  758745 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


## Adding Experience to the Buffer
To fill the buffer above its minimum length, we use the `add` function. The function expects batches of experience, which we create by stacking timesteps. Note: We have specified that the buffer expects batches of experiences but we can specify that individual timesteps are added each time.

In [6]:
fake_batch = jax.tree_map(lambda x: jnp.stack([x + i for i in range(add_batch_size)]), fake_timestep)
state = buffer.add(state, fake_batch)
assert not buffer.can_sample(state)  # Buffer is not ready to sample
state = buffer.add(state, fake_batch)
assert buffer.can_sample(state)  # Buffer is now ready to sample

## Sampling from the Buffer
To sample from the buffer, we use the `sample` function:

In [7]:
rng_key = jax.random.PRNGKey(0)  # Setup source of randomness
batch = buffer.sample(state, rng_key)  # Sample a batch of data

By inspecting the batch object, you can see that it is a TransitionSample object. This object contains an ExperiencePair object containing the transition data. The first and second attributes of the ExperiencePair object will match the structure of `fake_timestep` with an added batch dimension and sequence dimension of 2.

In [8]:
print(batch.experience.first.keys()) # prints dict_keys(['obs', 'reward'])
print(batch.experience.second.keys()) # prints dict_keys(['obs', 'reward'])
print(batch.experience.first['reward'].shape) # prints (4,) = (sample_batch_size, *fake_timestep['reward'].shape)

dict_keys(['obs', 'reward'])
dict_keys(['obs', 'reward'])
(4,)


## Buffer State and Structure
Inspecting the buffer state reveals its structure:

- `experience`: A pytree matching the structure of the timestep but with two extra axis of size add_batch_size and max_length//add_batch_size.
- `current_index`: Tracks where in the buffer experience should be added.
- `is_full`: A boolean array indicating if the buffer has been filled above max_length//add_batch_size.

In [9]:
print(state.__dict__.keys()) 
print(state.experience.keys())
print(state.experience['obs'].shape) # prints (6, 5, 2) = (add_batch_size, max_length//add_batch_size, *fake_timestep['obs'].shape)

dict_keys(['experience', 'current_index', 'is_full'])
dict_keys(['obs', 'reward'])
(6, 5, 2)


## Using the Buffer with `jax.pmap`
Flashbax buffers can be `jit`-ed and `pmap`-ed. The following code demonstrates how to use the Flat Buffer with `pmap`:

In [10]:
# Define a function to create a fake batch of data
def get_fake_batch(fake_timestep: chex.ArrayTree, batch_size) -> chex.ArrayTree:
    return jax.tree_map(lambda x: jnp.stack([x + i for i in range(batch_size)]), fake_timestep)

add_batch_size = 8

# Re-instantiate the buffer
buffer = fbx.make_n_step_buffer(max_length, min_length, sample_batch_size, add_sequences, add_batch_size)

# Initialize the buffer's state with a "device" dimension
fake_timestep_per_device = jax.tree_map(
    lambda x: jnp.stack([x + i for i in range(DEVICE_COUNT_MOCK)]), fake_timestep
)
state = jax.pmap(buffer.init)(fake_timestep_per_device)

# Fill the buffer above its minimum length
fake_batch = jax.pmap(get_fake_batch, static_broadcasted_argnums=1)(
    fake_timestep_per_device, add_batch_size
)
# Add two timesteps to form one transition pair
state = jax.pmap(buffer.add)(state, fake_batch)
state = jax.pmap(buffer.add)(state, fake_batch)
assert buffer.can_sample(state).all()

# Sample from the buffer
rng_key_per_device = jax.random.split(rng_key, DEVICE_COUNT_MOCK)
batch = jax.pmap(buffer.sample)(state, rng_key_per_device)




When inspecting the objects, you'll observe an extra leading "device" dimension, replicating the buffer behavior across multiple devices.

In [11]:
print(state.experience['obs'].shape) # prints (2, 8, 4 , 2) = (DEVICE_COUNT_MOCK, add_batch_size, max_length//add_batch_size, *fake_timestep['obs'].shape)
print(batch.experience.first['reward'].shape) # prints (2, 4,) = (DEVICE_COUNT_MOCK, sample_batch_size, *fake_timestep['reward'].shape)
print(buffer.can_sample(state)) # prints [True, True] as the state on each device is full above `min_length`.

(2, 8, 4, 2)
(2, 4)
[ True  True]
