In [1]:
import flashbax as fbx
import jax.numpy as jnp
from jax.tree_util import tree_map
import jax

key = jax.random.PRNGKey(0)

In [2]:
# Create our first buffer, with a sample batch size of 4
buffer_a = fbx.make_trajectory_buffer(
    add_batch_size=1,
    max_length_time_axis=1000,
    min_length_time_axis=5,
    sample_sequence_length=5,
    period=1,
    sample_batch_size=4,
)

timestep = {
    "obs": jnp.ones((2)),
    "acts": jnp.ones(3),
}

state_a = buffer_a.init(
    timestep,
)
for i in range(100):
    # Fill with POSITIVE values
    state_a = jax.jit(buffer_a.add, donate_argnums=0)(
        state_a,
        tree_map(lambda x, _i=i: (x * _i)[None, None, ...], timestep),
    )

sample_a = buffer_a.sample(state_a, key)
tree_map(lambda x: x.shape, sample_a)

TrajectoryBufferSample(experience={'acts': (4, 5, 3), 'obs': (4, 5, 2)})

In [3]:
# Create our second buffer, with a sample batch size of 16
buffer_b = fbx.make_trajectory_buffer(
    add_batch_size=1,
    max_length_time_axis=1000,
    min_length_time_axis=5,
    sample_sequence_length=5,
    period=1,
    sample_batch_size=16,
)

timestep = {
    "obs": jnp.ones((2)),
    "acts": jnp.ones(3),
}

state_b = buffer_b.init(
    timestep,
)
for i in range(100):
    # Fill with NEGATIVE values
    state_b = jax.jit(buffer_b.add, donate_argnums=0)(
        state_b,
        tree_map(lambda x, _i=i: (- x * _i)[None, None, ...], timestep),
    )

sample_b = buffer_b.sample(state_b, key)
tree_map(lambda x: x.shape, sample_b)

TrajectoryBufferSample(experience={'acts': (16, 5, 3), 'obs': (16, 5, 2)})

In [4]:
# Make the mixer, with a ratio of 1:3 from buffer_a:buffer_b
mixer = fbx.make_mixer(
    buffers=[buffer_a, buffer_b],
    sample_batch_size=8,
    proportions=[1,3],
)

# jittable sampling!
mixer_sample = jax.jit(mixer.sample)

In [5]:
# Sample from the mixer, using the usual flashbax API
joint_sample = mixer_sample(
    [state_a, state_b],
    key,
)

# Notice the resulting shape
tree_map(lambda x: x.shape, joint_sample)

TrajectoryBufferSample(experience={'acts': (8, 5, 3), 'obs': (8, 5, 2)})

In [6]:
# Notice how the first 1/4 * 8 = 2 batches are from buffer_a (POSITIVE VALUES)
tree_map(lambda x: x[0:2], joint_sample)

TrajectoryBufferSample(experience={'acts': Array([[[90., 90., 90.],
        [91., 91., 91.],
        [92., 92., 92.],
        [93., 93., 93.],
        [94., 94., 94.]],

       [[56., 56., 56.],
        [57., 57., 57.],
        [58., 58., 58.],
        [59., 59., 59.],
        [60., 60., 60.]]], dtype=float32), 'obs': Array([[[90., 90.],
        [91., 91.],
        [92., 92.],
        [93., 93.],
        [94., 94.]],

       [[56., 56.],
        [57., 57.],
        [58., 58.],
        [59., 59.],
        [60., 60.]]], dtype=float32)})

In [7]:
# and how the second 3/4 * 8 = 6 batches are from buffer_b (NEGATIVE VALUES)
tree_map(lambda x: x[2:], joint_sample)

TrajectoryBufferSample(experience={'acts': Array([[[-34., -34., -34.],
        [-35., -35., -35.],
        [-36., -36., -36.],
        [-37., -37., -37.],
        [-38., -38., -38.]],

       [[-88., -88., -88.],
        [-89., -89., -89.],
        [-90., -90., -90.],
        [-91., -91., -91.],
        [-92., -92., -92.]],

       [[-30., -30., -30.],
        [-31., -31., -31.],
        [-32., -32., -32.],
        [-33., -33., -33.],
        [-34., -34., -34.]],

       [[-11., -11., -11.],
        [-12., -12., -12.],
        [-13., -13., -13.],
        [-14., -14., -14.],
        [-15., -15., -15.]],

       [[-78., -78., -78.],
        [-79., -79., -79.],
        [-80., -80., -80.],
        [-81., -81., -81.],
        [-82., -82., -82.]],

       [[-15., -15., -15.],
        [-16., -16., -16.],
        [-17., -17., -17.],
        [-18., -18., -18.],
        [-19., -19., -19.]]], dtype=float32), 'obs': Array([[[-34., -34.],
        [-35., -35.],
        [-36., -36.],
        [-37., -