# Quickstart: Using the Flat Buffer with Flashbax

This guide demonstrates how to use the N-Step Buffer, for experience replay in reinforcement learning tasks. The N-Step Buffer operates by saving all experience data in a first-in-first-out (FIFO) queue and returns batches of uniformly sampled experience from it. This is akin to the buffer used in the [original DQN paper](https://arxiv.org/abs/1312.5602) when using an n-step of 1. 

## Prerequisites

If running locally as a dev then uncomment the below 2 lines. 

In [None]:
# import sys
# sys.path.insert(0, "../")

In [None]:
import chex
import jax
import jax.numpy as jnp

# Setup fake devices - we use this later with `jax.pmap`.
DEVICE_COUNT_MOCK = 2
chex.set_n_cpu_devices(DEVICE_COUNT_MOCK)

In [None]:
%%capture
try:
    import flashbax as fbx
except ModuleNotFoundError:
    print('installing flashbax')
    %pip install -q flashbax
    import flashbax as fbx

### Initialize the N-Step Buffer

The following code demonstrates how to initialize the N-Step Buffer. For this section we will use an n-step of 1 to keep things simple.


In [None]:
# First define hyper-parameters of the buffer.
max_length = (
    32  # Maximum length of buffer (max number of experiences stored within the state).
)
min_length = (
    8  # Minimum number of experiences saved in the buffer state before we can sample.
)
sample_batch_size = 4  # Batch size of experience data sampled from the buffer.

add_sequences = False  # Will we be adding data in sequences to the buffer?
add_batch_size = 6  # Will we be adding data in batches to the buffer?
# It is possible to add data in both sequences and batches.
# If adding data in batches, what is the batch size that is being added each time?

# Instantiate the flat buffer, which is a Dataclass of pure functions.
buffer = fbx.make_n_step_buffer(
    max_length=max_length,
    min_length=min_length,
    sample_batch_size=sample_batch_size,
    add_sequences=add_sequences,
    add_batch_size=add_batch_size,
    n_step=1,
)

## Key Functionality of the Flat Buffer

The Flat Buffer provides the following key functions:

1. `init`: Initialize the state of the buffer.
2. `add`: Add a new batch of experience data to the buffer.
3. `can_sample`: Check if the buffer is ready to be sampled.
4. `sample`: Sample a batch from the buffer.

## Initialize the Buffer State

To demonstrate how to use the buffer, we'll start by initializing its state using the `init` function. This requires a unit of experience data, which is used to infer the structure of the experience that will be added later. For this example, we create a fake timestep:

In [None]:
fake_timestep = {"obs": jnp.array([5, 4]), "reward": jnp.array(1.0)}
state = buffer.init(fake_timestep)

## Adding Experience to the Buffer
To fill the buffer above its minimum length, we use the `add` function. The function expects batches of experience, which we create by stacking timesteps. Note: We have specified that the buffer expects batches of experiences but we can specify that individual timesteps are added each time.

In [None]:
fake_batch = jax.tree_map(
    lambda x: jnp.stack([x + i for i in range(add_batch_size)]), fake_timestep
)
state = buffer.add(state, fake_batch)
assert not buffer.can_sample(state)  # Buffer is not ready to sample
state = buffer.add(state, fake_batch)
assert buffer.can_sample(state)  # Buffer is now ready to sample

## Sampling from the Buffer
To sample from the buffer, we use the `sample` function:

In [None]:
rng_key = jax.random.PRNGKey(0)  # Setup source of randomness
batch = buffer.sample(state, rng_key)  # Sample a batch of data

By inspecting the batch object, you can see that it is a TransitionSample object. This object contains an ExperiencePair object containing the transition data. The first and second attributes of the ExperiencePair object will match the structure of `fake_timestep` with an added batch dimension and sequence dimension of 2.

In [None]:
print(batch.experience.first.keys())  # prints dict_keys(['obs', 'reward'])
print(batch.experience.second.keys())  # prints dict_keys(['obs', 'reward'])
print(
    batch.experience.first["reward"].shape
)  # prints (4,) = (sample_batch_size, *fake_timestep['reward'].shape)

## Buffer State and Structure
Inspecting the buffer state reveals its structure:

- `experience`: A pytree matching the structure of the timestep but with two extra axis of size add_batch_size and max_length//add_batch_size.
- `current_index`: Tracks where in the buffer experience should be added.
- `is_full`: A boolean array indicating if the buffer has been filled above max_length//add_batch_size.

In [None]:
print(state.__dict__.keys())
print(state.experience.keys())
print(
    state.experience["obs"].shape
)  # prints (6, 5, 2) = (add_batch_size, max_length//add_batch_size, *fake_timestep['obs'].shape)

## Using the Buffer with `jax.pmap`
Flashbax buffers can be `jit`-ed and `pmap`-ed. The following code demonstrates how to use the Flat Buffer with `pmap`:

In [None]:
# Define a function to create a fake batch of data
def get_fake_batch(fake_timestep: chex.ArrayTree, batch_size) -> chex.ArrayTree:
    return jax.tree_map(
        lambda x: jnp.stack([x + i for i in range(batch_size)]), fake_timestep
    )


add_batch_size = 8

# Re-instantiate the buffer
buffer = fbx.make_n_step_buffer(
    max_length=max_length,
    min_length=min_length,
    sample_batch_size=sample_batch_size,
    add_sequences=add_sequences,
    add_batch_size=add_batch_size,
    n_step=1,
)

# Initialize the buffer's state with a "device" dimension
fake_timestep_per_device = jax.tree_map(
    lambda x: jnp.stack([x + i for i in range(DEVICE_COUNT_MOCK)]), fake_timestep
)
state = jax.pmap(buffer.init)(fake_timestep_per_device)

# Fill the buffer above its minimum length
fake_batch = jax.pmap(get_fake_batch, static_broadcasted_argnums=1)(
    fake_timestep_per_device, add_batch_size
)
# Add two timesteps to form one transition pair
state = jax.pmap(buffer.add)(state, fake_batch)
state = jax.pmap(buffer.add)(state, fake_batch)
assert buffer.can_sample(state).all()

# Sample from the buffer
rng_key_per_device = jax.random.split(rng_key, DEVICE_COUNT_MOCK)
batch = jax.pmap(buffer.sample)(state, rng_key_per_device)

When inspecting the objects, you'll observe an extra leading "device" dimension, replicating the buffer behavior across multiple devices.

In [None]:
print(
    state.experience["obs"].shape
)  # prints (2, 8, 4 , 2) = (DEVICE_COUNT_MOCK, add_batch_size, max_length//add_batch_size, *fake_timestep['obs'].shape)
print(
    batch.experience.first["reward"].shape
)  # prints (2, 4,) = (DEVICE_COUNT_MOCK, sample_batch_size, *fake_timestep['reward'].shape)
print(
    buffer.can_sample(state)
)  # prints [True, True] as the state on each device is full above `min_length`.

## N-step > 1

We will now demonstrate how to use n-step buffers greater than 1.

First lets simplify the set up to deal with only a batch size of 1.

In [None]:
n_step = 3
add_batch_size = 1
add_sequence_size = 12
max_length = 64
sample_batch_size = 4
add_sequences = True
gamma = 1.0

Then we need to create our n step buffer with an n step greater than 1. This is where we need to specify if we wish to have sequence mapping functions. For this example we are going to use our utility functions to produce n step returns. 

We first need to create our attribute-to-function map as follows:

In [None]:
from flashbax.buffers import n_step_buffer

returns_fun = lambda x, y: jax.vmap(n_step_buffer.n_step_returns, in_axes=(0, 0, None))(
    x, y, n_step
)
discount_fun = lambda x: jax.vmap(n_step_buffer.n_step_product, in_axes=(0, None))(
    x, n_step
)
n_step_functional_map = {
    ("reward", "reward", "discount"): returns_fun,
    ("discount", "discount"): discount_fun,
}

In [None]:
buffer = fbx.make_n_step_buffer(
    max_length=max_length,
    min_length=min_length,
    sample_batch_size=sample_batch_size,
    add_sequences=add_sequences,
    add_batch_size=add_batch_size,
    n_step=n_step,
    n_step_functional_map=n_step_functional_map,
)

fake_timestep = {
    "obs": jnp.array([5, 4]),
    "reward": jnp.array(1.0),
    "discount": jnp.array(1.0),
}

buffer_state = buffer.init(fake_timestep)

fake_batch_sequence = jax.tree_map(
    lambda x: jnp.stack([x for i in range(add_batch_size)])[:, jnp.newaxis].repeat(
        add_sequence_size, axis=1
    ),
    fake_timestep,
)

fake_batch_sequence["discount"] = fake_batch_sequence["discount"].at[0,2].set(0) #* gamma
fake_batch_sequence["reward"] = fake_batch_sequence["reward"] + jnp.arange(add_sequence_size)

buffer_state = buffer.add(buffer_state, fake_batch_sequence)

rng_key = jax.random.PRNGKey(2)

sampled_batch = buffer.sample(buffer_state, rng_key)

print("The Reward Sequence added:",fake_batch_sequence["reward"])
print("The Discount Sequence added:",fake_batch_sequence["discount"])

print("The n-step return of the sampled data:", sampled_batch.experience.first["reward"])
print("The reward after n steps (but not included in) of the sampled data:", sampled_batch.experience.second["reward"])
print("The n-step discount of the sampled data:", sampled_batch.experience.first["discount"])
print("The discount after n steps (but not included in) of the sampled data:", sampled_batch.experience.second["discount"])


Lets analyse the above result. We can use the reward value after n steps to know which sequence was sampled. For the first item in the batch, the reward after n steps is 5. This means the reward sequence sampled was [ 2.  3.  4.  5.]. Additionally, the discount sequence sampled was [1. 0. 1. 1.]. In this case let us manually calculate what the 3-step reward should be for the first position. we have 2 + 1.0*(3 + 0.0*(4)) = 5. We see this result in the n-step return values given by the experience tuple. Additionally, we use the product of n-step discounts to know whether or not we reached a discount at some point during the n-step calculation so that we can accurately identify if we need to use a bootstrap value or not for training. We see that our discount value is zero which is correct. The values present in the .second transition give us access to information that occurs after the n-step transition i.e. the reward after the n-step transition that is not included in the n-steps. This would be how we would access the observation we would use for bootstrapping or other information such as the future action for a SARSA-like agent.