## 変化点検出

この例題は以下のサイトで紹介されているコードを参考に作成しています。

- https://docs.pymc.io/notebooks/getting_started.html

もともとは、「炭鉱での事故発生頻度と安全基準の改定の関連性を調べる」という趣旨の解析ですが、このノートブックではデータを人工データに変更しています。

## Install Package

In [None]:
!pip install numpyro

【重要】パッケージのインストール完了後に、ランタイムを再起動して下さい！

## Import Package

In [None]:
import numpyro
import numpyro.distributions as dist

import arviz as az

import jax
import jax.numpy as jnp

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
plt.rcParams['font.size'] = 16

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

## Generate Data

In [None]:
r0 = 2
r1 = 1

years = np.arange(1851, 1970)

switchpoint = 1890

r = np.where(years <= switchpoint, r0, r1)

count = dist.Poisson(r).sample(jax.random.PRNGKey(0))

In [None]:
fig = plt.figure(figsize=(12, 4))

plt.plot(years, count, 'o', markersize=8, alpha=0.4)

plt.xlabel('Year')
plt.ylabel('Number of accidents');

## Define Model & Inference

In [None]:
years = jnp.array(years)
count = jnp.array(count)

In [None]:
def model(years, count=None):
    
    n = len(years)
    
    p = jnp.ones(n) / n
    k = numpyro.sample('switchpoint_index', dist.Categorical(p))
    
    switchpoint = numpyro.deterministic('switchpoint', years[k])
    
    r0 = numpyro.sample('r0', dist.HalfNormal(10))
    r1 = numpyro.sample('r1', dist.HalfNormal(10))

    r_switched = numpyro.deterministic('r_switched', jnp.where(years <= switchpoint, r0, r1))
    
    numpyro.sample('obs', dist.Poisson(r_switched), obs=count)

In [None]:
nuts = numpyro.infer.NUTS(model, target_accept_prob=0.99)
kernel = numpyro.infer.DiscreteHMCGibbs(nuts)

mcmc = numpyro.infer.MCMC(kernel, num_warmup=500, num_samples=3000, num_chains=4)

mcmc.run(jax.random.PRNGKey(1), years, count)
trace = mcmc.get_samples()

In [None]:
idata = az.from_numpyro(mcmc)

In [None]:
axes = az.plot_trace(idata, var_names=['r0', 'r1', 'switchpoint'])

plt.setp(axes[2, 0].get_xticklabels(), rotation=45, ha='right')
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata, var_names=['r0', 'r1', 'switchpoint'])

## Check Results

In [None]:
switchpoint_mean = trace['switchpoint'].mean(axis=0)
r_switched_mean = trace['r_switched'].mean(axis=0)

In [None]:
hdi = az.hdi(idata, hdi_prob=0.95)
switchpoint_hdi = hdi['switchpoint']

In [None]:
plt.figure(figsize=(12, 6))

plt.plot(years, count, 'o', markersize=8, alpha=0.6)
plt.ylabel('Number of accidents')
plt.xlabel('Year')

plt.vlines(switchpoint_mean, count.min(), count.max(), color='C1')

plt.fill_betweenx(y=[count.min(), count.max()], x1=switchpoint_hdi[0], x2=switchpoint_hdi[1], alpha=0.5, color='C1')

plt.plot(years, r_switched_mean, 'k--', linewidth=2);