In [1]:
import genjax
from genjax import dippl
from genjax import gensp
from genjax import select, dirac
import equinox as eqx
import jax
import jax.numpy as jnp
import adevjax

console = genjax.pretty(show_locals=True)
key = jax.random.PRNGKey(314159)

I0000 00:00:1696109307.511520   21263 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


In [2]:
class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias

In [3]:
in_size, out_size = 1, 5
decoder = Linear(in_size, out_size, key=jax.random.PRNGKey(0))

in_size, out_size = 5, 1
encoder = Linear(in_size, out_size, key=jax.random.PRNGKey(1))

In [4]:
@genjax.gen
def decoder_model(decoder):
    x = dippl.normal_reinforce(0.0, 1.0) @ "x"
    v = decoder(jnp.array([x]))
    v = dirac(v) @ "v"


decoder_model = gensp.choice_map_distribution(decoder_model, select("x", "v"), None)
(w, decoder_v) = decoder_model.random_weighted(key, decoder)
decoder_v




├── [1m:v[0m
│   └──  f32[5]
└── [1m:x[0m
    └──  f32[]

In [5]:
@genjax.gen
def encoder_model(encoder, chm):
    v = chm.get_leaf_value()["v"]
    latent = encoder(v)[0]
    x = dippl.normal_reinforce(latent, 1.0) @ "x"


encoder_model = gensp.choice_map_distribution(encoder_model, select("x"), None)
(w, encoder_v) = encoder_model.random_weighted(key, encoder, decoder_v)
encoder_v




└── [1m:x[0m
    └──  f32[]

In [8]:
def variational_grad(
    key, data: genjax.ValueChoiceMap, encoder: Linear, decoder: Linear
):
    @dippl.loss
    def vae_loss(encoder, decoder):
        v = dippl.upper(encoder_model)(encoder, data)
        merged = gensp.merge(v, data)
        dippl.lower(decoder_model)(merged, decoder)

    return vae_loss.grad_estimate(key, (encoder, decoder))