## Lotka Volterra Model

この例題は、以下のサイトで紹介されているコードを参考に作成しています。

- https://docs.pymc.io/pymc-examples/examples/ode_models/ODE_with_manual_gradients.html
- https://num.pyro.ai/en/latest/examples/ode.html
- https://mc-stan.org/users/documentation/case-studies/lotka-volterra-predator-prey.html


## Install Packages

In [None]:
!pip install numpyro
!pip install japanize_matplotlib

【重要】パッケージのインストール完了後に、ランタイムを再起動して下さい！

## Import Packages

In [None]:
import jax
import jax.numpy as jnp
import jax.experimental.ode as ode

import numpyro
import numpyro.distributions as dist
import numpyro.examples.datasets as datasets

import arviz as az
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

import japanize_matplotlib

In [None]:
plt.rcParams['font.size'] = 14

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

## Load Data

In [None]:
year = np.arange(1900, 1921, 1)

hare = np.array([30.0, 47.2, 70.2, 77.4, 36.3, 20.6, 18.1, 21.4, 22.0, 25.4,
                 27.1, 40.3, 57.0, 76.6, 52.3, 19.5, 11.2, 7.6, 14.6, 16.2, 24.7])

lynx = np.array([4.0, 6.1, 9.8, 35.2, 59.4, 41.7, 19.0, 13.0, 8.3, 9.1, 7.4,
                 8.0, 12.3, 19.5, 45.7, 51.1, 29.7, 15.8, 9.7, 10.1, 8.6])

In [None]:
fig = plt.figure(figsize=(12, 4))

plt.plot(year, hare, 'o-', label='hare (カンジキウサギ)')
plt.plot(year, lynx, 'o-', label='lynx (カナダオオヤマネコ)')

plt.title('捕獲頭数記録（1900年〜1921年）')
plt.ylabel('捕獲頭数 [千頭]')
plt.xticks(year, rotation=45)
plt.legend();

## Convert Data Format

In [None]:
N = len(year)
t = jnp.arange(N).astype(float)

In [None]:
data = np.hstack([np.expand_dims(hare, 1), np.expand_dims(lynx, 1)])

## Lotka-Voltera Model

In [None]:
def dz_dt(z, t, a, b, c, d):
        
    u = z[0]
    v = z[1]
        
    du_dt = (a - b * v) * u
    dv_dt = (-c + d * u) * v
    
    return jnp.stack([du_dt, dv_dt])

In [None]:
# 適当な係数と初期値で微分方程式を解いてみる

a = 0.547
b = 0.027
c = 0.799
d = 0.024

z_init = data[0, :]

z = ode.odeint(dz_dt, z_init, t, a, b, c, d, rtol=1e-6, atol=1e-5, mxstep=1000)

In [None]:
fig = plt.figure(figsize=(12, 4))

plt.plot(year, z[:, 0], 'o-', label='hare (カンジキウサギ)')
plt.plot(year, z[:, 1], 'o-', label='lynx (カナダオオヤマネコ)')


plt.title('Lotka-Voltera 方程式の解')
plt.ylabel('捕獲頭数 [千頭]')
plt.xticks(t+1900, rotation=45)
plt.legend();

## Define Model & Inference

In [None]:
def model(t, y=None):
    
    z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1), sample_shape=(2,))
    
    a = numpyro.sample('a', dist.TruncatedNormal(low=0, loc=1.0, scale=0.5))
    b = numpyro.sample('b', dist.TruncatedNormal(low=0, loc=0.05, scale=0.05))
    c = numpyro.sample('c', dist.TruncatedNormal(low=0, loc=1.0, scale=0.5))
    d = numpyro.sample('d', dist.TruncatedNormal(low=0, loc=0.05, scale=0.05))
        
    z = ode.odeint(dz_dt, z_init, t, a, b, c, d, rtol=1e-6, atol=1e-5, mxstep=1000)
    
    sigma = numpyro.sample("sigma", dist.LogNormal(-1, 1), sample_shape=(2,))
    
    numpyro.sample("y", dist.LogNormal(jnp.log(z), sigma), obs=y)

In [None]:
# MCMC の初期値を手動で設定しない場合
nuts = numpyro.infer.NUTS(model)

In [None]:
# MCMCの初期値を手動で設定する場合
init_values = {'a':1.0, 'b':0.05, 'c':1.0, 'd':0.05, 'z_init':data[0, :], 'sigma':0.5}
init_strategy = numpyro.infer.init_to_value(values=init_values)

nuts = numpyro.infer.NUTS(model, target_accept_prob=0.95, init_strategy=init_strategy)

In [None]:
mcmc = numpyro.infer.MCMC(nuts, num_warmup=2000, num_samples=1000, num_chains=4)

mcmc.run(jax.random.PRNGKey(0), t, y=data)
mcmc_samples = mcmc.get_samples()

idata = az.from_numpyro(mcmc)

In [None]:
az.plot_trace(idata)
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata)

## Check Prediction

In [None]:
t_pred = jnp.arange(N+10).astype(float)

In [None]:
predictive = numpyro.infer.Predictive(model, mcmc_samples)
ppc_samples = predictive(jax.random.PRNGKey(2), t_pred)

y_pred = ppc_samples['y']

In [None]:
mu_pred = jnp.mean(y_pred, 0)
pi_pred = jnp.percentile(y_pred, jnp.array((5, 95)), 0)

In [None]:
fig = plt.figure(figsize=(12, 5))

plt.plot(t+1900, data[:, 0], 'o', color='C0', label='カンジキウサギ（観測値）')
plt.plot(t+1900, data[:, 1], 'o', color='C1', label='カナダオオヤマネコ（観測値）')

plt.plot(t_pred+1900, mu_pred[:, 0], '--', color='C0', label='カンジキウサギ（予測値）')
plt.plot(t_pred+1900, mu_pred[:, 1], '--', color='C1', label='カナダオオヤマネコ（予測値）')

plt.fill_between(t_pred+1900, pi_pred[0, :, 0], pi_pred[1, :, 0], color='C0', alpha=0.2)
plt.fill_between(t_pred+1900, pi_pred[0, :, 1], pi_pred[1, :, 1], color='C1', alpha=0.2)

plt.ylim([0, 160])

plt.title('事後予測分布 (90%-Credible Interval)')
plt.ylabel('捕獲頭数 [千頭]')
plt.legend();