# Modeling and Evaluation

This notebook trains and evaluates resistance prediction models:
1. Load features and labels
2. Train XGBoost model
3. Evaluate predictions
4. Feature importance analysis
5. Subgroup analysis


In [None]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split

# Add src to path
sys.path.insert(0, str(Path.cwd().parent))

from src.models import XGBResistancePredictor
from src.evaluate import generate_evaluation_report
from src.utils import load_adata


In [None]:
# Load features
features_df = pd.read_csv("../data/synthetic/sample_features.csv")
print(f"Loaded features: {features_df.shape}")

# Load labels from adata
adata = load_adata("../data/synthetic/adata_with_trajectories.h5ad")
patient_labels = adata.obs.groupby('patient_id')['resistance_mechanism'].first().reset_index()

# Merge
data = features_df.merge(patient_labels, on='patient_id', how='inner')
print(f"Merged data: {data.shape}")

# Prepare X and y
feature_cols = [col for col in data.columns if col not in ['patient_id', 'resistance_mechanism']]
X = data[feature_cols].fillna(0)
y = data[['patient_id', 'resistance_mechanism']]

print(f"Features: {X.shape}")
print(f"Labels: {y.shape}")


In [None]:
# Train model
model = XGBResistancePredictor(random_state=42)
metrics = model.train(X, y, validation_split=0.2)

print("Training Metrics:")
print(f"  Train AUC: {metrics['train_auc']:.3f}")
print(f"  Val AUC: {metrics['val_auc']:.3f}")
print(f"  Train AP: {metrics['train_ap']:.3f}")
print(f"  Val AP: {metrics['val_ap']:.3f}")


In [None]:
# Get predictions
predictions = model.predict(X)
predictions_df = pd.DataFrame(
    predictions,
    columns=[f'prob_{mech}' for mech in model.resistance_mechanisms]
)
predictions_df['patient_id'] = data['patient_id'].values
predictions_df['predicted_mechanism'] = [
    model.resistance_mechanisms[np.argmax(pred)]
    for pred in predictions
]
predictions_df['true_mechanism'] = data['resistance_mechanism'].values

print("Predictions:")
print(predictions_df.head(10))


In [None]:
# Feature importance
importance_df = model.get_feature_importance()
top_features = importance_df.groupby('feature')['importance'].sum().sort_values(ascending=False).head(20)

plt.figure(figsize=(10, 8))
top_features.plot(kind='barh')
plt.xlabel('Total Importance')
plt.title('Top 20 Most Important Features')
plt.tight_layout()
plt.show()


In [None]:
# Save results
predictions_df.to_csv("../results/predictions.csv", index=False)
importance_df.to_csv("../results/feature_importance.csv", index=False)
model.save("../results/xgb_model.pkl")

print("Results saved to ../results/")
