<a href="https://colab.research.google.com/github/eohta/udemy-numpyro-basic/blob/main/04_babies/01_linear_regression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 【線形回帰】新生児の体重

新生児の性別を無視して、妊娠期間と体重の関係を線形回帰モデルにあてはめてみる。

## Package Installation

In [None]:
!pip install numpyro

インストール完了後にランタイムを再スタートして下さい！

## Import Package

In [None]:
import numpyro
import numpyro.distributions as dist

import jax
import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

In [None]:
plt.rcParams['font.size'] = 12
plt.rcParams['figure.figsize'] = [8, 6]

## Load & Check Data

In [None]:
data = pd.DataFrame({
    
    'weeks':[36, 38, 39, 41, 37, 38, 40, 40, 38, 41, 38, 38, 40, 40, 36, 39, 40,
       35, 39, 38, 37, 43, 39, 39, 40, 40, 37, 38, 39, 38, 40, 40, 42, 37,
       41, 38, 37, 39, 40, 40, 38, 41, 38, 37, 39, 39, 43, 38, 38, 38],
    
    'weight':[2980, 2707, 3049, 3429, 2500, 2845, 3071, 3435, 3058, 3123, 3215,
       2902, 3015, 2983, 2727, 3121, 3114, 2511, 3327, 2864, 2749, 3621,
       2860, 3074, 3234, 3083, 2797, 3025, 3129, 2990, 3035, 2990, 3513,
       2687, 3380, 2863, 2715, 3012, 3083, 2938, 2837, 3455, 3175, 2646,
       2889, 2975, 3474, 3052, 3167, 2762],
    
    'gender':[0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0,
       0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0,
       1, 1, 1, 0, 0, 1]
})

In [None]:
data.head(10)

In [None]:
sns.scatterplot(data=data, x='weeks', y='weight', hue='gender', s=100)

plt.xlabel('Period [weeks]')
plt.ylabel('Weight [g]');

## Preprocess & Scale Data

In [None]:
x = data['weeks'].values
y = data['weight'].values

In [None]:
# 標準化

x_mu = np.mean(x)
x_sd = np.std(x)

x_scaled = (x - x_mu) / x_sd

y_mu = np.mean(y)
y_sd = np.std(y)

y_scaled = (y - y_mu) / y_sd

## Define Model & Inference

In [None]:
def model(x_scaled=None, y_scaled=None, num_data=0):
    
    a = numpyro.sample('a', dist.Normal(0, 10))
    b = numpyro.sample('b', dist.Normal(0, 10))
    
    mu = a * x_scaled + b
    
    sd = numpyro.sample('sd', dist.HalfCauchy(5))
    
    with numpyro.plate('data', num_data):
                
        numpyro.sample('obs', dist.Normal(mu, sd), obs=y_scaled)
        

In [None]:
nuts = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=3000, num_chains=4)

mcmc.run(jax.random.PRNGKey(0), x_scaled=x_scaled, y_scaled=y_scaled, num_data=len(y_scaled))
mcmc_samples = mcmc.get_samples()

idata = az.from_numpyro(mcmc)

In [None]:
idata

In [None]:
az.plot_trace(idata)
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata)

## Check MCMC-samples

In [None]:
def plot_lines(trace):
    
    samples_to_plot = 50

    x_scaled_new = np.linspace(-3, 3, 50)
        
    for k in range(1, samples_to_plot):
        
        a_sample = trace['a'][-k]
        b_sample = trace['b'][-k]
    
        mu = a_sample * x_scaled_new + b_sample
        
        plt.plot(x_scaled_new, mu, c='g', alpha=0.1)
    

In [None]:
mcmc_samples['b'].shape

In [None]:
plot_lines(mcmc_samples)

sns.scatterplot(x=x_scaled, y=y_scaled, hue=data['gender'], s=80)

plt.xlabel('Period (Standardized)')
plt.ylabel('Weight (Standardized)');

In [None]:
sns.regplot(x=x_scaled, y=y_scaled, scatter_kws={'s':80})

plt.xlabel('Period (Standardized)')
plt.ylabel('Weight (Standardized)');

## Posterior Predictive Check

In [None]:
predictive = numpyro.infer.Predictive(model, mcmc_samples)

ppc_samples = predictive(jax.random.PRNGKey(1), x_scaled=x_scaled, num_data=len(y_scaled))

idata_ppc = az.from_numpyro(mcmc, posterior_predictive=ppc_samples)

In [None]:
az.plot_hdi(x_scaled, idata_ppc.posterior_predictive['obs'])

sns.scatterplot(x=x_scaled, y=y_scaled, hue=data['gender'], s=80)

plt.xlabel('Period (Standardized)')
plt.ylabel('Weight (Standardized)');