In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import re
import tempfile
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (accuracy_score, precision_score, recall_score, 
                             f1_score, roc_auc_score, average_precision_score,
                             roc_curve, precision_recall_curve, confusion_matrix,
                             ConfusionMatrixDisplay, classification_report)
from openpyxl import Workbook
from openpyxl.utils.dataframe import dataframe_to_rows
from openpyxl.drawing.image import Image
from joblib import dump, load

# ====================== 配置参数 ======================
CONFIG = {
    'data_file': "data.xlsx",
    'target': "1yearegfr",
    'numerical_features': ['ePWV', 'SII', '24h-UP', 'eGFR'],
    'test_size': 0.25,
    'random_state': None,
    'manual_threshold': 0.46,  # 添加手动阈值参数
    'manual_params': {
        'n_estimators': 500,
        'max_depth': 5,
        'min_samples_split': 2,
        'min_samples_leaf': 5,
        'max_features': 0.1,
        'class_weight': 'balanced',
        'bootstrap': True,
        'oob_score': False,
        'random_state': 42,
        'criterion': 'entropy',
        'n_jobs': -1
    },
    'output_dir': "results\6_ablation"
}

# ====================== 工具函数 ======================
def load_and_preprocess_data(file_path):
    """加载并预处理数据"""
    data = pd.read_excel(file_path)
    
    print("数据前5行：")
    print(data.head())
    print("\n数据描述统计：")
    print(data.describe())
    print("\n缺失值检查：")
    print(data.isnull().sum())
    
    return data.dropna()

def preprocess_features(data, numerical_features, target):
    """预处理特征数据"""
    X = data[numerical_features].copy()
    y = data[target]
    
    return X, y

def evaluate_by_ckd_group(model, X, y_true, ckd_groups, threshold=0.5):
    """按CKD分组评估模型性能"""
    group_metrics = {}
    y_prob = model.predict_proba(X)[:, 1]
    y_pred = (y_prob >= threshold).astype(int)
    
    for group in sorted(ckd_groups.unique()):
        group_indices = ckd_groups == group
        if sum(group_indices) == 0:
            continue
            
        group_y_true = y_true[group_indices]
        group_y_prob = y_prob[group_indices]
        group_y_pred = y_pred[group_indices]
        
        metrics = {
            'n_samples': sum(group_indices),
            'accuracy': accuracy_score(group_y_true, group_y_pred),
            'precision': precision_score(group_y_true, group_y_pred, zero_division=0),
            'recall': recall_score(group_y_true, group_y_pred, zero_division=0),
            'f1': f1_score(group_y_true, group_y_pred, zero_division=0),
            'roc_auc': roc_auc_score(group_y_true, group_y_prob) if len(np.unique(group_y_true)) > 1 else np.nan,
            'pr_auc': average_precision_score(group_y_true, group_y_prob)
        }
        group_metrics[group] = metrics
    
    return group_metrics

def calculate_metrics(y_true, y_pred, y_prob, prefix=''):
    """计算评估指标"""
    return {
        f'{prefix}accuracy': accuracy_score(y_true, y_pred),
        f'{prefix}precision': precision_score(y_true, y_pred, average='binary', zero_division=0),
        f'{prefix}recall': recall_score(y_true, y_pred, average='binary', zero_division=0),
        f'{prefix}f1': f1_score(y_true, y_pred, average='binary', zero_division=0),
        f'{prefix}roc_auc': roc_auc_score(y_true, y_prob),
        f'{prefix}pr_auc': average_precision_score(y_true, y_prob)
    }

def save_individual_plots(best_rf, X_test, y_test, y_prob, X_train, y_train, y_train_prob, 
                         manual_threshold, test_roc_auc, test_pr_auc):
    """保存各个子图"""
    os.makedirs(CONFIG['output_dir'], exist_ok=True)
    
    # 特征重要性
    feature_importance = pd.Series(best_rf.feature_importances_, index=X_train.columns)
    plt.figure(figsize=(6, 4))
    feature_importance.sort_values().plot(kind='barh')
    plt.title('Feature Importance(Gini)')
    plt.tight_layout()
    plt.savefig(f"{CONFIG['output_dir']}/01_feature_importance.png", dpi=300)
    plt.close()

    # ROC曲线
    fpr_test, tpr_test, _ = roc_curve(y_test, y_prob)
    fpr_train, tpr_train, _ = roc_curve(y_train, y_train_prob)
    train_roc_auc = roc_auc_score(y_train, y_train_prob)
    
    plt.figure(figsize=(6, 4))
    plt.plot(fpr_test, tpr_test, label=f'Test ROC (AUC = {test_roc_auc:.2f})')
    plt.plot(fpr_train, tpr_train, label=f'Train ROC (AUC = {train_roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], 'k--', lw=1)
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve (Train vs Test)')
    plt.legend(loc='lower right')
    plt.tight_layout()
    plt.savefig(f"{CONFIG['output_dir']}/02_roc_curve.png", dpi=300)
    plt.close()

    # PR曲线
    precision, recall, _ = precision_recall_curve(y_test, y_prob)
    plt.figure(figsize=(6, 4))
    plt.plot(recall, precision, label=f'Test PR (AUC = {test_pr_auc:.2f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{CONFIG['output_dir']}/03_pr_curve.png", dpi=300)
    plt.close()

    # 混淆矩阵 - 只使用手动输入的阈值
    y_pred = (y_prob >= manual_threshold).astype(int)
    cm = confusion_matrix(y_test, y_pred)
    plt.figure(figsize=(5, 4))
    ConfusionMatrixDisplay(confusion_matrix=cm).plot(cmap='Blues', colorbar=False)
    plt.title(f'Test CM (Manual Threshold={manual_threshold:.2f})')
    plt.tight_layout()
    plt.savefig(f"{CONFIG['output_dir']}/04_cm_manual.png", dpi=300)
    plt.close()

    # 阈值选择图 - 显示手动阈值位置
    thresholds_arr = np.linspace(0, 1, 100)
    f1_scores_curve = [f1_score(y_test, (y_prob >= t).astype(int)) for t in thresholds_arr]
    
    plt.figure(figsize=(6, 4))
    plt.plot(thresholds_arr, f1_scores_curve, label='F1 Score')
    plt.axvline(manual_threshold, color='r', ls='--', label=f'Manual Threshold = {manual_threshold:.2f}')
    plt.xlabel('Threshold')
    plt.ylabel('F1 Score')
    plt.title('Threshold Selection (F1 Score)')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{CONFIG['output_dir']}/05_threshold_selection.png", dpi=300)
    plt.close()

# ====================== 主执行流程 ======================
def main():
    # 显示当前使用的阈值
    print(f"使用手动设置的阈值: {CONFIG['manual_threshold']}")
    
    # 1. 数据准备
    data = load_and_preprocess_data(CONFIG['data_file'])
    X, y = preprocess_features(data, CONFIG['numerical_features'], CONFIG['target'])
    
    # 2. 划分数据集
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=CONFIG['test_size'], random_state=CONFIG['random_state'], stratify=y
    )
    
    # 3. 交叉验证
    rf = RandomForestClassifier(**CONFIG['manual_params'])
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=CONFIG['random_state'])
    cv_scores = cross_val_score(rf, X_train, y_train, cv=cv, scoring='roc_auc', n_jobs=-1)
    
    print(f"\n交叉验证结果（ROC AUC）：平均得分: {cv_scores.mean():.4f} (±{cv_scores.std():.4f})")
    
    # 4. 使用手动参数训练模型
    best_rf = RandomForestClassifier(**CONFIG['manual_params'])
    best_rf.fit(X_train, y_train)
    
    print(f"\n使用手动参数训练模型:")
    for param, value in CONFIG['manual_params'].items():
        print(f"  {param}: {value}")
    
    # 5. 模型预测
    y_train_prob = best_rf.predict_proba(X_train)[:, 1]
    y_prob = best_rf.predict_proba(X_test)[:, 1]
    
    # 6. 使用手动阈值进行预测
    manual_threshold = CONFIG['manual_threshold']
    y_pred_manual = (y_prob >= manual_threshold).astype(int)
    
    # 7. 计算评估指标
    test_roc_auc = roc_auc_score(y_test, y_prob)
    test_pr_auc = average_precision_score(y_test, y_prob)
    
    # 8. 可视化
    save_individual_plots(best_rf, X_test, y_test, y_prob, X_train, y_train, 
                         y_train_prob, manual_threshold, test_roc_auc, test_pr_auc)
    
    # 9. 汇总结果并导出 Excel
    os.makedirs(CONFIG['output_dir'], exist_ok=True)
    excel_path = os.path.join(CONFIG['output_dir'], "RF_results.xlsx")

    with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
        # 9.1 交叉验证结果
        cv_df = pd.DataFrame({
            "fold": np.arange(1, len(cv_scores)+1),
            "roc_auc": cv_scores
        })
        cv_df.loc["mean"] = ["mean", cv_scores.mean()]
        cv_df.loc["std"]  = ["std",  cv_scores.std()]
        cv_df.to_excel(writer, sheet_name="CV_result", index=False)

        # 9.2 手动参数
        manual_params_df = pd.Series(CONFIG['manual_params']).to_frame("value")
        manual_params_df.to_excel(writer, sheet_name="Manual_params")

        # 9.3 训练集和测试集性能
        # 训练集性能
        y_train_pred_default = (y_train_prob >= 0.5).astype(int)
        train_metrics = pd.Series(
            calculate_metrics(y_train, y_train_pred_default, y_train_prob, prefix="train_")
        ).to_frame("value")
        
        # 测试集性能（使用手动阈值）
        y_test_pred_manual = (y_prob >= manual_threshold).astype(int)
        test_metrics = pd.Series(
            calculate_metrics(y_test, y_test_pred_manual, y_prob, prefix="test_")
        ).to_frame("value")
        
        # 合并训练集和测试集性能
        performance_df = pd.concat([train_metrics, test_metrics], axis=1)
        performance_df.columns = ['train', 'test']
        performance_df.to_excel(writer, sheet_name="Performance_metrics")

        # 9.4 按 CKD 分组评估（使用手动阈值）
        if "CKD" in data.columns:
            ckd_groups = data.loc[X_test.index, "CKD"]
            group_metrics = evaluate_by_ckd_group(best_rf, X_test, y_test,
                                                  ckd_groups,
                                                  threshold=manual_threshold)
            group_df = pd.DataFrame(group_metrics).T
            group_df.index.name = "CKD_stage"
            group_df.to_excel(writer, sheet_name="Metrics_by_CKD")

        # 9.5 特征重要性
        feat_imp = pd.Series(best_rf.feature_importances_, index=X.columns)\
                    .sort_values(ascending=False)\
                    .to_frame("importance")
        feat_imp.to_excel(writer, sheet_name="Feature_importance")

        # 9.6 阈值信息
        threshold_info = pd.Series({
            'manual_threshold': manual_threshold
        }).to_frame("value")
        threshold_info.to_excel(writer, sheet_name="Threshold_info")

    print(f"\n所有结果已汇总并保存至：{excel_path}")
    
    
    print("\n所有分析完成！")

if __name__ == "__main__":
    main()