In [1]:
import tensorstore as ts
import numpy as np
import jax
import jax.numpy as jnp
from chex import Array
from typing import NamedTuple
import flashbax as fbx
from flashbax.buffers.trajectory_buffer import TrajectoryBufferState

class Timestep(NamedTuple):
    obs: Array

In [4]:
buffer = fbx.make_flat_buffer(
    max_length=1_000_000,
    min_length=1,
    sample_batch_size=1,
    #add_batch_size=2,
    add_sequences=True,
)



In [5]:
state.experience.obs.shape

NameError: name 'state' is not defined

In [61]:
state = buffer.init(
    Timestep(
        obs=jnp.zeros((2, 4), dtype=jnp.float32),
    ),
)
buffer_add = jax.jit(buffer.add, donate_argnums=0)

In [62]:
state.current_index

Array(0, dtype=int32, weak_type=True)

In [63]:
state = buffer_add(state, Timestep(
    obs=-1 * jnp.arange(160_000, dtype=jnp.float32).reshape((20_000, 2, 4)),
))

In [36]:
# CURRENT LIMITATIONS / TODO LIST
# - Overwriting existing vault -> functionality?
# - Only works when state.experience has depth of 1
# - Indexing is a bit weird
# - Better async stuff

import os
import json
from etils import epath
from pprint import pprint
DRIVER = "file://"
METADATA_FILE = "metadata.json"

class Vault:
    def __init__(
        self,
        base_path: str,
        init_fbx_state: TrajectoryBufferState,
        metadata: dict = None,
        vault_multiplier: int = 1_000,
    ) -> None:
        # Base path is the root folder for this vault. Must be absolute.
        self.base_path = os.path.abspath(base_path)

        # We use epath for metadata
        metadata_path = epath.Path(os.path.join(self.base_path, METADATA_FILE))

        # Check if the vault exists, & either read metadata or create it
        base_path_exists = os.path.exists(self.base_path)
        if base_path_exists:
            self.metadata = json.loads(metadata_path.read_text())
        else:
            # Create the vault root dir
            os.mkdir(self.base_path)
            self.metadata = {
                'version': 0.0,
                'vault_multiplier': vault_multiplier,
                'tree_struct': init_fbx_state.experience._fields,
                **(metadata or {}),
            }
            metadata_path.write_text(json.dumps(self.metadata))
        pprint(self.metadata)

        self.vault_multiplier = vault_multiplier

        self.vault_ds = ts.open(
            self._get_base_spec('vault_index'),
            dtype=jnp.int32,
            shape=(1,),
            create=not base_path_exists,
        ).result()
        self.vault_index = self.vault_ds.read().result()[0]

        self.all_ds = jax.tree_util.tree_map_with_path(
            lambda path, x: self._init_leaf(
                name=path[0].name,
                leaf=x,
                create_checkpoint=not base_path_exists,
            ),
            init_fbx_state.experience,
        )

        # TODO
        self.old_index = 0


    def _get_base_spec(self, name: str):
        return {
            'driver': 'zarr',
            'kvstore': {
                'driver': 'ocdbt',
                'base': f'{DRIVER}{self.base_path}',
                'path': name,
            },
        }

    def _init_leaf(self, name, leaf, create_checkpoint: bool = False):
        spec = self._get_base_spec(name)
        leaf_ds = ts.open(
            spec,
            dtype=leaf.dtype if create_checkpoint else None,
            shape=(leaf.shape[0], self.vault_multiplier * leaf.shape[1], *leaf.shape[2:]) \
                if create_checkpoint else None,
            create=create_checkpoint, # TODO!
        ).result()
        return leaf_ds
    
    def _write_leaf(
        self,
        name: str,
        leaf: Array,
        time_interval: slice = slice(None)
    ):
        leaf_ds = getattr(self.all_ds, name)
        leaf_ds[:, time_interval, ...].write(leaf[:, time_interval, ...]).result()

    def _read_leaf(self, name: str, time_interval: slice = slice(None)):
        leaf_ds = getattr(self.all_ds, name)
        return leaf_ds[:, time_interval, ...].read().result()

    def write(self, state: TrajectoryBufferState, time_interval = None):
        new_index = int(state.current_index)
        time_interval = slice(self.old_index, new_index) if time_interval is None else time_interval
        jax.tree_util.tree_map_with_path(
            lambda path, x: self._write_leaf(path[0].name, x, time_interval),
            state.experience,
        )
        self.vault_index += (new_index - self.old_index)
        self.vault_ds.write(self.vault_index).result()

        self.old_index = new_index

    def read(self, time_interval: slice = slice(None)):
        return jax.tree_util.tree_map_with_path(
            lambda path, _: self._read_leaf(path[0].name, time_interval),
            state.experience,
        )



In [68]:
v = Vault('demo', state)

{'tree_struct': ['obs'], 'vault_multiplier': 1000, 'version': 0.0}


In [69]:
v.vault_index

60000

In [70]:
v.write(state)

In [71]:
v.read(slice(0,10000))

Timestep(obs=array([[[[-0.0000e+00, -1.0000e+00, -2.0000e+00, -3.0000e+00],
         [-4.0000e+00, -5.0000e+00, -6.0000e+00, -7.0000e+00]],

        [[-8.0000e+00, -9.0000e+00, -1.0000e+01, -1.1000e+01],
         [-1.2000e+01, -1.3000e+01, -1.4000e+01, -1.5000e+01]],

        [[-1.6000e+01, -1.7000e+01, -1.8000e+01, -1.9000e+01],
         [-2.0000e+01, -2.1000e+01, -2.2000e+01, -2.3000e+01]],

        ...,

        [[-7.9976e+04, -7.9977e+04, -7.9978e+04, -7.9979e+04],
         [-7.9980e+04, -7.9981e+04, -7.9982e+04, -7.9983e+04]],

        [[-7.9984e+04, -7.9985e+04, -7.9986e+04, -7.9987e+04],
         [-7.9988e+04, -7.9989e+04, -7.9990e+04, -7.9991e+04]],

        [[-7.9992e+04, -7.9993e+04, -7.9994e+04, -7.9995e+04],
         [-7.9996e+04, -7.9997e+04, -7.9998e+04, -7.9999e+04]]]],
      dtype=float32))

In [7]:
ds = ts.open(
    {
        'driver': 'zarr',
        'kvstore': {
            'driver': 'ocdbt',
            'base': f'{DRIVER}{"/Users/callum/Mava/demo"}',
            'path': 'obs',
        },
    },
).result()

In [49]:
aa = {
    'driver': 'zarr',
    'kvstore': {
        'driver': 'ocdbt',
        'base': f'{DRIVER}{"/Users/callum/Mava/demo"}',
        'path': "obs",
    },
}

In [52]:
from etils import epath
import json

path = epath.Path('/Users/callum/Mava/demo/metadata')

In [53]:
path.write_text(json.dumps(aa))

107

In [56]:
json.loads(path.read_text())['driver']

'zarr'

In [13]:
state.experience._fields

('obs',)