In [8]:
import pandas as pd
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
import json


def calculate_metrics(y_true, y_pred):
    """
    计算多标签分类任务的指标: 加权F1, 总体acc, precision, recall。

    参数:
    y_true (List[List[int]]): 真实标签
    y_pred (List[List[int]]): 预测标签

    返回:
    dict: 各指标的分数
    """

    # 计算指标
    metrics = {
        "f1": float(f1_score(y_true, y_pred, average="weighted")),
        "accuracy": float(accuracy_score(y_true, y_pred)),
        "precision": float(precision_score(y_true, y_pred, average="weighted")),
        "recall": float(recall_score(y_true, y_pred, average="weighted")),
    }

    return metrics


def get_label_pred(test_file, pred_file):
    """获取测试集标签以及预测结果

    Args:
        test_file (_type_): 带ground_truth标签的测试集文件
        pred_file (_type_): 对应的预测结果文件
    """
    test_data = json.load(open(test_file, "r"))
    labels = [item["output"] for item in test_data]

    pred_data = [json.loads(line) for line in open(pred_file, "r")]
    preds = [item["predict"] for item in pred_data]

    return labels, preds


def cal_acc(test_file, pred_file):
    labels, preds = get_label_pred(test_file, pred_file)
    metrics = calculate_metrics(y_true=labels, y_pred=preds)
    dlo_metrics = calculate_metrics(y_true=labels[:300], y_pred=preds[:300])
    img_metrics = calculate_metrics(y_true=labels[300:], y_pred=preds[300:])
    return metrics, dlo_metrics, img_metrics


if __name__ == "__main__":
    test_file = "./data/mire/train2.json"
    pred_file = "./saves/Qwen2-VL-7B-Instruct/lora/eval_2024-11-26-22-05-54/generated_predictions.jsonl"
    metrics, dlo_metrics, img_metrics = cal_acc(test_file, pred_file)

    print("总分：", metrics)
    print("对话意图得分：", dlo_metrics)
    print("图片场景得分：", img_metrics)


总分： {'f1': 0.8912284893924897, 'accuracy': 0.894, 'precision': 0.8999297428578823, 'recall': 0.894}
对话意图得分： {'f1': 0.9454438552901191, 'accuracy': 0.9466666666666667, 'precision': 0.9528636733636734, 'recall': 0.9466666666666667}
图片场景得分： {'f1': 0.8679933325792196, 'accuracy': 0.8714285714285714, 'precision': 0.8772437726411149, 'recall': 0.8714285714285714}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [7]:
test_file

'./data/mire/train2.json'