<a href="https://colab.research.google.com/github/eohta/udemy-numpyro-basic/blob/main/08_porkbuns/01_fit_glmm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 【一般化線形混合モデル】肉まんの販売個数

とあるエリアのコンビニエンスストア 10店舗の肉まんの販売個数のデータをモデル化する。特徴量としては気温データのみが与えられているが、その他の特徴量はないため、店舗ごとの立地など違いを「ランダム効果」として、ポアソン分布でモデル化してみる。

## Package Installation

In [None]:
!pip install numpyro

インストール完了後にランタイムを再スタートして下さい！

## Import Package

In [None]:
import numpyro
import numpyro.distributions as dist

import jax
import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

In [None]:
plt.rcParams['font.size'] = 12
plt.rcParams['figure.figsize'] = [8, 6]

## Load Data

In [None]:
data = pd.DataFrame({
    
    'temperature':[13.2,  2.7,  8.7,  4.5, 10.8, 13.8,  6.2,  5.8,  2.5,  7. , 13.2,
        2.7,  8.7,  4.5, 10.8, 13.8,  6.2,  5.8,  2.5,  7. , 13.2,  2.7,
        8.7,  4.5, 10.8, 13.8,  6.2,  5.8,  2.5,  7. , 13.2,  2.7,  8.7,
        4.5, 10.8, 13.8,  6.2,  5.8,  2.5,  7. , 13.2,  2.7,  8.7,  4.5,
       10.8, 13.8,  6.2,  5.8,  2.5,  7. , 13.2,  2.7,  8.7,  4.5, 10.8,
       13.8,  6.2,  5.8,  2.5,  7. , 13.2,  2.7,  8.7,  4.5, 10.8, 13.8,
        6.2,  5.8,  2.5,  7. , 13.2,  2.7,  8.7,  4.5, 10.8, 13.8,  6.2,
        5.8,  2.5,  7. , 13.2,  2.7,  8.7,  4.5, 10.8, 13.8,  6.2,  5.8,
        2.5,  7. , 13.2,  2.7,  8.7,  4.5, 10.8, 13.8,  6.2,  5.8,  2.5,
        7. ],
    
    'num_sold':[0, 5, 2, 1, 0, 0, 3, 4, 5, 3, 0, 1, 1, 2, 3, 2, 3, 3, 2, 0, 1, 1,
       0, 0, 1, 0, 3, 0, 2, 2, 2, 3, 5, 2, 2, 2, 3, 1, 4, 3, 2, 6, 1, 4,
       4, 1, 5, 6, 4, 2, 3, 3, 0, 3, 1, 1, 2, 4, 3, 1, 1, 3, 1, 1, 3, 2,
       7, 3, 5, 3, 6, 3, 3, 3, 1, 0, 5, 1, 2, 4, 0, 8, 2, 5, 4, 3, 7, 8,
       7, 3, 2, 3, 0, 2, 3, 1, 1, 1, 3, 1],
    
    'store_id':[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,
       6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8,
       8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]
})

In [None]:
data.head(20)

## Preprocess & Visualize Data

In [None]:
x = data['temperature'].values
y = data['num_sold'].values

store_id = data['store_id'].values

num_stores = np.max(store_id) + 1

In [None]:
num_days = 10

temperature = x[:num_days]
num_sold = y.reshape(num_stores, -1)

In [None]:
fig = plt.figure(figsize=(10, 6))

sns.heatmap(num_sold, annot=True, cmap='jet')

plt.title('Number of Sales')
plt.xlabel('Date ID')
plt.ylabel('Store ID')

plt.tight_layout()

In [None]:
fig = plt.figure(figsize=(8, 6))

sns.scatterplot(x=temperature, y=num_sold.sum(axis=0), s=100)

plt.title('Temperature vs Sales')

plt.xlabel('Temperature')
plt.ylabel('Number of Sale (Sum of all stores)');

## Scale Data

In [None]:
x_mu = np.mean(x)
x_sd = np.std(x)

x_scaled = (x - x_mu) / x_sd

## Define Model & Inference

In [None]:
def model(x_scaled=None, store_id=None, y=None, num_data=0):
    
    a = numpyro.sample('a', dist.Normal(0, 10))
    b = numpyro.sample('b', dist.Normal(0, 10))
    
    s = numpyro.sample('s', dist.HalfCauchy(5))
    r = numpyro.sample('r', dist.Normal(0, s), sample_shape=(num_stores,))
    
    theta = a * x_scaled + r[store_id] + b
    
    mu = jax.numpy.exp(theta)
        
    with numpyro.plate('data', num_data):
        
        numpyro.sample('obs', dist.Poisson(mu), obs=y)
    

In [None]:
nuts = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=3000, num_chains=4)

mcmc.run(jax.random.PRNGKey(0), x_scaled=x_scaled, store_id=store_id, y=y, num_data=len(y))
mcmc_samples = mcmc.get_samples()

idata = az.from_numpyro(mcmc)

In [None]:
az.plot_trace(idata)
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata)

## Posterior Predictive Check-1

In [None]:
predictive = numpyro.infer.Predictive(model, mcmc_samples)

ppc_samples = predictive(jax.random.PRNGKey(1), x_scaled=x_scaled, store_id=store_id, num_data=len(y))

idata_ppc = az.from_numpyro(mcmc, posterior_predictive=ppc_samples)

In [None]:
ppc_samples['obs'].shape

In [None]:
fig = plt.figure(figsize=(12, 24))

for k in range(30):
    
    ax = fig.add_subplot(10, 3, k+1)
    
    az.plot_dist(ppc_samples['obs'][:, k])
    ax.axvline(y[k], color='r', linestyle='dashed')
    ax.set_title('Temperature = {}, Store ID = {}'.format(x[k], store_id[k]))

plt.tight_layout()

## Posterior Predictive Check-2

In [None]:
x_new = np.arange(0, 20)

x_scaled_new = (x_new - x_mu) / x_sd

In [None]:
store_id_new = 4
store_id_dup = np.ones_like(x_new, dtype=int) * store_id_new

In [None]:
predictive = numpyro.infer.Predictive(model, mcmc_samples)

ppc_samples = predictive(jax.random.PRNGKey(1), x_scaled=x_scaled_new, store_id=store_id_dup, num_data=len(x_scaled_new))

idata_ppc = az.from_numpyro(posterior_predictive=ppc_samples)

In [None]:
obs_pred = idata_ppc.posterior_predictive['obs']

In [None]:
obs_pred.shape

In [None]:
obs_mean = obs_pred.mean(axis=0).mean(axis=0)

In [None]:
obs_mean.shape

In [None]:
# 事後予測分布の表示
az.plot_hdi(x_new, obs_pred)
plt.plot(x_new, obs_mean)

# 観測データの表示
idx = data['store_id'] == store_id_new
plt.scatter(data[idx]['temperature'], data[idx]['num_sold'], s=100)

plt.xticks(x_new)
plt.xlabel('Temperature')
plt.ylabel('Number of Sales')
plt.title('Store ID = {}'.format(store_id_new));

## Check Random Effects

In [None]:
az.plot_violin(idata.posterior['r'], grid=(1, num_stores), figsize=(12, 4));

## Compare with True Values

人工データのランダム効果は真の値がわかっているので、推定された値と比較をしてみる。

In [None]:
r_true = np.array([-0.22, -0.79, -1.05, -0.02, 0.25, 0.01, 0.11, -0.11, 0.46, -0.27])

In [None]:
r_mean = mcmc_samples['r'].mean(axis=0)

In [None]:
fig = plt.figure(figsize=(10, 4))

plt.plot(np.arange(num_stores), r_true, 'o-', markersize=8, label='True Values')
plt.plot(r_mean, 'o-', markersize=8, label='Estimated Values')
plt.xticks(np.arange(num_stores))
plt.xlabel('Store ID')
plt.ylabel('Random Effects')
plt.legend()

plt.tight_layout()

In [None]:
fig = plt.figure(figsize=(8, 6))

sns.scatterplot(x=r_mean, y=r_true, s=100)

plt.xlabel('True Value')
plt.ylabel('Estimated Value');