In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch.optim
from sklearn.metrics import *


def validate(model, dataloader):
    """
    计算每一个epoch结束的模型性能
    :param epoch: 当前Epoch
    :return: valid_loss, valid_accuracy, valid_specificity, valid_alarm_sen, valid_alarm_acc, valid_auc
    """
    model.eval()
    model.to("cuda")
    true_labels = []
    predicted_probs = []

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to("cuda"), targets.to("cuda")
            outputs = model(inputs.float())
            true_labels.append(targets.cpu().numpy())
            predicted_probs.append(outputs.cpu().numpy())

        true_labels_flat = np.concatenate(true_labels)
        predicted_probs_flat = np.concatenate(predicted_probs)


        valid_auc = _plot_roc_curve(true_labels_flat, predicted_probs_flat)
        valid_prc = _plot_prc_curve(true_labels_flat, predicted_probs_flat)

    return valid_auc, valid_prc

def _plot_roc_curve(true_labels_flat, predicted_probs_flat):
    fpr, tpr, thresholds = roc_curve(true_labels_flat, predicted_probs_flat)
    valid_auc = auc(fpr, tpr)
    return valid_auc

def _plot_prc_curve(true_labels_flat, predicted_probs_flat):
    precision, recall, thresholds = precision_recall_curve(true_labels_flat, predicted_probs_flat)
    prc_auc = auc(recall, precision)
    return prc_auc

In [2]:
from torch import nn


class VotingEnsemble(nn.Module):
    def __init__(self, models, voting_type='soft'):
        super(VotingEnsemble, self).__init__()
        self.models = models
        self.voting_type = voting_type

    def forward(self, x):
        outputs = [model(x) for model in self.models]
        outputs = torch.stack(outputs, dim=0)

        if self.voting_type == 'soft':
            avg_output = torch.mean(outputs, dim=0)
            return avg_output

        elif self.voting_type == 'hard':
            binary_output = (outputs > 0.5).int()
            votes = torch.sum(binary_output, dim=0)
            num = outputs.shape[0] - outputs.shape[0] // 2
            hard_voted_result = (votes >= num).int()
            return hard_voted_result


        else:
            raise ValueError("voting_type must be 'soft' or 'hard'")

In [6]:
from torch.utils.data import TensorDataset, DataLoader


def voting(data_dir, models_info, selected_model, batch_size):
    train_ROC = []
    train_PRC = []
    val_ROC = []
    val_PRC = []
    test_ROC = []
    test_PRC = []

    if len(selected_model) == 0:
        selected_model = [idx for idx, _ in models_info.items()]

    models = []
    for idx, model_info in models_info.items():
        if idx in selected_model:
            model = torch.load(model_info['model_path']).to('cuda')
            models.append(model)

    data = torch.load(data_dir)
    data_train = data['data_tensor_train']
    label_train = data['label_tensor_train']
    data_val = data['data_tensor_val']
    label_val = data['label_tensor_val']
    data_test = data['data_tensor_test']
    label_test = data['label_tensor_test']
    i = 1

    dataset_train = TensorDataset(data_train, label_train)
    dataset_val = TensorDataset(data_val, label_val)
    dataset_test = TensorDataset(data_test, label_test)

    # 利用 DataLoader 来加载数据集
    train_dataloader_f = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
    val_dataloader_f = DataLoader(dataset_val, batch_size=batch_size, shuffle=True)
    test_dataloader_f = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)

    # voting_type = 'soft'
    # print("Voting Type:", voting_type)
    # voting_model = VotingEnsemble(models, voting_type=voting_type)
    # print("Voting Model Results:")
    # 
    # auroc, auprc = validate(voting_model, train_dataloader_f)
    # train_ROC.append(auroc)
    # train_PRC.append(auprc)
    # print("\tAUROC train:", auroc)
    # print("\tAUPRC train:", auprc)
    # 
    # auroc, auprc = validate(voting_model, val_dataloader_f)
    # val_ROC.append(auroc)
    # val_PRC.append(auprc)
    # print("\tAUROC val:", auroc)
    # print("\tAUPRC val:", auprc)
    # 
    # auroc, auprc = validate(voting_model, test_dataloader_f)
    # test_ROC.append(auroc)
    # test_PRC.append(auprc)
    # print("\tAUROC test:", auroc)
    # print("\tAUPRC test:", auprc)
    # 
    # voting_type = 'hard'
    # print("\nVoting Type:", voting_type)
    # voting_model = VotingEnsemble(models, voting_type=voting_type)
    # print("Voting Model Results:")
    # 
    # auroc, auprc = validate(voting_model, train_dataloader_f)
    # train_ROC.append(auroc)
    # train_PRC.append(auprc)
    # print("\tAUROC train:", auroc)
    # print("\tAUPRC train:", auprc)
    # 
    # auroc, auprc = validate(voting_model,val_dataloader_f)
    # val_ROC.append(auroc)
    # val_PRC.append(auprc)
    # print("\tAUROC val:", auroc)
    # print("\tAUPRC val:", auprc)
    # 
    # auroc, auprc = validate(voting_model, test_dataloader_f)
    # test_ROC.append(auroc)
    # test_PRC.append(auprc)
    # print("\tAUROC test:", auroc)
    # print("\tAUPRC test:", auprc)
    # 
    # print("\n")

    for name in selected_model:
        print(f'{i}.', models_info[name]['model_name'])
        auroc, auprc = validate(models[i - 1], train_dataloader_f)
        train_ROC.append(auroc)
        train_PRC.append(auprc)
        print("\tAUROC train:", auroc)
        print("\tAUPRC train:", auprc)

        auroc, auprc = validate(models[i - 1],val_dataloader_f)
        val_ROC.append(auroc)
        val_PRC.append(auprc)
        print("\tAUROC val:", auroc)
        print("\tAUPRC val:", auprc)

        auroc, auprc = validate(models[i - 1], test_dataloader_f)
        test_ROC.append(auroc)
        test_PRC.append(auprc)
        print("\tAUROC test:", auroc)
        print("\tAUPRC test:", auprc)
        i += 1

    auc_dict = {
        'train_roc': train_ROC,
        'train_prc': train_PRC,
        'val_roc': val_ROC,
        'val_prc': val_PRC,
        'test_roc': test_ROC,
        'test_prc': test_PRC
    }

    return auc_dict

In [7]:
models_info_zyy = {
    "model1": {
        "model_name": "BiLSTM_BN, undersample, 52",
        "model_path": "ZYY/zzz_saved_model/ZYY_BiLSTM_BN_model_undersample_FocalLoss_100_0.01_model_52.pth"
    },
    "model2": {
        "model_name": "BiLSTM_BN, undersample, 57",
        "model_path": "ZYY/zzz_saved_model/ZYY_BiLSTM_BN_model_undersample_FocalLoss_100_0.01_model_57.pth"
    },
    "model3": {
        "model_name": "BiLSTM_BN_larger, undersample, 21",
        "model_path": "ZYY/zzz_saved_model/ZYY_BiLSTM_BN_larger_model_undersample_FocalLoss_100_0.01_model_21.pth"
    }
}
data_dir = f'E:\deeplearning\Zhongda\zyy_tensor.pth'
selected_model = [] # 默认全选
auc_dict = voting(data_dir, models_info_zyy, selected_model, 512)

1. BiLSTM_BN, undersample, 52
	AUROC train: 0.6827885236433993
	AUPRC train: 0.2272901083862063
	AUROC val: 0.8400546914095379
	AUPRC val: 0.3417621887198033
	AUROC test: 0.6817865935577786
	AUPRC test: 0.09671451037319048
2. BiLSTM_BN, undersample, 57
	AUROC train: 0.5901832544400971
	AUPRC train: 0.17009256689477908
	AUROC val: 0.7744427889213422
	AUPRC val: 0.4315726727164132
	AUROC test: 0.5591761533359847
	AUPRC test: 0.10489862390386981
3. BiLSTM_BN_larger, undersample, 21
	AUROC train: 0.679106450485927
	AUPRC train: 0.1873166420682143
	AUROC val: 0.8442721081339178
	AUPRC val: 0.251515068467237
	AUROC test: 0.691536485028206
	AUPRC test: 0.11736421687618051
