# ベイズ推定の実装

確率モデルをコードとして書く考え方（確率的プログラミング Probabilistic Programming）をとることが多い。

具体的なツールとしては Stan や NumPyroなどを使う

## 例

線形回帰モデル

$$
y_i = \alpha+X_i \beta + e_i, \quad e_i \sim \mathcal{N}(0, \sigma)
$$

について、

$$
\begin{aligned}
y_i & \sim \mathcal{N}\left(\alpha+X_i \beta, \sigma\right) \\
\alpha & \sim \mathcal{N}(0,1) \\
\beta_j & \sim \mathcal{N}(0,1) \\
\sigma & \sim \operatorname{HalfNormal}(1) \quad(\sigma \geq 0)
\end{aligned}
$$

In [10]:
# サンプルデータ生成
import numpy as np

rng = np.random.default_rng(0)
N, K = 200, 3
X = rng.normal(size=(N, K))

alpha_true = 1.0
beta_true = np.array([0.5, -1.2, 0.3])
sigma_true = 0.7

y = alpha_true + X @ beta_true + rng.normal(scale=sigma_true, size=N)

## Stan

```stan
data {
  int<lower=0> N;   // number of data items
  int<lower=0> K;   // number of predictors
  matrix[N, K] x;   // predictor matrix
  vector[N] y;      // outcome vector
}
parameters {
  real alpha;           // intercept
  vector[K] beta;       // coefficients for predictors
  real<lower=0> sigma;  // error scale
}
model {
  alpha ~ normal(0, 1);
  beta  ~ normal(0, 1);
  sigma ~ normal(0, 1); // <lower=0> にしているので半正規分布
  y ~ normal(x * beta + alpha, sigma);  // likelihood
}
```

`cmdstanpy`パッケージを使う場合の例

```python
from cmdstanpy import CmdStanModel
model = CmdStanModel(stan_file="hoge.stan")

data = {"N": N, "K": K, "X": X, "y": y}
fit = model.sample(
    data=data,
    chains=4,
    iter_warmup=1000,
    iter_sampling=1000,
    seed=0,
)

df = fit.draws_pd()  # pandas DataFrame
print(df[["alpha", "beta[1]", "beta[2]", "beta[3]", "sigma"]].describe())
```

## NumPyro

JAXという高速な科学計算ライブラリを使っている

In [11]:
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

def model(X, y=None):
    N, K = X.shape
    alpha = numpyro.sample("alpha", dist.Normal(0.0, 1.0))
    beta  = numpyro.sample("beta",  dist.Normal(0.0, 1.0).expand([K]))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1.0))
    mu = alpha + jnp.dot(X, beta)
    numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

nuts = NUTS(model)
mcmc = MCMC(nuts, num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=False)
mcmc.run(jax.random.PRNGKey(0), X=jnp.array(X), y=jnp.array(y))

samples = mcmc.get_samples(group_by_chain=False)
# samples["beta"] shape: (num_draws, K)
print({k: (v.mean(0), v.std(0)) for k, v in samples.items() if k in ["alpha","sigma"]})
print("beta mean:", samples["beta"].mean(0))
print("beta sd  :", samples["beta"].std(0))

  mcmc = MCMC(nuts, num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=False)


{'alpha': (Array(0.97254834, dtype=float64), Array(0.05002473, dtype=float64)), 'sigma': (Array(0.71080548, dtype=float64), Array(0.03604194, dtype=float64))}
beta mean: [ 0.43455282 -1.26495461  0.33638894]
beta sd  : [0.04932507 0.05165164 0.05145463]


## PyMC

NumPyroと近い書き味だが、サンプラーを他のものにすることもできる

#### 特徴：MCMCサンプラーを選べる

- Python NUTS sampler （デフォルト、NumPyroより速いことも）
- NumPyro JAX NUTS sampler
- BlackJAX NUTS sampler（大規模データで速いらしい）
- Nutpie NUTS sampler（Rustで書かれていてJAXくらい速いらしい）

```python
pm.sample(nuts_sampler="blackjax")
```

[Faster Sampling with JAX and Numba — PyMC example gallery](https://www.pymc.io/projects/examples/en/latest/samplers/fast_sampling_with_jax_and_numba.html)

In [20]:
import pymc as pm
import numpy as np

with pm.Model() as m:
    alpha = pm.Normal("alpha", mu=0.0, sigma=5.0)
    beta  = pm.Normal("beta",  mu=0.0, sigma=2.0, shape=K)
    sigma = pm.HalfNormal("sigma", sigma=2.0)

    mu = alpha + pm.math.dot(X, beta)
    y_obs = pm.Normal("y", mu=mu, sigma=sigma, observed=y)

    idata = pm.sample(
        draws=1000,
        tune=1000,
        chains=4,
        random_seed=0,
        target_accept=0.8,
    )

# ArviZでまとめて比較しやすい
import arviz as az
display(az.summary(idata, var_names=["alpha","beta","sigma"]))

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha, beta, sigma]
  self.pid = os.fork()


Output()

  self.pid = os.fork()


Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
alpha,0.975,0.051,0.879,1.071,0.001,0.001,5105.0,3164.0,1.0
beta[0],0.437,0.05,0.343,0.529,0.001,0.001,5539.0,3541.0,1.0
beta[1],-1.267,0.051,-1.363,-1.172,0.001,0.001,7651.0,3447.0,1.0
beta[2],0.338,0.052,0.242,0.437,0.001,0.001,5716.0,3368.0,1.0
sigma,0.712,0.038,0.646,0.785,0.001,0.001,5299.0,3226.0,1.0
