# Clinical Trial Adverse Event Prediction - Model Training

This notebook trains multiple machine learning models to predict serious adverse events in clinical trial patients. We'll compare the performance of different algorithms and select the best model for deployment in clinical decision support systems.

In [None]:
# Import required libraries for clinical model training
import sys
import os

# Add the current directory to Python path for local imports
sys.path.append('.')

from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import AdaBoostClassifier
from xgboost import XGBClassifier

# Import our clinical training framework
from generic_trainer import train_adverse_event_model
from domino_short_id import domino_short_id

print("✓ Successfully imported clinical modeling libraries")

## Configure Clinical Trial Data

We'll use the transformed clinical features created in the data engineering step. This dataset includes patient demographics, clinical measurements, treatment information, and derived risk scores.

In [None]:
# Specify the transformed clinical dataset
transformed_df_filename = 'transformed_trial_data.csv'  # Updated for clinical data

print(f"📊 Using clinical dataset: {transformed_df_filename}")
print("📋 This dataset contains preprocessed patient features including:")
print("   • Patient demographics (age, BMI)")
print("   • Clinical history (years with condition, comorbidities)")
print("   • Treatment information (adherence, prior medications)")
print("   • Derived risk scores (clinical risk, adherence balance)")
print("   • Target variable: AdverseEvent (0=No, 1=Serious AE)")

## Define Clinical Prediction Models

We'll train three different algorithms commonly used in clinical prediction:

1. **AdaBoost**: Ensemble method good for handling imbalanced clinical data
2. **Gaussian Naive Bayes**: Simple probabilistic model for baseline comparison
3. **XGBoost**: Gradient boosting for capturing complex clinical interactions

In [None]:
# Configure clinical prediction models with appropriate hyperparameters

# AdaBoost: Good for imbalanced adverse event data
ada_model = {
    'model': AdaBoostClassifier(
        n_estimators=100,  # Increased for better clinical performance
        learning_rate=0.1,
        algorithm="SAMME",
        random_state=42
    ), 
    'name': "AdaBoost",
    'description': "Adaptive Boosting for adverse event prediction"
}

# Gaussian Naive Bayes: Simple baseline model
gnb_model = {
    'model': GaussianNB(), 
    'name': "GaussianNB",
    'description': "Probabilistic model assuming feature independence"
}

# XGBoost: Advanced ensemble method for complex clinical patterns
xgb_model = {
    'model': XGBClassifier(
        n_estimators=100,
        learning_rate=0.1,
        max_depth=4,  # Prevent overfitting on clinical data
        subsample=0.8,
        colsample_bytree=0.8,
        tree_method='hist',
        n_jobs=-1,
        use_label_encoder=False,
        eval_metric="auc",
        random_state=42
    ), 
    'name': "XGBoost",
    'description': "Gradient boosting for complex clinical interactions"
}

# List of all models to train
clinical_models = [ada_model, gnb_model, xgb_model]

print(f"🔬 Configured {len(clinical_models)} clinical prediction models:")
for model_dict in clinical_models:
    print(f"   • {model_dict['name']}: {model_dict['description']}")

## Train Clinical Prediction Models

Each model will be trained on the clinical trial data and evaluated using metrics relevant to patient safety:

- **Sensitivity (Recall)**: Ability to identify patients who will have adverse events
- **Specificity**: Ability to correctly identify safe patients
- **ROC-AUC**: Overall discriminative ability
- **Precision**: Proportion of high-risk predictions that are correct

All training results and model artifacts are logged to MLflow for regulatory compliance and reproducibility.

In [None]:
# Train all clinical prediction models
training_results = []

print("🚀 Starting clinical model training pipeline...\n")

for i, model_dict in enumerate(clinical_models, 1):
    model_name = model_dict['name']
    model_obj = model_dict['model']
    
    print(f"[{i}/{len(clinical_models)}] Training {model_name} for adverse event prediction...")
    print(f"📝 Model: {model_dict['description']}")
    
    try:
        # Train the model using our clinical training framework
        result = train_adverse_event_model(model_obj, model_name, transformed_df_filename)
        training_results.append({
            'model_name': model_name,
            'result': result,
            'status': 'success'
        })
        
        print(f"✅ {model_name} training completed successfully")
        
        # Display key clinical metrics
        if result and 'roc_auc' in result:
            print(f"   📊 ROC-AUC: {result['roc_auc']:.3f}")
        if result and 'recall_adverse' in result:
            print(f"   🎯 Sensitivity: {result['recall_adverse']:.3f}")
        if result and 'precision_adverse' in result:
            print(f"   ⚖️  Precision: {result['precision_adverse']:.3f}")
            
    except Exception as e:
        print(f"❌ Error training {model_name}: {str(e)}")
        training_results.append({
            'model_name': model_name,
            'result': None,
            'status': 'failed',
            'error': str(e)
        })
    
    print(f"{'='*60}\n")

print(f"🏁 Clinical model training pipeline completed!")
print(f"✅ {sum(1 for r in training_results if r['status'] == 'success')} models trained successfully")
print(f"❌ {sum(1 for r in training_results if r['status'] == 'failed')} models failed")

## Clinical Model Performance Summary

Review the performance of all trained models to select the best candidate for clinical deployment.

In [None]:
# Summarize clinical model performance
print("📈 CLINICAL MODEL PERFORMANCE SUMMARY")
print("="*50)

successful_models = [r for r in training_results if r['status'] == 'success' and r['result']]

if successful_models:
    # Sort by ROC-AUC score (primary clinical metric)
    successful_models.sort(key=lambda x: x['result'].get('roc_auc', 0), reverse=True)
    
    print(f"\n🏆 Best performing model: {successful_models[0]['model_name']}")
    print(f"   ROC-AUC: {successful_models[0]['result'].get('roc_auc', 'N/A'):.3f}")
    
    print("\n📊 All Model Performance:")
    print(f"{'Model':<12} {'ROC-AUC':<8} {'Sensitivity':<12} {'Precision':<10} {'Clinical Use':<20}")
    print("-" * 70)
    
    for model_result in successful_models:
        name = model_result['model_name']
        result = model_result['result']
        roc_auc = result.get('roc_auc', 0)
        sensitivity = result.get('recall_adverse', 0)
        precision = result.get('precision_adverse', 0)
        
        # Clinical interpretation
        if roc_auc >= 0.8:
            clinical_use = "Excellent for clinical use"
        elif roc_auc >= 0.7:
            clinical_use = "Good for clinical use"
        elif roc_auc >= 0.6:
            clinical_use = "Acceptable performance"
        else:
            clinical_use = "Needs improvement"
        
        print(f"{name:<12} {roc_auc:<8.3f} {sensitivity:<12.3f} {precision:<10.3f} {clinical_use:<20}")
    
    print("\n💡 Clinical Interpretation:")
    print("   • Sensitivity: Ability to identify patients who will have adverse events")
    print("   • Precision: Proportion of high-risk predictions that are correct")
    print("   • ROC-AUC: Overall ability to distinguish safe from at-risk patients")
    print("   • Target: ROC-AUC ≥ 0.75 for clinical deployment")
    
else:
    print("❌ No models trained successfully. Please check the data and try again.")

print("\n🔬 Next Steps:")
print("   1. Review model performance in MLflow experiments")
print("   2. Select best model for clinical validation")
print("   3. Deploy selected model for real-time patient risk scoring")
print("   4. Set up monitoring for model performance in clinical use")

## Regulatory and Clinical Considerations

### Model Validation Requirements
- **Internal Validation**: Cross-validation on training data (completed)
- **External Validation**: Test on independent patient cohort (recommended)
- **Temporal Validation**: Test on data from different time periods
- **Site Validation**: Test across different clinical sites

### Clinical Implementation
- **Sensitivity Priority**: For patient safety, prioritize high sensitivity over precision
- **Alert Fatigue**: Balance between catching adverse events and avoiding too many false alarms
- **Clinical Workflow**: Integrate predictions into existing clinical decision support systems
- **Human Oversight**: Always require clinical review of model predictions

### Regulatory Documentation
- All model training runs are logged in MLflow for audit trail
- Feature importance helps explain model decisions to clinicians
- Performance metrics align with clinical validation standards
- Model artifacts stored for regulatory submission if required