# Notebook 3: GP 分類與非共軛似然

本 notebook 示範如何使用高斯過程進行分類，處理非共軛似然（Bernoulli）。

**學習目標：**
- 理解共軛 vs 非共軛似然
- 學習使用 Gauss-Hermite 和 Monte Carlo 估計器
- 進行二分類和多分類
- 視覺化決策邊界和不確定性

**關鍵概念：**
- **非共軛似然**: Bernoulli、Poisson 等，需要近似方法
- **Gauss-Hermite**: 確定性數值積分
- **Monte Carlo**: 隨機採樣估計


In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
import jax
jax.config.update('jax_platform_name', 'cpu')
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

from plotting_style import COLORS, PALETTES, setup_plot_style, get_figure_size, format_axes, create_custom_colormap
setup_plot_style()

from infodynamics_jax.core import Phi
from infodynamics_jax.gp.kernels.params import KernelParams
from infodynamics_jax.gp.kernels.rbf import rbf as rbf_kernel
from infodynamics_jax.gp.likelihoods import get as get_likelihood
from infodynamics_jax.energy import InertialEnergy, InertialCFG
from infodynamics_jax.inference.optimisation import TypeII, TypeIICFG
from infodynamics_jax.infodynamics import run, RunCFG
from infodynamics_jax.gp.ansatz.state import VariationalState
from infodynamics_jax.gp.ansatz.expected import qfi_from_qu_full

print(f"JAX version: {jax.__version__}")

In [None]:
%load_ext autoreload
%autoreload 2
import sys
if '..' not in sys.path:
    sys.path.insert(0, '..')

import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
import jax
jax.config.update('jax_platform_name', 'cpu')
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

from utils import COLORS, PALETTES, setup_plot_style, get_figure_size, format_axes, create_custom_colormap
setup_plot_style()

from infodynamics_jax.core import Phi
from infodynamics_jax.gp.kernels.params import KernelParams
from infodynamics_jax.gp.kernels.rbf import rbf as rbf_kernel
from infodynamics_jax.gp.likelihoods import get as get_likelihood
from infodynamics_jax.energy import InertialEnergy, InertialCFG
from infodynamics_jax.inference.optimisation import TypeII, TypeIICFG
from infodynamics_jax.infodynamics import run, RunCFG
from infodynamics_jax.gp.ansatz.state import VariationalState
from infodynamics_jax.gp.ansatz.expected import qfi_from_qu_full

print(f"JAX version: {jax.__version__}")

## 1. 生成二分類數據

In [None]:
key = jax.random.key(456)
N_train = 100

# 類別 0: 圍繞 (-2, -2) 的聚類
key, subkey = jax.random.split(key)
X_class0 = jax.random.normal(subkey, (N_train // 2, 2)) * 0.8 + jnp.array([-2.0, -2.0])

# 類別 1: 圍繞 (2, 2) 的聚類
key, subkey = jax.random.split(key)
X_class1 = jax.random.normal(subkey, (N_train // 2, 2)) * 0.8 + jnp.array([2.0, 2.0])

# 合併
X_train = jnp.vstack([X_class0, X_class1])
Y_train = jnp.concatenate([
    jnp.zeros(N_train // 2),
    jnp.ones(N_train // 2)
])

# 打亂
key, subkey = jax.random.split(key)
perm = jax.random.permutation(subkey, N_train)
X_train = X_train[perm]
Y_train = Y_train[perm]

print(f"訓練集: {N_train} 個點")
print(f"類別 0: {jnp.sum(Y_train == 0)} 個點")
print(f"類別 1: {jnp.sum(Y_train == 1)} 個點")

## 2. 訓練 GP 分類器

In [None]:
# 獲取 Bernoulli 似然
bernoulli_likelihood = get_likelihood("bernoulli")

# 初始化核參數
kernel_params = KernelParams(lengthscale=jnp.array(1.0), variance=jnp.array(1.0))

# 創建誘導點（網格）
M = 20
Z_x1 = jnp.linspace(X_train[:, 0].min(), X_train[:, 0].max(), int(jnp.sqrt(M)))
Z_x2 = jnp.linspace(X_train[:, 1].min(), X_train[:, 1].max(), int(jnp.sqrt(M)))
Z_grid = jnp.stack(jnp.meshgrid(Z_x1, Z_x2), axis=-1).reshape(-1, 2)
Z = Z_grid[:M]

# 創建 Phi（Bernoulli 沒有噪聲方差）
phi_init = Phi(
    kernel_params=kernel_params,
    Z=Z,
    likelihood_params={},  # Bernoulli 沒有額外參數
    jitter=1e-5,
)

# 創建 InertialEnergy（使用 Gauss-Hermite 估計器）
inertial_cfg = InertialCFG(
    estimator="gh",  # Gauss-Hermite 數值積分
    gh_n=20,         # 積分點數
    inner_steps=0,   # 無內部優化
)

inertial_energy = InertialEnergy(
    kernel_fn=rbf_kernel,
    likelihood=bernoulli_likelihood,
    cfg=inertial_cfg,
)

print("InertialEnergy 已創建（Bernoulli 似然 + Gauss-Hermite 估計器）")

In [None]:
# 配置並運行 TypeII
typeii_cfg = TypeIICFG(steps=150, lr=1e-2, optimizer="adam", jit=True, constrain_params=True)
method = TypeII(cfg=typeii_cfg)

key, subkey = jax.random.split(key)
out = run(
    key=subkey,
    method=method,
    energy=inertial_energy,
    phi_init=phi_init,
    energy_args=(X_train, Y_train),
    cfg=RunCFG(jit=True),
)

phi_opt = out.result.phi
energy_trace = out.result.energy_trace

print("\n優化完成！")
print(f"最終能量: {energy_trace[-1]:.2f}")
print(f"優化後的超參數:")
print(f"  長度尺度: {float(phi_opt.kernel_params.lengthscale):.3f}")
print(f"  方差: {float(phi_opt.kernel_params.variance):.3f}")

## 3. 進行預測並視覺化決策邊界

In [None]:
# 創建預測網格
resolution = 100
x1_min, x1_max = X_train[:, 0].min() - 1, X_train[:, 0].max() + 1
x2_min, x2_max = X_train[:, 1].min() - 1, X_train[:, 1].max() + 1
xx1, xx2 = jnp.meshgrid(
    jnp.linspace(x1_min, x1_max, resolution),
    jnp.linspace(x2_min, x2_max, resolution)
)
X_grid = jnp.c_[xx1.ravel(), xx2.ravel()]

# 計算後驗狀態
state = VariationalState.initialise(phi_opt, X_train, Y_train)

# 在網格上進行預測
mu_grid, var_grid = qfi_from_qu_full(
    phi_opt, X_grid, rbf_kernel, state.m_u, state.L_u
)

mu_grid = mu_grid.squeeze()
var_grid = var_grid.squeeze()

# 轉換潛在函數為概率: p(y=1) = sigmoid(f)
prob_class1 = jax.nn.sigmoid(mu_grid)
prob_class1 = prob_class1.reshape(xx1.shape)

print(f"預測完成！")
print(f"概率範圍: [{prob_class1.min():.3f}, {prob_class1.max():.3f}]")

In [None]:
# 視覺化決策邊界
fig, axes = plt.subplots(1, 2, figsize=get_figure_size('wide', 0.5))

# 圖 1: 決策邊界與類別概率
ax = axes[0]
prob_cmap = create_custom_colormap('probability')
contour = ax.contourf(xx1, xx2, prob_class1, levels=30, cmap=prob_cmap, alpha=0.85, zorder=1)
ax.contour(xx1, xx2, prob_class1, levels=[0.5], colors='black', linewidths=3, linestyles='-', zorder=2)

ax.scatter(X_train[Y_train == 0, 0], X_train[Y_train == 0, 1], 
          c=COLORS['class0'], s=70, alpha=0.9, label='類別 0', 
          edgecolors='white', linewidths=1.5, marker='o', zorder=4)
ax.scatter(X_train[Y_train == 1, 0], X_train[Y_train == 1, 1], 
          c=COLORS['class1'], s=70, alpha=0.9, label='類別 1', 
          edgecolors='white', linewidths=1.5, marker='s', zorder=4)

format_axes(ax, title='決策邊界與類別概率', xlabel='特徵 X₁', ylabel='特徵 X₂', 
           legend=True, legend_loc='upper right')
plt.colorbar(contour, ax=ax, label='P(類別 1)', shrink=0.8)

# 圖 2: 預測不確定性
ax = axes[1]
std_grid = jnp.sqrt(var_grid).reshape(xx1.shape)
uncertainty_cmap = create_custom_colormap('uncertainty')
contour2 = ax.contourf(xx1, xx2, std_grid, levels=30, cmap=uncertainty_cmap, alpha=0.85, zorder=1)
ax.scatter(X_train[:, 0], X_train[:, 1], c='gray', s=50, alpha=0.6, 
          edgecolors='white', linewidths=1, zorder=3)

format_axes(ax, title='預測不確定性分析', xlabel='特徵 X₁', ylabel='特徵 X₂', legend=False)
plt.colorbar(contour2, ax=ax, label='標準差', shrink=0.8)

plt.tight_layout()
plt.show()

## 總結

在本 notebook 中，我們學習了：

1. **非共軛似然**: Bernoulli 用於二分類
2. **近似方法**: Gauss-Hermite（確定性）vs Monte Carlo（隨機）
3. **GP 分類**: 潛在函數 f(x) 由 GP 建模，概率 p(y=1|x) = σ(f(x))
4. **不確定性量化**: 自然的不確定性量化

**何時使用每個估計器：**

| 估計器 | 優點 | 缺點 | 最佳適用 |
|--------|------|------|---------|
| **Gauss-Hermite** | 確定性、穩定 | 限於 1D 積分 | 標準分類 |
| **Monte Carlo** | 可擴展到高維 | 噪聲、收斂較慢 | 多輸出、複雜模型 |

**擴展：**
- **多類別**: 使用 softmax 似然
- **計數數據**: Poisson 或 Negative Binomial
- **有序**: 有序類別（評級、嚴重程度）
