## Student's t-distribution

## Install Packages

In [None]:
!pip install numpyro

【重要】パッケージのインストール完了後に、ランタイムを再起動して下さい！

## Import Packages

In [None]:
import numpyro
import numpyro.distributions as dist

import arviz as az

import jax
import jax.numpy as jnp

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import scipy

In [None]:
plt.rcParams['font.size'] = 14
plt.rcParams['figure.figsize'] = (8,5)

## Student's t-distribution

https://ja.wikipedia.org/wiki/T%E5%88%86%E5%B8%83

$\nu = 1$ の場合は、コーシー分布に一致する

In [None]:
x = np.linspace(-5, 5, 100)

for nu in [1, 2, 10]:

    y = scipy.stats.t.pdf(x, nu)
    
    plt.plot(x, y, label=r'$\nu={}$'.format(nu))

plt.title('Student\'s t-distribution')
plt.xlim([-5, 5])
plt.legend();

In [None]:
nu = 10

y = dist.StudentT(nu, loc=0, scale=1).sample(jax.random.PRNGKey(0), sample_shape=(100000,))

sns.histplot(y)

plt.title('Student\'s t-distribution')
plt.xlim([-5, 5]);

## Gamma distribution

https://ja.wikipedia.org/wiki/%E3%82%AC%E3%83%B3%E3%83%9E%E5%88%86%E5%B8%83

$\alpha=1$ のとき、指数分布に一致する

In [None]:
x = np.linspace(0, 20, 100)

for alpha in [1, 2, 5, 10]:

    y = scipy.stats.gamma.pdf(x, alpha)
    
    plt.plot(x, y, label=r'$\alpha={}, \beta=1$'.format(alpha))

plt.title('Gamma Distribution')
plt.legend();

In [None]:
for beta in [0.3, 0.5, 1]:

    scale = 1 / beta
    
    y = scipy.stats.gamma.pdf(x, 2, scale=scale)
    
    plt.plot(x, y, label=r'$\alpha=2, \beta={}$'.format(beta))

plt.title('Gamma Distribution')
plt.xlim([0, 20])
plt.legend();

## Exponential Distribution

In [None]:
x = np.linspace(0, 10, 100)

for scale in [1, 2, 5, 10]:
    
    y = scipy.stats.expon.pdf(x, scale=scale)
    
    plt.plot(x, y, label=r'$\lambda=1/{:.0f}$'.format(scale))

plt.title('Exponential Distribution')
plt.xlim([0, 10])
plt.legend();

## Prior for nu in Student's t-distribution

https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations


In [None]:
y0 = dist.Exponential(1/30).sample(jax.random.PRNGKey(0), sample_shape=(100000,))
y1 = dist.Gamma(2, 0.1).sample(jax.random.PRNGKey(1), sample_shape=(100000,))

plt.figure(figsize=(10, 4))

sns.histplot(y0, color='C0')
sns.histplot(y1, color='C1')

plt.xlim([0, 100]);