<a href="https://colab.research.google.com/github/funway/nid-imbalance-study/blob/main/utils/evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 评估函数封装
🚀 NYIT 880 | 🧑🏻‍💻 funway

- **该文件只定义 函数 或 类** ‼️
- 在其他 ipynb 文件中运行 `%run full_path/file_name.ipynb` 即可导入该文件中的函数与类
- 由于 Google Drive 的写入缓存, 所以修改该文件后, 可能需要等待一定时间(几十秒), 在引用处才会生效

In [None]:
if 'now' not in globals() or not callable(globals()['now']):
    def now() -> str:
        """获取当前时间"""
        from datetime import datetime
        from zoneinfo import ZoneInfo
        return datetime.now(tz=ZoneInfo('America/Vancouver')).strftime('%x %X %Z')
# else:
#     print(f"[{now()}] ⚠️ Function 'now' is already defined")

In [None]:
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, balanced_accuracy_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt
import seaborn as sns

def generate_evaluation_report(y_true, y_predict,
                               label_mapping: dict,
                               figure_output: str,
                               figure_show: bool = True):
    """
    生成报告，包括混淆矩阵、多分类报告、二分类报告。
    """

    ### 混淆矩阵 ###
    report_text_cm = f"[{now()}] ================= 📊 Confusion Matrix =================\n"

    ## 计算混淆矩阵 ##
    cm = confusion_matrix(y_true, y_pred)

    ## 归一化混淆矩阵(计算为百分比) ##
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    labels = [f"({k}) {v}" for k, v in sorted(label_mapping.items(), key=lambda x: x[1])]
    # print(f"[{now()}] 🏷️ Label mapping: {label_mapping}")
    # print(f"[{now()}] 🏷️ labels: {labels}")

    ## 绘制混淆矩阵图像 ##
    np.set_printoptions(precision=2, suppress=True)  # 设置打印时保留小数点后2位, 禁用科学计数法
    plt.figure(figsize=(10, 8), dpi=150)
    sns.heatmap(cm_normalized, annot=True, fmt=".2f", cmap="Blues", xticklabels=labels, yticklabels=labels)
    plt.title(f"Confusion Matrix (Multiclass)\n{figure_output.stem}")
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    # 保存图像
    plt.savefig(figure_output, bbox_inches='tight')
    # 显示图像 (show 完会自动清空当前图像, 所以得先 save 再 show!)
    if figure_show:
        plt.show()

    ## 输出混淆矩阵文本 ##
    # 保证宽度对齐
    col_width = 7
    report_text_cm += " " * 24 + "".join([name[name.index(")")+2:].ljust(col_width) for name in labels]) + "\n"
    for i, row in enumerate(cm_normalized):
        row_name = labels[i][:20].ljust(24)
        row_vals = "".join([
            ("-----".ljust(col_width) if val < 1e-4 else f"{val:.3f}".ljust(col_width))
            for val in row])
        report_text_cm += row_name + row_vals + "\n"


    ### 多分类报告 ###
    report_multiclass = classification_report(y_true, y_pred, target_names=labels, zero_division=0, digits=6)
    # zero_division 是防止除零告警, 如果有某个类别完全没有判断正确的样本，就会出现除零告警

    # Macro 平均 (不考虑不同类别的样本数量差异)
    precision_macro = precision_score(y_true, y_pred, average='macro', zero_division=0)
    rec_macro = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1_macro = f1_score(y_true, y_pred, average='macro', zero_division=0)

    # Weighted 平均 (考虑不同类别的样本数量差异)
    precision_weighted = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    rec_weighted = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1_weighted = f1_score(y_true, y_pred, average='weighted', zero_division=0)

    acc = accuracy_score(y_true, y_pred)
    balanced_acc = balanced_accuracy_score(y_true, y_pred)

    report_text_multiclass =  (
        f"[{now()}] ================= 📝 Multiclass Report =================\n"
        + report_multiclass + "\n"
        + f"📊 Macro average\n"
        + f"    Precision: {precision_macro:.6f}\n"
        + f"    Recall:    {rec_macro:.6f}\n"
        + f"    F1 Score:  {f1_macro:.6f}\n\n"
        + f"📊 Weighted average\n"
        + f"    Precision: {precision_weighted:.6f}\n"
        + f"    Recall:    {rec_weighted:.6f}\n"
        + f"    F1 Score:  {f1_weighted:.6f}\n\n"
        + f"🎯 Accuracy: {acc:.6f}\n"
        + f"🎯 Balanced Accuracy: {balanced_acc:.6f}\n"
    )


    ### 二分类报告 ###
    # 将多分类标签映射为二分类
    # benign 标签为 0 分类，所有攻击标签作为 1 分类
    normal_class_label = 0
    y_true_bin = np.where(y_true == normal_class_label, 0, 1)  # 原 label 为 0 的作为二分类的 0 类别，其他的作为二分类的 1 类别
    y_pred_bin = np.where(y_pred == normal_class_label, 0, 1)

    report_binary = classification_report(y_true_bin, y_pred_bin, zero_division=0, digits=6)

    # 计算常用二分类指标
    acc_bin = accuracy_score(y_true_bin, y_pred_bin)  # 0,1 分类的准确率
    pre_bin = precision_score(y_true_bin, y_pred_bin)  # 1 分类(攻击类别)的精确度(预测攻击类别有多准)
    rec_bin = recall_score(y_true_bin, y_pred_bin) # 攻击类别的召回率 (有多少真实的攻击类别被预测出来了)
    f1_bin = f1_score(y_true_bin, y_pred_bin)  # 攻击类别的

    # 🧮 计算 FPR（误报率） = FP / (FP + TN)
    cm_bin = confusion_matrix(y_true_bin, y_pred_bin)
    if cm_bin.shape == (2, 2):
        tn, fp, fn, tp = cm_bin.ravel()
        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0
    else:
        print(f"[{now()}] ⚠️ confusion matrix shape is not (2, 2)")
        raise Exception("confusion matrix shape is not (2, 2)")

    report_text_binary = (
        f"[{now()}] ================= 📝 Binary Report (Normal vs Attack) =================\n"
        + report_binary + "\n"
        + f"🎯 Accuracy       : {acc_bin:.6f}\n"
        + f"✅ Precision      : {pre_bin:.6f}\n"
        + f"🔁 Recall / DR    : {rec_bin:.6f}\n"
        + f"🎯 F1 Score       : {f1_bin:.6f}\n"
        + f"🚨 FPR (误报率)    : {fpr:.6f}\n"
    )


    ### 返回报告 ###
    report_text = report_text_cm + '\n' + report_text_multiclass + '\n' + report_text_binary
    return report_text