<a href="https://colab.research.google.com/github/eohta/udemy-numpyro-basic/blob/main/06_smoke/01_poisson_regression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 【ポアソン回帰】冠動脈心疾患と喫煙習慣

冠動脈心疾患のデータを使って、ポアソン回帰を行ってみる。

## Package Installation

In [None]:
!pip install numpyro

インストール完了後にランタイムを再スタートして下さい！

## Import Package

In [None]:
import numpyro
import numpyro.distributions as dist

import jax
import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

In [None]:
plt.rcParams['font.size'] = 12
plt.rcParams['figure.figsize'] = [8, 6]

## Load  Data

In [None]:
data = pd.DataFrame({
    
    'agecat':[1, 2, 3, 4, 5, 1, 2, 3, 4, 5],
    'deaths':[ 32, 104, 206, 186, 102, 2, 12, 28, 28, 31],
    'population':[52407, 43248, 28612, 12663, 5317, 18790, 10673, 5710, 2585, 1462],
    'smoke':[1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
})

In [None]:
data

## Preprocess & Scale Data

In [None]:
agecat = data['agecat'].values
deaths = data['deaths'].values
smoke = data['smoke'].values

population = data['population'].values

In [None]:
agecat_mu = np.mean(agecat)
agecat_sd = np.std(agecat)

agecat_scaled = (agecat - agecat_mu) / agecat_sd

## Check Data

In [None]:
sns.scatterplot(x=agecat, y=deaths/population, hue=smoke, s=150)
plt.xticks(agecat)

plt.xlabel('Age Category')
plt.ylabel('Death Rate');

In [None]:
sns.scatterplot(x=agecat, y=np.log(deaths/population), hue=smoke, s=150)
plt.xticks(agecat)

plt.xlabel('Age Category')
plt.ylabel('log( Death Rate )');

## Poisson Regression

In [None]:
def model(smoke=None, agecat_scaled=None, deaths=None, population=None, num_data=0):
    
    b1 = numpyro.sample('b1', dist.Normal(0, 10))
    b2 = numpyro.sample('b2', dist.Normal(0, 10))
    b3 = numpyro.sample('b3', dist.Normal(0, 10))
    
    theta = b1 + b2 * smoke + b3 * agecat_scaled
    
    mu = jax.numpy.exp(theta) * population
        
    with numpyro.plate('data', num_data):
        
        numpyro.sample('obs', dist.Poisson(rate=mu), obs=deaths)
    

In [None]:
nuts = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=3000, num_chains=4)

mcmc.run(jax.random.PRNGKey(0), smoke=smoke, agecat_scaled=agecat_scaled, deaths=deaths, population=population, num_data=len(agecat_scaled))
mcmc_samples = mcmc.get_samples()

idata = az.from_numpyro(mcmc)

In [None]:
idata

In [None]:
az.plot_trace(idata)
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata)

In [None]:
az.plot_posterior(idata, var_names=['b2', 'b3'], ref_val=0);

## Posterior Predictive Check

In [None]:
predictive = numpyro.infer.Predictive(model, mcmc_samples)

ppc_samples = predictive(jax.random.PRNGKey(1), agecat_scaled=agecat_scaled, smoke=smoke, population=population, num_data=len(agecat_scaled))

idata_ppc = az.from_numpyro(mcmc, posterior_predictive=ppc_samples)

In [None]:
ppc_samples['obs'].shape

In [None]:
fig = plt.figure(figsize=(12, 12))

for k in range(10):
    
    ax = fig.add_subplot(5, 2, k+1)
    
    az.plot_dist(ppc_samples['obs'][:, k])
    ax.axvline(deaths[k], color='r', linestyle='dashed')
    ax.set_title('Age Category = {}, Smoke = {}'.format(agecat[k], smoke[k]))

plt.tight_layout()

## Export Inference Data

In [None]:
idata.to_netcdf('idata_base.nc')