In [9]:
import pandas as pd
import numpy as np
import xgboost as xgb
import optuna
import joblib
import shap
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedGroupKFold, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import (accuracy_score, f1_score, precision_score, recall_score, 
 roc_auc_score, roc_curve, classification_report, confusion_matrix, auc)
from sklearn.utils.class_weight import compute_class_weight
import seaborn as sns
import os
from datetime import datetime

In [10]:
# 1. Configuration and Setup

# Parameters
needle_height = '1.3'
conjugate = 'chlr'
n_trials = 50
dataset_key = f"{needle_height}_{conjugate}"

# Create output directory with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = f"results/{dataset_key}_{timestamp}"
os.makedirs(output_dir, exist_ok=True)

# Base directory
base_dir = r"D:\20241129_solid_nN_1.3_2.4_mdck_siRNA_tnsfn_chlr"
dataset_path = base_dir + r"solid_1.3_chlr_cell_level.csv"

# Define morphological and intensity features
cell_morph_features = [
 'area', 'perimeter', 'major_axis_length', 'minor_axis_length', 
 'eccentricity', 'circularity', 'solidity', 'orientation'
]

nuclear_morph_features = [
 'nuclear_area', 'nuclear_perimeter', 'nuclear_major_axis_length', 
 'nuclear_minor_axis_length', 'nuclear_eccentricity', 'nuclear_circularity', 
 'nuclear_solidity', 'nuclear_orientation'
]

channel_feature_suffixes = [
 'intensity_p10', 'intensity_p25', 'intensity_p50', 
 'intensity_p75', 'intensity_p90'
]

protein_channels = ['actin', 'caveolin', 'clathrin_hc', 'nuclei']

# Generate feature list with caveolin features first to ensure dominance
feature_list = cell_morph_features + nuclear_morph_features

for suffix in channel_feature_suffixes:
 feature_list.append(f"caveolin_{suffix}")

for ch in protein_channels:
 if ch != 'caveolin':
  for suffix in channel_feature_suffixes:
   feature_list.append(f"{ch}_{suffix}")


In [11]:
# 2. Model Training and Evaluation

def process_dataset(dataset_path, dataset_name, area_percentiles=(2, 98)):
 print(f"\n=== Processing {dataset_name} ===")
 
 # Extract conjugate type from dataset_name
 conjugate_type = dataset_name.split('_')[1]
 
 # Set the correct intensity column name
 intensity_column = f"{conjugate_type}_intensity_mean"
 
 print(f"Using intensity column: {intensity_column}")
 
 # Load dataset
 df = pd.read_csv(dataset_path)
 
 # Threshold for chlr
 intensity_threshold = 300
 
 # Apply area filtering based on percentiles (2, 98)
 cell_area_min, cell_area_max = np.percentile(df['area'], area_percentiles)
 nuclear_area_min, nuclear_area_max = np.percentile(df['nuclear_area'], area_percentiles)
 
 # Filter cells and nuclei based on thresholds
 df_filtered = df[
  (df['area'] >= cell_area_min) & 
  (df['area'] <= cell_area_max) & 
  (df[intensity_column] > intensity_threshold)
 ].copy()
 
 nuclei_threshold = (
  (df_filtered['nuclear_area'] >= nuclear_area_min) & 
  (df_filtered['nuclear_area'] <= nuclear_area_max)
 )
 
 nuclear_cols = [col for col in df_filtered.columns if col.startswith('nuclear_')]
 df_filtered.loc[~nuclei_threshold, nuclear_cols] = np.nan
 
 # Binarise target variable
 df_filtered['conjugate_category'] = np.where(
  df_filtered[intensity_column] > 300,
  1,
  0
 )
 
 print("Unique conjugate_category values:", df_filtered['conjugate_category'].unique())
 print("Value counts:\n", df_filtered['conjugate_category'].value_counts())
 
 y = df_filtered['conjugate_category']
 
 label_encoder = LabelEncoder()
 y_encoded = label_encoder.fit_transform(y)
 
 X = df_filtered[feature_list]
 
 images = df_filtered['image_id']
 
 # Dictionaries for aggregated metrics
 all_fold_metrics = []
 class_report_list = []
 shap_values_list = []
 mean_fpr = np.linspace(0, 1, 100)
 tprs = []
 aucs = []
 
 # Class distribution tracking
 class_distributions = []
 
 # Create dictionaries to store performance metrics
 train_metrics = {}
 test_metrics = {}
 aggregated_metrics = {
  'accuracy': [],
  'precision': [],
  'recall': [],
  'f1': [],
  'roc_auc': []
 }
 
 # Outer CV: Stratified Group K-Fold
 outer_cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)
 
 # Dummy model for SHAP initialisation
 dummy_model = xgb.XGBClassifier()
 dummy_model.fit(X.iloc[:10], y_encoded[:10])
 explainer = shap.TreeExplainer(dummy_model)
 
 for fold, (train_idx, test_idx) in enumerate(outer_cv.split(X, y_encoded, groups=images), start=1):
  print(f"\n=== Outer Fold {fold} ===")
  print(f"Fold {fold}: n_test={len(test_idx)}")
  print(f"Class distribution: Class 0: {np.sum(y_encoded[test_idx] == 0)}, Class 1: {np.sum(y_encoded[test_idx] == 1)}")
  
  X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
  y_train, y_test = y_encoded[train_idx], y_encoded[test_idx]
  
  # Class distribution tracking in fold
  class_distributions.append({
   "train": np.bincount(y_train, minlength=2),
   "test": np.bincount(y_test, minlength=2)
  })
  
  scaler = StandardScaler()
  X_train_scaled = scaler.fit_transform(X_train)
  X_test_scaled = scaler.transform(X_test)
  
  def objective(trial):
   params = {
    'max_depth': trial.suggest_int('max_depth', 3, 10),
    'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.1),
    'subsample': trial.suggest_float('subsample', 0.6, 1.0),
    'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
    'min_child_weight': trial.suggest_int('min_child_weight', 1, 10),
    'gamma': trial.suggest_float('gamma', 0, 5),
    'reg_alpha': trial.suggest_float('reg_alpha', 0.0, 10.0),
    'reg_lambda': trial.suggest_float('reg_lambda', 0.0, 10.0),
    'n_estimators': trial.suggest_int('n_estimators', 50, 200)
   }
   
   model = xgb.XGBClassifier(random_state=42, **params)
   
   inner_cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
   inner_scores = []
   
   for inner_train_idx, inner_valid_idx in inner_cv.split(X_train_scaled, y_train):
    X_inner_train = X_train_scaled[inner_train_idx]
    X_inner_valid = X_train_scaled[inner_valid_idx]
    y_inner_train = y_train[inner_train_idx]
    y_inner_valid = y_train[inner_valid_idx]
    
    model.fit(X_inner_train, y_inner_train)
    y_pred_inner = model.predict(X_inner_valid)
    score = accuracy_score(y_inner_valid, y_pred_inner)
    inner_scores.append(score)
   
   return np.mean(inner_scores)
  
  study = optuna.create_study(direction='maximize')
  study.optimize(objective, n_trials=n_trials)
  
  best_params = study.best_params
  
  best_model = xgb.XGBClassifier(random_state=42, **best_params)
  best_model.fit(X_train_scaled, y_train)
  
  # Calculate comprehensive metrics
  y_test_pred = best_model.predict(X_test_scaled)
  y_test_proba = best_model.predict_proba(X_test_scaled)
  
  # Store fold metrics
  fold_metrics = {
   "fold": fold,
   "accuracy": accuracy_score(y_test, y_test_pred),
   "f1_weighted": f1_score(y_test, y_test_pred, average='weighted'),
   "precision_weighted": precision_score(y_test, y_test_pred, average='weighted'),
   "recall_weighted": recall_score(y_test, y_test_pred, average='weighted'),
   "roc_auc": roc_auc_score(y_test, y_test_proba, multi_class='ovr')
  }
  all_fold_metrics.append(fold_metrics)
  
  # Generate class-wise metrics
  class_report = classification_report(y_test, y_test_pred, output_dict=True)
  class_report_list.append(class_report)
  
  # Calculate SHAP values
  explainer = shap.TreeExplainer(best_model)
  shap_values = explainer.shap_values(X_test_scaled)
  shap_values_list.append(shap_values)
  
  # Save model
  model_filename = f"{output_dir}/model_{dataset_name}_fold_{fold}.joblib"
  joblib.dump(best_model, model_filename)
  print(f"Model saved as {model_filename}")
  
  try:
   fpr, tpr, _ = roc_curve((y_test == 1).astype(int), y_test_proba[:, 1])
   roc_auc = auc(fpr, tpr)
   
   interp_tpr = np.interp(mean_fpr, fpr, tpr)
   interp_tpr[0] = 0.0
   tprs.append(interp_tpr)
   aucs.append(roc_auc)
  except Exception as e:
   print(f"Error calculating ROC: {str(e)}")
 
 # After all folds complete
 metrics_df = pd.DataFrame(all_fold_metrics)
 avg_metrics = {
  'accuracy': metrics_df['accuracy'].mean(),
  'accuracy_std': metrics_df['accuracy'].std(),
  'f1_weighted': metrics_df['f1_weighted'].mean(),
  'f1_weighted_std': metrics_df['f1_weighted'].std(),
  'precision_weighted': metrics_df['precision_weighted'].mean(),
  'precision_weighted_std': metrics_df['precision_weighted'].std(),
  'recall_weighted': metrics_df['recall_weighted'].mean(),
  'recall_weighted_std': metrics_df['recall_weighted'].std(),
  'roc_auc': metrics_df['roc_auc'].mean(),
  'roc_auc_std': metrics_df['roc_auc'].std()
 }
 
 best_fold_idx = np.argmax(metrics_df['roc_auc'])
 best_fold = metrics_df.iloc[best_fold_idx]['fold']
 best_model_path = f"{output_dir}/model_{dataset_name}_fold_{best_fold}.joblib"
 best_model = joblib.load(best_model_path)
 
 return avg_metrics, best_model, best_fold, tprs, aucs, mean_fpr, shap_values_list, X, test_idx


In [None]:
# 3. Visualisation and Analysis

def create_visualizations(avg_metrics, best_model, best_fold, tprs, aucs, mean_fpr, shap_values_list, X, test_idx, dataset_name, output_dir):
 # Aggregate ROC curve
 if len(tprs) > 0 and len(aucs) > 0:
  try:
   fig, ax = plt.subplots(figsize=(10, 8))
   
   for i, tpr in enumerate(tprs):
    ax.plot(mean_fpr, tpr, alpha=0.3, label=f'ROC fold {i+1} (AUC = {aucs[i]:.2f})')
   
   mean_tpr = np.mean(tprs, axis=0)
   mean_auc = auc(mean_fpr, mean_tpr)
   std_auc = np.std(aucs)
   ax.plot(mean_fpr, mean_tpr, 'b-', label=f'Mean ROC (AUC = {mean_auc:.2f} ± {std_auc:.2f})', lw=2)
   
   std_tpr = np.std(tprs, axis=0)
   tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
   tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
   ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=0.2, label=r'± 1 std. dev.')
   
   ax.plot([0, 1], [0, 1], 'k--')
   ax.set_xlim([0.0, 1.0])
   ax.set_ylim([0.0, 1.05])
   ax.set_xlabel('False Positive Rate')
   ax.set_ylabel('True Positive Rate')
   ax.set_title('Aggregate ROC Curve for Class 1')
   ax.legend(loc="lower right")
   plt.savefig(f"{output_dir}/aggregate_roc_class1_{dataset_name}.png")
   plt.close(fig)
  except Exception as e:
   print(f"Error generating aggregate ROC curve: {str(e)}")
 
 # Best fold ROC curve
 try:
  fig, ax = plt.subplots(figsize=(10, 8))
  X_test = X.iloc[test_idx]
  y_test = y_encoded[test_idx]
  X_test_scaled = scaler.transform(X_test)
  y_test_proba = best_model.predict_proba(X_test_scaled)
  
  fpr, tpr, _ = roc_curve((y_test == 1).astype(int), y_test_proba[:, 1])
  roc_auc = auc(fpr, tpr)
  ax.plot(fpr, tpr, label=f'Class 1 (AUC = {roc_auc:.2f})')
  ax.plot([0, 1], [0, 1], 'k--')
  ax.set_xlabel('False Positive Rate')
  ax.set_ylabel('True Positive Rate')
  ax.set_title(f'ROC Curve for Class 1 - Best Fold {best_fold}')
  ax.legend()
  plt.savefig(f"{output_dir}/best_fold_roc_class1_{dataset_name}.png")
  plt.close(fig)
 except Exception as e:
  print(f"Error generating best fold ROC curve: {str(e)}")
 
 # Aggregate SHAP plot
 try:
  fig, ax = plt.subplots(figsize=(12, 10))
  shap_values_combined = np.vstack([sv for sv in shap_values_list])
  X_test_combined = pd.concat([X.iloc[test_idx] for _, test_idx in list(outer_cv.split(X, y_encoded, groups=images))])
  shap.summary_plot(shap_values_combined, X_test_combined, plot_type="beeswarm", show=False)
  plt.title('Aggregate SHAP Feature Importance')
  plt.tight_layout()
  plt.savefig(f"{output_dir}/aggregate_shap_beeswarm_{dataset_name}.png")
  plt.close(fig)
 except Exception as e:
  print(f"Error generating aggregate SHAP plot: {str(e)}")
 
 # Best fold SHAP plot
 try:
  fig, ax = plt.subplots(figsize=(12, 10))
  X_test = X.iloc[test_idx]
  explainer = shap.TreeExplainer(best_model)
  shap_values = explainer.shap_values(X_test_scaled)
  shap.summary_plot(shap_values, X_test, plot_type="beeswarm", show=False)
  plt.title(f'SHAP Feature Importance - Best Fold {best_fold}')
  plt.tight_layout()
  plt.savefig(f"{output_dir}/best_fold_shap_beeswarm_{dataset_name}.png")
  plt.close(fig)
 except Exception as e:
  print(f"Error generating best fold SHAP plot: {str(e)}")
 
 # Print final results
 print("\n=== Final Results ===")
 print(f"Dataset: {dataset_name}")
 print(f"Accuracy: {avg_metrics['accuracy']:.4f} ± {avg_metrics['accuracy_std']:.4f}")
 print(f"F1 Score (weighted): {avg_metrics['f1_weighted']:.4f} ± {avg_metrics['f1_weighted_std']:.4f}")
 print(f"Precision (weighted): {avg_metrics['precision_weighted']:.4f} ± {avg_metrics['precision_weighted_std']:.4f}")
 print(f"Recall (weighted): {avg_metrics['recall_weighted']:.4f} ± {avg_metrics['recall_weighted_std']:.4f}")
 print(f"ROC AUC: {avg_metrics['roc_auc']:.4f} ± {avg_metrics['roc_auc_std']:.4f}")
 print(f"Best Fold: {best_fold}")
