In [None]:
import matplotlib.pyplot as plt
import numpy as np

import jax
from jax import vmap
import jax.numpy as jnp
import jax.scipy.stats as stats
import jax.random as jr

import blackjax

from itertools import count
from functools import partial

from dynamax.parameters import to_unconstrained, from_unconstrained, log_det_jac_constrain
from dynamax.utils.utils import pytree_stack, ensure_array_has_batch_dim
from dynamax.linear_gaussian_ssm import LinearGaussianSSM

from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

# Plot ACFS

In [None]:
from util.param import sample_prior
from util.sample import map_sims
from util.train import to_train_array
from util.numerics import compute_acf
from simulators import lgssm, svssm, sirssm, lvssm
from matplotlib.patches import Patch


def get_acf_plot(key, ssm, max_lag, state_dim, emission_dim, num_timesteps, target_vars):

    setup = ssm.setup(state_dim, emission_dim, 0, target_vars)
    ssmodel = setup['ssm']
    props = setup['props']
    info = setup['exp_info']

    key, subkey = jr.split(key)
    params = sample_prior(subkey, props, 1)
    acfs_all_runs = []
    states_all_runs = []
    emissions_all_runs = []

    for param in params:

        cp = to_train_array(param, props)
        key, subkey = jr.split(key)
        states, emissions = map_sims(subkey, cp, props, ssmodel, num_timesteps)
        states_all_runs.append(states)
        emissions_all_runs.append(emissions)
        acfs = compute_acf(emissions, max_lag)
        acfs_all_runs.append(acfs)

    acfs_all_runs = jnp.array(acfs_all_runs)
    states_all_runs = jnp.array(states_all_runs)
    emissions_all_runs = jnp.array(emissions_all_runs)
    acf_mean = acfs_all_runs.mean(0)
    maxx = jnp.max(acf_mean)
    custom_patch = Patch(color='none', label=info['sim'].upper())

    return acf_mean, maxx, custom_patch

ssms = [lgssm, svssm, lvssm, sirssm]
state_dims = [10, 1, 2, 3]
emission_dims = [10, 1, 1, 1]
target_vars = [['d4'],[ 'd4'],[ 'd3'], ['d3']]

outputs = {
    'lgssm': {
        'means': [],
        'maxxs': []
    },

    'svssm': {
        'means': [],
        'maxxs': []
    },

    'lvssm': {
        'means': [],
        'maxxs': []
    },

    'sirssm': {
        'means': [],
        'maxxs': []
    }
}

figs = []
axs = []

key = jr.PRNGKey(0)

for i, ssm in enumerate(ssms):

    for trial in range(10):

        key, subkey = jr.split(key)
        mean, maxx, patch = get_acf_plot(subkey, ssm, 100, state_dims[i], 1, 200, target_vars[i])

        outputs[ssm.__name__.split('.')[-1]]['means'].append(mean)
        outputs[ssm.__name__.split('.')[-1]]['maxxs'].append(maxx)

    outputs[ssm.__name__.split('.')[-1]]['patch'] = patch

In [None]:
modname = 'lgssm'
mean = jnp.mean(jnp.array(outputs[modname]['means']).T, axis=1)
plt.plot(mean, label='mean')
plt.plot(jnp.array(outputs[modname]['means']).T)
plt.legend()
plt.show()

In [None]:
fig, ax = plt.subplots(4, 1, figsize=(4, 4))

for i, output in enumerate(outputs):

    means = jnp.array(outputs[output]['means']).T
    maxxs = jnp.array(outputs[output]['maxxs'])
    patch = outputs[output]['patch']

    ax[i].plot(means)
    ax[i].set_ylim(0, maxxs[i])
    ax[i].legend(handles=[patch], loc='upper right')

ax[3].set_xlabel('Lag')
plt.subplots_adjust(hspace=0.4)  # increase space between plots

fig, ax = plt.subplots(4, 1, figsize=(4, 4))

for i, output in enumerate(outputs):

    means = jnp.array(outputs[output]['means']).T
    patch = outputs[output]['patch']

    ax[i].plot(jnp.mean(means, axis=1))
    ax[i].legend(handles=[patch], loc='upper right')