In [9]:
import torch
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, precision_recall_curve,
    confusion_matrix, roc_auc_score, roc_curve, average_precision_score
)
import matplotlib.pyplot as plt

In [10]:
def evaluate_model(model, data_loader, device, pos_label=1):
    """
    评估二分类模型性能

    :param model: 训练好的模型
    :param data_loader: 数据加载器
    :param device: 设备（CPU 或 GPU）
    :param pos_label: 正类标签，默认为 1
    :return: 评估结果字典
    """
    model.eval()
    all_preds = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for features, targets in data_loader:
            features = features.to(device)
            targets = targets.to(device)

            outputs = model(features)
            if isinstance(outputs, tuple):
                outputs = outputs[0]

            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(probs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())

    # 计算评估指标
    metrics = {
        'accuracy': accuracy_score(all_targets, all_preds),
        'precision': precision_score(all_targets, all_preds, pos_label=pos_label),
        'recall': recall_score(all_targets, all_preds, pos_label=pos_label),
        'f1': f1_score(all_targets, all_preds, pos_label=pos_label),
        'confusion_matrix': confusion_matrix(all_targets, all_preds).tolist(),
        'roc_auc': roc_auc_score(all_targets, all_probs) if len(set(all_targets)) > 1 else None,
        'mAP': average_precision_score(all_targets, all_probs),
        'fpr': None,
        'tpr': None,
        'precision_curve': None,
        'recall_curve': None
    }

    # 计算 FPR 和 TPR
    if metrics['roc_auc'] is not None:
        fpr, tpr, _ = roc_curve(all_targets, all_probs, pos_label=pos_label)
        metrics['fpr'] = fpr
        metrics['tpr'] = tpr

    # 计算 Precision-Recall 曲线
    precision, recall, _ = precision_recall_curve(all_targets, all_probs, pos_label=pos_label)
    metrics['precision_curve'] = precision
    metrics['recall_curve'] = recall

    return metrics

In [11]:
def evaluate_final(model, data_loader, device, pos_label=1):
    """
    评估模型性能，包括详细指标

    :param model: 训练好的模型
    :param data_loader: 数据加载器
    :param device: 设备（CPU 或 GPU）
    :param pos_label: 正类标签，默认为 1
    :return: 评估结果字典
    """
    metrics = evaluate_model(model, data_loader, device, pos_label)

    # 打印评估结果
    print(f'Evaluation for model: {model.__class__.__name__}')
    print(f'{{')
    print(f'    Accuracy: {metrics["accuracy"] * 100:.2f}%')
    print(f'    Precision: {metrics["precision"]:.2f}')
    print(f'    Recall: {metrics["recall"]:.2f}')
    print(f'    F1 Score: {metrics["f1"]:.2f}')
    print(f'    Confusion Matrix: {metrics["confusion_matrix"]}')
    print(f'    ROC AUC: {metrics["roc_auc"]:.2f}')
    print(f'    mAP: {metrics["mAP"]:.2f}')
    print(f'}}')

    # 绘制 ROC 曲线
    if metrics['fpr'] is not None and metrics['tpr'] is not None:
        plt.figure()
        plt.plot(metrics['fpr'], metrics['tpr'], label='ROC curve (area = %0.2f)' % metrics['roc_auc'])
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic')
        plt.legend(loc="lower right")
        plt.show()

    # 绘制 PR 曲线
    if metrics['precision_curve'] is not None and metrics['recall_curve'] is not None:
        plt.figure()
        plt.plot(metrics['recall_curve'], metrics['precision_curve'], label=f'PR curve (area = {metrics["mAP"]:.2f})')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title('Precision-Recall Curve')
        plt.legend(loc="lower right")
        plt.show()

In [12]:
def evaluate_epoch(model, train_loader, valid_loader, device, epoch, num_epochs, criterion, pos_label=1):
    """
    评估一个 epoch 的性能，同时评估训练集和验证集

    :param model: 训练好的模型
    :param train_loader: 训练数据加载器
    :param valid_loader: 验证数据加载器
    :param device: 设备（CPU 或 GPU）
    :param epoch: 当前 epoch
    :param num_epochs: 总 epoch 数
    :param criterion: 损失函数
    :param pos_label: 正类标签，默认为 1
    :return: 训练集和验证集的评估结果字典
    """
    model.eval()
    train_loss = 0.0
    valid_loss = 0.0

    # 评估训练集
    train_metrics = evaluate_model(model, train_loader, device, pos_label)
    for features, targets in train_loader:
        features = features.to(device)
        targets = targets.to(device)
        outputs = model(features)
        if isinstance(outputs, tuple):
            outputs = outputs[0]
        loss = criterion(outputs, targets)
        train_loss += loss.item()
    train_loss /= len(train_loader)

    # 评估验证集
    valid_metrics = evaluate_model(model, valid_loader, device, pos_label)
    for features, targets in valid_loader:
        features = features.to(device)
        targets = targets.to(device)
        outputs = model(features)
        if isinstance(outputs, tuple):
            outputs = outputs[0]
        loss = criterion(outputs, targets)
        valid_loss += loss.item()
    valid_loss /= len(valid_loader)

    # 输出评估指标
    print(f'Evaluation at epoch : {epoch + 1:03d}/{num_epochs:03d}')
    print(f'{{')
    print(f'    Loss: train = {train_loss:.4f}; validation = {valid_loss:.4f}')
    print(f'    Accuracy: train = {train_metrics["accuracy"] * 100:.2f}%; validation = {valid_metrics["accuracy"] * 100:.2f}%')
    print(f'    Precision: train = {train_metrics["precision"]:.2f}; validation = {valid_metrics["precision"]:.2f}')
    print(f'    Recall: train = {train_metrics["recall"]:.2f}; validation = {valid_metrics["recall"]:.2f}')
    print(f'    F1 Score: train = {train_metrics["f1"]:.2f}; validation = {valid_metrics["f1"]:.2f}')
    print(f'}}')

    # 更新评估结果字典
    train_metrics['loss'] = train_loss
    valid_metrics['loss'] = valid_loss

    return train_metrics, valid_metrics