<a href="https://colab.research.google.com/github/eohta/udemy-numpyro-basic/blob/main/07_fruits/01_fit_glmm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 【一般化線形混合モデル】果物の収穫量

果物の収穫量のデータに対してポアソン回帰（ランダム効果あり）のモデルを適用してみる。

## Package Installation

In [None]:
!pip install numpyro

インストール完了後にランタイムを再スタートして下さい！

## Import Package

In [None]:
import numpyro
import numpyro.distributions as dist

import jax
import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

In [None]:
plt.rcParams['font.size'] = 12

## Load & Check Data

In [None]:
data = pd.DataFrame({
    
    'num_fruits':[445, 378, 383, 406, 458, 433, 568, 521, 446, 437, 327, 508, 342,
       385, 363, 495, 347, 419, 380, 582]
})

In [None]:
data

In [None]:
sns.displot(data=data, x='num_fruits', kde=True, binwidth=50);

plt.xlabel('Number of Fruits');

In [None]:
sns.displot(data=data, x='num_fruits', kde=True, binwidth=25);

plt.xlabel('Number of Fruits');

In [None]:
print('平均：{:.2f}'.format(data['num_fruits'].mean()))
print('分散：{:.2f}'.format(data['num_fruits'].var()))

## Define Model & Inference

In [None]:
y = data['num_fruits'].values

num_trees = len(y)

In [None]:
def model(y=None, num_data=0):
    
    b = numpyro.sample('b', dist.Normal(0, 10))
    
    s = numpyro.sample('s', dist.HalfCauchy(5))
    r = numpyro.sample('r', dist.Normal(0, s), sample_shape=(num_trees,))
    
    theta = b + r
    
    mu = jax.numpy.exp(theta)
        
    with numpyro.plate('data', num_data):
        
        numpyro.sample('obs', dist.Poisson(mu), obs=y)
    

In [None]:
nuts = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=3000, num_chains=4)

mcmc.run(jax.random.PRNGKey(0), y=y, num_data=len(y))
mcmc_samples = mcmc.get_samples()

idata = az.from_numpyro(mcmc)

In [None]:
idata

In [None]:
az.plot_trace(idata)
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata)

## Posterior Predictive Check

In [None]:
predictive = numpyro.infer.Predictive(model, mcmc_samples)

ppc_samples = predictive(jax.random.PRNGKey(1), num_data=len(y))

idata_ppc = az.from_numpyro(mcmc, posterior_predictive=ppc_samples)

In [None]:
ppc_samples['obs'].shape

In [None]:
fig = plt.figure(figsize=(12, 24))

for k in range(20):
    
    ax = fig.add_subplot(10, 2, k+1)
    
    az.plot_dist(ppc_samples['obs'][:, k])
    ax.axvline(y[k], color='r', linestyle='dashed')
    ax.set_title('ID = {}'.format(k))

plt.tight_layout()

## Check Random Effects

In [None]:
az.plot_violin(idata.posterior['r'], grid=(1, num_trees), figsize=(12, 4));

## Check New Feature

In [None]:
data_updated = pd.DataFrame({
    
    'num_fruits':[445, 378, 383, 406, 458, 433, 568, 521, 446, 437, 327, 508, 342, 385, 363, 495, 347, 419, 380, 582],
    'span':[4.5, 3.8, 3.7, 4.2, 4.5, 4.1, 4.2, 4.5, 3.9, 4. , 3.2, 3.9, 3. , 4.1, 3.7, 4.9, 4.3, 4. , 3.1, 4.8]
})

In [None]:
data_updated.head()

In [None]:
span = data_updated['span']

In [None]:
r_mean = mcmc_samples['r'].mean(axis=0)

In [None]:
fig = plt.figure(figsize=(6, 6))

sns.scatterplot(x=span, y=r_mean, s=100)

plt.xlabel('Span')
plt.ylabel('Random Effects');