<a href="https://colab.research.google.com/github/eohta/udemy-numpyro-basic/blob/main/03_trips/02_fit_zero_inflated_poisson.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 【ゼロ過剰ポアソン分布】出張回数

出張回数のデータをゼロ過剰ポアソン分布にあてはめて、情報量基準（WAIC）により比較する。

## Package Installation

In [None]:
!pip install numpyro

インストール完了後にランタイムを再スタートして下さい！

## Import Package

In [None]:
import numpyro
import numpyro.distributions as dist

import jax
import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

In [None]:
plt.rcParams['font.size'] = 12

## Load & Check Data

In [None]:
y = np.array([
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 0, 0, 2, 2, 0, 0, 0, 1, 0, 0, 1, 0, 3, 1, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 2, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0,
        3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0,
        0, 0])

In [None]:
print(y)

## Define Model & Inference

In [None]:
def model(y=None, num_data=0):
    
    psi = numpyro.sample('psi', dist.Uniform(low=0.0, high=1.0))
    lam = numpyro.sample('mu', dist.HalfNormal(10))
    
    with numpyro.plate('data', num_data):
    
        numpyro.sample('obs', dist.ZeroInflatedPoisson(psi, rate=lam), obs=y)

In [None]:
nuts = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=3000, num_chains=4)

mcmc.run(jax.random.PRNGKey(0), y=y, num_data=len(y))
mcmc_samples = mcmc.get_samples()

idata = az.from_numpyro(mcmc)

## Check MCMC-samples

In [None]:
az.plot_trace(idata);
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata)

In [None]:
az.plot_posterior(idata)
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

## Posterior Predictive Check

In [None]:
predictive = numpyro.infer.Predictive(model, mcmc_samples)

ppc_samples = predictive(jax.random.PRNGKey(1), num_data=len(y))

idata_ppc = az.from_numpyro(mcmc, posterior_predictive=ppc_samples)

In [None]:
az.plot_ppc(idata_ppc, num_pp_samples=1000);

## Information Criteria

In [None]:
az.waic(idata, scale='deviance')

In [None]:
az.loo(idata, scale='deviance')

## Compare Models

In [None]:
idata_poisson = az.from_netcdf('idata.nc')

In [None]:
dict_idata = {'Poisson':idata_poisson, 'Zero Inflated Poisson':idata}

In [None]:
df_waic = az.compare(dict_idata, ic='waic', scale='deviance')
df_waic

In [None]:
az.plot_compare(df_waic, figsize=(8, 3));

In [None]:
df_loo = az.compare(dict_idata, ic='loo', scale='deviance')
df_loo

In [None]:
az.plot_compare(df_loo, figsize=(8, 3));