# Vault demonstration

In [1]:
%%capture
try:
    import flashbax as fbx
except ModuleNotFoundError:
    print('installing flashbax')
    %pip install -q flashbax
    import flashbax as fbx

In [2]:
import jax
from typing import NamedTuple
import jax.numpy as jnp
from flashbax.vault import Vault
import flashbax as fbx
from chex import Array

We create a simple timestep structure, with a corresponding flat buffer.

In [3]:
class FbxTransition(NamedTuple):
    obs: Array

tx = FbxTransition(obs=jnp.zeros(shape=(2,)))

buffer = fbx.make_n_step_buffer(
    max_length=5,
    min_length=1,
    sample_batch_size=1,
)
buffer_state = buffer.init(tx)
buffer_add = jax.jit(buffer.add, donate_argnums=0)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.




The shape of this buffer is $(B = 1, T = 5, E = 2)$, meaning the buffer can hold 5 timesteps, where each observation is of shape $(2,)$.

In [5]:
buffer_state.experience.obs.shape

(1, 5, 2)

We create the vault, based on the buffer's experience structure.

In [6]:
v = Vault(
    vault_name="demo",
    experience_structure=buffer_state.experience,
    rel_dir="/tmp"
)

New vault created at /tmp/demo/20240205140817


We now add 10 timesteps to the buffer, and write that buffer to the vault. We inspect the buffer and vault state after each timestep.

In [7]:
for i in range(1, 10):
    print('------------------')
    print("Buffer state:")
    print(buffer_state.experience.obs)
    print()

    v.write(buffer_state)

    print("Vault state:")
    print(v.read().experience.obs)
    print('------------------')

    buffer_state = buffer_add(
        buffer_state,
        FbxTransition(obs=i * jnp.ones(1))
    )

------------------
Buffer state:
[[[0. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]]]

Vault state:
[]
------------------
------------------
Buffer state:
[[[1. 1.]
  [0. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]]]

Vault state:
[[[1. 1.]]]
------------------
------------------
Buffer state:
[[[1. 1.]
  [2. 2.]
  [0. 0.]
  [0. 0.]
  [0. 0.]]]

Vault state:
[[[1. 1.]
  [2. 2.]]]
------------------
------------------
Buffer state:
[[[1. 1.]
  [2. 2.]
  [3. 3.]
  [0. 0.]
  [0. 0.]]]

Vault state:
[[[1. 1.]
  [2. 2.]
  [3. 3.]]]
------------------
------------------
Buffer state:
[[[1. 1.]
  [2. 2.]
  [3. 3.]
  [4. 4.]
  [0. 0.]]]

Vault state:
[[[1. 1.]
  [2. 2.]
  [3. 3.]
  [4. 4.]]]
------------------
------------------
Buffer state:
[[[1. 1.]
  [2. 2.]
  [3. 3.]
  [4. 4.]
  [5. 5.]]]

Vault state:
[[[1. 1.]
  [2. 2.]
  [3. 3.]
  [4. 4.]
  [5. 5.]]]
------------------
------------------
Buffer state:
[[[6. 6.]
  [2. 2.]
  [3. 3.]
  [4. 4.]
  [5. 5.]]]

Vault state:
[[[1. 1.]
  [2. 2.]
  [3. 3.]


Notice that when the buffer (implemented as a ring buffer) wraps around, the vault continues storing the data:
```
Buffer state:
[[[6. 6.]
  [2. 2.]
  [3. 3.]
  [4. 4.]
  [5. 5.]]]

Vault state:
[[[1. 1.]
  [2. 2.]
  [3. 3.]
  [4. 4.]
  [5. 5.]
  [6. 6.]]]
```

Note: the vault must be given the buffer state at least every `max_steps` number of timesteps (i.e. before stale data is overwritten in the ring buffer).