# **1 Global Set-up**

## **1.1 Define All Variables and folder path**

In [None]:
target_variable = {
    "xco2": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/XCO2_resample/global_grid_0.1_2019_2025_xco2.nc",
    "emission": "/data3/interns/NRT_CO2_Emission_Map_Project/ML_XCO2/CarbonMonitor0Power_emission_201901_202505.nc"  
}
feature_variables = {
    "t2m": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/ERA5_resample/t2m_daily_0p1deg.nc",
    "d2m": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/ERA5_resample/d2m_daily_0p1deg.nc",
    "u10": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/ERA5_resample/u10_daily_0p1deg.nc",
    "v10": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/ERA5_resample/v10_daily_0p1deg.nc",
    "msl": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/ERA5_resample/msl_daily_0p1deg.nc",
    "sp": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/ERA5_resample/sp_daily_0p1deg.nc",
    "skt": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/ERA5_resample/skt_daily_0p1deg.nc",
    "tp": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/ERA5_resample/tp_daily_0p1deg.nc",
    "e": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/ERA5_resample/e_daily_0p1deg.nc",
    "ssr": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/ERA5_resample/ssr_daily_0p1deg.nc",
    "str": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/ERA5_resample/str_daily_0p1deg.nc",
    "tcw": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/ERA5_resample/tcw_daily_0p1deg.nc",
    "blh": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/ERA5_resample/blh_daily_0p1deg.nc",
    
    "NO2": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/XCO2_resample/global_grid_0.1_2019_2025_NO2.nc",
    "is_weekend": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/XCO2_resample/global_grid_0.1_2019_2025_weekday_weekend.nc",
    "population": "/data3/interns/NRT_CO2_Emission_Map_Project/MingjuanZhang_work/Population_global_0.1degree_2019_2025_ns.nc",
    "elevation": "/data3/interns/NRT_CO2_Emission_Map_Project/MingjuanZhang_work/SRTM_elevation_global_0.1degree_2019_2025_ns.nc",
    "landuse": "/data3/interns/NRT_CO2_Emission_Map_Project/MingjuanZhang_work/Landuse_global_0.1degree_2019_2025_ns.nc",
    "aspect": "/data3/interns/NRT_CO2_Emission_Map_Project/MingjuanZhang_work/SRTM_aspect_global_0.1degree_2019_2025_ns.nc",
    "ndvi": "/data3/interns/NRT_CO2_Emission_Map_Project/MingjuanZhang_work/NDVI_global_0.1degree_2019_2025_ns.nc",
    "gpp": "/data3/interns/NRT_CO2_Emission_Map_Project/MingjuanZhang_work/GPP_global_0.1degree_2019_2025_ns.nc",
    "lai": "/data3/interns/NRT_CO2_Emission_Map_Project/MingjuanZhang_work/LAI_global_0.1degree_2019_2025_ns.nc",
    "ntl": "/data3/interns/NRT_CO2_Emission_Map_Project/MingjuanZhang_work/VIIRS_NTL_global_0.1degree_2019_2025_ns.nc",
    "evi": "/data3/interns/NRT_CO2_Emission_Map_Project/MingjuanZhang_work/EVI_global_0.1degree_2019_2025_ns.nc",
    "slope": "/data3/interns/NRT_CO2_Emission_Map_Project/MingjuanZhang_work/SRTM_slope_global_0.1degree_2019_2025_ns.nc",
    "odiac": "/data3/interns/NRT_CO2_Emission_Map_Project/HaoHu_work/odiac_interp_2019_2025.nc",
    "CO2_fire": "/data3/interns/NRT_CO2_Emission_Map_Project/PinyiLu_work/GFAS_resample/GFAS_resample_final.nc",

}


## **1.2 Load all modules**

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from sklearn.preprocessing import LabelEncoder
import seaborn as sns
from sklearn.model_selection import train_test_split
import os
from sklearn.model_selection import GroupKFold, GridSearchCV
from xgboost import XGBRegressor, callback 
from sklearn.base import clone
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error  

import matplotlib as mpl

import shap

import matplotlib.collections as mcoll

# **2 Model Training and Validation**

## **2.1 Load Data**

In [None]:
data = np.load("../xco2_nonnan_processed.npy", allow_pickle=True)

df = pd.DataFrame(data)

df

## **2.2 Split spacetime block to training and test**

In [None]:
df["lat_bin"] = ((df["lat"] + 90) // 10).astype(int)
df["lon_bin"] = ((df["lon"] + 180) // 10).astype(int)

df["time_bin"] = (
    pd.to_datetime(
        df[["year", "month"]].assign(day=1)
    )
    .dt.to_period("M")               
    .astype(str)     
)

df["spacetime_block"] = (
      df["lat_bin"].astype(str) + "_"+ df["lon_bin"].astype(str) + "_"+ df["time_bin"]
      )

blocks = df["spacetime_block"].unique()
train_blocks, test_blocks = train_test_split(
    blocks, test_size=0.20, random_state=42
)

df["split"] = df["spacetime_block"].isin(test_blocks)\
               .map({True: "test", False: "trainval"})

df_trainval = df[df["split"] == "trainval"].copy()
df_test     = df[df["split"] == "test"].copy()

print("Total number of blocks: ", len(blocks))

## **2.3 Grid-search best models**

### 2.3.1 Define grid params

In [None]:
param_grid = {
    'n_estimators':         [300],
    'max_depth':            [10],
    'learning_rate':        [0.1],
    'gamma' :               [0.1],
    'min_child_weight':     [2],
    'subsample':            [0.9],
    'colsample_bytree':     [0.9],
    'reg_alpha':            [0.01],
    'reg_lambda':           [0.01],
}

params = {k: (v[0] if isinstance(v, (list, tuple, np.ndarray)) else v)
          for k, v in param_grid.items()}

### 2.3.2 Prepare model

In [None]:
features = [col for col in df.columns if col not in ["lat", "lon", "xco2", "time", "split", "time_bin", "lat_bin", "lon_bin", "spacetime_block","month", "emission"]]
print(features)
X = df_trainval[features]
y = df_trainval["xco2"]
groups = df_trainval["spacetime_block"]

gkf = GroupKFold(n_splits=5)
xgb_model = XGBRegressor(
    objective="reg:squarederror",
    random_state=42,
    tree_method="hist",
    device="cuda", 
    **params 
)

best_model = xgb_model


### 2.3.3 Cross-validation for trainning dataset

In [None]:
scoring = {
    "RMSE": "neg_root_mean_squared_error",
    "R2":   "r2"
}

splits = list(gkf.split(X, y, groups=groups))

grid_search = GridSearchCV(
    estimator=xgb_model,
    param_grid=param_grid,
    cv=splits,
    scoring=scoring,
    refit="RMSE",    
    return_train_score=True,
    verbose=2,
    n_jobs=1
)
X = X.apply(pd.to_numeric, errors='coerce')
y = y.apply(pd.to_numeric, errors='coerce')
grid_search.fit(X, y)

best_model = grid_search.best_estimator_

best_idx      = grid_search.best_index_

mean_test_rmse = -grid_search.cv_results_["mean_test_RMSE"][best_idx]
mean_test_r2   =     grid_search.cv_results_["mean_test_R2"][best_idx]

std_test_rmse  =  grid_search.cv_results_["std_test_RMSE"][best_idx]
std_test_r2    =  grid_search.cv_results_["std_test_R2"][best_idx]

print(f"✅ The Best params：{grid_search.best_params_}")
print(f"✅ CV mean RMSE：{mean_test_rmse:.4f}，sd：{std_test_rmse:.4f}")
print(f"✅ CV mean R²：{mean_test_r2:.4f}，sd：{std_test_r2:.4f}")

### 2.3.4 Prepare 5-CV dataset saving -> oof_df

In [None]:
base_est = best_model 

n = len(y)
oof_pred = np.full(n, np.nan, dtype=float)
fold_ids = np.full(n, -1, dtype=int)

splits = list(gkf.split(X, y, groups=groups))

for k, (tr_idx, va_idx) in enumerate(splits):
    est = clone(base_est)
    est.fit(X.iloc[tr_idx], y.iloc[tr_idx])
    pred = est.predict(X.iloc[va_idx])
    oof_pred[va_idx] = pred
    fold_ids[va_idx] = k

keep_cols = [c for c in ["lat", "lon"] if c in df.columns]

oof_df = pd.DataFrame({
    "y_true": y.values,
    "y_pred": oof_pred,
    "fold":  fold_ids
}, index=y.index).join(df.loc[y.index, keep_cols])

## **2.4 Train the final model**

### 2.4.1 Using all the trainval data

In [None]:
best_model.fit(X, y)

# best_model.save_model(f"Trained_xgb_model_full/xgb_model_full_random.json") 

### 2.4.2 Plot 5-cv scattering plots (train and test)

In [None]:
y_train=oof_df["y_true"]
y_pred_train=oof_df["y_pred"]
X_test = df_test[features]
y_test = df_test["xco2"]

y_pred_test = best_model.predict(X_test)

rmse_test = np.sqrt(mean_squared_error(y_test, y_pred_test))
r2_test   = r2_score(y_test, y_pred_test)
rmse_train = np.sqrt(mean_squared_error(y_train, y_pred_train))
r2_train   = r2_score(y_train, y_pred_train)
mae_test   = mean_absolute_error(y_test,  y_pred_test)  
mb_test    = np.mean(y_pred_test - y_test)

mae_train  = mean_absolute_error(y_train, y_pred_train)       
mb_train   = np.mean(y_pred_train - y_train)

# ===== Plot Figures=====
mpl.rcParams.update({
    "font.size": 11, 
    "axes.linewidth": 1.0,
    "xtick.direction": "in",
    "ytick.direction": "in",
    "xtick.major.size": 4,
    "ytick.major.size": 4,
})

vmin = np.nanmin([y_train.min(), y_test.min(), y_pred_train.min(), y_pred_test.min()])
vmax = np.nanmax([y_train.max(), y_test.max(), y_pred_train.max(), y_pred_test.max()])
pad  = 0.02 * (vmax - vmin)
lims = (vmin - pad, vmax + pad)

def panel(ax, y_true, y_pred, tag, rmse, r2, mae, mb,  gridsize=200, cmap="RdYlBu_r"): 
    hb = ax.hexbin(
        y_true, y_pred,
        gridsize=gridsize, extent=[*lims, *lims], cmap=cmap,
        mincnt=1, bins='log' 
    )
    ax.plot(lims, lims, ls="--", lw=1.2, color="k")
    ax.set_xlim(lims); ax.set_ylim(lims)
    ax.set_aspect('equal', adjustable='box')
    ax.set_xlabel("Observed XCO$_2$ (ppm)")
    ax.set_ylabel("Predicted XCO$_2$ (ppm)")
    ax.text(0.02, 0.98, f"({tag})", transform=ax.transAxes, va="top", ha="left", fontsize=13, fontweight="bold")
    ax.text(0.4, 0.3,
            f"RMSE = {rmse:.2f} ppm\n"
            f"MAE = {mae:.2f} ppm\n"
            f"MB = {mb:+.2f} ppm\n"
            r"$R^2$ = " + f"{r2:.2f}",
            transform=ax.transAxes, va="top", ha="left",
            bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="0.6", lw=0.8))
    return hb

fig, axes = plt.subplots(1, 2, figsize=(8, 3.8), sharex=True, sharey=True)
hb1 = panel(axes[0], y_train, y_pred_train, "a", rmse_train, r2_train, mae_train, mb_train)
hb2 = panel(axes[1], y_test,  y_pred_test,  "b", rmse_test,  r2_test,  mae_test,  mb_test)

fig.tight_layout()
cbar = fig.colorbar(hb2, ax=axes.ravel().tolist(), shrink=0.9, pad=0.02)
cbar.set_label("Point density (log scale)")
plt.show()

In [None]:
std_y_test = np.std(y_test, ddof=1)
print("STD(y_test) =", std_y_test)

### **2.4.3 Plot seasonal (spring, summer, fall, winter) (Train)**

In [None]:
y_test=oof_df["y_true"]
y_pred_test=oof_df["y_pred"]

seasons = {
    "DJF (Dec–Feb)": [12, 1, 2],
    "MAM (Mar–May)": [3, 4, 5],
    "JJA (Jun–Aug)": [6, 7, 8],
    "SON (Sep–Nov)": [9,10,11],
}
tags = ["a", "b", "c", "d"]

season_preds = {}
for name, months in seasons.items():
    mask = df_trainval["month"].isin(months)
    if not np.any(mask):
        season_preds[name] = None
        continue
    X_s = df_trainval.loc[mask, features]
    y_s = df_trainval.loc[mask, "xco2"].values
    yhat_s = best_model.predict(X_s)
    season_preds[name] = (y_s, yhat_s)

all_true = np.concatenate([v[0] for v in season_preds.values() if v is not None])
all_pred = np.concatenate([v[1] for v in season_preds.values() if v is not None])
vmin = np.nanmin([all_true.min(), all_pred.min()])
vmax = np.nanmax([all_true.max(), all_pred.max()])
pad  = 0.02 * (vmax - vmin)
lims = (vmin - pad, vmax + pad)


mpl.rcParams.update({
    "font.size": 11,
    "axes.linewidth": 1.0,
    "xtick.direction": "in",
    "ytick.direction": "in",
    "xtick.major.size": 4,
    "ytick.major.size": 4,
})

def panel(ax, y_true, y_pred, tag, title, rmse, mae, mb, r2,
          gridsize=200, cmap="RdYlBu_r"):
    hb = ax.hexbin(
        y_true, y_pred,
        gridsize=gridsize, extent=[*lims, *lims], cmap=cmap,
        mincnt=1, bins='log'
    )
    ax.plot(lims, lims, ls="--", lw=1.2, color="k")
    ax.set_xlim(lims); ax.set_ylim(lims)
    ax.set_aspect('equal', adjustable='box')
    ax.set_xlabel("Observed XCO$_2$ (ppm)")
    ax.set_ylabel("Predicted XCO$_2$ (ppm)")
    ax.text(0.02, 0.98, f"({tag})", transform=ax.transAxes,
            va="top", ha="left", fontsize=13, fontweight="bold")
    ax.text(0.40, 0.3,
            f"RMSE = {rmse:.2f} ppm\nMAE = {mae:.2f} ppm\nMB = {mb:+.2f} ppm\n$R^2$ = {r2:.2f}",
            transform=ax.transAxes, va="top", ha="left",
            bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="0.6", lw=0.8))
    return hb

fig, axes = plt.subplots(2, 2, figsize=(7.0, 6.5), sharex=True, sharey=True)

for ax, (tag, (name, _)) in zip(axes.ravel(), zip(tags, seasons.items())):
    pair = season_preds[name]
    if pair is None:
        ax.text(0.5, 0.5, f"No data: {name}", ha="center", va="center", transform=ax.transAxes)
        ax.axis("off")
        continue

    y_s, yhat_s = pair
    rmse = np.sqrt(mean_squared_error(y_s, yhat_s))
    mae  = mean_absolute_error(y_s, yhat_s)
    mb   = np.mean(yhat_s - y_s)
    r2   = r2_score(y_s, yhat_s)

    panel(ax, y_s, yhat_s, tag, name, rmse, mae, mb, r2, gridsize=80, cmap="RdYlBu_r")

fig.tight_layout()
plt.show()

### **2.4.4 Plot seasonal (spring, summer, fall, winter) (Test)**

In [None]:
seasons = {
    "DJF (Dec–Feb)": [12, 1, 2],
    "MAM (Mar–May)": [3, 4, 5],
    "JJA (Jun–Aug)": [6, 7, 8],
    "SON (Sep–Nov)": [9,10,11],
}
tags = ["a", "b", "c", "d"]

season_preds = {}
for name, months in seasons.items():
    mask = df_test["month"].isin(months)
    if not np.any(mask):
        season_preds[name] = None
        continue
    X_s   = df_test.loc[mask, features]
    y_s   = df_test.loc[mask, "xco2"].values
    yhat_s = best_model.predict(X_s)
    season_preds[name] = (y_s, yhat_s)

all_true = np.concatenate([v[0] for v in season_preds.values() if v is not None])
all_pred = np.concatenate([v[1] for v in season_preds.values() if v is not None])
vmin = np.nanmin([all_true.min(), all_pred.min()])
vmax = np.nanmax([all_true.max(), all_pred.max()])
pad  = 0.02 * (vmax - vmin)
lims = (vmin - pad, vmax + pad)

mpl.rcParams.update({
    "font.size": 11,
    "axes.linewidth": 1.0,
    "xtick.direction": "in",
    "ytick.direction": "in",
    "xtick.major.size": 4,
    "ytick.major.size": 4,
})

def panel(ax, y_true, y_pred, tag, title, rmse, mae, mb, r2,
          gridsize=200, cmap="RdYlBu_r"):
    hb = ax.hexbin(
        y_true, y_pred,
        gridsize=gridsize, extent=[*lims, *lims], cmap=cmap,
        mincnt=1, bins='log'
    )
    ax.plot(lims, lims, ls="--", lw=1.2, color="k")
    ax.set_xlim(lims); ax.set_ylim(lims)
    ax.set_aspect('equal', adjustable='box')
    ax.set_xlabel("Observed XCO$_2$ (ppm)")
    ax.set_ylabel("Predicted XCO$_2$ (ppm)")
    ax.text(0.02, 0.98, f"({tag})", transform=ax.transAxes,
            va="top", ha="left", fontsize=13, fontweight="bold")
    ax.text(0.40, 0.30,
            f"RMSE = {rmse:.2f} ppm\nMAE = {mae:.2f} ppm\nMB = {mb:+.2f} ppm\n$R^2$ = {r2:.2f}",
            transform=ax.transAxes, va="top", ha="left",
            bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="0.6", lw=0.8))
    return hb

fig, axes = plt.subplots(2, 2, figsize=(7.0, 6.5), sharex=True, sharey=True)
for ax, (tag, (name, _)) in zip(axes.ravel(), zip(tags, seasons.items())):
    pair = season_preds[name]
    if pair is None:
        ax.text(0.5, 0.5, f"No data: {name}", ha="center", va="center", transform=ax.transAxes)
        ax.axis("off")
        continue

    y_s, yhat_s = pair
    rmse = np.sqrt(mean_squared_error(y_s, yhat_s))
    mae  = mean_absolute_error(y_s, yhat_s)
    mb   = np.mean(yhat_s - y_s)
    r2   = r2_score(y_s, yhat_s)

    panel(ax, y_s, yhat_s, tag, name, rmse, mae, mb, r2, gridsize=80, cmap="RdYlBu_r")

fig.tight_layout()
plt.show()

### **2.4.5 Plot lattitude-belt scattering**

In [None]:
data_df = df_trainval.copy()
required_cols = {"lon", "lat", "y_true", "y_pred"}
X_test = data_df[features]
y_test = data_df["xco2"]
y_pred_test = best_model.predict(X_test)
data_df["y_pred"]= y_pred_test 
data_df["y_true"]=data_df["xco2"]

missing = required_cols - set(data_df.columns)

lon = data_df["lon"].to_numpy()
if np.nanmax(lon) > 180:
    data_df = data_df.copy()
    data_df["lon"] = ((lon + 180) % 360) - 180

def _clean_xy(y_true, y_pred):
    y_true = np.asarray(y_true, dtype=float)
    y_pred = np.asarray(y_pred, dtype=float)
    m = np.isfinite(y_true) & np.isfinite(y_pred)
    return y_true[m], y_pred[m]

BIN_DEG = 20
bins = np.arange(-90, 90 + BIN_DEG, BIN_DEG)   
labels = [f"[{bins[i]},{bins[i+1]})" for i in range(len(bins)-1)]
labels[-1] = labels[-1].replace(")", "]")

lat_vals = data_df["lat"].to_numpy()
idx = np.digitize(lat_vals, bins, right=True) - 1  
idx = np.clip(idx, 0, len(labels)-1)

data_df = data_df.assign(
    lat_band_idx = idx,
    lat_band = pd.Categorical([labels[i] for i in idx],
                              categories=labels, ordered=True)
)

rows = []
for i, band in enumerate(labels):
    sub = data_df[data_df["lat_band_idx"] == i]
    yt, yp = _clean_xy(sub["y_true"].to_numpy(), sub["y_pred"].to_numpy())
    n = len(yt)
    if n == 0:
        rows.append({
            "lat_band": band, "n": 0,
            "RMSE": np.nan, "MAE": np.nan, "MB": np.nan, "R2": np.nan,
            "y_true_mean": np.nan, "y_pred_mean": np.nan
        })
        continue

    rmse = float(np.sqrt(mean_squared_error(yt, yp)))
    mae  = float(mean_absolute_error(yt, yp))
    mb   = float(np.nanmean(yp - yt))
    try:
        r2 = float(r2_score(yt, yp))
    except Exception:
        r2 = np.nan

    rows.append({
        "lat_band": band, "n": n,
        "RMSE": rmse, "MAE": mae, "MB": mb, "R2": r2,
        "y_true_mean": float(np.nanmean(yt)), "y_pred_mean": float(np.nanmean(yp))
    })

metrics_by_latband = pd.DataFrame(rows).set_index("lat_band")

yt_all, yp_all = _clean_xy(data_df["y_true"].to_numpy(), data_df["y_pred"].to_numpy())
if len(yt_all) > 0:
    overall = pd.DataFrame({
        "n": [len(yt_all)],
        "RMSE": [float(np.sqrt(mean_squared_error(yt_all, yp_all)))],
        "MAE": [float(mean_absolute_error(yt_all, yp_all))],
        "MB": [float(np.nanmean(yp_all - yt_all))],
        "R2": [float(r2_score(yt_all, yp_all)) if np.var(yt_all) > 0 else np.nan],
        "y_true_mean": [float(np.nanmean(yt_all))],
        "y_pred_mean": [float(np.nanmean(yp_all))]
    }, index=["All"])
    metrics_by_latband = pd.concat([metrics_by_latband, overall], axis=0)

order = ["n", "RMSE", "MAE", "MB", "R2", "y_true_mean", "y_pred_mean"]
metrics_by_latband = metrics_by_latband[order]
metrics_by_latband[["RMSE","MAE","MB","R2","y_true_mean","y_pred_mean"]] = \
    metrics_by_latband[["RMSE","MAE","MB","R2","y_true_mean","y_pred_mean"]].round(3)

print("every 10° latitude belt performance：")
metrics_by_latband

In [None]:
data_df = df_test.copy()
required_cols = {"lon", "lat", "y_true", "y_pred"}
X_test = data_df[features]
y_test = data_df["xco2"]
y_pred_test = best_model.predict(X_test)
data_df["y_pred"]= y_pred_test 
data_df["y_true"]=data_df["xco2"]

missing = required_cols - set(data_df.columns)

lon = data_df["lon"].to_numpy()
if np.nanmax(lon) > 180:
    data_df = data_df.copy()
    data_df["lon"] = ((lon + 180) % 360) - 180

def _clean_xy(y_true, y_pred):
    y_true = np.asarray(y_true, dtype=float)
    y_pred = np.asarray(y_pred, dtype=float)
    m = np.isfinite(y_true) & np.isfinite(y_pred)
    return y_true[m], y_pred[m]

BIN_DEG = 20
bins = np.arange(-90, 90 + BIN_DEG, BIN_DEG)   
labels = [f"[{bins[i]},{bins[i+1]})" for i in range(len(bins)-1)]
labels[-1] = labels[-1].replace(")", "]")

lat_vals = data_df["lat"].to_numpy()
idx = np.digitize(lat_vals, bins, right=True) - 1  
idx = np.clip(idx, 0, len(labels)-1)

data_df = data_df.assign(
    lat_band_idx = idx,
    lat_band = pd.Categorical([labels[i] for i in idx],
                              categories=labels, ordered=True)
)

rows = []
for i, band in enumerate(labels):
    sub = data_df[data_df["lat_band_idx"] == i]
    yt, yp = _clean_xy(sub["y_true"].to_numpy(), sub["y_pred"].to_numpy())
    n = len(yt)
    if n == 0:
        rows.append({
            "lat_band": band, "n": 0,
            "RMSE": np.nan, "MAE": np.nan, "MB": np.nan, "R2": np.nan,
            "y_true_mean": np.nan, "y_pred_mean": np.nan
        })
        continue

    rmse = float(np.sqrt(mean_squared_error(yt, yp)))
    mae  = float(mean_absolute_error(yt, yp))
    mb   = float(np.nanmean(yp - yt))
    try:
        r2 = float(r2_score(yt, yp))
    except Exception:
        r2 = np.nan

    rows.append({
        "lat_band": band, "n": n,
        "RMSE": rmse, "MAE": mae, "MB": mb, "R2": r2,
        "y_true_mean": float(np.nanmean(yt)), "y_pred_mean": float(np.nanmean(yp))
    })

metrics_by_latband = pd.DataFrame(rows).set_index("lat_band")

yt_all, yp_all = _clean_xy(data_df["y_true"].to_numpy(), data_df["y_pred"].to_numpy())
if len(yt_all) > 0:
    overall = pd.DataFrame({
        "n": [len(yt_all)],
        "RMSE": [float(np.sqrt(mean_squared_error(yt_all, yp_all)))],
        "MAE": [float(mean_absolute_error(yt_all, yp_all))],
        "MB": [float(np.nanmean(yp_all - yt_all))],
        "R2": [float(r2_score(yt_all, yp_all)) if np.var(yt_all) > 0 else np.nan],
        "y_true_mean": [float(np.nanmean(yt_all))],
        "y_pred_mean": [float(np.nanmean(yp_all))]
    }, index=["All"])
    metrics_by_latband = pd.concat([metrics_by_latband, overall], axis=0)

order = ["n", "RMSE", "MAE", "MB", "R2", "y_true_mean", "y_pred_mean"]
metrics_by_latband = metrics_by_latband[order]
metrics_by_latband[["RMSE","MAE","MB","R2","y_true_mean","y_pred_mean"]] = \
    metrics_by_latband[["RMSE","MAE","MB","R2","y_true_mean","y_pred_mean"]].round(3)

print("every 10° latitude belt performance：")
metrics_by_latband

# **3. SHAP Plotting**

In [None]:

X_sample = X.sample(n=10000, random_state=42)

explainer = shap.Explainer(best_model)

shap_values = explainer(X_sample) 

vals = np.abs(shap_values.values) 
mean_abs = vals.mean(axis=0) 
total = mean_abs.sum()
pct = 100.0 * mean_abs / total 

imp_df = pd.DataFrame({
    "feature": X_sample.columns,
    "mean_abs_shap": mean_abs,
    "percent": pct
}).sort_values("percent", ascending=False).reset_index(drop=True)

In [None]:
mpl.rcParams.update({
    "font.family": "sans-serif",
    "font.sans-serif": ["Arial", "Helvetica", "DejaVu Sans"],
    "font.size": 10,
    "axes.linewidth": 0.8,
    "axes.labelsize": 10,
    "axes.titlesize": 10,
    "xtick.direction": "in",
    "ytick.direction": "in",
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,    
    "xtick.major.size": 4,
    "ytick.major.size": 4,
    "legend.frameon": False,
    "figure.titlesize": 10,
})

colmap = {
    "t2m": "T2",
    "d2m": "DP2",
    "u10": "U10",
    "v10": "V10",
    "msl": "MSL",
    "sp":  "Psfc",
    "skt": "TS",
    "tp": "Prec",
    "e": "E",
    "ssr": "SSR",
    "str": "STR",
    "tcw": "TCW",
    "blh": "BLH",
    "NO2": "NO2",
    "is_weekend": "WE",
    "population": "POP",
    "elevation": "ELE",
    "aspect": "ASP",
    "ndvi": "NDVI",
    "gpp": "GPP",
    "lai": "LAI",
    "ntl": "NTL",
    "evi": "EVI",
    "slope": "SLO",
    "odiac": "FFE",
    "CO2_fire": "BBE",
    "geo_x": "geo_x",
    "geo_y": "geo_y",
    "geo_z": "geo_z",
    "month_sin": "mon_sin",
    "month_cos": "mon_cos",
}

X_sample = X.sample(n=5000, random_state=42)  # 或直接用 X
explainer = shap.TreeExplainer(best_model, feature_perturbation="interventional", model_output="raw")
shap_values = explainer(X_sample)
X_plot = X_sample.rename(columns=colmap, errors="ignore")

shap.summary_plot(
    shap_values, X_plot, plot_type="dot", feature_names=list(X_plot.columns),
    cmap=mpl.colormaps['RdBu_r'], color_bar_label="Feature value", max_display=15,
    show=False
)

fig = plt.gcf()
fig.set_size_inches(7, 7)
ax = plt.gca()
for spine in ["top", "right"]:
    ax.spines[spine].set_visible(False)

plt.tight_layout()

plt.show()

In [None]:
mpl.rcParams.update({
    "font.family": "sans-serif",
    "font.sans-serif": ["Arial", "Helvetica", "DejaVu Sans"],
    "font.size": 14,
    "axes.linewidth": 0.8,
    "axes.labelsize": 8,
    "axes.titlesize": 8,
    "xtick.direction": "in",
    "ytick.direction": "in",
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,    
    "xtick.major.size": 4,
    "ytick.major.size": 4,
    "legend.frameon": False,
})

X_sample = X.sample(n=500, random_state=42) 
shap_values_numpy = explainer.shap_values(X_sample)
fig, ax1 = plt.subplots(figsize=(10, 8), dpi=300)
X_plot = X_sample.rename(columns=colmap, errors="ignore")

shap.summary_plot(
    shap_values_numpy, X_plot, plot_type="dot", feature_names=list(X_plot.columns),
    cmap=mpl.colormaps['RdBu_r'], color_bar_label="Feature value", max_display=15,
    show=False
)

cbar_ax = None
for ax in fig.axes:
    if ax.get_ylabel() == "Feature value" or ax.get_xlabel() == "Feature value":
        cbar_ax = ax
        break

cbar_ax.set_ylabel("Feature value", fontsize=14)


plt.gca().set_position([0.5, 0.5, 0.65, 0.65])
ax1 = plt.gca()

for coll in ax1.collections:
    if isinstance(coll, mcoll.PathCollection):
        sizes = coll.get_sizes()
        coll.set_sizes(sizes * 0.2)

ax2 = ax1.twiny()

shap.summary_plot(shap_values_numpy, X_plot, plot_type="bar", show=False, max_display=15, color="#BDBDBD")

plt.gca().set_position([0.5, 0.5, 0.65, 0.65])
ax1.set_zorder(2)
ax1.patch.set_alpha(0) 
ax2.set_zorder(1)
bars = ax2.patches 
for bar in bars:

    ax1.set_xlabel('Shapley Value Contribution (Bee Swarm)', fontsize=14)
    ax2.set_xlabel('Mean Shapley Value (Feature Importance)', fontsize=14)
    ax2.xaxis.set_label_position('top') 
    ax2.xaxis.tick_top() 
    ax1.set_ylabel('Features', fontsize=14)
    plt.tight_layout()
    plt.show()

# **4. Learning Curve**

In [None]:
df_train = df_trainval
df_val   = df_test

dtrain = xgb.DMatrix(df_train[features], label=df_train["xco2"])
dval   = xgb.DMatrix(df_val[features],   label=df_val["xco2"])

scoring = {
    "RMSE": "neg_root_mean_squared_error",
    "R2":   "r2"
}

splits = list(gkf.split(X, y, groups=groups))

params = {k: (v[0] if isinstance(v, (list, tuple, np.ndarray)) else v)
          for k, v in param_grid.items()}
params.update({
    "objective": "reg:squarederror",
    "tree_method": "hist",
})

evals_result = {}
bst = xgb.train(
    params,
    dtrain,
    num_boost_round=100,
    evals=[(dtrain, "train"), (dval, "validation")],
    early_stopping_rounds=50,
    evals_result=evals_result,
    verbose_eval=True
)

train_rmse = evals_result["train"]["rmse"]
val_rmse   = evals_result["validation"]["rmse"]
iters      = list(range(len(train_rmse)))

plt.figure(figsize=(6,4))
plt.plot(iters, train_nmse := train_rmse, label="Train RMSE", lw=2)
plt.plot(iters, val_rmse,        label="Val   RMSE", lw=2)
plt.axvline(bst.best_iteration, linestyle="--", color="gray",
            label=f"Early Stop @ {bst.best_iteration}")
plt.xlabel("Boosting Round")
plt.ylabel("RMSE")
plt.title("Learning Curve (Block-based Train/Val Split)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()