In [207]:
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as random
import flax
import flax.linen as nn
from functools import partial
import numpyro.distributions as dist
import numpyro
import coix
import optax

import flax.linen as nn

In [208]:
numpyro.set_platform("cpu")
coix.set_backend("coix.numpyro")

In [209]:
class LSTM_MDN(nn.Module):
    n_mixture_components: int
    n_features: int

    @nn.compact
    def __call__(self, z_prev, x_curr, carry=None):
        x_curr = jnp.repeat(x_curr, z_prev.shape[0], axis=0) # repeat for each particle
        x = jnp.concatenate([z_prev, x_curr], axis=-1)
        # lstm_cell = nn.OptimizedLSTMCell(name="lstm_cell", features=self.n_features)
        # use GRU for now, since I'm not sure how to init the carry for LSTM
        lstm_cell = nn.GRUCell(name="gru_cell", features=self.n_features)
        if carry is None:
            carry = self.param('carry_init', lambda key, shape: jnp.zeros(shape), x.shape[:-1] + (self.n_features,))
        carry, x = lstm_cell(carry, x)
        mu_t = nn.Dense(self.n_mixture_components)(x)
        log_sigma_t = nn.Dense(self.n_mixture_components)(x)
        pi_t = nn.Dense(self.n_mixture_components)(x)
        return mu_t, jnp.exp(log_sigma_t), nn.softmax(pi_t)

In [210]:
lstm_mdn = LSTM_MDN(n_mixture_components=3, n_features=50)

In [232]:
def ssm_proposal(proposal, t, inputs):
    mu_t, sigma_t, pi_t = proposal(inputs["zs"][..., t], inputs["xs"][t])
    k = numpyro.sample("k", dist.Categorical(pi_t))
    z_t = numpyro.sample("z", dist.Normal(mu_t[k], sigma_t[k]))
    return z_t

def ssm_target(proposal, t, inputs):
    if t == 0:
        z_t = numpyro.sample("z", dist.Normal(0, 5))
    else:
        z_t_loc = inputs["zs"][t-1] / 2 + 25 * inputs["zs"][t-1] / (1 + inputs["z"][t-1] ** 2) + 8 * jnp.cos(1.2 * t)
        z_t = numpyro.sample("z", dist.Normal(z_t_loc, jnp.sqrt(10)))
    numpyro.sample("x", dist.Normal(z_t ** 2 / 20, 1), obs=inputs["xs"][t])
    inputs = {"zs": inputs["zs"].at[:, t].set(z_t), "xs": inputs["xs"]}
    return inputs

In [233]:
def make_ssm(params, num_particles=10, T_max=1000):
    network = coix.util.BindModule(lstm_mdn, params)
    make_particle_plate = lambda: numpyro.plate("particle", num_particles, dim=-1)
    targets = lambda t: make_particle_plate()(
        partial(ssm_target, network, t)
    )
    proposals = lambda t: make_particle_plate()(
        partial(ssm_proposal, network, t)
    )
    program = coix.algo.nasmc(targets, proposals, num_targets=T_max)
    return program

In [234]:
def ssm(xs = None, T_max = 1000):
    z_0 = numpyro.sample("z_0", dist.Normal(0, 5))
    z_t_m1 = z_0
    for t in range(1, T_max):
        z_t_loc = z_t_m1 / 2 + 25 * z_t_m1 / (1 + z_t_m1 ** 2) + 8 * jnp.cos(1.2 * t)
        z_t = numpyro.sample(f"z_{t}", dist.Normal(z_t_loc, jnp.sqrt(10)))
        x_t = numpyro.sample(f"x_{t}", dist.Normal(z_t ** 2 / 20, 1), obs=xs[t - 1] if xs is not None else None)
        z_t_m1 = z_t
    return x_t

In [235]:
def loss_fn(params, key, num_particles=10, T_max = 1000):
    shuffle_rng, rng_key = random.split(key)
    tr = numpyro.handlers.trace(numpyro.handlers.seed(ssm, rng_key)).get_trace()
    zs = jnp.zeros((num_particles, T_max))
    xs = jnp.stack([0, *[tr[f"x_{t}"]["value"] for t in range(1, T_max)]])
    assert xs.shape[0] == zs.shape[1] == T_max
    inputs={"zs": zs, "xs": xs}

    program = make_ssm(params, num_particles=num_particles, T_max=T_max)
    _, _, metrics = coix.traced_evaluate(program, seed=rng_key)(inputs)
    return metrics["loss"], metrics

In [236]:
num_particles = 10
num_steps = 1000
init_params = lstm_mdn.init(random.PRNGKey(0), z_prev=jnp.zeros((num_particles,)), x_curr=jnp.zeros(1), carry=None)
lstm_mdn_params, _ = coix.util.train(
    partial(
        loss_fn,
        num_particles=num_particles,
    ),
    init_params,
    optax.adam(3e-4),
    num_steps=num_steps,
)

Compiling the first train step...


0
(10, 1000)
(1000,)


TypeError: ssm_proposal() takes 3 positional arguments but 4 were given