In [None]:
import pandas as pd
import numpy as np
import xgboost as xgb
import optuna
import joblib
import shap
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedGroupKFold, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import (accuracy_score, f1_score, precision_score, recall_score, 
                           roc_auc_score, roc_curve, classification_report, confusion_matrix, auc)
from sklearn.utils.class_weight import compute_class_weight
import seaborn as sns
import os
from datetime import datetime

In [None]:
# Hardcode parameters instead of using command line arguments
needle_height = '1.3'
conjugate = 'chlr'
n_trials = 50
dataset_key = f"{needle_height}_{conjugate}"

# Create output directory with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = f"results/{dataset_key}_{timestamp}"
os.makedirs(output_dir, exist_ok=True)

# Base directory
base_dir = r"D:\20241129_solid_nN_1.3_2.4_mdck_siRNA_tnsfn_chlr"

# Define dataset path directly for 1.3_chlr
dataset_path = base_dir + r"\20241129_solid_nN_1.3_mdck_chlr_dataset\solid_1.3_chlr_cell_level.csv"

# Define morphological and intensity features
cell_morph_features = [
    'area', 'perimeter', 'major_axis_length', 'minor_axis_length', 
    'eccentricity', 'circularity', 'solidity', 'orientation'
]

nuclear_morph_features = [
    'nuclear_area', 'nuclear_perimeter', 'nuclear_major_axis_length', 
    'nuclear_minor_axis_length', 'nuclear_eccentricity', 'nuclear_circularity', 
    'nuclear_solidity', 'nuclear_orientation'
]

channel_feature_suffixes = [
    'intensity_p10', 'intensity_p25', 'intensity_p50', 
    'intensity_p75', 'intensity_p90'
]

protein_channels = ['actin', 'caveolin', 'clathrin_hc', 'nuclei']

# Generate feature list with caveolin features first to ensure dominance
feature_list = cell_morph_features + nuclear_morph_features

for suffix in channel_feature_suffixes:
    feature_list.append(f"caveolin_{suffix}")

for ch in protein_channels:
    if ch != 'caveolin':
        for suffix in channel_feature_suffixes:
            feature_list.append(f"{ch}_{suffix}")

def process_dataset(dataset_path, dataset_name, area_percentiles=(2, 98)):
    print(f"\n=== Processing {dataset_name} ===")
    
    # Extract conjugate type from dataset_name
    conjugate_type = dataset_name.split('_')[1] # Will be 'chlr'
    
    # Set the correct intensity column name
    intensity_column = f"{conjugate_type}_intensity_mean"
    
    print(f"Using intensity column: {intensity_column}")
    
    # Load dataset
    df = pd.read_csv(dataset_path)
    
    # Determine threshold for chlr
    intensity_threshold = 300
    
    # Apply area filtering based on percentiles
    cell_area_min, cell_area_max = np.percentile(df['area'], area_percentiles)
    nuclear_area_min, nuclear_area_max = np.percentile(df['nuclear_area'], area_percentiles)
    
    # Filter cells and nuclei based on thresholds
    df_filtered = df[
        (df['area'] >= cell_area_min) & 
        (df['area'] <= cell_area_max) & 
        (df[intensity_column] > intensity_threshold)
    ].copy()
    
    nuclei_threshold = (
        (df_filtered['nuclear_area'] >= nuclear_area_min) & 
        (df_filtered['nuclear_area'] <= nuclear_area_max)
    )
    
    nuclear_cols = [col for col in df_filtered.columns if col.startswith('nuclear_')]
    df_filtered.loc[~nuclei_threshold, nuclear_cols] = np.nan
    
    # Convert target into categorical bins
    num_bins = 5
    df_filtered['conjugate_category'] = pd.qcut(df_filtered[intensity_column], q=num_bins, labels=False)
    
    y = df_filtered['conjugate_category']
    
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y)
    
    X = df_filtered[feature_list]
    
    images = df_filtered['image_id']
    
    # Initialize storage for aggregated metrics
    all_fold_metrics = []
    class_report_list = []
    shap_values_list = []
    mean_fpr = np.linspace(0, 1, 100)
    tprs = []
    aucs = []
    
    # Track class distributions
    class_distributions = []
    
    # Outer CV: Stratified Group K-Fold
    outer_cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)
    
    # Create a dummy model for SHAP initialization
    dummy_model = xgb.XGBClassifier()
    dummy_model.fit(X.iloc[:10], y_encoded[:10])
    explainer = shap.TreeExplainer(dummy_model)
    
    for fold, (train_idx, test_idx) in enumerate(outer_cv.split(X, y_encoded, groups=images), start=1):
        print(f"\n=== Outer Fold {fold} ===")
        
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y_encoded[train_idx], y_encoded[test_idx]
        
        # Track class distribution in this fold
        class_distributions.append({
            "train": np.bincount(y_train, minlength=num_bins),
            "test": np.bincount(y_test, minlength=num_bins)
        })
        
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_test_scaled = scaler.transform(X_test)
        
        def objective(trial):
            params = {
                'max_depth': trial.suggest_int('max_depth', 3, 10),
                'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.1),
                'subsample': trial.suggest_float('subsample', 0.6, 1.0),
                'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
                'min_child_weight': trial.suggest_int('min_child_weight', 1, 10),
                'gamma': trial.suggest_float('gamma', 0, 5),
                'reg_alpha': trial.suggest_float('reg_alpha', 0.0, 10.0),
                'reg_lambda': trial.suggest_float('reg_lambda', 0.0, 10.0),
                'n_estimators': trial.suggest_int('n_estimators', 50, 200)
            }
            
            model = xgb.XGBClassifier(random_state=42, **params)
            
            inner_cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
            inner_scores = []
            
            for inner_train_idx, inner_valid_idx in inner_cv.split(X_train_scaled, y_train):
                X_inner_train = X_train_scaled[inner_train_idx]
                X_inner_valid = X_train_scaled[inner_valid_idx]
                y_inner_train = y_train[inner_train_idx]
                y_inner_valid = y_train[inner_valid_idx]
                
                model.fit(X_inner_train, y_inner_train)
                y_pred_inner = model.predict(X_inner_valid)
                score = accuracy_score(y_inner_valid, y_pred_inner)
                inner_scores.append(score)
            
            return np.mean(inner_scores)
        
        study = optuna.create_study(direction='maximize')
        study.optimize(objective, n_trials=n_trials)
        
        best_params = study.best_params
        
        best_model = xgb.XGBClassifier(random_state=42, **best_params)
        best_model.fit(X_train_scaled, y_train)
        
        # After model training, calculate comprehensive metrics
        y_test_pred = best_model.predict(X_test_scaled)
        y_test_proba = best_model.predict_proba(X_test_scaled)
        
        # Store fold metrics
        fold_metrics = {
            "fold": fold,
            "accuracy": accuracy_score(y_test, y_test_pred),
            "f1_weighted": f1_score(y_test, y_test_pred, average='weighted'),
            "precision_weighted": precision_score(y_test, y_test_pred, average='weighted'),
            "recall_weighted": recall_score(y_test, y_test_pred, average='weighted'),
            "roc_auc": roc_auc_score(y_test, y_test_proba, multi_class='ovr')
        }
        all_fold_metrics.append(fold_metrics)
        
        # Generate class-wise metrics
        class_report = classification_report(y_test, y_test_pred, output_dict=True)
        class_report_list.append(class_report)
        
        # Calculate SHAP values
        explainer = shap.TreeExplainer(best_model)
        shap_values = explainer.shap_values(X_test_scaled)
        shap_values_list.append(shap_values)
        
        # Plot ROC curves for each class
        plt.figure(figsize=(10, 8))
        for i in range(len(np.unique(y_encoded))):
            if i < len(y_test_proba[0]): # Ensure class exists in predictions
                fpr, tpr, _ = roc_curve((y_test == i).astype(int), y_test_proba[:, i])
                roc_auc = auc(fpr, tpr)
                plt.plot(fpr, tpr, label=f'Class {i} (AUC = {roc_auc:.2f})')
                
                # Store for aggregate ROC
                if i == 0: # Main class of interest
                    interp_tpr = np.interp(mean_fpr, fpr, tpr)
                    interp_tpr[0] = 0.0
                    tprs.append(interp_tpr)
                    aucs.append(roc_auc)
        
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'Per-Class ROC Curves - Fold {fold}')
        plt.legend()
        plt.savefig(f"{output_dir}/per_class_roc_fold_{fold}_{dataset_name}.png")
        plt.close()
        
        # Plot confusion matrix
        cm = confusion_matrix(y_test, y_test_pred)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title(f'Confusion Matrix - Fold {fold}')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.savefig(f"{output_dir}/confusion_matrix_fold_{fold}_{dataset_name}.png")
        plt.close()
        
        # Feature-target correlation analysis
        correlation_df = X_test.copy()
        correlation_df['target'] = y_test
        plt.figure(figsize=(15, 10))
        corr_with_target = correlation_df.corr()['target'].sort_values(ascending=False)
        sns.heatmap(pd.DataFrame(corr_with_target).T, annot=True, cmap='coolwarm')
        plt.title(f'Feature-Target Correlations - Fold {fold}')
        plt.savefig(f"{output_dir}/feature_target_corr_fold_{fold}_{dataset_name}.png")
        plt.close()
        
        # SHAP summary plot for this fold
        plt.figure(figsize=(12, 10))
        shap.summary_plot(shap_values, X_test, plot_type="bar", show=False)
        plt.title(f'SHAP Feature Importance - Fold {fold}')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/shap_summary_fold_{fold}_{dataset_name}.png")
        plt.close()
        
        # Feature importance plot
        plt.figure(figsize=(12, 8))
        xgb.plot_importance(best_model, max_num_features=20)
        plt.title(f"Feature Importance - {conjugate_type.upper()} - Fold {fold}")
        plt.savefig(f"{output_dir}/feature_importance_{conjugate_type}_fold_{fold}_{dataset_name}.png")
        plt.close()
        
        # Save model
        model_filename = f"{output_dir}/model_{dataset_name}_fold_{fold}.joblib"
        joblib.dump(best_model, model_filename)
        print(f"Model saved as {model_filename}")
    
    # After all folds complete:
    # 1. Aggregate performance metrics
    metrics_df = pd.DataFrame(all_fold_metrics)
    avg_metrics = {
        'accuracy': metrics_df['accuracy'].mean(),
        'accuracy_std': metrics_df['accuracy'].std(),
        'f1_weighted': metrics_df['f1_weighted'].mean(),
        'f1_weighted_std': metrics_df['f1_weighted'].std(),
        'precision_weighted': metrics_df['precision_weighted'].mean(),
        'precision_weighted_std': metrics_df['precision_weighted'].std(),
        'recall_weighted': metrics_df['recall_weighted'].mean(),
        'recall_weighted_std': metrics_df['recall_weighted'].std(),
        'roc_auc': metrics_df['roc_auc'].mean(),
        'roc_auc_std': metrics_df['roc_auc'].std()
    }
    
    # 2. Generate aggregate ROC curve
    plt.figure(figsize=(10, 8))
    
    # Plot individual fold ROC curves (faded)
    for i, tpr in enumerate(tprs):
        plt.plot(mean_fpr, tpr, alpha=0.3, label=f'ROC fold {i+1} (AUC = {aucs[i]:.2f})')
    
    # Plot mean ROC
    mean_tpr = np.mean(tprs, axis=0)
    mean_auc = auc(mean_fpr, mean_tpr)
    std_auc = np.std(aucs)
    plt.plot(mean_fpr, mean_tpr, 'b-', label=f'Mean ROC (AUC = {mean_auc:.2f} ± {std_auc:.2f})', lw=2)
    
    # Plot standard deviation
    std_tpr = np.std(tprs, axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
    plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=0.2, label=r'± 1 std. dev.')
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Aggregate ROC Curve')
    plt.legend(loc="lower right")
    plt.savefig(f"{output_dir}/aggregate_roc_{dataset_name}.png")
    plt.close()
    
    # 3. Class distribution analysis
    class_df = pd.DataFrame()
    for fold, dist in enumerate(class_distributions, 1):
        fold_df = pd.DataFrame({
            'fold': fold,
            'class': range(len(dist['train'])),
            'train_count': dist['train'],
            'test_count': dist['test']
        })
        class_df = pd.concat([class_df, fold_df])
    
    # Plot class distribution
    plt.figure(figsize=(12, 8))
    sns.boxplot(x='class', y='train_count', data=class_df)
    plt.title('Class Distribution Across Folds (Train)')
    plt.savefig(f"{output_dir}/class_distribution_train_{dataset_name}.png")
    plt.close()
    
    return avg_metrics

# Run the analysis for 1.3_chlr only
avg_metrics = process_dataset(dataset_path, dataset_key)

# Print final results
print("\n=== Final Results ===")
print(f"Dataset: {dataset_key}")
print(f"Accuracy: {avg_metrics['accuracy']:.4f} ± {avg_metrics['accuracy_std']:.4f}")
print(f"F1 Score (weighted): {avg_metrics['f1_weighted']:.4f} ± {avg_metrics['f1_weighted_std']:.4f}")
print(f"Precision (weighted): {avg_metrics['precision_weighted']:.4f} ± {avg_metrics['precision_weighted_std']:.4f}")
print(f"Recall (weighted): {avg_metrics['recall_weighted']:.4f} ± {avg_metrics['recall_weighted_std']:.4f}")
print(f"ROC AUC: {avg_metrics['roc_auc']:.4f} ± {avg_metrics['roc_auc_std']:.4f}")
