# PKSmart Model Validation Report

이 노트북은 커스텀 학습된 PKSmart 모델들의 검증 결과를 시각화합니다.

- **PK 모델**: Human CL, VDss, fup, MRT, thalf
- **LD50 모델**: 급성 경구 독성
- **Tox21 모델**: 12개 독성 endpoint

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc
import warnings
warnings.filterwarnings('ignore')

# Style settings
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12

RESULTS_DIR = 'results'

## 1. PK Model Validation

In [None]:
# Load PK results
pk_results = pd.read_csv(f'{RESULTS_DIR}/pk_validation_results.csv')
pk_predictions = pd.read_csv(f'{RESULTS_DIR}/pk_predictions.csv')

print("=== PK Model Validation Results ===")
display(pk_results)

In [None]:
# PK Metrics Bar Chart
fig, axes = plt.subplots(1, 3, figsize=(14, 5))

# R² Score
ax = axes[0]
colors = ['#2196F3' if r > 0.5 else '#f44336' for r in pk_results['R2']]
ax.barh(pk_results['Model'], pk_results['R2'], color=colors)
ax.axvline(x=0.5, color='gray', linestyle='--', label='R²=0.5')
ax.set_xlabel('R² Score')
ax.set_title('PK Models: R² Score')
ax.set_xlim(0, 1)

# RMSE
ax = axes[1]
ax.barh(pk_results['Model'], pk_results['RMSE'], color='#FF9800')
ax.set_xlabel('RMSE')
ax.set_title('PK Models: RMSE')

# 2-Fold Accuracy
ax = axes[2]
colors = ['#4CAF50' if acc > 50 else '#f44336' for acc in pk_results['2-Fold%']]
ax.barh(pk_results['Model'], pk_results['2-Fold%'], color=colors)
ax.axvline(x=50, color='gray', linestyle='--', label='50%')
ax.set_xlabel('2-Fold Accuracy (%)')
ax.set_title('PK Models: 2-Fold Accuracy')
ax.set_xlim(0, 100)

plt.tight_layout()
plt.savefig(f'{RESULTS_DIR}/pk_metrics_bar.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Predicted vs Actual Scatter Plots
models = pk_predictions['model'].unique()
n_models = len(models)
n_cols = 3
n_rows = (n_models + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5*n_rows))
axes = axes.flatten() if n_models > 1 else [axes]

for idx, model in enumerate(models):
    ax = axes[idx]
    data = pk_predictions[pk_predictions['model'] == model]
    
    ax.scatter(data['y_true'], data['y_pred'], alpha=0.5, s=30)
    
    # y=x line
    lims = [min(data['y_true'].min(), data['y_pred'].min()),
            max(data['y_true'].max(), data['y_pred'].max())]
    ax.plot(lims, lims, 'r--', linewidth=2, label='y=x')
    
    # 2-fold lines
    if data['log_transformed'].iloc[0]:
        fold_2 = np.log10(2)
        ax.plot(lims, [lims[0]+fold_2, lims[1]+fold_2], 'g--', alpha=0.5, label='2-fold')
        ax.plot(lims, [lims[0]-fold_2, lims[1]-fold_2], 'g--', alpha=0.5)
    
    ax.set_xlabel('Actual (log10)' if data['log_transformed'].iloc[0] else 'Actual')
    ax.set_ylabel('Predicted (log10)' if data['log_transformed'].iloc[0] else 'Predicted')
    ax.set_title(f'{model}')
    ax.legend(loc='upper left', fontsize=8)

# Hide empty subplots
for idx in range(n_models, len(axes)):
    axes[idx].set_visible(False)

plt.tight_layout()
plt.savefig(f'{RESULTS_DIR}/pk_pred_vs_actual.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Residual Distribution
fig, axes = plt.subplots(1, min(5, n_models), figsize=(15, 4))
if n_models == 1:
    axes = [axes]

for idx, model in enumerate(models[:5]):
    ax = axes[idx]
    data = pk_predictions[pk_predictions['model'] == model]
    residuals = data['y_pred'] - data['y_true']
    
    ax.hist(residuals, bins=30, edgecolor='black', alpha=0.7, color='#2196F3')
    ax.axvline(x=0, color='red', linestyle='--', linewidth=2)
    ax.set_xlabel('Residual (Pred - Actual)')
    ax.set_ylabel('Count')
    ax.set_title(f'{model.split("_")[-1]}')

plt.tight_layout()
plt.savefig(f'{RESULTS_DIR}/pk_residuals.png', dpi=150, bbox_inches='tight')
plt.show()

## 2. LD50 Model Validation

In [None]:
# Load LD50 results
ld50_results = pd.read_csv(f'{RESULTS_DIR}/ld50_validation_results.csv')
ld50_predictions = pd.read_csv(f'{RESULTS_DIR}/ld50_predictions.csv')

print("=== LD50 Model Validation Results ===")
display(ld50_results)

In [None]:
# LD50 Predicted vs Actual
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Scatter plot
ax = axes[0]
ax.scatter(ld50_predictions['y_true'], ld50_predictions['y_pred'], alpha=0.3, s=20)

lims = [min(ld50_predictions['y_true'].min(), ld50_predictions['y_pred'].min()),
        max(ld50_predictions['y_true'].max(), ld50_predictions['y_pred'].max())]
ax.plot(lims, lims, 'r--', linewidth=2, label='y=x')

fold_2 = np.log10(2)
ax.plot(lims, [lims[0]+fold_2, lims[1]+fold_2], 'g--', alpha=0.5, label='2-fold')
ax.plot(lims, [lims[0]-fold_2, lims[1]-fold_2], 'g--', alpha=0.5)

ax.set_xlabel('Actual log10(LD50+1)')
ax.set_ylabel('Predicted log10(LD50+1)')
ax.set_title(f"LD50: Predicted vs Actual (R²={ld50_results['R2'].iloc[0]:.3f})")
ax.legend()

# Residual histogram
ax = axes[1]
residuals = ld50_predictions['y_pred'] - ld50_predictions['y_true']
ax.hist(residuals, bins=50, edgecolor='black', alpha=0.7, color='#FF9800')
ax.axvline(x=0, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Residual (Pred - Actual)')
ax.set_ylabel('Count')
ax.set_title('LD50: Residual Distribution')

plt.tight_layout()
plt.savefig(f'{RESULTS_DIR}/ld50_validation.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Tox21 Model Validation

In [None]:
# Load Tox21 results
tox21_results = pd.read_csv(f'{RESULTS_DIR}/tox21_validation_results.csv')
tox21_predictions = pd.read_csv(f'{RESULTS_DIR}/tox21_predictions.csv')

print("=== Tox21 Model Validation Results ===")
display(tox21_results)

In [None]:
# Tox21 Metrics Bar Chart
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# AUC-ROC
ax = axes[0]
colors = ['#4CAF50' if auc > 0.7 else '#FF9800' if auc > 0.6 else '#f44336' 
          for auc in tox21_results['AUC']]
ax.barh(tox21_results['Model'], tox21_results['AUC'], color=colors)
ax.axvline(x=0.7, color='green', linestyle='--', alpha=0.7, label='Good (0.7)')
ax.axvline(x=0.5, color='red', linestyle='--', alpha=0.7, label='Random (0.5)')
ax.set_xlabel('AUC-ROC')
ax.set_title('Tox21 Models: AUC-ROC')
ax.set_xlim(0, 1)
ax.legend()

# F1 Score
ax = axes[1]
ax.barh(tox21_results['Model'], tox21_results['F1'], color='#9C27B0')
ax.set_xlabel('F1 Score')
ax.set_title('Tox21 Models: F1 Score')
ax.set_xlim(0, 1)

plt.tight_layout()
plt.savefig(f'{RESULTS_DIR}/tox21_metrics_bar.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ROC Curves for all Tox21 endpoints
fig, ax = plt.subplots(figsize=(10, 8))

colors = plt.cm.tab20(np.linspace(0, 1, 12))

for idx, model_name in enumerate(tox21_results['Model']):
    data = tox21_predictions[tox21_predictions['model'] == model_name]
    if len(data) < 10:
        continue
    
    fpr, tpr, _ = roc_curve(data['y_true'], data['y_pred_proba'])
    roc_auc = auc(fpr, tpr)
    
    ax.plot(fpr, tpr, color=colors[idx], lw=2,
            label=f'{model_name} (AUC={roc_auc:.2f})')

ax.plot([0, 1], [0, 1], 'k--', lw=2, label='Random')
ax.set_xlim([0, 1])
ax.set_ylim([0, 1.05])
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('Tox21: ROC Curves (All Endpoints)')
ax.legend(loc='lower right', fontsize=8)

plt.tight_layout()
plt.savefig(f'{RESULTS_DIR}/tox21_roc_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Summary Heatmap
fig, ax = plt.subplots(figsize=(8, 10))

metrics_df = tox21_results.set_index('Model')[['AUC', 'Accuracy', 'F1', 'Precision', 'Recall']]

sns.heatmap(metrics_df, annot=True, fmt='.3f', cmap='RdYlGn', 
            vmin=0, vmax=1, ax=ax, linewidths=0.5,
            cbar_kws={'label': 'Score'})
ax.set_title('Tox21 Models: Performance Metrics Heatmap')

plt.tight_layout()
plt.savefig(f'{RESULTS_DIR}/tox21_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Summary

In [None]:
print("="*60)
print("PKSmart Model Validation Summary")
print("="*60)

print("\n--- PK Models ---")
print(f"Average R²: {pk_results['R2'].mean():.4f}")
print(f"Average 2-Fold Accuracy: {pk_results['2-Fold%'].mean():.1f}%")

print("\n--- LD50 Model ---")
print(f"R²: {ld50_results['R2'].iloc[0]:.4f}")
print(f"2-Fold Accuracy: {ld50_results['2-Fold%'].iloc[0]:.1f}%")

print("\n--- Tox21 Models ---")
print(f"Average AUC: {tox21_results['AUC'].mean():.4f}")
print(f"Average Accuracy: {tox21_results['Accuracy'].mean():.4f}")
print(f"Average F1: {tox21_results['F1'].mean():.4f}")

print("\n" + "="*60)