In [None]:
import jax
import jax.numpy as jnp
from functools import partial
import flax

num_envs = 3


@flax.struct.dataclass
class ReplayBufferState:
    """Contains data related to a replay buffer."""

    data: jnp.ndarray
    insert_position: int
    sample_position: int


def init_replay_buffer_state(data_shape):
    data = jnp.zeros(data_shape, dtype=jnp.float32)
    return ReplayBufferState(data, 0, 0)


@partial(jax.jit)
def insert_in_replay_state(buffer_state: ReplayBufferState, samples: jax.Array, mask: jax.Array) -> ReplayBufferState:
    # Padded indices of the mask elements
    samples_size = jnp.sum(mask)
    mask_indices = jnp.where(mask, size=len(mask), fill_value=len(mask))

    # Current buffer state
    data = buffer_state.data
    insert_idx = buffer_state.insert_position
    size_buffer = buffer_state.sample_position

    # Create a copy of the buffer with samples inserted at insert_idx
    data_indices = insert_idx + jnp.arange(len(mask))
    update_mask = jnp.arange(len(mask))[:, None] < samples_size
    data = data.at[data_indices].set(jnp.where(update_mask, samples[mask_indices], data[data_indices]))
    insert_idx = (insert_idx + samples_size) % size_buffer
    sample_idx = jnp.minimum(buffer_state.sample_position + samples_size, size_buffer)

    return buffer_state.replace(
        data=data,
        insert_position=insert_idx,
        sample_position=sample_idx,
    )


# Create a buffer state
buffer_state = init_replay_buffer_state((1000, 10))

# Dummy samples
samples = jnp.ones((num_envs, 10))

# Mask to insert only the first two samples
mask = jnp.array([True, True, False])

buffer_state = insert_in_replay_state(buffer_state, samples, mask)

print(buffer_state.data[:5])