<a href="https://colab.research.google.com/github/eohta/udemy-numpyro-basic/blob/main/02_plants/02_compare_mean_values.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 【ポアソン分布の比較】野菜の収穫量

異なる肥料の効果をポアソン分布の母数を比較することで評価してみる。

## Package Installation

In [None]:
!pip install numpyro

インストール完了後にランタイムを再スタートして下さい！

## Import Package

In [None]:
import numpyro
import numpyro.distributions as dist

import jax
import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

In [None]:
plt.rcParams['font.size'] = 12

## Load & Check Data

In [None]:
data = pd.DataFrame({
    
    'yield':[ 7, 13, 13, 11,  5,  6,  8, 11, 10, 11, 11, 11, 11, 14,  8, 15, 10,
        9, 13, 12,  8, 15,  7, 11,  5, 11, 15, 10, 13,  9,  8, 12, 13,  6,
        8,  5, 13,  8,  5, 10, 18,  9,  7, 12, 11,  5,  9, 10, 13, 13,  7,
       12,  8, 16, 10,  6, 12, 13, 10, 12,  9,  7, 12, 11,  8, 15, 13, 11,
        9, 17, 11, 10, 15, 19, 11, 13, 12,  9, 10, 10],
    
    'group':[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
})

In [None]:
data.head(20)

In [None]:
sns.displot(data=data, x='yield', hue='group', col='group');

In [None]:
data.groupby('group').mean()

## Define Model & Inference

In [None]:
y = data['yield'].values
g = data['group'].values

In [None]:
def model(y=None, g=None, num_data=0):
    
    mu = numpyro.sample('mu', dist.HalfNormal(10), sample_shape=(2,))
    
    mu_dup = mu[g]
    
    with numpyro.plate('data', num_data):
    
        numpyro.sample('obs', dist.Poisson(mu_dup), obs=y)
    
    mu_diff = numpyro.deterministic('mu_diff', mu[1] - mu[0])

In [None]:
nuts = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=3000, num_chains=4)

mcmc.run(jax.random.PRNGKey(0), y=y, g=g, num_data=len(y))
mcmc_samples = mcmc.get_samples()

idata = az.from_numpyro(mcmc)

In [None]:
az.plot_trace(idata)
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

## Compare Parameters

In [None]:
az.plot_posterior(idata, var_names=['mu_diff'], ref_val=0)

plt.xlabel('$\lambda_1 - \lambda_0$', fontsize=18);