In [9]:
%%capture
! pip install flashbax

In [25]:
import flashbax as fbx
import jax
import jax.numpy as jnp

Let's create the buffer function, and create a train loop. Note that the buffer state is created inside the loop.

In [26]:
buffer = fbx.make_flat_buffer(
    max_length=100_000,
    min_length=1,
    sample_batch_size=1,
)

def train_loop(buffer_fn):
    timestep = {
        "obs": jnp.zeros(shape=(10, 20, 30), dtype=jnp.float32),
    }
    buffer_state = buffer_fn.init(timestep)
    buffer_add = buffer.add

    buffer_state = jax.lax.fori_loop(
        lower=0,
        upper=50_000, # Add 50_000 elements
        body_fun=lambda x, y: buffer_add(y, jax.tree_map(lambda z: jnp.ones_like(z) * x, timestep)),
        init_val=buffer_state,
    )
    return buffer_state



Before anything else, let's check our function works.

In [27]:
buffer_state_test = train_loop(buffer_fn=buffer)

# If all goes well, the 1234th timestep will contain the values "1234"
import chex
chex.assert_trees_all_equal(
    buffer_state_test.experience['obs'][0, 1234, ...],
    jnp.ones(shape=(10, 20, 30), dtype=jnp.float32) * 1234,
)

We don't want to run out of RAM, so we'll manually delete our buffer state each time.

In [28]:
del buffer_state_test

Install memory profiler

In [29]:
%%capture
!pip install memory-profiler
%load_ext memory_profiler

In [30]:
# Memory strain of vanilla train function, without jitting
%memit buffer_state_A = train_loop(buffer_fn=buffer)

peak memory: 5052.51 MiB, increment: 4578.70 MiB


In [31]:
del buffer_state_A

In [32]:
# Memory strain of jitted train function -- jitted OUTSIDE
%memit buffer_state_B = jax.jit(train_loop, static_argnums=0)(buffer_fn=buffer)

peak memory: 2765.12 MiB, increment: 2290.14 MiB


In [33]:
del buffer_state_B

In [35]:
def train_loop_with_donate(buffer_fn):
    timestep = {
        "obs": jnp.zeros(shape=(10, 20, 30), dtype=jnp.float32),
    }
    buffer_state = buffer_fn.init(timestep)

    # ------------------
    # Note the change:
    buffer_add = jax.jit(buffer.add, donate_argnums=0)
    # ------------------

    buffer_state = jax.lax.fori_loop(
        lower=0,
        upper=50_000, # Add 50_000 elements
        body_fun=lambda x, y: buffer_add(
            y, jax.tree_map(lambda z: jnp.ones_like(z) * x, timestep)
          ),
        init_val=buffer_state,
    )
    return buffer_state

In [36]:
# Memory strain of train function, jitted inside + donate, but not jitted outside
%memit buffer_state_C = train_loop_with_donate(buffer_fn=buffer)

peak memory: 5054.29 MiB, increment: 4577.61 MiB


In [37]:
del buffer_state_C

In [38]:
%memit buffer_state_D = jax.jit(train_loop_with_donate, static_argnums=0)(buffer_fn=buffer)

peak memory: 2766.87 MiB, increment: 2290.04 MiB


In [39]:
del buffer_state_D