## Vertical Throw Up

この例題は、以下のサイトで紹介されているコードを参考に作成しています。

https://docs.pymc.io/pymc-examples/examples/ode_models/ODE_API_introduction.html


## Install Package

In [None]:
!pip install numpyro
!pip install japanize_matplotlib

【重要】パッケージのインストール完了後に、ランタイムを再起動して下さい！

## Import Packages

In [None]:
import jax
import jax.numpy as jnp
import jax.experimental.ode as ode

import numpyro
import numpyro.distributions as dist
import arviz as az

import numpy as np

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

import japanize_matplotlib

In [None]:
plt.rcParams['font.size'] = 14

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

## Generate Data

In [None]:
def dz_dt(z, t, m, g, r):
    
    y = z[0]
    v = z[1]
    
    dy_dt = v
    dv_dt = (- m * g - r * v) / m
        
    return jnp.stack([dy_dt, dv_dt])

In [None]:
m = 2.0
g = 9.8
r = 0.4

t_true = jnp.arange(0, 10, 0.5).astype(float)
z_init = jnp.array([0, 50]).astype(float)

z = ode.odeint(dz_dt, z_init, t_true, m, g, r)

y_true = z[:, 0]
v_true = z[:, 1]

In [None]:
np.random.seed(0)

sigma = 5.0
n_observed = 12

t_observed = t_true[:n_observed]
y_observed = np.random.normal(y_true[:n_observed], sigma)

In [None]:
plt.figure(figsize=(10, 6))

plt.plot(t_true, y_true, '.-', label='真値')
plt.plot(t_observed, y_observed, 'o', label='観測値')

plt.xlabel('時間 [t]')
plt.ylabel('高度 [m]')

plt.legend();

## Define Model & Inference

In [None]:
def model(t, y_observed=None):
    
    sigma = numpyro.sample('sigma', dist.HalfNormal(10))
#    sigma = numpyro.sample('sigma', dist.HalfCauchy(1))
    gamma = numpyro.sample('gamma', dist.LogNormal(0, 1))
    
    y_init = numpyro.sample('y_init', dist.Normal(0, 10))
    v_init = numpyro.sample('v_init', dist.Normal(50, 10))
    
    z_init = jnp.stack([y_init, v_init])
    
    z = ode.odeint(dz_dt, z_init, t, m, g, gamma)
        
    numpyro.sample('y', dist.Normal(z[:, 0], sigma), obs=y_observed)

In [None]:
# MCMCの初期値を設定する場合
init_values = {'y_init':1.0, 'v_init':55.0, 'sigma':4.0, 'gamma':0.2}
init_strategy = numpyro.infer.init_to_value(values=init_values)

nuts = numpyro.infer.NUTS(model, target_accept_prob=0.95, init_strategy=init_strategy)

In [None]:
# MCMCの初期値を設定しない場合
nuts = numpyro.infer.NUTS(model)

In [None]:
mcmc = numpyro.infer.MCMC(nuts, num_warmup=2000, num_samples=1000, num_chains=4)

mcmc.run(jax.random.PRNGKey(0), t_observed, y_observed=y_observed)
mcmc_samples = mcmc.get_samples()

idata = az.from_numpyro(mcmc)

In [None]:
az.plot_trace(idata)
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata)

## Check Prediction

In [None]:
t_pred = jnp.arange(0, 10, 0.5).astype(float)

In [None]:
predictive = numpyro.infer.Predictive(model, mcmc_samples)
ppc_samples = predictive(jax.random.PRNGKey(2), t_pred)

y_pred = ppc_samples['y']

In [None]:
mu_pred = jnp.mean(y_pred, 0)
pi_pred = jnp.percentile(y_pred, jnp.array((5, 95)), 0)

In [None]:
plt.figure(figsize=(10, 6))

plt.plot(t_observed, y_observed, 'o', color='C1', label='観測値')
plt.plot(t_true, y_true, '--', color='C2', label='真値')

plt.plot(t_pred, mu_pred, '-.', color='C0', label='予測値 (平均)')
plt.fill_between(t_pred, pi_pred[0, :], pi_pred[1, :], color='C0', alpha=0.2)

plt.title('事後予測分布 (90%-Credible Interval)')
plt.xlabel('時間 [t]')
plt.ylabel('高度 [m]')

plt.legend();