# Vault demonstration

In [None]:
%%capture
try:
    import flashbax as fbx
except ModuleNotFoundError:
    print('installing flashbax')
    %pip install -q flashbax
    import flashbax as fbx

In [None]:
import jax
from typing import NamedTuple
import jax.numpy as jnp
from flashbax.vault import Vault
import flashbax as fbx
from chex import Array

We create a simple timestep structure, with a corresponding flat buffer.

In [None]:
class FbxTransition(NamedTuple):
    obs: Array

tx = FbxTransition(obs=jnp.zeros(shape=(2,)))

buffer = fbx.make_flat_buffer(
    max_length=5,
    min_length=1,
    sample_batch_size=1,
)
buffer_state = buffer.init(tx)
buffer_add = jax.jit(buffer.add, donate_argnums=0)

The shape of this buffer is $(B = 1, T = 5, E = 2)$, meaning the buffer can hold 5 timesteps, where each observation is of shape $(2,)$.

In [None]:
buffer_state.experience.obs.shape

(1, 5, 2)

We create the vault, based on the buffer's experience structure.

In [None]:
v = Vault(
    vault_name="demo",
    experience_structure=buffer_state.experience,
    rel_dir="/tmp"
)

New vault created at /tmp/demo/20250923012939
Since the provided buffer state has a temporal dimension of 5, you must write to the vault at least every 4 timesteps to avoid data loss.


We now add 10 timesteps to the buffer, and write that buffer to the vault. We inspect the buffer and vault state after each timestep.

In [24]:
for i in range(1, 10):
    print('------------------')
    print("Buffer state:")
    print(buffer_state.experience.obs)
    print()

    v.write(buffer_state)

    print("Vault state:")
    print(v.read().experience.obs)
    print('------------------')

    buffer_state = buffer_add(
        buffer_state,
        FbxTransition(obs=i * jnp.ones(1))
    )

------------------
Buffer state:
[[[5. 5.]
  [6. 6.]
  [7. 7.]
  [8. 8.]
  [9. 9.]]]

Vault state:
[[[1. 1.]
  [1. 1.]
  [2. 2.]
  [3. 3.]
  [4. 4.]
  [5. 5.]
  [6. 6.]
  [7. 7.]
  [8. 8.]
  [9. 9.]]]
------------------
------------------
Buffer state:
[[[1. 1.]
  [6. 6.]
  [7. 7.]
  [8. 8.]
  [9. 9.]]]

Vault state:
[[[1. 1.]
  [1. 1.]
  [2. 2.]
  [3. 3.]
  [4. 4.]
  [5. 5.]
  [6. 6.]
  [7. 7.]
  [8. 8.]
  [9. 9.]
  [1. 1.]]]
------------------
------------------
Buffer state:
[[[1. 1.]
  [2. 2.]
  [7. 7.]
  [8. 8.]
  [9. 9.]]]

Vault state:
[[[1. 1.]
  [1. 1.]
  [2. 2.]
  [3. 3.]
  [4. 4.]
  [5. 5.]
  [6. 6.]
  [7. 7.]
  [8. 8.]
  [9. 9.]
  [1. 1.]
  [2. 2.]]]
------------------
------------------
Buffer state:
[[[1. 1.]
  [2. 2.]
  [3. 3.]
  [8. 8.]
  [9. 9.]]]

Vault state:
[[[1. 1.]
  [1. 1.]
  [2. 2.]
  [3. 3.]
  [4. 4.]
  [5. 5.]
  [6. 6.]
  [7. 7.]
  [8. 8.]
  [9. 9.]
  [1. 1.]
  [2. 2.]
  [3. 3.]]]
------------------
------------------
Buffer state:
[[[1. 1.]
  [2. 2.]
  [3. 3

Notice that when the buffer (implemented as a ring buffer) wraps around, the vault continues storing the data:
```
Buffer state:
[[[6. 6.]
  [2. 2.]
  [3. 3.]
  [4. 4.]
  [5. 5.]]]

Vault state:
[[[1. 1.]
  [2. 2.]
  [3. 3.]
  [4. 4.]
  [5. 5.]
  [6. 6.]]]
```

Note: the vault must be given the buffer state at least every `max_steps` number of timesteps (i.e. before stale data is overwritten in the ring buffer).

However, for offline learning scenario where you want to directly write in a full buffer, remember to specify the interval

Here we create a same buffer and add data until the buffer is full

In [None]:
MAX_LENGTH = 5

fbuffer = fbx.make_flat_buffer(
    max_length=MAX_LENGTH,
    min_length=1,
    sample_batch_size=1,
)
fbuffer_state = fbuffer.init(tx)
fbuffer_add = jax.jit(fbuffer.add, donate_argnums=0)

for i in range(MAX_LENGTH):
    fbuffer_state = fbuffer_add(
        fbuffer_state,
        FbxTransition(obs=i * jnp.ones(1))
    )

print(fbuffer_state.is_full)

True


Now we create a new vault to write in the full buffer

In [None]:
fv = Vault(
    vault_name="TestVault",
    experience_structure=buffer_state.experience,
    rel_dir="/tmp"
)

fv.write(buffer_state)
vault_data = fv.read()

print("Vault data shape:", vault_data.experience.obs.shape)
print("Vault data is empty:", jnp.all(vault_data.experience.obs == 0))

New vault created at /tmp/TestVault/20250923015646
Since the provided buffer state has a temporal dimension of 5, you must write to the vault at least every 4 timesteps to avoid data loss.
Vault data shape: (1, 0, 2)
Vault data is empty: True


You can notice that in this way the write is unsuccessful, the correct way is to specify the interval:

In [None]:
del fv

fv = Vault(
    vault_name="TestVault",
    experience_structure=buffer_state.experience,
    rel_dir="/tmp"
)

fv.write(buffer_state, source_interval=(0, MAX_LENGTH))
vault_data = fv.read()

print("Vault data shape:", vault_data.experience.obs.shape)
print("Vault data is empty:", jnp.all(vault_data.experience.obs == 0))

New vault created at /tmp/TestVault/20250923015840
Since the provided buffer state has a temporal dimension of 5, you must write to the vault at least every 4 timesteps to avoid data loss.
Vault data shape: (1, 5, 2)
Vault data is empty: False


if you don't know the length of the buffer, you can use:

In [None]:
max_index = fbx.utils.get_tree_shape_prefix(fbuffer_state.experience, n_axes=2)[1]
print(max_index)

5
