# 模型評估 (Model Evaluation)
## 第三期大腸癌存活預測研究

本筆記本評估訓練好的模型效能

In [None]:
# 導入套件
import pandas as pd
import numpy as np
import sys
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# 設定路徑
project_root = Path.cwd().parent
sys.path.append(str(project_root))

# 導入自訂模組
from src.model_evaluation import SurvivalModelEvaluator
from src.model_training import SurvivalModelTrainer
from src.utils import load_config

# 設定繪圖風格
sns.set_style('whitegrid')
plt.rcParams['font.sans-serif'] = ['Microsoft JhengHei', 'Arial']
plt.rcParams['axes.unicode_minus'] = False

print("套件載入完成")

In [None]:
# 載入配置
config = load_config(str(project_root / 'config' / 'config.yaml'))

# 載入測試資料
test_df = pd.read_csv(project_root / 'data' / 'processed' / 'test_features.csv')

print(f"測試集形狀: {test_df.shape}")

In [None]:
# 初始化評估器
evaluator = SurvivalModelEvaluator(config)

# 初始化訓練器 (用於載入模型)
trainer = SurvivalModelTrainer(config)

print("評估器已初始化")

In [None]:
# 準備測試資料
duration_col = 'survival_time'
event_col = 'event'

if duration_col in test_df.columns and event_col in test_df.columns:
    y_test_time = test_df[duration_col].values
    y_test_event = test_df[event_col].values
    
    feature_cols = [col for col in test_df.columns 
                   if col not in [duration_col, event_col]]
    X_test = test_df[feature_cols]
    
    print(f"測試特徵數量: {len(feature_cols)}")
    print(f"測試樣本數量: {len(X_test)}")
else:
    print("警告: 找不到存活時間或事件欄位")

## 1. 載入並評估 Cox 模型

In [None]:
# 載入 Cox 模型
cox_model_path = project_root / 'models' / 'cox_ph.pkl'

if cox_model_path.exists():
    try:
        cox_model = trainer.load_model(str(cox_model_path))
        
        # 預測風險分數
        risk_scores = cox_model.predict_partial_hazard(test_df[feature_cols])
        
        # 計算 C-index
        c_index = evaluator.calculate_c_index(
            y_test_time, 
            y_test_event, 
            risk_scores
        )
        
        # 生成評估報告
        cox_metrics = {'C-index': c_index}
        report_path = project_root / 'results' / 'tables' / 'cox_evaluation.txt'
        evaluator.generate_evaluation_report(
            'Cox Proportional Hazards',
            cox_metrics,
            str(report_path)
        )
        
        # 繪製風險分組生存曲線
        fig_path = project_root / 'results' / 'figures' / 'cox_risk_groups.png'
        evaluator.plot_risk_groups(
            risk_scores,
            y_test_time,
            y_test_event,
            n_groups=3,
            save_path=str(fig_path)
        )
        
    except Exception as e:
        print(f"Cox 模型評估失敗: {e}")
else:
    print("未找到 Cox 模型")

## 2. 評估隨機存活森林

In [None]:
# 載入隨機存活森林模型
rsf_model_path = project_root / 'models' / 'random_survival_forest.pkl'

if rsf_model_path.exists():
    try:
        from sksurv.util import Surv
        
        rsf_model = trainer.load_model(str(rsf_model_path))
        
        # 準備測試資料
        X_test_array = X_test.values
        y_test_surv = Surv.from_arrays(y_test_event.astype(bool), y_test_time)
        
        # 預測風險分數
        risk_scores = rsf_model.predict(X_test_array)
        
        # 計算 C-index
        c_index = evaluator.calculate_c_index(
            y_test_time,
            y_test_event,
            risk_scores
        )
        
        # 生成評估報告
        rsf_metrics = {'C-index': c_index}
        report_path = project_root / 'results' / 'tables' / 'rsf_evaluation.txt'
        evaluator.generate_evaluation_report(
            'Random Survival Forest',
            rsf_metrics,
            str(report_path)
        )
        
        # 特徵重要性
        if hasattr(rsf_model, 'feature_importances_'):
            fig_path = project_root / 'results' / 'figures' / 'rsf_feature_importance.png'
            evaluator.plot_feature_importance(
                feature_cols,
                rsf_model.feature_importances_,
                top_n=20,
                save_path=str(fig_path)
            )
        
        # 繪製風險分組生存曲線
        fig_path = project_root / 'results' / 'figures' / 'rsf_risk_groups.png'
        evaluator.plot_risk_groups(
            risk_scores,
            y_test_time,
            y_test_event,
            n_groups=3,
            save_path=str(fig_path)
        )
        
    except ImportError:
        print("警告: 需要 scikit-survival 套件")
    except Exception as e:
        print(f"隨機存活森林評估失敗: {e}")
else:
    print("未找到隨機存活森林模型")

## 3. Kaplan-Meier 分析

In [None]:
# 繪製整體 Kaplan-Meier 曲線
if duration_col in test_df.columns and event_col in test_df.columns:
    fig_path = project_root / 'results' / 'figures' / 'kaplan_meier_overall.png'
    evaluator.plot_kaplan_meier_curves(
        y_test_time,
        y_test_event,
        title="整體 Kaplan-Meier 生存曲線",
        save_path=str(fig_path)
    )

## 4. 模型比較

In [None]:
# 比較所有模型的效能
if len(evaluator.results) > 0:
    comparison_df = pd.DataFrame(evaluator.results).T
    print("\n模型效能比較:")
    print(comparison_df)
    
    # 儲存比較結果
    comparison_path = project_root / 'results' / 'tables' / 'model_comparison.csv'
    comparison_df.to_csv(comparison_path)
    print(f"\n比較結果已儲存至: {comparison_path}")
    
    # 視覺化比較
    if 'C-index' in comparison_df.columns:
        plt.figure(figsize=(10, 6))
        comparison_df['C-index'].plot(kind='bar')
        plt.title('模型 C-index 比較')
        plt.ylabel('C-index')
        plt.xlabel('模型')
        plt.xticks(rotation=45)
        plt.ylim([0.5, 1.0])
        plt.axhline(y=0.5, color='r', linestyle='--', label='Random')
        plt.legend()
        plt.tight_layout()
        
        fig_path = project_root / 'results' / 'figures' / 'model_comparison.png'
        plt.savefig(fig_path, dpi=300, bbox_inches='tight')
        print(f"比較圖已儲存至: {fig_path}")
        plt.show()
else:
    print("沒有可用的評估結果")

In [None]:
print("\n模型評估完成！")
print(f"所有結果已儲存至: {project_root / 'results'}")