Genjax supports serializing traces into a byte format. To support different disk formats, GenJax offers different backends. Here we use `msgpack_serialize` which writes traces using the `MsgPack` protocol.

In [1]:
import genjax
from genjax._src.core.serialization.msgpack import msgpack_serialize
import jax
import jax.numpy as jnp

In [2]:
@genjax.static_gen_fn
def model(mu, coins):
    x = genjax.normal(mu, 1.0) @ "x"
    y = genjax.flip(jnp.sum(coins) / coins) @ "y"
    return y

key = jax.random.PRNGKey(314159)
tr = model.simulate(key, (-2.1, jnp.array([1, 1, 0])))
print(tr)

StaticTrace(
  gen_fn=StaticGenerativeFunction(source=<function model>),
  args=(-2.1, i32[3]),
  retval=i32[3],
  address_choices=Trie(
    inner={
      'x':
      DistributionTrace(
        gen_fn=TFPDistribution(
          make_distribution=<class 'tensorflow_probability.substrates.jax.distributions.normal.Normal'>
        ),
        args=(-2.1, 1.0),
        value=f32[],
        score=f32[]
      ),
      'y':
      DistributionTrace(
        gen_fn=TFPDistribution(make_distribution=<function <lambda>>),
        args=(f32[3],),
        value=i32[3],
        score=f32[]
      )
    }
  ),
  cache=Trie(inner={}),
  score=f32[]
)


We can now convert `tr` into a byte representation using `serialize`. Note that GenJax provides `pickle`-like functions (e.g. `dumps`) that can alternatively be called instead. 

In [3]:
payload = msgpack_serialize.serialize(tr) # or msgpack_serialize.dumps(tr)

Deserialization is slightly different then one might be used to. In addition to passing in the bytes, the generative function is specified as well.

In [5]:
retrieved_tr = msgpack_serialize.deserialize(payload, model)
print(retrieved_tr)

StaticTrace(
  gen_fn=StaticGenerativeFunction(source=<function model>),
  args=(-2.1, i32[3]),
  retval=i32[3],
  address_choices=Trie(
    inner={
      'x':
      DistributionTrace(
        gen_fn=TFPDistribution(
          make_distribution=<class 'tensorflow_probability.substrates.jax.distributions.normal.Normal'>
        ),
        args=(-2.1, 1.0),
        value=f32[],
        score=f32[]
      ),
      'y':
      DistributionTrace(
        gen_fn=TFPDistribution(make_distribution=<function <lambda>>),
        args=(f32[3],),
        value=i32[3],
        score=f32[]
      )
    }
  ),
  cache=Trie(inner={}),
  score=f32[]
)


Note that serialized trace can be written to disk with the `MsgPack` backend. This means that even if the Python runtime terminates, the written trace can still be deserialized so long as the generative function induces the same trace structure. Observe that a fresh copy of `model` works just as well to deserialize the trace.

In [11]:
@genjax.static_gen_fn
def model_copy(mu, coins):
    x = genjax.normal(mu, 1.0) @ "x"
    y = genjax.flip(jnp.sum(coins) / coins) @ "y"
    return y

retrieved_tr = msgpack_serialize.deserialize(payload, model_copy)
print(retrieved_tr)

StaticTrace(
  gen_fn=StaticGenerativeFunction(source=<function model_copy>),
  args=(-2.1, i32[3]),
  retval=i32[3],
  address_choices=Trie(
    inner={
      'x':
      DistributionTrace(
        gen_fn=TFPDistribution(
          make_distribution=<class 'tensorflow_probability.substrates.jax.distributions.normal.Normal'>
        ),
        args=(-2.1, 1.0),
        value=f32[],
        score=f32[]
      ),
      'y':
      DistributionTrace(
        gen_fn=TFPDistribution(make_distribution=<function <lambda>>),
        args=(f32[3],),
        value=i32[3],
        score=f32[]
      )
    }
  ),
  cache=Trie(inner={}),
  score=f32[]
)
