In [1]:
import jax
import jax.numpy as jnp
from functools import partial
import flax

num_envs = 3

@flax.struct.dataclass
class ReplayBufferState:
    """Contains data related to a replay buffer."""

    data: jnp.ndarray
    insert_position: int
    sample_position: int

def init_replay_buffer_state(data_shape):
    data = jnp.zeros(data_shape, dtype=jnp.float32)
    return ReplayBufferState(data, 0, 0)

@partial(jax.jit, static_argnames="mask_size")
def insert_in_replay_state(buffer_state: ReplayBufferState, samples: jax.Array, mask: jax.Array, mask_size: int) -> ReplayBufferState:
    indices_mask = jnp.nonzero(mask, size=mask_size)[0]

    # Apply the mask to the samples to keep the valid ones
    new_samples = jnp.take(samples, indices_mask, axis=0)
    samples_size = len(new_samples)

    # Current buffer state
    data = buffer_state.data
    insert_idx = buffer_state.insert_position
    size_buffer = buffer_state.sample_position

    # Insert the new samples in the buffer
    data = jax.lax.dynamic_update_slice_in_dim(data, new_samples, insert_idx, axis=0)
    insert_idx = (insert_idx + samples_size) % size_buffer
    sample_idx = jnp.minimum(buffer_state.sample_position + samples_size, size_buffer)

    return buffer_state.replace(
        data=data,
        insert_position=insert_idx,
        sample_position=sample_idx,
    )

# Create a buffer state
buffer_state = init_replay_buffer_state((1000, 10))

# Dummy samples
samples = jnp.ones((num_envs, 10))

# Mask to insert only the first two samples
mask = jnp.array([True, True, False])

buffer_state = insert_in_replay_state(buffer_state, samples, mask)

2023-11-14 16:05:30.078392: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.3.52). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function insert_in_replay_state at /tmp/ipykernel_179206/3259525882.py:20 for jit. This concrete value was not available in Python because it depends on the value of the argument mask.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError