In [None]:
import numpy as np
import sys
sys.path.append("../")
from data_utils import AMSMQTPDataset
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import pandas as pd
import os

save_dir = "../cv_results"
os.makedirs(save_dir, exist_ok=True)

data_set = AMSMQTPDataset()
bound = data_set[0][-1]
models = data_set.models

folds = range(1, 6)
layers = range(1, 6)

for l in layers:
    res = {x: [] for x in models}
    res["pred"] = []
    res["true"] = []
    for i in folds:
        filename = f"../training_results/finetune_Generator_UNet_epochs_250_Glr_0.00035_Dlr_7e-07_patience_10_batch_size_50_warmup_epochs_0_perceptual_loss_vgg16_kernel_size_3_exp_ratio_2_squeeze_ratio_2_point_alpha_0.01_fold_n_{i}/test_results.npz"
        data = np.load(filename)
        preds = data["preds"][:, l - 1, :, :]
        trues = data["trues"][:, l - 1, :, :]
        mask = np.tile(bound[np.newaxis, :, :], (preds.shape[0], 1, 1))
        lrs = data["lrs"]

        res["pred"].append(preds[mask == 1])
        res["true"].append(trues[mask == 1])
        for i, model in enumerate(models):
            model_data = data["lrs"][:, i, :, :]
            res[model].append(model_data[mask == 1])

    for k, v in res.items():
        res[k] = np.concatenate(v)

    np.savez(os.path.join(save_dir, f"layer{l}.npz"), **res)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
from sklearn.metrics import PredictionErrorDisplay
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import os

save_dir = "../plot"
# 创建一个包含6个子图的figure
fig = plt.figure(figsize=(12, 12))
gs = gridspec.GridSpec(3, 2, height_ratios=[1, 1, 1])  # 3行2列的布局

# 前两行每行两个子图
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, 0])
ax4 = fig.add_subplot(gs[1, 1])

# 最后一行一个子图从第0.5列到1.5列的宽度
ax5 = fig.add_subplot(gs[2, :])
pos = ax5.get_position()
ax5.set_position([pos.x0 + pos.width * 0.25, pos.y0, pos.width * 0.5, pos.height])

# 设置子图标题
# ax1.set_title('Subplot 1')
# ax2.set_title('Subplot 2')
# ax3.set_title('Subplot 3')
# ax4.set_title('Subplot 4')
# ax5.set_title('Subplot 5')

axes = [ax1, ax2, ax3, ax4, ax5]

for ax, l in zip(axes, range(1, 6)):
    data = np.load(f"../cv_results/layer{l}.npz")
    pred = data["pred"]
    true = data["true"]

    PredictionErrorDisplay.from_predictions(
        y_true=true,
        y_pred=pred,
        kind="actual_vs_predicted",
        ax=ax,
        scatter_kwargs={"alpha": 0.2, "color": "tab:blue"},
        line_kwargs={"color": "tab:red"},
    )

    r2 = r2_score(true, pred)
    mae = mean_absolute_error(true, pred)
    ubrmse = mean_squared_error(true - true.mean(), pred - pred.mean(), squared=False)
    rmse = mean_squared_error(true, pred, squared=False)
    r = np.corrcoef(true, pred)[0, 1]

    ax.plot([], [], " ", label=f"$R^2$: {r2:.3f}")
    ax.plot([], [], " ", label=f"R: {r:.3f}")
    ax.plot([], [], " ", label=f"MAE: {mae:.3f}")
    ax.plot([], [], " ", label=f"RMSE: {rmse:.3f}")
    ax.plot([], [], " ", label=f"ubRMSE: {ubrmse:.3f}")
    ax.legend(loc="upper left")

    ax.set_title(f"Layer{l}")

plt.tight_layout()
# 显示图形
pos = ax5.get_position()
ax5.set_position([pos.x0 + pos.width * 0.25, pos.y0, pos.width * 0.5, pos.height])

plt.savefig(os.path.join(save_dir, "cv_result.pdf"), format="pdf", bbox_inches="tight")

In [None]:
import skill_metrics as sm 
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import mean_squared_error

data = np.load("../cv_results/layer1.npz")
models = ["access_cm2", "bcc_csm2_mr", "canesm5_canoe", "cnrm_cm6_1_hr", "miroc6", "miroc_es2l", "mri_esm2_0", 
        "noresm2_mm", "taiesm1", "cesm2", "cmcc_cm2_sr5", "cnrm_esm2_1", "ec_earth3_veg_lr", "kace_1_0_g", "mpi_esm1_2_lr", 
        "noresm2_lm", "ukesm1_0_ll"]
items = ["true", "pred"] + models
# 提供数据，SD标准差（也可以是相对标准差），CC相关系数，rmsd均方根误差。 存放顺序分别为（观测，模式1，模式2，模式3...）所以SD和CC的第一个数值是1。
# sd = np.array([1,0.82,1.12,1.48,0.65,0.8])
# cc = np.array([1,0.48,0.36,0.74,0.54,0.61])
# rmsd = np.array([0,1.68,1.57,1.92,2.34,2.18])
model_map = {
    "access_cm2": "ACCESS-CM2", "bcc_csm2_mr": "BCC-CSM2-MR", "canesm5_canoe": "CanESM5-CanOE", "cnrm_cm6_1_hr": "CNRM-CM6-1-HR", 
    "miroc6": "MIROC6", "miroc_es2l": "MIROC-ES2L", "mri_esm2_0": "MRI-ESM2-0", "noresm2_mm": "NorESM2-MM", "taiesm1": "TaiESM1", 
    "cesm2": "CESM2", "cmcc_cm2_sr5": "CMCC-CM2-SR5", "cnrm_esm2_1": "CNRM-ESM2-1", "ec_earth3_veg_lr": "EC-Earth3-Veg-LR", "kace_1_0_g": "KACE-1-0-G", 
    "mpi_esm1_2_lr": "MPI-ESM1-2-HR", "noresm2_lm": "NorESM2-LM", "ukesm1_0_ll": "UKESM1-0-LL"
}
label=['AMSMQTP','UNet-Gan'] + [model_map[x] for x in models]

ensemble = np.stack([data[x] for x in models], axis=0).mean(axis=0)

sd = [data[x].std() for x in items] + [ensemble.std()]
cc = [np.corrcoef(data["true"], data[x])[0, 1] for x in items] + [np.corrcoef(data["true"], ensemble)[0, 1]]
rmsd = [mean_squared_error(data["true"], data[x], squared=False) for x in items] + [mean_squared_error(data["true"], ensemble, squared=False)]
sd = np.array(sd)
cc = np.array(cc)
rmsd = np.array(rmsd)
# sd = sd / sd[0]

fig = plt.figure(figsize=(12,8))
ax = fig.add_axes([0.1, 0.6, 0.6, 0.6])

# 绘图核心函数
sm.taylor_diagram(sd,rmsd,cc,markerLabel = label + ["Ensemble"],markercolor="k",markerSize=6,markerLegend = 'on',#基本参数
                  colCOR="k",styleCOR="--",widthCOR=.4, #CC相关设置
                  colSTD="k",widthSTD=.9,styleSTD="--", showlabelsSTD="on", 
                  widthRMS=0.5,labelRMS='RMSE',colRMS='k', rmsLabelFormat="0:.2f", #RMSD相关设置
                  colOBS="r",styleOBS="-",widthOBS=1,markerObs="",titleOBS="AMSMQTP", #观测值设置
                 )

ax.grid(False)

# ax.text(-0.36, -0.18, 'Standard Deviation', fontsize=12, color='black', fontweight='bold')
ax.text(-0.025, -0.013, 'Standard Deviation', fontsize=12, color='black', fontweight='bold')

# 获取图例对象
legend = ax.get_legend()

# 调整图例项的字体大小
for text in legend.get_texts():
    text.set_fontsize(14)

# # 设置图例按一列显示
# legend.set_ncol(1)

# plt.tight_layout()

plt.savefig("../plot/taylor.pdf", format="pdf", bbox_inches="tight")