# とにかく NumPyro に慣れるためのノート

In [None]:
import numpyro
import numpyro.distributions as dist

import jax
import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
%config InlineBackend.figure_format = 'retina'

In [None]:
# cpu がどんな感じで動いているか？
jax.devices()

In [None]:
data = pd.DataFrame({

    'yield':[7, 13, 13, 11, 5, 6, 8, 11, 10, 11, 11, 11, 11, 14, 8, 15, 10, 9, 13, 12, 8, 15, 7, 11, 5, 11,
             15, 10, 13, 9, 8, 12, 13, 6, 8, 5, 13, 8, 5, 10, 18, 9, 7, 12, 11, 5, 9, 10, 13, 13, 7, 12, 8, 16, 10,
             6, 12, 13, 10, 12, 9, 7, 12, 11, 8, 15, 13, 11, 9, 17, 11, 10, 15, 19, 11, 13, 12, 9, 10, 10],
    'group':[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
             1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

})

In [None]:
y = data['yield'].values

In [None]:
np.mean(y)

In [None]:
np.var(y)

### ポイント
平均と分散が同じくらいだからおそらくポワソン分布っぽいね

In [None]:
fig, axes = plt.subplots(1, 2, figsize = (12, 4))

sns.histplot(y, binwidth = 1, ax = axes[0])
axes[0].set_xlabel('Yield')

sns.histplot(y, binwidth = 3, ax = axes[1])
axes[1].set_xlabel('Yield')

plt.tight_layout()

## モデル定義

In [None]:
def model(y = None, num_data = 0):
    # パラメーターの事前分布
    mu = numpyro.sample('mu', dist.HalfNormal(10))
    # 観測データに基づく尤度の定義
    with numpyro.plate('data', num_data):
        numpyro.sample('obs', dist.Poisson(mu), obs = y)

In [None]:
nuts = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts, num_warmup = 500, num_samples = 3000, num_chains = 4)

mcmc.run(jax.random.PRNGKey(0), y = y, num_data = len(y))

In [None]:
mcmc_samples = mcmc.get_samples()

In [None]:
mcmc_samples

In [None]:
# inference data object に変換
idata = az.from_numpyro(mcmc)

In [None]:
idata

In [None]:
az.plot_trace(idata)

In [None]:
az.summary(idata)

In [None]:
az.plot_posterior(idata)

In [None]:
idata.posterior['mu']

### 事後予測チェック PPC
モデルが上手くできているか？

In [None]:
predictive = numpyro.infer.Predictive(model, mcmc_samples)

ppc_samples = predictive(jax.random.PRNGKey(1), num_data = len(y))

In [None]:
idata_ppc = az.from_numpyro(mcmc, posterior_predictive = ppc_samples)

In [None]:
idata_ppc

idata_ppc ができたら確認していく

In [None]:
ppc_mean = ppc_samples['obs'].mean(axis = 1)
ppc_var = ppc_samples['obs'].var(axis = 1)

In [None]:
fig, axes = plt.subplots(1, 2, figsize = (12, 4))

sns.kdeplot(ppc_mean, ax = axes[0])
axes[0].axvline(y.mean(), color = 'r', linestyle = 'dashed')
axes[0].set_xlabel('stats = mean')

sns.kdeplot(ppc_var, ax = axes[1])
axes[1].axvline(y.var(), color = 'r', linestyle = 'dashed')
axes[1].set_xlabel('stats = var')

In [None]:
az.plot_ppc(idata_ppc, kind = 'kde', num_pp_samples = 50, figsize = (12, 4))

In [None]:
az.plot_ppc(idata_ppc, kind = 'kde', num_pp_samples = 3000, figsize = (12, 4))

実際のデータ（黒線ヒストグラム）が、推定したモデルから得られたサンプリング（3000本の青線ヒストグラム）に埋もれているから大丈夫そうですね、という話

### A/B 比較について

In [None]:
sns.displot(data = data, x = 'yield', hue = 'group', col = 'group')

In [None]:
data.groupby('group').mean()

In [None]:
y = data['yield'].values
g = data['group'].values

In [None]:
def model(y = None, g = None, num_data = 0):
    mu = numpyro.sample('mu', dist.HalfNormal(10), sample_shape = (2,))

    mu_dup = mu[g]

    with numpyro.plate('data', num_data):
        numpyro.sample('obs', dist.Poisson(mu_dup), obs = y)

    mu_diff = numpyro.deterministic('mu_diff', mu[1] - mu[0])

In [None]:
nuts = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts, num_warmup = 500, num_samples = 3000, num_chains = 4)

mcmc.run(jax.random.PRNGKey(0), y = y, g = g, num_data = len(y))

In [None]:
mcmc_samples = mcmc.get_samples()

In [None]:
idata = az.from_numpyro(mcmc)

In [None]:
az.plot_trace(idata)

### A/Bのパラメーターを比較する

In [None]:
az.plot_posterior(idata, var_names = ['mu_diff'], ref_val = 0) 
plt.xlabel('$\lambda_1 - \lambda_0$', fontsize = 18)

### ポイント
ここでベイズの旨み。有意水準だと5％ないので有意と言えない、となってしまうが、実際にはこんな形の差になりそうだという分布が出てくるので直感的に理解しやすい

# ゼロ過剰ポワソン

In [None]:
data_2 = pd.DataFrame({
    'yield':[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 2, 0, 3, 0, 2, 0, 2,
             0, 2, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 3, 0,
             1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    'group':[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
             1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
})

In [None]:
y_2 = data_2['yield'].values
g_2 = data_2['group'].values

In [None]:
np.mean(y_2)

In [None]:
np.var(y_2)

In [None]:
fig, axes = plt.subplots(1, 2, figsize = (12, 4))

sns.histplot(y_2, binwidth = 0.5, ax = axes[0])
axes[0].set_xlabel('Yield')

sns.histplot(y_2, binwidth = 0.5, ax = axes[1])
axes[1].set_xlabel('Yield')

plt.tight_layout()

In [None]:
def model_2(y = None, g = None, num_data = 0):

    psi = numpyro.sample('psi', dist.Uniform(low = 0.0, high = 1.0), sample_shape = (2,))
    lam = numpyro.sample('mu', dist.HalfNormal(10), sample_shape = (2,))

    psi_dup = psi[g]
    lam_dup = lam[g]

    with numpyro.plate('data', num_data):
        numpyro.sample('obs', dist.ZeroInflatedPoisson(psi_dup, rate = lam_dup), obs = y)

    psi_diff = numpyro.deterministic('psi_diff', psi[1] - psi[0])
    lam_diff = numpyro.deterministic('lam_diff', lam[1] - lam[0])

In [None]:
nuts = numpyro.infer.NUTS(model_2)
mcmc_2 = numpyro.infer.MCMC(nuts, num_warmup = 500, num_samples = 3000, num_chains = 4)

mcmc_2.run(jax.random.PRNGKey(0), y = y_2, g = g_2, num_data = len(y_2))

In [None]:
mcmc_samples_2 = mcmc_2.get_samples()

In [None]:
idata_2 = az.from_numpyro(mcmc_2)

In [None]:
az.plot_trace(idata_2);

事後予測チェック

In [None]:
az.plot_posterior(idata_2, var_names = ['psi_diff'], ref_val = 0) 
plt.xlabel('$\psi_1 - \psi_0$', fontsize = 18)

In [None]:
az.plot_posterior(idata_2, var_names = ['lam_diff'], ref_val = 0) 
plt.xlabel('$\lambda_1 - \lambda_0$', fontsize = 18)

### 情報基準（WAIC）

In [None]:
az.waic(idata, scale = 'deviance')

In [None]:
az.waic(idata_2, scale = 'deviance')

In [None]:
dict_idata = {'Poisson':idata, 'ZeroInfPoisson': idata_2}

df_waic = az.compare(dict_idata, ic = 'waic', scale = 'deviance')

In [None]:
df_waic

In [None]:
az.plot_compare(df_waic, figsize = (8,3))

# 線形回帰

In [58]:
# 全体の値に対して線形回帰を行う場合

def model_linear(x_scaled = None, y_scaled = None, num_data = 0):
    a = numpyro.sample('a', dist.Normal(0, 10))
    b = numpyro.sample('b', dist.Normal(0, 10))

    mu = a * x_scale + b # 線形回帰式

    sa = numpyro.sample('sd', dist.HalfCauchy(5))

    with nupyro.plate('data', num_data):
        numpyro.sample('obs', dist.Normal(mu, sd), obs = y_scaled)

In [57]:
# ex. 男女の値に対して線形回帰を行う場合

def model_linear_2_groups(x_scaled = None, y_scaled = None, group = None, num_data = 0):
    a = numpyro.sample('a', dist.Normal(0, 10), sample_shape = (2,))
    b = numpyro.sample('b', dist.Normal(0, 10), sample_shape = (2,))

    mu = a[group] * x_scaled + b[group] # 線形回帰式

    sa = numpyro.sample('sd', dist.HalfCauchy(5))

    with nupyro.plate('plate_obs', num_data):
        numpyro.sample('obs', dist.Normal(mu, sd), obs = y_scaled)