In [1]:
!nvidia-smi

Mon Jul  8 18:54:17 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A30                     Off | 00000000:01:00.0 Off |                   On |
| N/A   31C    P0              30W / 165W |     50MiB / 24576MiB |     N/A      Default |
|                                         |                      |              Enabled |
+-----------------------------------------+----------------------+----------------------+

+------------------------------------------------------------------

In [2]:
import chex
import jax
import jax.numpy as jnp
import flashbax as fbx
from qdax.core.neuroevolution.buffers.buffer import QDTransition


In [3]:
# First define hyper-parameters of the buffer.
max_length_time_axis = 1000 # Maximum length of the buffer along the time axis. 
min_length_time_axis = 1000 # Minimum length across the time axis before we can sample.
sample_batch_size = 64 # Batch size of trajectories sampled from the buffer.
add_batch_size = 64 # Batch size of trajectories added to the buffer.
sample_sequence_length = 1000 # Sequence length of trajectories sampled from the buffer.
add_sequence_length = 1000 # Sequence length of trajectories added to the buffer.
period = 1000 # Period at which we sample trajectories from the buffer.

# Instantiate the trajectory buffer, which is a NamedTuple of pure functions.
buffer = fbx.make_trajectory_buffer(
    max_length_time_axis=max_length_time_axis,
    min_length_time_axis=min_length_time_axis,
    sample_batch_size=sample_batch_size,
    add_batch_size=add_batch_size,
    sample_sequence_length=sample_sequence_length,
    period=period
)

In [4]:
dummy_transition = QDTransition.init_dummy(
    observation_dim=28,
    action_dim=8,
    descriptor_dim=4,
)

In [5]:
dummy_transition.dones.shape

()

In [6]:
state = buffer.init(dummy_transition)

In [7]:
random_key = jax.random.PRNGKey(0)

broadcast_fn = lambda x: jnp.broadcast_to(x, (add_batch_size, add_sequence_length, *x.shape))
fake_batch_sequence = jax.tree_map(broadcast_fn, dummy_transition)

state = buffer.add(state, fake_batch_sequence)
assert buffer.can_sample(state) 

In [8]:
fake_batch_sequence.obs.shape

(64, 1000, 28)

In [9]:
state

TrajectoryBufferState(experience=QDTransition(obs=Array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., 

In [10]:
batch = buffer.sample(state, random_key)

In [11]:
print(batch.experience.obs.shape)

(64, 1000, 28)


In [14]:
batch.experience.keys

AttributeError: 'QDTransition' object has no attribute 'keys'