In [1]:
import pickle
import numpy as np
import pandas as pd

from cuml.ensemble import RandomForestClassifier
# from xgboost import XGBClassifier
# from lightgbm import LGBMClassifier
from sklearn.metrics import classification_report, confusion_matrix
from model_utils import get_dataset

from sklearn.metrics import roc_curve, roc_auc_score
import matplotlib.pyplot as plt

In [2]:
# 讀取模型
def load_model(path):
    return pickle.load(open(path, 'rb'))

def roc(clf, x_test, y_test, plot_roc=False):
    
    if isinstance(clf, RandomForestClassifier):
        y_pred_proba = np.array(clf.predict_proba(x_test))[:, 1]
    else:
        y_pred_proba = clf.predict_proba(x_test)[:, 1]
    # 假设你有真实标签和预测概率
    y_true = y_test.to_numpy()
    y_scores = y_pred_proba

    # 计算FPR和TPR
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    
    optimal_idx = np.argmax(tpr - fpr)
    optimal_threshold = thresholds[optimal_idx]
    print(f'optimal_threshold: {optimal_threshold}')

    # 计算AUC值
    auc = roc_auc_score(y_true, y_scores)
    print(f'AUC: {auc}')
    
    if plot_roc:
        # 绘制ROC曲线
        plt.plot(fpr, tpr, label='ROC curve (AUC = %0.4f)' % auc)
        plt.plot([0, 1], [0, 1], 'k--')  # 绘制对角线
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic')
        plt.legend(loc='lower right')
        plt.show()
        
def get_models_roc(F_cos, with_S=False):
    
    cols_selected = ['F1','F2','F3','F4','F5','F6','F7',F_cos]
    ws = ''
    
    if with_S and F_cos == 'F10':
        cols_selected.insert(0, 'section')
        ws = '_S'
        
    
    x_train, y_train, x_test, y_test = get_dataset(train_set = train_path,
                                               test_set = test_path,
                                               cols = cols_selected, # input data (x)
                                               tgt = 'label' # target label (y)
                                               )
    
    for clf_type in ['RF', 'XGB', 'LGB']:
        print(clf_type)
        res = load_model(f'model/{clf_type}_model_{F_cos}{ws}.pkl')
        roc(res, x_test, y_test)
        print()

In [3]:
train_path = '../../dataset/to_extractive/train.parquet'
test_path = '../../dataset/to_extractive/test.parquet'

In [4]:
x_train, y_train, x_test, y_test = get_dataset(train_set = train_path,
                                               test_set = test_path,
                                               cols = ['section','F1','F2','F3','F4','F5','F6','F7','F10'], # input data (x)
                                               tgt = 'label' # target label (y)
                                               )

In [5]:
res = load_model('model/LGB_model_F10_S.pkl')
roc(res, x_test, y_test)

optimal_threshold: 0.5105444268625421
AUC: 0.7808466631304664


In [6]:
get_models_roc('F10', with_S=True)

RF
optimal_threshold: 0.4933217465877533
AUC: 0.7432594688991462

XGB
optimal_threshold: 0.4986463487148285
AUC: 0.78036443099712

LGB
optimal_threshold: 0.5105444268625421
AUC: 0.7808466631304664



In [7]:
get_models_roc('F8')

RF
optimal_threshold: 0.4922090172767639
AUC: 0.7327006171270529

XGB
optimal_threshold: 0.4916720688343048
AUC: 0.7379643380406058

LGB
optimal_threshold: 0.5040252138163217
AUC: 0.7381477437275191



In [8]:
get_models_roc('F9')

RF
optimal_threshold: 0.49609261751174927
AUC: 0.7364963030195983

XGB
optimal_threshold: 0.5143572688102722
AUC: 0.7412624640871914

LGB
optimal_threshold: 0.5018649301113228
AUC: 0.7415262350630945



In [9]:
get_models_roc('F10')

RF
optimal_threshold: 0.4844803810119629
AUC: 0.7373888208710646

XGB
optimal_threshold: 0.49048787355422974
AUC: 0.742466243023354

LGB
optimal_threshold: 0.4970260154858402
AUC: 0.742640951573551

