In [21]:
import genjax
import rich
from rich import inspect
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import pyarrow as pa
import pandas as pd
import numpy as np

console = genjax.pretty()

In [102]:
@genjax.gen
def model(key, x):
    key, y = genjax.trace("x", genjax.Normal, shape=(2, 2))(key, (x, 1.0))
    return key, y

In [103]:
key = jax.random.PRNGKey(314159)
key, *sub_keys = jax.random.split(key, 100 + 1)
sub_keys = jnp.array(sub_keys)
_, tr = jax.jit(jax.vmap(model.simulate, in_axes=(0, None)))(sub_keys, (5.0,))
tr

In [117]:
def get_dataframe(tr: genjax.Trace):
    values, forms = jtu.tree_flatten(tr)

    # This forces scalars into a single-dim array.
    def _check_expand(v):
        if v.shape == ():
            return np.expand_dims(v, axis=-1)
        else:
            return v

    arrays = list(map(lambda v: _check_expand(np.array(v)), values))
    df = pd.DataFrame(arrays).T
    df.columns = df.columns.astype(str)
    return df

In [121]:
df = get_dataframe(tr)
df.to_feather("scratch")