# Kalman Filter GVAR Analysis

This notebook implements Kalman Filter methods for Global Vector Autoregression (GVAR) modeling.

## Overview
GVAR models are used to analyze interconnected time series across multiple countries/regions, and the Kalman Filter provides a framework for state-space estimation and forecasting.

## Setup
The following cell imports required libraries for data manipulation, analysis, and visualization.

In [None]:
import numpy as np
import pandas as pd
import warnings; warnings.filterwarnings("ignore")
from __future__ import annotations
import matplotlib.pyplot as plt

import os

In [None]:
#Please find this part and change the path to your own file
'''
if __name__ == "__main__":
    fname = "/Users/poppy/iCloud Drive (Archive)/Desktop/GVAR/df_country_data_climate.xlsx"
    xls = pd.ExcelFile(fname)
    sheets = xls.sheet_names

    sheets_9_of_12 = [
        "India", "Brazil", "Chile", "Indonesia", "Mexico", "Peru",
        "Philippines", "South Africa", "Thailand"

    ttls = ['y', 'Dp', 'eq', 'ep', 'r', 'ys', 'Dps', 'eqs', 'rs', 'lrs', 'ensos']
    
    # ----------------------------------------------------------------------
    # More variable examples (customize based on your dataset)
    # ----------------------------------------------------------------------
    # For instance, you can replace or extend the list of exogenous variables:
    # 
    # Example 1 — add other macro indicators:
    # ttls = ['y', 'Dp', 'eq', 'ep', 'r',
    #          'ys', 'Dps', 'eqs', 'rs', 'lrs', 'ensos',
    #          'oil', 'gdp_us', 'commodity', 'trade']
    #
    # Example 2 — replace ENSO with another climate index:
    # ttls = ['y', 'Dp', 'eq', 'ep', 'r',
    #          'ys', 'Dps', 'eqs', 'rs', 'lrs', 'iod']   # Indian Ocean Dipole
    #
    # Example 3 — combine multiple environmental shocks:
    # ttls = ['y', 'Dp', 'eq', 'ep', 'r',
    #          'ys', 'Dps', 'eqs', 'rs', 'lrs',
    #          'ensos', 'iod', 'nao', 'pna']  # ENSO + IOD + NAO + PNA
    #
    # Note: You only need to ensure the same variable names exist in your Excel file.
    # The rest of the code will automatically adjust (lags, indices, etc.)
    # ----------------------------------------------------------------------

    ]
'''
# ---------- helpers ----------
def intersect_stable(a_list, b_list):
    b_index = {v: i for i, v in enumerate(b_list)}
    ia, ib, c_vals = [], [], []
    for i, v in enumerate(a_list):
        if v in b_index:
            ia.append(i)
            ib.append(b_index[v])
            c_vals.append(v)
    return c_vals, np.array(ia, dtype=int), np.array(ib, dtype=int)

# ---------- Kalman filter (multi-lag VARX) with standardization ----------
def kalman_multilag_filter( 
    Y: np.ndarray,
    Z_all: np.ndarray,
    Q0: np.ndarray,
    R0: np.ndarray,
    P0: np.ndarray,
    lambda_: float,
    lambda_R: float,
    lambda_Q: float,
    weights: np.ndarray | None = None,
    std_ew_hist_IN: np.ndarray | float | None = None,
    dropout0: np.ndarray | None = None,
    idx_0: int | None = None,
    lags: int = 2,
    standardize: bool = True,
    lambda_norm: float | None = None,
    eps: float = 1e-8,
):

    n, mY = Y.shape
    mX = Z_all.shape[1]
    m = lags * mY + mX
    p = m * mY

    P = P0.copy()
    R = R0.copy()
    Q = Q0.copy()

    try:
        if dropout0 is not None:
            dropout_exp = 3
            d0 = np.asarray(dropout0).ravel()[:p]
            idx_dropout = d0 < 1.0
            if np.any(idx_dropout):
                scale = (d0[idx_dropout] ** dropout_exp)
                P[np.ix_(idx_dropout, idx_dropout)] *= scale[:, None] * scale[None, :]
                Q[np.ix_(idx_dropout, idx_dropout)] *= scale[:, None] * scale[None, :]
    except Exception:
        pass

    theta = np.zeros((p, 1))
    theta_est = np.zeros((n, p))
    Y_pred_std = np.full((n, mY), np.nan)
    Y_pred_raw = np.full((n, mY), np.nan)
    P_hist = np.zeros((p, p, n))
    R_hist = np.zeros((mY, mY, n))
    Q_hist = np.zeros((p, p, n))
    e_std = np.full((n, mY), np.nan)
    e_raw = np.full((n, mY), np.nan)

    I_p = np.eye(p)
    
    
    # ---------- Causal (past-only) standardization init ----------
    if standardize:
        if lambda_norm is None:
            lambda_norm = lambda_R
        eps_floor = eps if 'eps' in locals() else 1e-8
    
        # 对齐 MATLAB 思路：从早期样本初始化（与 Python 0-based 保持一致）
        mu_Z       = Z_all[1, :].copy()  # ~ Z_all(t=1)
        mu_Y       = Y[1, :].copy()      # ~ Y(t=1)
        mu_Y_prev1 = Y[1, :].copy()      # ~ Y(t=1) for lag1 scale
        mu_Y_prev2 = Y[0, :].copy()      # ~ Y(t=0) for lag2 scale
    
        var_Z       = np.ones(mX)
        var_Y       = np.ones(mY)
        var_Y_prev1 = np.ones(mY)
        var_Y_prev2 = np.ones(mY)
    
        mu_Y_hist  = np.zeros((n, mY))
        std_Y_hist = np.ones((n, mY))
    
    for t in range(2, n):
        if standardize:

            mu_Y_inv  = mu_Y.copy()
            std_Y_inv = np.sqrt(np.maximum(var_Y, eps_floor)).copy()
        

            Y_raw_t     = Y[t, :].copy()
            Y_prev1_raw = Y[t - 1, :].copy()
            Y_prev2_raw = Y[t - 2, :].copy()
            Z_raw       = Z_all[t - 1, :].copy()
        

            var_Z       = lambda_norm * var_Z       + (1.0 - lambda_norm) * (Z_raw       - mu_Z)       ** 2
            var_Y       = lambda_norm * var_Y       + (1.0 - lambda_norm) * (Y_raw_t     - mu_Y)       ** 2
            var_Y_prev1 = lambda_norm * var_Y_prev1 + (1.0 - lambda_norm) * (Y_prev1_raw - mu_Y_prev1) ** 2
            var_Y_prev2 = lambda_norm * var_Y_prev2 + (1.0 - lambda_norm) * (Y_prev2_raw - mu_Y_prev2) ** 2
        

            mu_Z       = lambda_norm * mu_Z       + (1.0 - lambda_norm) * Z_raw
            mu_Y       = lambda_norm * mu_Y       + (1.0 - lambda_norm) * Y_raw_t
            mu_Y_prev1 = lambda_norm * mu_Y_prev1 + (1.0 - lambda_norm) * Y_prev1_raw
            mu_Y_prev2 = lambda_norm * mu_Y_prev2 + (1.0 - lambda_norm) * Y_prev2_raw
        

            std_Z       = np.sqrt(np.maximum(var_Z,       eps_floor))
            std_Y       = np.sqrt(np.maximum(var_Y,       eps_floor))
            std_Y_prev1 = np.sqrt(np.maximum(var_Y_prev1, eps_floor))
            std_Y_prev2 = np.sqrt(np.maximum(var_Y_prev2, eps_floor))
        

            Z_t     = (Z_raw       - mu_Z)       / std_Z
            Y_t_std = (Y_raw_t     - mu_Y_inv)   / std_Y_inv  
            Y_prev1 = (Y_prev1_raw - mu_Y_prev1) / std_Y_prev1
            Y_prev2 = (Y_prev2_raw - mu_Y_prev2) / std_Y_prev2
        

            mu_Y_hist[t, :]  = mu_Y_inv
            std_Y_hist[t, :] = std_Y_inv
        

            Y_t = Y_t_std.reshape(-1, 1)
        
        else:
            Y_t     = Y[t, :].reshape(-1, 1)
            Y_prev1 = Y[t - 1, :]
            Y_prev2 = Y[t - 2, :]
            Z_t     = Z_all[t - 1, :]

        
        X_t = np.concatenate([Y_prev1, Y_prev2, Z_t], axis=0)
        # X_t = np.concatenate([Y_prev1, Z_t], axis=0)
        H_t = np.zeros((mY, p))
        
        for j in range(mY):
            idx = slice(j * m, (j + 1) * m)
            H_t[j, idx] = X_t

        try:
            if idx_0 is not None:
                H_t[:, idx_0] = 0.0
                theta[idx_0, 0] = 0.0
        except Exception:
            pass

        P_pred     = P / lambda_ + Q
        theta_pred = theta
        
        R_inv = np.linalg.inv(R)
        P_inv = np.linalg.inv(P_pred)
        
        A = H_t.T @ R_inv @ H_t + c1 * P_inv + c2 * D
        b = H_t.T @ R_inv @ Y_t + c1 * P_inv @ theta_pred
        
        theta = np.linalg.solve(A, b)   
        P     = np.linalg.inv(A)     
        

        e_t = Y_t - H_t @ theta_pred
        delta_theta = theta - theta_pred
        
        R = lambda_R * R + (1.0 - lambda_R) * (e_t @ e_t.T)
        Q = lambda_Q * Q + (1.0 - lambda_Q) * (delta_theta @ delta_theta.T)
        
        theta_est[t, :]  = theta.ravel()
        Y_pred_std[t, :] = (H_t @ theta_pred).ravel()
        P_hist[:, :, t]  = P
        R_hist[:, :, t]  = R
        Q_hist[:, :, t]  = Q
        e_std[t, :]      = e_t.ravel()
    
        if standardize:
            yhat_raw        = Y_pred_std[t, :] * std_Y_inv + mu_Y_inv
            Y_pred_raw[t, :] = yhat_raw
            e_raw[t, :]      = (Y[t, :] - yhat_raw)
        else:
            Y_pred_raw[t, :] = Y_pred_std[t, :]
            e_raw[t, :]      = (Y[t, :] - Y_pred_raw[t, :])


    try:
        if weights is None:
            weights = np.ones((n, 1))
        w = np.asarray(weights).reshape(-1, 1)
        w = w / np.nansum(w)
        std_ew_hist = np.sqrt(np.nansum((e_std ** 2) * w, axis=0))
        denom = std_ew_hist_IN
        if np.isscalar(denom):
            denom = float(denom)
        merr = np.nanmean(std_ew_hist / denom)
    except Exception:
        merr = np.nan

    return merr, e_std, e_raw, theta_est, Y_pred_std, Y_pred_raw, P_hist, R_hist, Q_hist

def kalman_multilag_filter_em(
    Y: np.ndarray,
    Z_all: np.ndarray,
    Q0: np.ndarray,
    R0: np.ndarray,
    P0: np.ndarray,
    lambda_: float,
    lambda_R: float,
    lambda_Q: float,
    max_iter: int = 40,      
    tol: float = 1e-4,     
    lags: int = 2,
    standardize: bool = True,
    lambda_norm: float | None = None,
    eps: float = 1e-8,
    alpha: float = 0.9,
    beta: float = 0.8

):

    Q = Q0.copy()
    R = R0.copy()
    P_init = P0.copy()

    merr_hist = []

    for outer in range(max_iter):

        theta = np.zeros((lags*Y.shape[1] + Z_all.shape[1], 1))
        n, mY = Y.shape
        p = (lags * mY + Z_all.shape[1]) * mY

        merr, e_std, e_raw, theta_est, Y_pred_std, Y_pred_raw, P_hist, R_hist, Q_hist = \
            kalman_multilag_filter(
                Y, Z_all, Q, R, P_init, lambda_, lambda_R, lambda_Q,
                lags=lags, standardize=standardize, lambda_norm=lambda_norm, eps=eps
            )

        merr_hist.append(merr)

        R_new = np.cov(e_std[~np.isnan(e_std)].T) + 1e-6*np.eye(mY)

        dtheta = np.diff(theta_est, axis=0)
        if dtheta.shape[0] > 0:
            Q_new = np.cov(dtheta.T) + 1e-6*np.eye(dtheta.shape[1])
        else:
            Q_new = Q.copy()

        R = alpha * R + (1 - alpha) * R_new
        Q = beta * Q + (1 - beta) * Q_new

        if outer > 0 and abs(merr_hist[-1] - merr_hist[-2]) < tol:
            print(f"[KF-EM] Converged at iter {outer}, merr={merr_hist[-1]:.6f}")
            break

    return merr, e_std, e_raw, theta_est, Y_pred_std, Y_pred_raw, P_hist, R_hist, Q_hist, Q, R

def varx_rolling_predict(Y: np.ndarray,
                         Z: np.ndarray,
                         lags: int = 2,
                         window: int | None = None,
                         ridge: float = 1e-6):
    n, mY = Y.shape
    mX = Z.shape[1]
    m = lags * mY + mX
    Y_pred = np.full((n, mY), np.nan)
    e_raw  = np.full((n, mY), np.nan)

    for t in range(lags + 1, n - 1):
        if window is None:
            s0 = lags+1
        else:
            s0 = max(lags+1, t - 1 - window + 1)
        T = (t - 1) - s0 + 1
        if T <= 0:
            continue

        X_train = np.zeros((T, m))
        Y_train = np.zeros((T, mY))
        for k, s in enumerate(range(s0, t)):
            regY = np.concatenate([Y[s - i, :] for i in range(1, lags + 1)], axis=0)
            X_train[k, :] = np.concatenate([regY, Z[s-1, :]], axis=0)
            Y_train[k, :] = Y[s, :]

        A = X_train.T @ X_train + ridge * np.eye(m)
        B = X_train.T @ Y_train
        coef = np.linalg.solve(A, B)

        regY_next = np.concatenate([Y[t + 1 - i, :] for i in range(1, lags + 1)], axis=0)
        x_next = np.concatenate([regY_next, Z[t, :]], axis=0)

        yhat = x_next @ coef
        Y_pred[t + 1, :] = yhat
        e_raw[t + 1, :] = Y[t + 1, :] - yhat

    rmse = np.sqrt(np.nanmean(e_raw**2, axis=1))
    return Y_pred, e_raw, rmse


# ---------- main script ----------
c1 = 1.02   # temporal smoothing weight
c2 = 0.6   # shrinkage weight

if __name__ == "__main__":
    fname = "/Users/poppy/iCloud Drive (Archive)/Desktop/GVAR/df_country_data_climate.xlsx"
    xls = pd.ExcelFile(fname)
    sheets = xls.sheet_names

    sheets_9_of_12 = [
        "India", "Brazil", "Chile", "Indonesia", "Mexico", "Peru",
        "Philippines", "South Africa", "Thailand"
    ]

    
    ttls = ['y', 'Dp', 'eq', 'ep', 'r', 'ys', 'Dps', 'eqs', 'rs', 'lrs', 'ensos']
    ttls_Y = [ttls[i] for i in range(0, 5)]
    ttls_X = [ttls[i] for i in range(5, 11)]
    lags = 2
    mY0 = len(ttls_Y)
    mX0 = len(ttls_X)
    m0 = lags * mY0 + mX0
    n_country = len(sheets_9_of_12)
    E_hists = np.full((mY0, m0, n_country), np.nan)
    QRF_results = np.full((n_country, 5), np.nan)  # [Q*, R_a, R_c, P_a, P_c]

    # # parameters
    # lambda_ = 0.99
    # lambda_R = 0.9
    # lambda_Q = 0.9
    # lambda_e = 1 - 1 / 20

    lambda_ = 0.99
    lambda_R = 0.85
    lambda_Q = 0.85
    lambda_e = 1 - 1 / 20

    for k_country, country in enumerate(sheets_9_of_12):
        print(f"***** Country {country} *****")
        T = pd.read_excel(fname, sheet_name=country)
        T_col = list(T.columns)
        mT = T.shape[1] - 1

        ttl_Y, idx_TY, idx_Y = intersect_stable(ttls_Y, T_col)
        ttl_X, idx_TX, idx_X = intersect_stable(ttls_X, T_col)
        mY = len(idx_Y)
        mX = len(idx_X)

        idx_TYX = np.concatenate([idx_TY, mY + idx_TY, mY * 2 + idx_TX])
        # idx_TYX = np.concatenate([idx_TY, mY + idx_TX])


        Y = T.iloc[:, idx_Y].to_numpy()
        X = T.iloc[:, idx_X].to_numpy()
        nY = Y.shape[0]
        
        Y_df = T.iloc[:, idx_Y].copy()
        X_df = T.iloc[:, idx_X].copy()

        Yd = np.diff(Y, axis=0)
        Xd = np.diff(X, axis=0)

        n = Xd.shape[0]
        m = (lags * mY + mX)
        p = m * mY

        weights = np.exp(-(1 - lambda_e) * np.arange(n, 0, -1)).reshape(-1, 1)
        weights = weights / np.sum(weights)

        print(f"\n=== {country}: EM-KF optimization ===")

        D  = np.eye(p)  
        Q0 = (0.10) * np.eye(p)   # 初始猜测，可以小一点
        R0 = (0.10) * np.eye(mY)
        P0 = (1.0) * np.eye(p)
        
        merr, e_std, e_raw, theta_est, Y_pred_std, Y_pred_raw, \
        P_hist, R_hist, Q_hist, Q_final, R_final = kalman_multilag_filter_em(
            Yd, Xd, Q0, R0, P0,
            lambda_=lambda_, lambda_R=lambda_R, lambda_Q=lambda_Q,
            max_iter=5, tol=1e-4, lags=lags,
            standardize=True, lambda_norm=lambda_R,
            alpha=0.9, beta=0.8
        )

        
        print(f"Final EM-KF converged: merr={merr:.6f}")
        print(f"Q trace={np.trace(Q_final)/p:.6f}, R trace={np.trace(R_final)/mY:.6f}")

        # std_ew_hist = np.sqrt(np.nansum((e_std ** 2) * (np.ones((n_tmp, 1)) / n_tmp), axis=0))

        ml = np.isfinite(e_std).any(axis=1)
        w_eff = weights[ml] / np.sum(weights[ml])
        std_ew_hist = np.sqrt(np.sum((e_std[ml] ** 2) * w_eff, axis=0))

        
        dropout0s = np.ones((0, p))
        for i_Y in range(mY):
            for j_YX in range(m):
                idx_0 = i_Y * m + j_YX
                dropout0 = np.ones((p,))
                dropout0[idx_0] = 0.0
                dropout0s = np.vstack([dropout0s, dropout0])

                merr, _, _, _, _, _, _, _, _ = kalman_multilag_filter(
                    Yd, Xd, Q_final, R_final, P0, lambda_, lambda_R, lambda_Q,
                    weights=weights, std_ew_hist_IN=std_ew_hist, dropout0=dropout0,
                    standardize=True, lambda_norm=lambda_R, lags=lags
                )
                E_hists[idx_TY[i_Y], idx_TYX[j_YX], k_country] = merr

        # Per-country heatmap + histogram
        E_hists_k = E_hists[np.ix_(idx_TY, idx_TYX, [k_country])][:, :, 0]
        fig, axs = plt.subplots(2, 1, figsize=(12, 6))
        # im = axs[0].imshow(E_hists_k, aspect='auto')
        im = axs[0].imshow(E_hists_k, aspect='auto', vmin=0.95, vmax=1.05, cmap="coolwarm")

        axs[0].set_yticks(np.arange(len(ttl_Y)))
        axs[0].set_yticklabels(ttl_Y)
        labels_x = []
        for j in range(len(idx_TYX)):
            if j < len(idx_TY):
                labels_x.append(f"L1_{ttl_Y[j]}")
            elif j < 2 * len(idx_TY):
                labels_x.append(f"L2_{ttl_Y[j - len(idx_TY)]}")
            else:
                labels_x.append(ttl_X[j - 2 * len(idx_TY)])
        axs[0].set_xticks(np.arange(len(idx_TYX)))
        axs[0].set_xticklabels(labels_x, rotation=45, ha="right")
        axs[0].set_title(f"{country}: mean errors (drop-one coeff)")
        fig.colorbar(im, ax=axs[0], fraction=0.046, pad=0.04)

        vals = E_hists_k.ravel()
        axs[1].hist(vals[~np.isnan(vals)], bins=20, edgecolor="black")
        axs[1].set_title(f"{country}: histogram of errors")
        axs[1].set_xlabel("Error value")
        axs[1].set_ylabel("Frequency")
        plt.tight_layout()
        plt.show()

        # === All coefficients: time trajectories (KF only) ===
        E_k = E_hists[np.ix_(idx_TY, idx_TYX, [k_country])][:, :, 0]
        finite_mask = np.isfinite(E_k)
        if not np.any(finite_mask):
            print("No finite entries in E_k; skip all-coeff plot.")
        else:
            n_cols = min(8, m)              # 一行最多 8 个
            n_rows = int(np.ceil(m / n_cols))
            fig, axes = plt.subplots(mY * n_rows, n_cols,
                                     figsize=(3*n_cols, 2*mY*n_rows),
                                     sharex=True)
            axes = np.array(axes).reshape(mY * n_rows, n_cols)
        
            for i_loc in range(mY):
                for j_loc in range(m):
                    coef_idx = i_loc * m + j_loc
                    row = i_loc * n_rows + j_loc // n_cols
                    col = j_loc % n_cols
                    ax = axes[row, col]
        
                    series = theta_est[:, coef_idx].copy()
                    if series.shape[0] >= 2:
                        series[:2] = np.nan  # ignore warm-up steps
                    ax.plot(series, lw=0.8)
        
                    # 固定纵轴范围
                    ax.set_ylim(-0.4, 0.4)
        
                    # 标题
                    if j_loc < mY:
                        term = f"L1_{ttl_Y[j_loc]}"
                    elif j_loc < 2 * mY:
                        term = f"L2_{ttl_Y[j_loc - mY]}"
                    else:
                        term = ttl_X[j_loc - 2 * mY] if (j_loc - 2 * mY) < len(ttl_X) else f"X{j_loc - 2*mY}"
                    ax.set_title(f"Eq:{ttl_Y[i_loc]} | {term}", fontsize=8)
                    ax.grid(True, alpha=0.3)
        
            fig.suptitle(f"{country} — All coefficient trajectories (Kalman)", y=0.98)
            plt.tight_layout()
            plt.show()
        

        Q_tr = np.array([np.trace(Q_hist[:, :, i]) / p for i in range(2,n_tmp)])
        e_raw = Yd - Y_pred_raw
        rmse_norm = np.sqrt(np.nanmean(e_std**2, axis=1))
        rmse_raw  = np.sqrt(np.nanmean(e_raw**2, axis=1))
        # ---- Null model: Yhat(t) = Yd(t-1) ----
        Yhat_null = np.full_like(Yd, np.nan)
        Yhat_null[1:, :] = Yd[:-1, :]
        E_null = Yd - Yhat_null
        rmse_null = np.sqrt(np.nanmean(E_null**2, axis=1))
        
        plt.figure(954, figsize=(10, 8))
        plt.clf()
        plt.subplot(2, 2, 1); plt.plot(np.sqrt(Q_tr)); plt.title('Q (rms of coeff update noise)')
        plt.subplot(2, 2, 2); plt.plot(np.sqrt(R_tr)); plt.title('R (rms of obs noise)')
        plt.subplot(2, 2, 3); plt.plot(np.sqrt(P_tr)); plt.title('P (sqrt mean coeff var)')
        ax = plt.subplot(2, 2, 4)
        # ax.plot(rmse_norm, label='normalized')
        ax.plot(rmse_raw,  label='raw-units', alpha=0.8)
        ############################################################turn on
        
        Yhat_varx_raw, E_varx_raw, rmse_varx_raw = varx_rolling_predict(Yd, Xd, lags=lags)
        ax.plot(rmse_varx_raw, label='VARX raw', linestyle='--', alpha=0.9)
        ax.plot(rmse_null, label='rmse_null', linestyle='-', alpha=0.7)
        ax.set_title('Y−Yhat (rms of prediction error)')
        ax.legend()
        plt.tight_layout()
        plt.show()

        def weighted_corr(x, y, w):
            x = np.asarray(x).ravel()
            y = np.asarray(y).ravel()
            w = np.asarray(w).ravel()
        
            n = min(len(x), len(y), len(w))
            x, y, w = x[:n], y[:n], w[:n]
        
            m = ~np.isnan(x) & ~np.isnan(y) & ~np.isnan(w)
            if m.sum() < 3:
                return np.nan
        
            x, y, w = x[m], y[m], w[m]
            w = w / np.sum(w)
        
            mx = np.sum(w * x)
            my = np.sum(w * y)
        
            cov_xy = np.sum(w * (x - mx) * (y - my))
            var_x  = np.sum(w * (x - mx) ** 2)
            var_y  = np.sum(w * (y - my) ** 2)
        
            if var_x <= 0 or var_y <= 0:
                return np.nan
        
            return cov_xy / np.sqrt(var_x * var_y)
        
        def _safe_corr(a, b):
            a = np.asarray(a).ravel()
            b = np.asarray(b).ravel()
            m = ~np.isnan(a) & ~np.isnan(b)
            if m.sum() < 3:
                return np.nan
            return np.corrcoef(a[m], b[m])[0, 1]

        def plot_pred_corr(Yd, Y_pred, Xd, ttls_Y, ttls_X, country, method):
            # 统一长度 n-1
            Y_true = Yd[1:, :]
            Yhat   = Y_pred[1:, :]
            Y_lag1 = Yd[:-1, :]
            Z_lag1 = Xd[:-1, :]
        
            # 权重严格对齐
            w_use = w_eff[-Y_true.shape[0]:]
        
            mY = Y_true.shape[1]
            fig = plt.figure(figsize=(16, 6))
            gs = fig.add_gridspec(2, mY, height_ratios=[1.2, 1.0], hspace=0.35, wspace=0.25)
        
            for j in range(mY):
                ax_ts = fig.add_subplot(gs[0, j])
                ax_ts.plot(Y_true[:, j], color='tab:blue', lw=1.0, label='data')
                ax_ts.plot(Yhat[:, j],  color='tab:orange', lw=1.0, ls='--', label='prediction')
                ax_ts.set_title(ttls_Y[j], fontsize=11)
                if j == 0: ax_ts.legend(loc='upper right', fontsize=9)
        
                ax_sc = fig.add_subplot(gs[1, j])
                ax_sc.scatter(Y_true[:, j], Yhat[:, j], s=10, color='tab:purple', alpha=0.6)
                lim = np.nanmax(np.abs(np.concatenate([Y_true[:, j], Yhat[:, j]])))
                if np.isfinite(lim) and lim > 0:
                    ax_sc.plot([-lim, lim], [-lim, lim], color='gray', lw=1, ls=':')
                    ax_sc.set_xlim(-lim, lim); ax_sc.set_ylim(-lim, lim)
                ax_sc.set_xlabel('data', fontsize=9)
                ax_sc.set_ylabel('prediction', fontsize=9)
        
                r_pred = weighted_corr(Y_true[:, j], Yhat[:, j], w_use)
                r_lag  = weighted_corr(Y_lag1[:, j], Y_true[:, j], w_use)
                r_exog_all = np.array([weighted_corr(Z_lag1[:, k], Y_true[:, j], w_use) 
                                       for k in range(Z_lag1.shape[1])])
        
                if r_exog_all.size > 0 and np.any(np.isfinite(r_exog_all)):
                    k_best = int(np.nanargmax(np.abs(r_exog_all)))
                    r_exog = r_exog_all[k_best]
                    exog_name = ttls_X[k_best] if k_best < len(ttls_X) else f"X{k_best}"
                else:
                    r_exog, exog_name = np.nan, "N/A"
        
                txt = (
                    f"corr(Y, Yhat) = {r_pred: .3f}\n"
                    f"corr(Y[-1], Y) = {r_lag: .3f}\n"
                    f"max corr(Z[-1], Y) = {r_exog: .3f} ({exog_name})"
                )
                ax_sc.text(0.02, -0.25, txt, transform=ax_sc.transAxes, fontsize=9, va='top')
        
            fig.suptitle(f"{country} — {method}", fontsize=14, y=0.98)
            plt.show()


        # Kalman
        plot_pred_corr(Yd, Y_pred_raw, Xd, ttls_Y, ttls_X, country, "Kalman")
        
        # VARX
        Yhat_varx_raw, E_varx_raw, rmse_varx_raw = varx_rolling_predict(Yd, Xd, lags=lags)
        plot_pred_corr(Yd, Yhat_varx_raw, Xd, ttls_Y, ttls_X, country, "VARX")
        
        # Naive lag-1
        Yhat_null = np.full_like(Yd, np.nan)
        Yhat_null[1:, :] = Yd[:-1, :]
        plot_pred_corr(Yd, Yhat_null, Xd, ttls_Y, ttls_X, country, "Naive lag-1")          
                