# Notebook 2: 核函數選擇與比較

本 notebook 示範如何使用不同的核函數進行 GP 回歸，並比較它們的特性。

**學習目標：**
- 了解不同核函數的數學性質和適用場景
- 學習如何選擇合適的核函數
- 比較 RBF、Matérn、Periodic、Linear 等核函數
- 視覺化核函數的協方差結構

**核函數類型：**
- **RBF (Squared Exponential)**: 無限可微，適合平滑函數
- **Matérn**: 控制平滑度，適合真實世界數據
- **Periodic**: 捕捉週期性模式
- **Linear**: 捕捉線性趨勢


## 設置環境

In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
import jax
jax.config.update('jax_platform_name', 'cpu')
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from utils import COLORS, PALETTES, setup_plot_style, get_figure_size, format_axes
setup_plot_style()

from infodynamics_jax.core import Phi
from infodynamics_jax.gp.kernels.params import KernelParams
from infodynamics_jax.gp.kernels.rbf import rbf as rbf_kernel
from infodynamics_jax.gp.kernels.matern32 import matern32 as matern32_kernel
from infodynamics_jax.gp.kernels.matern52 import matern52 as matern52_kernel
from infodynamics_jax.gp.kernels.periodic import periodic as periodic_kernel
from infodynamics_jax.gp.kernels.linear import linear as linear_kernel
from infodynamics_jax.inference.optimisation import TypeII, TypeIICFG
from infodynamics_jax.inference.optimisation.vfe import make_vfe_objective
from infodynamics_jax.gp.predict import predict_typeii

print(f"JAX version: {jax.__version__}")

In [None]:
%load_ext autoreload
%autoreload 2
import sys
if '..' not in sys.path:
    sys.path.insert(0, '..')

import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
import jax
jax.config.update('jax_platform_name', 'cpu')
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from utils import COLORS, PALETTES, setup_plot_style, get_figure_size, format_axes
setup_plot_style()

from infodynamics_jax.core import Phi
from infodynamics_jax.gp.kernels.params import KernelParams
from infodynamics_jax.gp.kernels.rbf import rbf as rbf_kernel
from infodynamics_jax.gp.kernels.matern32 import matern32 as matern32_kernel
from infodynamics_jax.gp.kernels.matern52 import matern52 as matern52_kernel
from infodynamics_jax.gp.kernels.periodic import periodic as periodic_kernel
from infodynamics_jax.gp.kernels.linear import linear as linear_kernel
from infodynamics_jax.inference.optimisation import TypeII, TypeIICFG
from infodynamics_jax.inference.optimisation.vfe import make_vfe_objective
from infodynamics_jax.gp.predict import predict_typeii

print(f"JAX version: {jax.__version__}")

## 1. 生成測試函數

In [None]:
key = jax.random.key(123)
N_train = 50
noise_std = 0.2

X = jnp.linspace(-5, 5, N_train)[:, None]
X_test = jnp.linspace(-6, 6, 200)[:, None]

# 測試函數 1: 平滑函數（適合 RBF）
def smooth_function(x):
    return 2 * jnp.sin(x[:, 0]) + jnp.cos(2 * x[:, 0])

# 測試函數 2: 較不平滑（適合 Matérn 3/2）
def less_smooth_function(x):
    return jnp.abs(x[:, 0]) + 0.5 * jnp.sin(3 * x[:, 0])

# 測試函數 3: 週期函數（適合 Periodic）
def periodic_function(x):
    return 2 * jnp.sin(2 * jnp.pi * x[:, 0] / 3.0)

# 測試函數 4: 線性趨勢（適合 Linear）
def linear_function(x):
    return 1.5 * x[:, 0] + 0.5 * jnp.sin(x[:, 0])

print("測試函數已創建！")

## 2. 輔助函數：訓練和預測

In [None]:
def train_and_predict(X_train, Y_train, X_test, kernel_fn, kernel_params_init, 
                     n_inducing=15, n_steps=150, key=None):
    """訓練 GP 並進行預測"""
    # 創建誘導點
    Z = jnp.linspace(X_train.min(), X_train.max(), n_inducing)[:, None]
    
    # 創建 Phi
    phi_init = Phi(
        kernel_params=kernel_params_init,
        Z=Z,
        likelihood_params={"noise_var": jnp.array(0.1)},
        jitter=1e-5,
    )
    
    # 創建 VFE 目標
    vfe_objective = make_vfe_objective(kernel_fn=kernel_fn, residual="fitc")
    
    # 配置並運行 TypeII
    typeii_cfg = TypeIICFG(
        steps=n_steps,
        lr=1e-2,
        optimizer="adam",
        jit=True,
        constrain_params=True,
        min_noise_var=1e-3,
        clip_grad_norm=10.0,
    )
    method = TypeII(cfg=typeii_cfg)
    
    result = method.run(
        energy=vfe_objective,
        phi_init=phi_init,
        energy_args=(X_train, Y_train),
    )
    
    phi_opt = result.phi
    
    # 進行預測
    mu_test, predictive_var = predict_typeii(
        phi_opt, X_test, X_train, Y_train, kernel_fn, residual="fitc"
    )
    
    mu_test = mu_test.squeeze()
    std_test = jnp.sqrt(predictive_var.squeeze())
    
    return mu_test, std_test, phi_opt, result.energy_trace

## 3. 實驗：不同核函數在不同函數上的表現

In [None]:
# RBF 核在平滑函數上
f_train = smooth_function(X)
key, subkey = jax.random.split(key)
Y_train_rbf = f_train + noise_std * jax.random.normal(subkey, (N_train,))
f_test = smooth_function(X_test)

kernel_params_rbf = KernelParams(lengthscale=jnp.array(1.0), variance=jnp.array(1.0))
key, subkey = jax.random.split(key)
mu_rbf, std_rbf, phi_rbf, trace_rbf = train_and_predict(
    X, Y_train_rbf, X_test, rbf_kernel, kernel_params_rbf, key=subkey
)

print("RBF 核結果:")
print(f"  優化後長度尺度: {float(phi_rbf.kernel_params.lengthscale):.3f}")
print(f"  優化後方差: {float(phi_rbf.kernel_params.variance):.3f}")
print(f"  MSE: {float(jnp.mean((f_test - mu_rbf)**2)):.4f}")

In [None]:
# Matérn 3/2 核在較不平滑函數上
f_train = less_smooth_function(X)
key, subkey = jax.random.split(key)
Y_train_matern = f_train + noise_std * jax.random.normal(subkey, (N_train,))
f_test_matern = less_smooth_function(X_test)

kernel_params_matern = KernelParams(lengthscale=jnp.array(1.0), variance=jnp.array(1.0))
key, subkey = jax.random.split(key)
mu_matern, std_matern, phi_matern, trace_matern = train_and_predict(
    X, Y_train_matern, X_test, matern32_kernel, kernel_params_matern, key=subkey
)

print("Matérn 3/2 核結果:")
print(f"  優化後長度尺度: {float(phi_matern.kernel_params.lengthscale):.3f}")
print(f"  優化後方差: {float(phi_matern.kernel_params.variance):.3f}")
print(f"  MSE: {float(jnp.mean((f_test_matern - mu_matern)**2)):.4f}")

In [None]:
# Periodic 核在週期函數上
f_train = periodic_function(X)
key, subkey = jax.random.split(key)
Y_train_periodic = f_train + noise_std * jax.random.normal(subkey, (N_train,))
f_test_periodic = periodic_function(X_test)

kernel_params_periodic = KernelParams(
    lengthscale=jnp.array(1.0), 
    variance=jnp.array(1.0),
    period=jnp.array(3.0)
)

key, subkey = jax.random.split(key)
mu_periodic, std_periodic, phi_periodic, trace_periodic = train_and_predict(
    X, Y_train_periodic, X_test, periodic_kernel, kernel_params_periodic, key=subkey
)

print("Periodic 核結果:")
print(f"  優化後長度尺度: {float(phi_periodic.kernel_params.lengthscale):.3f}")
print(f"  優化後方差: {float(phi_periodic.kernel_params.variance):.3f}")
print(f"  優化後週期: {float(phi_periodic.kernel_params.period):.3f} (真實: 3.0)")
print(f"  MSE: {float(jnp.mean((f_test_periodic - mu_periodic)**2)):.4f}")

In [None]:
# Linear 核在線性趨勢上
f_train = linear_function(X)
key, subkey = jax.random.split(key)
Y_train_linear = f_train + noise_std * jax.random.normal(subkey, (N_train,))
f_test_linear = linear_function(X_test)

kernel_params_linear = KernelParams(
    variance=jnp.array(1.0),
    offset=jnp.array(0.0)
)

key, subkey = jax.random.split(key)
mu_linear, std_linear, phi_linear, trace_linear = train_and_predict(
    X, Y_train_linear, X_test, linear_kernel, kernel_params_linear, key=subkey
)

print("Linear 核結果:")
print(f"  優化後方差: {float(phi_linear.kernel_params.variance):.3f}")
print(f"  優化後偏移: {float(phi_linear.kernel_params.offset):.3f}")
print(f"  MSE: {float(jnp.mean((f_test_linear - mu_linear)**2)):.4f}")

## 4. 視覺化比較結果

In [None]:
fig, axes = plt.subplots(2, 2, figsize=get_figure_size('wide', 1.2))

experiments = [
    ("RBF 在平滑函數上", X, Y_train_rbf, X_test, f_test, mu_rbf, std_rbf, PALETTES['main'][0]),
    ("Matérn 3/2 在較不平滑函數上", X, Y_train_matern, X_test, f_test_matern, mu_matern, std_matern, PALETTES['main'][1]),
    ("Periodic 在週期函數上", X, Y_train_periodic, X_test, f_test_periodic, mu_periodic, std_periodic, PALETTES['main'][2]),
    ("Linear 在線性趨勢上", X, Y_train_linear, X_test, f_test_linear, mu_linear, std_linear, PALETTES['main'][3]),
]

for ax, (title, X_tr, Y_tr, X_te, f_te, mu, std, color) in zip(axes.flat, experiments):
    ax.fill_between(X_te[:, 0], mu - 2 * std, mu + 2 * std,
                   color=color, alpha=0.2, label='±2 std', zorder=1)
    ax.fill_between(X_te[:, 0], mu - std, mu + std,
                   color=color, alpha=0.3, label='±1 std', zorder=2)
    ax.plot(X_te[:, 0], mu, color=color, linewidth=2.5, 
           label='GP 均值', zorder=3)
    ax.plot(X_te[:, 0], f_te, color=COLORS['true'], linewidth=2, 
           label='真實函數', zorder=4, linestyle='--', alpha=0.8)
    ax.scatter(X_tr[:, 0], Y_tr, c=COLORS['train'], s=40, alpha=0.7, 
              label='訓練數據', zorder=5, edgecolors='white', linewidths=0.6)
    format_axes(ax, title=title, xlabel='輸入 X', ylabel='輸出 Y', 
               legend=True, legend_loc='best')

plt.suptitle('核函數比較：不同核函數在匹配函數上的表現', 
            fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout(rect=[0, 0, 1, 0.99])
plt.show()

## 總結

**核函數選擇指南：**

| 核函數 | 最佳適用場景 | 平滑度 | 關鍵參數 |
|--------|------------|--------|---------|
| **RBF** | 非常平滑的函數 | C^∞ (無限可微) | lengthscale, variance |
| **Matérn 3/2** | 中等平滑函數 | C¹ (一次可微) | lengthscale, variance |
| **Matérn 5/2** | 平滑但非無限 | C² (二次可微) | lengthscale, variance |
| **Periodic** | 週期性/季節性數據 | 取決於基礎核 | period, lengthscale, variance |
| **Linear** | 線性趨勢 | 不平滑 | variance, offset |

**關鍵見解：**
1. **RBF** 最常用但假設無限平滑度
2. **Matérn** 族提供更多靈活性來建模粗糙度
3. **Periodic** 核自動發現週期性
4. **Linear** 核捕捉全局趨勢
5. 核函數可以**組合**（相加、相乘）以構建更豐富的模型

**最佳實踐：**
- 未知函數時從 RBF 開始
- 真實世界數據使用 Matérn（通常具有有限平滑度）
- 時間序列有明顯週期性時使用 Periodic
- 組合核函數：例如 `Linear + RBF` 用於趨勢 + 平滑偏差
