## Attempt: Michaelis Menten with PriorCVAE
**Objective**: learn trajectories of the Michaelis Menten model with fixed initial values, but different parameter values.

In [1]:
import random as rnd

import numpy as np 
import matplotlib.pyplot as plt
import numpyro
from numpyro.infer import Predictive
from numpyro.diagnostics import hpdi
numpyro.set_host_device_count(4)
import jax
import optax
from jax import random
import jax.numpy as jnp


from priorCVAE_copy.models import MLPEncoder, MLPDecoder, VAE
from priorCVAE_copy.mcmc import run_mcmc_vae, vae_mcmc_inference_model
from priorCVAE_copy.datasets import MMDataset
from priorCVAE_copy.trainer import VAETrainer
from priorCVAE_copy.losses import SquaredSumAndKL

import jax.config as config
config.update("jax_enable_x64", True)

### Set Arguments

In [2]:
args = {
        "conditional": True,
        
        # architecture
        "input_dim" : 100,
        "hidden_dim": 80,
        "latent_dim": 60,
        
        # VAE training    
        "batch_size": 500,
        "num_iterations": 1000,
        "learning_rate": 1e-3,
        "vae_var": 1.,

        # MCMC inference
        "true_ls": 0.2,
        "num_warmup": 1000,
        "num_mcmc_samples": 4000,
        "num_chains": 4,
        "thinning": 1,
       }

### Generate Data

In [3]:
mm_sampler = MMDataset(n_data=args["input_dim"])
sample_x_train, sample_y_train, sample_k_train = mm_sampler.simulatedata(n_samples=1)



UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was body at /Users/lthao/Documents/GitHub/PINTS_priorCVAE/pints_jax/toy/stochastic/_markov_jump_model.py:128 traced for while_loop.
------------------------------
The leaked intermediate value was created on line /Users/lthao/Documents/GitHub/PINTS_priorCVAE/pints_jax/toy/stochastic/_markov_jump_model.py:139 (body). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/Users/lthao/Documents/GitHub/PINTS_priorCVAE/priorCVAE_copy/datasets/gp_dataset.py:125 (simulatedata)
/Users/lthao/Documents/GitHub/PINTS_priorCVAE/priorCVAE_copy/priors/MM.py:36 (MM)
/Users/lthao/Documents/GitHub/PINTS_priorCVAE/pints_jax/toy/stochastic/_markov_jump_model.py:301 (simulate)
/Users/lthao/Documents/GitHub/PINTS_priorCVAE/pints_jax/toy/stochastic/_markov_jump_model.py:169 (simulate_raw)
/Users/lthao/Documents/GitHub/PINTS_priorCVAE/pints_jax/toy/stochastic/_markov_jump_model.py:139 (body)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

In [None]:
sample_k_train.shape

In [None]:
import pints
import pints.toy as toy
import pints.toy.stochastic
import numpyro
import numpyro.distributions as npdist

In [None]:
initial = [1e4, 2e3, 2e4, 0]
model = toy.stochastic.MichaelisMentenModel(initial)
mm_predictive = Predictive(model.simulate(), num_samples=3)