In [78]:
import tensorstore as ts
import numpy as np
import jax
import jax.numpy as jnp
from chex import Array
from typing import NamedTuple
import flashbax as fbx
from flashbax.buffers.trajectory_buffer import TrajectoryBufferState

class Observation(NamedTuple):
    values: Array
    mask: Array

class Timestep(NamedTuple):
    obs: Observation

In [127]:
buffer = fbx.make_flat_buffer(
    max_length=1_000_000,
    min_length=1,
    sample_batch_size=1,
    #add_batch_size=2,
    add_sequences=True,
)



In [128]:
state = buffer.init(
    Timestep(
        obs=Observation(
            values=jnp.zeros((2, 4), dtype=jnp.float32),
            mask=jnp.zeros((2, 4), dtype=jnp.float32),
        ),
    ),
)
buffer_add = jax.jit(buffer.add, donate_argnums=0)
buffer_sample = jax.jit(buffer.sample)

In [131]:
state = buffer_add(state, Timestep(
    obs=Observation(
        values=jnp.ones((1_000, 2, 4), dtype=jnp.float32),
        mask=jnp.ones((1_000, 2, 4), dtype=jnp.float32),
    ),
))

In [132]:
state.current_index

Array(2000, dtype=int32, weak_type=True)

In [183]:
# CURRENT LIMITATIONS / TODO LIST
# - Only works when fbx_state.is_full is False (i.e. can't handle wraparound)
# - Better async stuff
# - Only tested with flat buffers

import os
import json
from etils import epath
from typing import Optional
from operator import attrgetter
DRIVER = "file://"
METADATA_FILE = "metadata.json"

def _nested_subattr(path):
    return str.join('.', jax.tree_util.tree_map(lambda s: s.name, path))

class Vault:
    def __init__(
        self,
        base_path: str,
        init_fbx_state: TrajectoryBufferState,
        metadata: dict = None,
        vault_multiplier: int = 1_000_000,
    ) -> None:
        # Base path is the root folder for this vault. Must be absolute.
        self.base_path = os.path.abspath(base_path)

        # We use epath for metadata
        metadata_path = epath.Path(os.path.join(self.base_path, METADATA_FILE))

        # Check if the vault exists, & either read metadata or create it
        base_path_exists = os.path.exists(self.base_path)
        if base_path_exists:
            self.metadata = json.loads(metadata_path.read_text())
        else:
            # Create the vault root dir
            os.mkdir(self.base_path)
            self.metadata = {
                'version': 0.0,
                'vault_multiplier': vault_multiplier,
                'tree_struct': init_fbx_state.experience._fields,
                **(metadata or {}),
            }
            metadata_path.write_text(json.dumps(self.metadata))

        self.vault_multiplier = vault_multiplier

        self.vault_ds = ts.open(
            self._get_base_spec('vault_index'),
            dtype=jnp.int32,
            shape=(1,),
            create=not base_path_exists,
        ).result()
        self.vault_index = int(self.vault_ds.read().result()[0])

        self.all_ds = jax.tree_util.tree_map_with_path(
            lambda path, x: self._init_leaf(
                name=_nested_subattr(path),
                leaf=x,
                create_checkpoint=not base_path_exists,
            ),
            init_fbx_state.experience,
        )

        self.fbx_sample_experience = jax.tree_map(
            lambda x: x[:, 0:1, ...],
            init_fbx_state.experience,
        )
        self.last_received_fbx_index = 0


    def _get_base_spec(self, name: str):
        return {
            'driver': 'zarr',
            'kvstore': {
                'driver': 'ocdbt',
                'base': f'{DRIVER}{self.base_path}',
                'path': name,
            },
        }

    def _init_leaf(self, name, leaf, create_checkpoint: bool = False):
        spec = self._get_base_spec(name)
        leaf_ds = ts.open(
            spec,
            dtype=leaf.dtype if create_checkpoint else None,
            shape=(leaf.shape[0], self.vault_multiplier * leaf.shape[1], *leaf.shape[2:]) \
                if create_checkpoint else None,
            create=create_checkpoint,
        ).result()
        return leaf_ds
    
    def _write_leaf(
        self,
        name: str,
        leaf: Array,
        source_interval: tuple,
        dest_interval: tuple,
    ):
        leaf_ds = attrgetter(name)(self.all_ds)
        leaf_ds[:, slice(*dest_interval), ...].write(
            leaf[:, slice(*source_interval), ...],
        ).result()

    def write(
        self,
        fbx_state: TrajectoryBufferState,
        source_interval: tuple = (None, None),
        dest_interval: tuple = (None, None),
    ):
        fbx_current_index = int(fbx_state.current_index)

        if source_interval == (None, None):
            source_interval = (self.last_received_fbx_index, fbx_current_index)
        write_length = source_interval[1] - source_interval[0]
        if write_length == 0:
            return

        if dest_interval == (None, None):
            dest_interval = (self.vault_index, self.vault_index + write_length)
        
        assert (source_interval[1] - source_interval[0]) == (dest_interval[1] - dest_interval[0])

        print(f"Writing {source_interval} from buffer into {dest_interval} in vault")

        jax.tree_util.tree_map_with_path(
            lambda path, x: self._write_leaf(
                name=_nested_subattr(path),
                leaf=x,
                source_interval=source_interval,
                dest_interval=dest_interval,
            ),
            fbx_state.experience,
        )
        self.vault_index += write_length
        self.vault_ds.write(self.vault_index).result()

        self.last_received_fbx_index = fbx_current_index

    def _read_leaf(
        self,
        name: str,
        read_interval: tuple,
    ):
        leaf_ds = attrgetter(name)(self.all_ds)
        return leaf_ds[:, slice(*read_interval), ...].read().result()

    def read(self, read_interval: tuple = (None, None)):
        if read_interval == (None, None):
            read_interval = (0, self.vault_index)

        return jax.tree_util.tree_map_with_path(
            lambda path, _: self._read_leaf(_nested_subattr(path), read_interval),
            self.fbx_sample_experience,
        )

    def get_buffer(self, size: int, key: Array, starting_index: Optional[int] = None):
        assert size <= self.vault_index
        if starting_index is None:
            starting_index = int(jax.random.randint(
                key=key,
                shape=(),
                minval=0,
                maxval=self.vault_index - size,
            ))
        return TrajectoryBufferState(
            experience=self.read((starting_index, starting_index + size)),
            current_index=starting_index + size,
            is_full=True,
        )



In [190]:
v = Vault('demo', state)

In [192]:
v.read()

Timestep(obs=Observation(values=array([[[[1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        ...,

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.]]]], dtype=float32), mask=array([[[[1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        ...,

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.]]]], dtype=float32)))

In [63]:
state = buffer_add(state, Timestep(
    obs=jnp.arange(800_000, dtype=jnp.float32).reshape((100_000, 2, 4)),
))
v.write(state)

Writing (840000, 940000) from buffer into (1620000, 1720000) in vault


In [193]:
loaded_state = v.get_buffer(10, jax.random.PRNGKey(43))

In [199]:
sample = buffer_sample(loaded_state, jax.random.PRNGKey(32))
sample

TransitionSample(experience=ExperiencePair(first=Timestep(obs=Observation(values=Array([[[1., 1., 1., 1.],
        [1., 1., 1., 1.]]], dtype=float32), mask=Array([[[1., 1., 1., 1.],
        [1., 1., 1., 1.]]], dtype=float32))), second=Timestep(obs=Observation(values=Array([[[1., 1., 1., 1.],
        [1., 1., 1., 1.]]], dtype=float32), mask=Array([[[1., 1., 1., 1.],
        [1., 1., 1., 1.]]], dtype=float32)))))