In [1]:
import flashbax as fbx
import jax.numpy as jnp
from jax.tree_util import tree_map
import jax

In [2]:
key = jax.random.PRNGKey(0)

In [3]:
buffer_a = fbx.make_trajectory_buffer(
    add_batch_size=1,
    max_length_time_axis=10_000,
    min_length_time_axis=5,
    sample_sequence_length=5,
    period=1,
    sample_batch_size=2,
)

timestep = {
    "obs": jnp.ones((2)),
    "acts": jnp.ones(3),
}

state_a = buffer_a.init(
    timestep,
)
for i in range(1500):
    state_a = buffer_a.add(
        state_a,
        tree_map(lambda x, _i=i: (x * _i)[None, None, ...], timestep),
    )

print(state_a.current_index)
tree_map(lambda x: x.shape, state_a)

1500


TrajectoryBufferState(experience={'acts': (1, 10000, 3), 'obs': (1, 10000, 2)}, current_index=(), is_full=())

In [4]:
buffer_b = fbx.make_trajectory_buffer(
    add_batch_size=1,
    max_length_time_axis=10_000,
    min_length_time_axis=5,
    sample_sequence_length=5,
    period=1,
    sample_batch_size=13,
)

timestep = {
    "obs": jnp.ones((2)),
    "acts": jnp.ones(3),
}

state_b = buffer_b.init(
    timestep,
)
for i in range(6000):
    state_b = buffer_b.add(
        state_b,
        tree_map(lambda x, _i=i: (1000 - x * _i)[None, None, ...], timestep),
    )

print(state_b.current_index)
tree_map(lambda x: x.shape, state_b)

6000


TrajectoryBufferState(experience={'acts': (1, 10000, 3), 'obs': (1, 10000, 2)}, current_index=(), is_full=())

In [5]:
sample_a = buffer_a.sample(state_a, key)
print(sample_a.experience['acts'].shape)

sample_b = buffer_b.sample(state_b, key)
print(sample_b.experience['acts'].shape)

(2, 5, 3)
(13, 5, 3)


In [6]:
mixer = fbx.make_mixer(
    buffers=[buffer_a, buffer_b],
    sample_batch_size=8,
    proportions=[2,3]
)

In [7]:
joint_sample = mixer.sample(
    [state_a, state_b],
    key,
)

joint_sample.experience['acts'].shape

(6, 5, 3)

In [8]:
mixer = fbx.make_mixer(
    buffers=[buffer_a, buffer_b],
    sample_batch_size=8,
    proportions=[0.1, 0.9]
)

In [9]:
joint_sample = mixer.sample(
    [state_a, state_b],
    key,
)

joint_sample.experience['acts'].shape

(8, 5, 3)