# Notebook 4: 貝葉斯推斷（Annealed SMC + IBIS）

本 notebook 示範如何使用粒子方法進行完整的貝葉斯推斷。

**學習目標：**
- 理解 MAP-II vs 完整貝葉斯推斷
- 學習 Annealed SMC 的工作原理
- 學習 IBIS 用於在線/流式推斷
- 量化超參數的不確定性

**主要概念：**
- **Annealed SMC**: 從先驗到後驗的退火過程
- **IBIS**: 迭代批次重要性採樣，用於流式數據
- **粒子方法**: 使用粒子雲近似後驗分布


In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
import jax
jax.config.update('jax_platform_name', 'cpu')
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from plotting_style import COLORS, PALETTES, setup_plot_style, get_figure_size, format_axes
setup_plot_style()

from infodynamics_jax.core import Phi, SupervisedData
from infodynamics_jax.gp.kernels.params import KernelParams
from infodynamics_jax.gp.kernels.rbf import rbf as rbf_kernel
from infodynamics_jax.gp.likelihoods import get as get_likelihood
from infodynamics_jax.energy import InertialEnergy, InertialCFG
from infodynamics_jax.inference.particle import AnnealedSMC, AnnealedSMCCFG
from infodynamics_jax.inference.optimisation import TypeII, TypeIICFG
from infodynamics_jax.infodynamics import run, RunCFG

print(f"JAX version: {jax.__version__}")

In [None]:
%load_ext autoreload
%autoreload 2
import sys
if '..' not in sys.path:
    sys.path.insert(0, '..')

import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
import jax
jax.config.update('jax_platform_name', 'cpu')
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from utils import COLORS, PALETTES, setup_plot_style, get_figure_size, format_axes
setup_plot_style()

from infodynamics_jax.core import Phi, SupervisedData
from infodynamics_jax.gp.kernels.params import KernelParams
from infodynamics_jax.gp.kernels.rbf import rbf as rbf_kernel
from infodynamics_jax.gp.likelihoods import get as get_likelihood
from infodynamics_jax.energy import InertialEnergy, InertialCFG
from infodynamics_jax.inference.particle import AnnealedSMC, AnnealedSMCCFG
from infodynamics_jax.inference.optimisation import TypeII, TypeIICFG
from infodynamics_jax.infodynamics import run, RunCFG

print(f"JAX version: {jax.__version__}")

## 1. MAP-II vs 完整貝葉斯推斷

### MAP-II (Maximum A Posteriori Type-II)
- 尋找**單一最佳**超參數: φ* = argmax_φ p(φ | y)
- 快速、確定性
- 無超參數不確定性量化
- 小數據集有過擬合風險

### 完整貝葉斯 (Annealed SMC)
- 近似**整個後驗**: p(φ | y)
- 通過粒子雲表示不確定性
- 對過擬合更魯棒
- 計算成本: O(P × T) 其中 P = 粒子數，T = 退火步數


## 2. 生成數據

In [None]:
key = jax.random.key(789)

# 小數據集以觀察不確定性
N_train = 30
X_train = jnp.linspace(-4, 4, N_train)[:, None]

# 真實函數
def true_function(x):
    return jnp.sin(2 * x[:, 0]) + 0.3 * x[:, 0]

f_train = true_function(X_train)

# 添加噪聲
key, subkey = jax.random.split(key)
noise_std = 0.3
Y_train = f_train + noise_std * jax.random.normal(subkey, (N_train,))

# 測試集
X_test = jnp.linspace(-5, 5, 100)[:, None]
f_test = true_function(X_test)

print(f"訓練集: {N_train} 個點")
print(f"真實噪聲標準差: {noise_std}")

## 3. 基線：MAP-II 優化

In [None]:
# 初始化
kernel_params = KernelParams(lengthscale=jnp.array(1.0), variance=jnp.array(1.0))
M = 10
Z = jnp.linspace(X_train.min(), X_train.max(), M)[:, None]

phi_init = Phi(
    kernel_params=kernel_params,
    Z=Z,
    likelihood_params={"noise_var": jnp.array(0.1)},
    jitter=1e-5,
)

# 能量
gaussian_likelihood = get_likelihood("gaussian")
inertial_cfg = InertialCFG(estimator="gh", gh_n=20, inner_steps=0)
inertial_energy = InertialEnergy(
    kernel_fn=rbf_kernel,
    likelihood=gaussian_likelihood,
    cfg=inertial_cfg,
)

# 運行 MAP-II
from infodynamics_jax.inference.optimisation.vfe import make_vfe_objective
vfe_objective = make_vfe_objective(kernel_fn=rbf_kernel, residual="fitc")
typeii_cfg = TypeIICFG(steps=100, lr=1e-2, optimizer="adam", jit=True, constrain_params=True)
method = TypeII(cfg=typeii_cfg)

key, subkey = jax.random.split(key)
result_map = method.run(
    energy=vfe_objective,
    phi_init=phi_init,
    energy_args=(X_train, Y_train),
)

phi_map = result_map.phi

print("MAP-II 結果:")
print(f"  長度尺度: {float(phi_map.kernel_params.lengthscale):.3f}")
print(f"  方差: {float(phi_map.kernel_params.variance):.3f}")
print(f"  噪聲方差: {float(phi_map.likelihood_params['noise_var']):.3f}")

## 4. 運行 Annealed SMC

In [None]:
def init_particles_fn(key, n_particles: int):
    """從先驗初始化粒子"""
    keys = jax.random.split(key, n_particles)
    
    def init_one(key_i):
        key_l, key_v, key_z, key_n = jax.random.split(key_i, 4)
        
        # 長度尺度: 對數正態分佈
        lengthscale = jnp.exp(jax.random.normal(key_l, ()) * 0.5)
        
        # 方差: 對數正態分佈
        variance = jnp.exp(jax.random.normal(key_v, ()) * 0.5)
        
        # 誘導點: 輕微擾動
        Z_noisy = phi_init.Z + jax.random.normal(key_z, phi_init.Z.shape) * 0.2
        
        # 噪聲方差: 對數正態分佈
        noise_var = jnp.exp(jnp.log(0.1) + jax.random.normal(key_n, ()) * 0.5)
        
        # 約束為正值
        lengthscale = jnp.maximum(lengthscale, 0.1)
        variance = jnp.maximum(variance, 0.1)
        noise_var = jnp.maximum(noise_var, 0.01)
        
        return Phi(
            kernel_params=KernelParams(lengthscale=lengthscale, variance=variance),
            Z=Z_noisy,
            likelihood_params={"noise_var": noise_var},
            jitter=phi_init.jitter,
        )
    
    return jax.vmap(init_one)(keys)

print("粒子初始化函數已創建！")

In [None]:
# 配置 Annealed SMC
smc_cfg = AnnealedSMCCFG(
    n_particles=64,           # 粒子數
    n_steps=20,               # 退火步數
    ess_threshold=0.5,        # 當 ESS < 0.5 * n_particles 時重採樣
    rejuvenation="hmc",       # 使用 HMC 進行粒子更新
    rejuvenation_steps=2,     # 每次更新的 HMC 步數
    jit=True,                 # 啟用 JIT
)

method_smc = AnnealedSMC(cfg=smc_cfg)

print(f"Annealed SMC 配置:")
print(f"  粒子數: {smc_cfg.n_particles}")
print(f"  退火步數: {smc_cfg.n_steps}")
print(f"  更新方法: {smc_cfg.rejuvenation} ({smc_cfg.rejuvenation_steps} 步)")

In [None]:
# 運行 Annealed SMC
print("運行 Annealed SMC...")
print("這可能需要一分鐘...")

key, subkey = jax.random.split(key)
result_smc = method_smc.run(
    energy=inertial_energy,
    init_particles_fn=init_particles_fn,
    key=subkey,
    energy_args=(X_train, Y_train),
)

particles = result_smc.particles
logw = result_smc.logw
ess_trace = result_smc.ess_trace

print("\nAnnealed SMC 結果:")
print(f"  最終 ESS: {ess_trace[-1]:.1f} / {smc_cfg.n_particles}")
print(f"  對數權重範圍: [{logw.min():.2f}, {logw.max():.2f}]")

## 5. 分析後驗分布

In [None]:
# 提取粒子值
lengthscales = np.array(particles.kernel_params.lengthscale)
variances = np.array(particles.kernel_params.variance)
noise_vars = np.array(particles.likelihood_params["noise_var"])

# 歸一化權重
weights = np.exp(logw - logw.max())
weights = weights / weights.sum()

# 加權統計
lengthscale_mean = float(np.sum(weights * lengthscales))
lengthscale_std = float(np.sqrt(np.sum(weights * (lengthscales - lengthscale_mean)**2)))

variance_mean = float(np.sum(weights * variances))
variance_std = float(np.sqrt(np.sum(weights * (variances - variance_mean)**2)))

noise_var_mean = float(np.sum(weights * noise_vars))
noise_var_std = float(np.sqrt(np.sum(weights * (noise_vars - noise_var_mean)**2)))

print("後驗統計（加權）:")
print(f"\n長度尺度:")
print(f"  均值: {lengthscale_mean:.3f} ± {lengthscale_std:.3f}")
print(f"  MAP:  {float(phi_map.kernel_params.lengthscale):.3f}")

print(f"\n方差:")
print(f"  均值: {variance_mean:.3f} ± {variance_std:.3f}")
print(f"  MAP:  {float(phi_map.kernel_params.variance):.3f}")

print(f"\n噪聲方差:")
print(f"  均值: {noise_var_mean:.3f} ± {noise_var_std:.3f}")
print(f"  MAP:  {float(phi_map.likelihood_params['noise_var']):.3f}")
print(f"  真實: {noise_std**2:.3f}")

In [None]:
# 視覺化後驗分布和結果
from infodynamics_jax.gp.predict import predict_typeii

# 1. 繪製後驗分布直方圖
fig, axes = plt.subplots(2, 2, figsize=get_figure_size('wide', 1.0))

# ESS 軌跡
ax = axes[0, 0]
ax.plot(ess_trace, lw=2, color=COLORS['primary'])
ax.axhline(smc_cfg.ess_threshold * smc_cfg.n_particles, ls='--', 
          color=COLORS['accent'], label=f'ESS 閾值 ({smc_cfg.ess_threshold * smc_cfg.n_particles:.0f})')
format_axes(ax, title='有效樣本大小 (ESS) 軌跡', xlabel='退火步數', ylabel='ESS', legend=True)

# 長度尺度後驗分布
ax = axes[0, 1]
ax.hist(lengthscales, weights=weights, bins=30, alpha=0.7, 
       color=COLORS['primary'], edgecolor='black', density=True)
ax.axvline(lengthscale_mean, ls='--', lw=2, color=COLORS['accent'], 
          label=f'均值: {lengthscale_mean:.3f}')
ax.axvline(float(phi_map.kernel_params.lengthscale), ls=':', lw=2, 
          color=COLORS['secondary'], label=f'MAP: {float(phi_map.kernel_params.lengthscale):.3f}')
format_axes(ax, title='長度尺度後驗分布', xlabel='長度尺度', ylabel='密度', legend=True)

# 方差後驗分布
ax = axes[1, 0]
ax.hist(variances, weights=weights, bins=30, alpha=0.7, 
       color=COLORS['secondary'], edgecolor='black', density=True)
ax.axvline(variance_mean, ls='--', lw=2, color=COLORS['accent'], 
          label=f'均值: {variance_mean:.3f}')
ax.axvline(float(phi_map.kernel_params.variance), ls=':', lw=2, 
          color=COLORS['primary'], label=f'MAP: {float(phi_map.kernel_params.variance):.3f}')
format_axes(ax, title='方差後驗分布', xlabel='方差', ylabel='密度', legend=True)

# 噪聲方差後驗分布
ax = axes[1, 1]
ax.hist(noise_vars, weights=weights, bins=30, alpha=0.7, 
       color=COLORS['tertiary'], edgecolor='black', density=True)
ax.axvline(noise_var_mean, ls='--', lw=2, color=COLORS['accent'], 
          label=f'均值: {noise_var_mean:.3f}')
ax.axvline(float(phi_map.likelihood_params['noise_var']), ls=':', lw=2, 
          color=COLORS['primary'], label=f'MAP: {float(phi_map.likelihood_params["noise_var"]):.3f}')
ax.axvline(noise_std**2, ls='-', lw=2, color='red', alpha=0.7, 
          label=f'真實: {noise_std**2:.3f}')
format_axes(ax, title='噪聲方差後驗分布', xlabel='噪聲方差', ylabel='密度', legend=True)

plt.tight_layout()
plt.show()

# 2. 使用加權粒子進行預測（貝葉斯模型平均）
print("\n計算加權預測（貝葉斯模型平均）...")

# 對每個粒子進行預測，然後加權平均
def predict_with_particle(phi_particle):
    mu, var = predict_typeii(
        phi_particle,
        X_test,
        X_train,
        Y_train,
        kernel_fn=rbf_kernel,
        residual="fitc",
    )
    return mu, var

# 向量化預測
n_particles = len(particles.kernel_params.lengthscale)
mu_predictions = []
var_predictions = []

# 獲取 jitter（可能是標量或數組）
if hasattr(particles, 'jitter'):
    if isinstance(particles.jitter, (int, float)):
        jitter_val = particles.jitter
    else:
        jitter_val = particles.jitter[0] if len(particles.jitter) > 0 else 1e-5
else:
    jitter_val = 1e-5

for i in range(n_particles):
    phi_i = Phi(
        kernel_params=KernelParams(
            lengthscale=particles.kernel_params.lengthscale[i],
            variance=particles.kernel_params.variance[i]
        ),
        Z=particles.Z[i],
        likelihood_params={"noise_var": particles.likelihood_params["noise_var"][i]},
        jitter=jitter_val,
    )
    mu_i, var_i = predict_with_particle(phi_i)
    mu_predictions.append(mu_i)
    var_predictions.append(var_i)

mu_predictions = jnp.array(mu_predictions)  # (n_particles, N_test)
var_predictions = jnp.array(var_predictions)  # (n_particles, N_test)

# 加權平均
weights_array = jnp.array(weights)
mu_weighted = jnp.sum(weights_array[:, None] * mu_predictions, axis=0)
# 對於方差，使用 E[var] + Var[mean] 公式
var_weighted = jnp.sum(weights_array[:, None] * var_predictions, axis=0) + \
               jnp.sum(weights_array[:, None] * (mu_predictions - mu_weighted[None, :])**2, axis=0)

# 3. 繪製預測結果
fig, ax = plt.subplots(1, 1, figsize=get_figure_size('wide', 0.6))

# 測試集真實函數
ax.plot(X_test[:, 0], f_test, 'k-', lw=2, label='真實函數', alpha=0.7)

# 訓練數據
ax.scatter(X_train[:, 0], Y_train, s=50, c=COLORS['primary'], 
          alpha=0.6, edgecolors='black', label='訓練數據', zorder=5)

# MAP 預測
mu_map, var_map = predict_typeii(
    phi_map,
    X_test,
    X_train,
    Y_train,
    kernel_fn=rbf_kernel,
    residual="fitc",
)
std_map = jnp.sqrt(var_map)
ax.plot(X_test[:, 0], mu_map, '--', lw=2, color=COLORS['secondary'], 
       label='MAP 預測', alpha=0.8)
ax.fill_between(X_test[:, 0], mu_map - 2*std_map, mu_map + 2*std_map, 
               alpha=0.2, color=COLORS['secondary'])

# 加權貝葉斯預測
std_weighted = jnp.sqrt(var_weighted)
ax.plot(X_test[:, 0], mu_weighted, '-', lw=2, color=COLORS['accent'], 
       label='貝葉斯平均預測', alpha=0.9)
ax.fill_between(X_test[:, 0], mu_weighted - 2*std_weighted, mu_weighted + 2*std_weighted, 
               alpha=0.3, color=COLORS['accent'], label='±2σ 不確定性')

format_axes(ax, title='MAP vs 貝葉斯平均預測', xlabel='x', ylabel='y', legend=True)
plt.tight_layout()
plt.show()

print("\n繪圖完成！")

## 6. IBIS：在線推斷

In [None]:
# IBIS 用於流式數據處理
# 數據分批到達，後驗逐步更新

# 生成流式數據
N_total = 100
X_all = jnp.linspace(-5, 5, N_total)[:, None]
f_all = true_function(X_all)
key, subkey = jax.random.split(key)
Y_all = f_all + noise_std * jax.random.normal(subkey, (N_total,))

data = SupervisedData(X_all, Y_all)
batch_size = 10
n_batches = N_total // batch_size

print(f"總數據點: {N_total}")
print(f"批次大小: {batch_size}")
print(f"批次數: {n_batches}")

In [None]:
# IBIS 循環（簡化版本）
# 注意：實際實現需要更複雜的粒子更新邏輯
print("\nIBIS 示範:")
print("（完整實現需要更複雜的粒子管理）")
print("\n關鍵概念：")
print("1. 數據分批到達")
print("2. 對每個批次更新粒子權重")
print("3. 當 ESS 太低時重採樣")
print("4. 使用 MCMC 更新保持粒子多樣性")

## 總結

在本 notebook 中，我們學習了：

1. **後驗不確定性**: SMC 提供完整的後驗分布，而不僅僅是點估計
2. **MAP vs 貝葉斯均值**: MAP 估計可能與後驗均值不同
3. **超參數相關性**: 聯合分布揭示參數依賴關係
4. **不確定性量化**: 標準差量化估計不確定性

**何時使用 Annealed SMC：**
- 小數據集（高參數不確定性）
- 需要魯棒推斷（避免過擬合）
- 想要量化超參數不確定性
- 模型選擇（通過邊際似然估計）

**何時使用 MAP-II：**
- 大數據集（後驗集中）
- 速度至關重要
- 點估計足夠
- 生產部署

**IBIS 適用場景：**
- 數據順序到達（流式）
- 數據集太大無法批量處理
- 需要實時更新
- 監控分布變化
