## Robust Linear Regression

この例題は、以下のサイトで紹介されているコードを参考に作成しています。

- https://docs.pymc.io/pymc-examples/examples/generalized_linear_models/GLM-robust.html


## Install Packages

In [None]:
!pip install numpyro

【重要】パッケージのインストール完了後に、ランタイムを再起動して下さい！

## Import Packages

In [None]:
import numpyro
import numpyro.distributions as dist

import jax
import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
plt.rcParams['font.size'] = 14

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

## Load & Check Data

In [None]:
data_clean = pd.read_csv('data_clean.csv')

x_clean = data_clean['x'].values
y_clean = data_clean['y'].values

In [None]:
data_observed = pd.read_csv('data_observed.csv')

x_observed = data_observed['x'].values
y_observed = data_observed['y'].values

In [None]:
fig = plt.figure(figsize=(8, 6))

plt.plot(x_observed, y_observed, 'x')
plt.plot(x_clean, y_clean)

plt.title('Data and Underlying Model');

## Define Model & Inference : Linear Model

In [None]:
def model(x, y):
    
    a = numpyro.sample('a', dist.Normal(0, 10))
    b = numpyro.sample('b', dist.Normal(0, 10))
    
    mu = a * x + b
    
    sc = numpyro.sample('sd', dist.HalfNormal(10))
            
    numpyro.sample('obs', dist.Normal(mu, sc), obs=y)

In [None]:
nuts = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=3000, num_chains=4)

mcmc.run(jax.random.PRNGKey(0), x_observed, y_observed)
trace = mcmc.get_samples()

idata = az.from_numpyro(mcmc)

In [None]:
az.plot_trace(idata)
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata)

## Define Model & Inference : Robust Linear Model

In [None]:
def model_robust(x, y):
    
    a = numpyro.sample('a', dist.Normal(0, 10))
    b = numpyro.sample('b', dist.Normal(0, 10))
    
    mu = a * x + b
    
    nu = numpyro.sample('nu', dist.Gamma(2, 0.1))
    sc = numpyro.sample('sc', dist.HalfNormal(10))
    
    numpyro.sample('obs', dist.StudentT(nu, loc=mu, scale=sc), obs=y)

In [None]:
nuts = numpyro.infer.NUTS(model_robust)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=3000, num_chains=4)

mcmc.run(jax.random.PRNGKey(0), x_observed, y_observed)
trace_robust = mcmc.get_samples()

idata_robust = az.from_numpyro(mcmc)

In [None]:
az.plot_trace(idata_robust)
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

## Visualize Parameters

In [None]:
def plot_lines(trace, ax, samples_to_plot=50):
    
    x = np.linspace(0, 1)
    
    for k in range(1, samples_to_plot):
        
        a_sample = trace['a'][-k]
        b_sample = trace['b'][-k]
        
        mu = a_sample * x + b_sample
        
        ax.plot(x, mu, c='g', alpha=0.1)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Linear Regression
plot_lines(trace, axes[0])

axes[0].plot(x_observed, y_observed, 'x')
axes[0].plot(x_clean, y_clean, lw=2)
axes[0].set_title('Linear Regression')

# Robust Linear Regression
plot_lines(trace_robust, axes[1])

axes[1].plot(x_observed, y_observed, 'x')
axes[1].plot(x_clean, y_clean, lw=2)
axes[1].set_title('Robust Linear Regression')

plt.tight_layout()

## Information Criteria

In [None]:
az.waic(idata, scale='deviance')

In [None]:
az.waic(idata_robust, scale='deviance')

In [None]:
dict_idata = {'Normal':idata, 'Robust':idata_robust}

In [None]:
df_waic = az.compare(dict_idata, ic='waic', scale='deviance')
df_waic

In [None]:
az.plot_compare(df_waic, figsize=(8, 3));