<a href="https://colab.research.google.com/github/eohta/udemy-numpyro-basic/blob/main/04_babies/03_linear_regression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 【線形回帰】新生児の体重

新生児の性別と妊娠期間の両方を組み込んだ線形回帰モデルを構成してみる。このモデルでは更に男児と女児の回帰係数が異なる傾きを持つものとし、以前のモデルとの比較を行っている。

## Package Installation

In [None]:
!pip install numpyro

インストール完了後にランタイムを再スタートして下さい！

## Import Package

In [None]:
import numpyro
import numpyro.distributions as dist

import jax
import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

In [None]:
plt.rcParams['font.size'] = 12
plt.rcParams['figure.figsize'] = [8, 6]

## Load & Check Data

In [None]:
data = pd.DataFrame({
    
    'weeks':[36, 38, 39, 41, 37, 38, 40, 40, 38, 41, 38, 38, 40, 40, 36, 39, 40,
       35, 39, 38, 37, 43, 39, 39, 40, 40, 37, 38, 39, 38, 40, 40, 42, 37,
       41, 38, 37, 39, 40, 40, 38, 41, 38, 37, 39, 39, 43, 38, 38, 38],
    
    'weight':[2980, 2707, 3049, 3429, 2500, 2845, 3071, 3435, 3058, 3123, 3215,
       2902, 3015, 2983, 2727, 3121, 3114, 2511, 3327, 2864, 2749, 3621,
       2860, 3074, 3234, 3083, 2797, 3025, 3129, 2990, 3035, 2990, 3513,
       2687, 3380, 2863, 2715, 3012, 3083, 2938, 2837, 3455, 3175, 2646,
       2889, 2975, 3474, 3052, 3167, 2762],
    
    'gender':[0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0,
       0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0,
       1, 1, 1, 0, 0, 1]
})

In [None]:
data.head(10)

## Preprocess & Scale Data

In [None]:
x = data['weeks'].values
y = data['weight'].values

gender = data['gender'].values

In [None]:
# 標準化

x_mu = np.mean(x)
x_sd = np.std(x)

x_scaled = (x - x_mu) / x_sd

y_mu = np.mean(y)
y_sd = np.std(y)

y_scaled = (y - y_mu) / y_sd

## Define Model & Inference

In [None]:
def model(x_scaled=None, y_scaled=None, gender=None, num_data=0):
    
    a = numpyro.sample('a', dist.Normal(0, 10))
    b = numpyro.sample('b', dist.Normal(0, 10), sample_shape=(2,))
    
    mu = a * x_scaled + b[gender]
    
    sd = numpyro.sample('sd', dist.HalfCauchy(5))
    
    with numpyro.plate('data', num_data):
        
        numpyro.sample('obs', dist.Normal(mu, sd), obs=y_scaled)
        
    b_diff = numpyro.deterministic('b_diff', b[1] - b[0])

In [None]:
nuts = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=3000, num_chains=4)

mcmc.run(jax.random.PRNGKey(0), x_scaled=x_scaled, y_scaled=y_scaled, gender=gender, num_data=len(y_scaled))
mcmc_samples=mcmc.get_samples()

idata = az.from_numpyro(mcmc)

In [None]:
idata

In [None]:
az.plot_trace(idata)
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata)

In [None]:
az.plot_posterior(idata, var_names=['b_diff'], ref_val=0);

## Information Criteria

In [None]:
az.waic(idata)

In [None]:
az.loo(idata)

## Compare Models

In [None]:
idata_imported = az.from_netcdf('idata.nc')

In [None]:
idata_dict = {'Slope is same':idata, 'Slope is different':idata_imported}

In [None]:
df_waic = az.compare(idata_dict, ic='waic', scale='deviance')
df_waic

In [None]:
az.plot_compare(df_waic, figsize=(8, 4));

In [None]:
df_loo = az.compare(idata_dict, ic='loo', scale='deviance')
df_loo

In [None]:
az.plot_compare(df_loo, figsize=(8, 4));

## Posterior Predictive Check

In [None]:
predictive = numpyro.infer.Predictive(model, mcmc_samples)

In [None]:
x_new = np.arange(35, 44)
x_scaled_new = (x_new - x_mu) / x_sd

In [None]:
# Case : Baby Boy
gender_new = np.zeros_like(x_new, dtype=int)

In [None]:
ppc_samples = predictive(jax.random.PRNGKey(1), x_scaled=x_scaled_new, gender=gender_new, num_data=len(x_scaled_new))

idata_ppc = az.from_numpyro(mcmc, posterior_predictive=ppc_samples)

In [None]:
obs_pred = idata_ppc.posterior_predictive['obs']
obs_mean = obs_pred.mean(axis=0).mean(axis=0)

In [None]:
obs_pred.shape

In [None]:
# 事後予測分布の表示
az.plot_hdi(x_scaled_new, obs_pred)
plt.plot(x_scaled_new, obs_mean)

# 観測データの表示
sns.scatterplot(x=x_scaled, y=y_scaled, hue=gender, s=80)

plt.xlabel('Period (Standardized)')
plt.ylabel('Weight (Standardized)')
plt.title('Posterior Prediction : Baby Boy');

In [None]:
# Case : Baby Girl
gender_new = np.ones_like(x_new, dtype=int)

In [None]:
ppc_samples = predictive(jax.random.PRNGKey(1), x_scaled=x_scaled_new, gender=gender_new, num_data=len(x_scaled_new))

idata_ppc = az.from_numpyro(mcmc, posterior_predictive=ppc_samples)

In [None]:
obs_pred = idata_ppc.posterior_predictive['obs']
obs_mean = obs_pred.mean(axis=0).mean(axis=0)

In [None]:
# 事後予測分布の表示
az.plot_hdi(x_scaled_new, obs_pred)
plt.plot(x_scaled_new, obs_mean)

# 観測データの表示
sns.scatterplot(x=x_scaled, y=y_scaled, hue=gender, s=80)

plt.xlabel('Period (Standardized)')
plt.ylabel('Weight (Standardized)')
plt.title('Posterior Prediction : Baby Girl');