## Change Point Detection


## Install Packages

In [None]:
!pip install numpyro

【重要】パッケージのインストール完了後に、ランタイムを再起動して下さい！

## Import Packages

In [None]:
import numpyro
import numpyro.distributions as dist

import jax
import jax.numpy as jnp

import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from numpyro.contrib.control_flow import scan

In [None]:
plt.rcParams['font.size'] = 14
plt.rcParams['figure.figsize'] = (8, 4)

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

## Generate Data

In [None]:
np.random.seed(1)

y0 = np.random.normal(0, 0.5, 100)
y1 = np.random.normal(1, 0.5, 100)
y2 = np.random.normal(-1, 0.5, 100)
y3 = np.random.normal(0, 0.5, 100)

y_obs = np.concatenate([y0, y1, y2, y3])
t_obs = np.arange(len(y_obs))

plt.plot(t_obs, y_obs)
plt.xlabel('Time');

## Define Model & Inference

In [None]:
y_obs = jax.numpy.array(y_obs)

In [None]:
def model(y_obs):
    
    num_steps = len(y_obs)
    
    sd_b = numpyro.sample('sd_b', dist.HalfNormal(10))
    
    c = numpyro.sample('c', dist.Normal(0, 10))
    b = numpyro.sample('b', dist.Cauchy(0, sd_b), sample_shape=(num_steps,))
    u = numpyro.deterministic('u', c + jnp.cumsum(b))
    
    sd_y = numpyro.sample('sd_y', dist.HalfNormal(10))
    
    numpyro.sample('y', dist.Normal(u, sd_y), obs=y_obs)

In [None]:
#hmc = numpyro.infer.HMC(model, target_accept_prob=0.95)
nuts = numpyro.infer.NUTS(model, target_accept_prob=0.95)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=5000, num_samples=5000, num_chains=4)

mcmc.run(jax.random.PRNGKey(1), y_obs)

trace = mcmc.get_samples()

In [None]:
idata = az.from_numpyro(mcmc)

In [None]:
az.plot_trace(idata, var_names=['c', 'sd_b', 'sd_y'])
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata, var_names=['c', 'sd_b', 'sd_y'])

## Check Latent Variable (u)

In [None]:
u_sampled = trace['u']

In [None]:
mu = jnp.mean(u_sampled, 0)
pi = jnp.percentile(u_sampled, jnp.array([5, 95]), 0)

In [None]:
plt.plot(t_obs, y_obs, color='C0', alpha=0.5)

plt.plot(t_obs, mu, '-.', color='C1')
plt.fill_between(t_obs, pi[0, :], pi[1, :], color='C1', alpha=0.5)

plt.ylim([-2.5, 2.5])
plt.xlabel('Time');

In [None]:
plt.plot(u_sampled[:5, :].T)

plt.ylim([-2.5, 2.5])
plt.xlabel('Time');