In [None]:
import polars as pl
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    classification_report, confusion_matrix, accuracy_score,
    balanced_accuracy_score, f1_score, matthews_corrcoef
)
import xgboost as xgb
import optuna
import joblib
import json
from pathlib import Path
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
warnings.filterwarnings('ignore')

In [None]:
# Create output directory
output_dir = Path("model_output")
output_dir.mkdir(exist_ok=True)

In [None]:
# Configuration
USE_OPTUNA = False  # Toggle for hyperparameter tuning
N_TRIALS = 10  # Number of Optuna trials
RANDOM_STATE = 42
TARGET_COLUMN = 'seizure_type'

In [None]:
# Feature list
main_features = [
            'seizure_de_delta_std',
            'seizure_de_delta_max',
            'seizure_de_delta_asymmetry_std',
            'seizure_de_theta_std',
            #'seizure_de_theta_max',
            'seizure_de_theta_asymmetry_std',
            'seizure_de_alpha_std',
            'seizure_de_alpha_max',
            'seizure_de_alpha_asymmetry_std',
            'seizure_de_low_beta_std',
            'seizure_de_low_beta_max',
            'seizure_de_low_beta_asymmetry_std',
            'seizure_de_high_beta_std',
            'seizure_de_high_beta_max',
            'seizure_de_high_beta_asymmetry_std',
            'seizure_de_gamma_std',
            'seizure_de_gamma_max',
            'seizure_de_gamma_asymmetry_std',
            'seizure_de_high_gamma_std',
            'seizure_de_high_gamma_max',
            'seizure_de_high_gamma_asymmetry_std',
            #'seizure_psd_sef50',
            #'seizure_psd_sef75', # high for all
            #'seizure_psd_sef90', # high for all
            #'seizure_psd_sef95',
            'frontal_seizure_psd_delta_cv',
            #'frontal_seizure_psd_theta_cv',
            'frontal_seizure_psd_alpha_cv',
            'frontal_seizure_psd_low_beta_cv',
            'frontal_seizure_psd_high_beta_cv',
            #'frontal_seizure_psd_gamma_mean', # higher for seizure type
            #'frontal_seizure_psd_gamma_std',
            #'frontal_seizure_psd_gamma_cv',
            #'frontal_seizure_psd_high_gamma_cv',
            #'frontal_seizure_psd_sef50', # high for seizure type
            'frontal_seizure_psd_sef75', # high for seizure type, highish for other two
            #'frontal_seizure_psd_sef90',
            #'frontal_seizure_psd_spectral_centroid', # high for seizure type
            'frontal_seizure_psd_spectral_spread',
            'temporal_seizure_psd_high_gamma_cv',
            'temporal_seizure_psd_sef50',
            'parietal_seizure_psd_high_gamma_cv',
            #'occipital_seizure_psd_sef50', # high for all
            #'occipital_seizure_psd_sef75', # hihg for all
            #'occipital_seizure_psd_sef90', # high for all
            'occipital_seizure_psd_sef95',
            'central_seizure_psd_gamma_cv',
            'central_seizure_psd_high_gamma_cv',
            #'central_seizure_psd_sef50', # high for seizure type
            #'central_seizure_psd_sef75', # high for seizure type
            'central_seizure_psd_sef90', # high for seizure type and lateralization
            #'central_seizure_psd_sef95',
            'central_seizure_psd_spectral_centroid', # high for seizure type
            'central_seizure_psd_spectral_spread',
            #'left_seizure_psd_delta_cv', 
            'left_seizure_psd_alpha_cv',
            #'left_seizure_psd_high_beta_cv',
            'left_seizure_psd_gamma_cv',
            'left_seizure_psd_high_gamma_cv',
            'left_seizure_psd_sef50', # high for seizure type
            'left_seizure_psd_sef75', # high for seizure type
            'left_seizure_psd_sef90',
            #'left_seizure_psd_spectral_centroid', # high for seizure type
            #'left_seizure_psd_spectral_spread',
            'right_seizure_psd_sef50', # high for all
            'right_seizure_psd_sef75', # high for all
            'right_seizure_psd_sef90', # high for all
            'right_seizure_psd_sef95',
            'seizure_wt_level0_mean_std',
            'seizure_wt_level0_std_std',
            'seizure_wt_level0_max_std',
            'seizure_wt_level1_entropy_std',
            'seizure_wt_level1_std_std',
            'seizure_wt_level1_max_std',
            #'seizure_wt_level2_entropy_std',
            'seizure_wt_level2_mean_std',
            'seizure_wt_level2_std_std',
            'seizure_wt_level2_max_std',
            #'seizure_wt_level3_entropy_std', # high for all
            'seizure_wt_level3_mean_mean', # high for all
            'seizure_wt_level3_mean_std', # high for all
            'seizure_wt_level3_std_mean', # high for all
            'seizure_wt_level3_std_std',
            #'seizure_wt_level3_max_std', # high for all
            'seizure_wt_level4_energy_mean',
            'seizure_wt_level4_mean_std',
            'seizure_wt_level4_std_std',
            'seizure_wt_level5_mean_std',
            #'seizure_time_mean_std', # high for seizure type
            #'seizure_time_mean_max',
            #'seizure_time_std_std', # high for all
            #'seizure_time_std_max',
            'seizure_time_rms_std', # high for all
            'seizure_time_rms_max',
            'seizure_time_peak_to_peak_std', # high for all
            #'seizure_time_peak_to_peak_max',
            'seizure_time_zero_crossings_mean', # high for all
            'seizure_time_zero_crossings_std', # high for all
            #'seizure_time_zero_crossings_max',
            'seizure_time_line_length_std', # high for all
            'seizure_time_line_length_max', # high for all
            'seizure_time_line_length_min',
            'seizure_rhythmic_fast_theta_ratio',
            'seizure_mean_propagation_speed',
            'seizure_std_propagation_speed', # high for all
            'seizure_max_propagation_speed',
]  

In [None]:
# Load data
df = pl.read_parquet("processed_data/comprehensive_eeg_features.parquet")
print(f"Data shape: {df.shape}")

In [None]:
def prepare_data(df, target_column, features):
    """Prepare data for training"""
    # Filter out samples without labels
    labeled_df = df.filter(pl.col(target_column) != "")
    
    # If no features specified, use all except targets
    if not features:
        exclude_cols = ['seizure_id', 'seizure_index', 'seizure_type', 
                       'localization', 'lateralization']
        features = [col for col in df.columns if col not in exclude_cols]
    
    X = labeled_df.select(features).to_numpy()
    y = labeled_df.select(target_column).to_numpy().ravel()
    
    # Encode labels
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y)
    
    # Calculate class weights for imbalance
    unique, counts = np.unique(y_encoded, return_counts=True)
    class_weights = {}
    for cls, count in zip(unique, counts):
        class_weights[cls] = len(y_encoded) / (len(unique) * count)
    
    print(f"Features: {len(features)}")
    print(f"Samples: {len(X)}")
    print(f"Classes: {label_encoder.classes_}")
    print(f"Class distribution: {dict(zip(unique, counts))}")
    print(f"Class weights: {class_weights}")
    
    return X, y_encoded, label_encoder, features, class_weights

In [None]:
def objective(trial, X_train, y_train, X_val, y_val, class_weights):
    """Optuna objective function for hyperparameter tuning"""
    n_classes = len(np.unique(y_train))
    
    params = {
        'objective': 'multi:softprob' if n_classes > 2 else 'binary:logistic',
        'num_class': n_classes if n_classes > 2 else None,
        'max_depth': trial.suggest_int('max_depth', 1, 4),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.03, log=True),
        'subsample': trial.suggest_float('subsample', 0.6, 0.9),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 0.9),
        'min_child_weight': trial.suggest_int('min_child_weight', 0.1, 5),
        'gamma': trial.suggest_float('gamma', 0.1, 0.5),
        'reg_alpha': trial.suggest_float('reg_alpha', 0.01, 0.4),
        'reg_lambda': trial.suggest_float('reg_lambda', 1.0, 4.0),
        'n_estimators': trial.suggest_int('n_estimators', 1000, 3000),
        'random_state': RANDOM_STATE,
        'tree_method': 'approx',
        'device': 'cpu'
    }
    
    # Remove None values
    params = {k: v for k, v in params.items() if v is not None}
    
    # Add sample weights for training
    sample_weights = np.array([class_weights[y] for y in y_train])
    
    # Train model
    if n_classes > 2:
        model = xgb.XGBClassifier(**params)
    else:
        model = xgb.XGBClassifier(**params, scale_pos_weight=sum(y_train==0)/sum(y_train==1))
    
    model.fit(
        X_train, y_train,
        sample_weight=sample_weights,
        eval_set=[(X_val, y_val)],
        early_stopping_rounds=100,
        verbose=False
    )
    
    # Predict and evaluate
    y_pred = model.predict(X_val)
    score = balanced_accuracy_score(y_val, y_pred)
    
    return score

In [None]:
def train_model(X, y, label_encoder, features, class_weights):
    """Train XGBoost model with optional Optuna tuning"""
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=RANDOM_STATE, stratify=y
    )
    
    n_classes = len(np.unique(y))
    
    # Default parameters
    best_params = {
        'objective': 'multi:softprob' if n_classes > 2 else 'binary:logistic',
        'num_class': n_classes if n_classes > 2 else None,
        'max_depth': 1,
        'learning_rate': 0.017945399517256322,
        'subsample': 0.8142447990706547,
        'colsample_bytree': 0.8067253171735458,
        'min_child_weight': 4,
        'gamma': 0.3506098841684803,
        'reg_alpha': 0.20782777353226756,
        'reg_lambda': 2.972148679745178,
        'n_estimators': 958,
        'random_state': RANDOM_STATE,
        'tree_method': 'hist',
        'device': 'cpu'
    }
    
    # Optuna hyperparameter tuning
    if USE_OPTUNA:
        print("\nStarting Optuna hyperparameter tuning...")
        study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler())
        study.optimize(
            lambda trial: objective(trial, X_train, y_train, X_test, y_test, class_weights),
            n_trials=N_TRIALS,
            show_progress_bar=True
        )
        
        # Get best parameters
        best_params.update(study.best_params)
        print(f"Best trial score: {study.best_value:.4f}")
        print(f"Best parameters: {study.best_params}")
        
        # Save parameters
        with open(output_dir / 'best_params.json', 'w') as f:
            json.dump(study.best_params, f, indent=2)
        print(f"Parameters saved to {output_dir / 'best_params.json'}")
    
    # Remove None values
    best_params = {k: v for k, v in best_params.items() if v is not None}
    
    # Train final model with best parameters
    print("\nTraining final model...")
    sample_weights = np.array([class_weights[y] for y in y_train])
    
    if n_classes > 2:
        model = xgb.XGBClassifier(**best_params)
    else:
        model = xgb.XGBClassifier(**best_params, scale_pos_weight=sum(y_train==0)/sum(y_train==1))
    
    model.fit(
        X_train, y_train,
        sample_weight=sample_weights,
        eval_set=[(X_test, y_test)],
        #early_stopping_rounds=100,
        verbose=True
    )
    
    return model, X_test, y_test, best_params

In [None]:
def evaluate_model(model, X_test, y_test, label_encoder):
    """Evaluate model performance"""
    y_pred = model.predict(X_test)
    y_pred_proba = model.predict_proba(X_test)
    
    print("\n" + "="*50)
    print("MODEL EVALUATION")
    print("="*50)
    
    # Metrics
    acc = accuracy_score(y_test, y_pred)
    bal_acc = balanced_accuracy_score(y_test, y_pred)
    f1_macro = f1_score(y_test, y_pred, average='macro')
    f1_weighted = f1_score(y_test, y_pred, average='weighted')
    mcc = matthews_corrcoef(y_test, y_pred)
    
    print(f"Accuracy: {acc:.4f}")
    print(f"Balanced Accuracy: {bal_acc:.4f}")
    print(f"F1 Score (Macro): {f1_macro:.4f}")
    print(f"F1 Score (Weighted): {f1_weighted:.4f}")
    print(f"Matthews Correlation Coefficient: {mcc:.4f}")
    
    # Classification report
    print("\nClassification Report:")
    print(classification_report(y_test, y_pred, target_names=label_encoder.classes_))
    
    # Confusion matrix
    cm = confusion_matrix(y_test, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=label_encoder.classes_,
                yticklabels=label_encoder.classes_)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.tight_layout()
    plt.savefig(output_dir / 'confusion_matrix.png')
    plt.show()
    
    metrics = {
        'accuracy': acc,
        'balanced_accuracy': bal_acc,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'mcc': mcc
    }
    
    # Save metrics
    with open(output_dir / 'metrics.json', 'w') as f:
        json.dump(metrics, f, indent=2)
    
    return metrics

In [None]:
def plot_feature_importance(model, features, top_n=50):
    """Plot and save feature importance"""
    # Get feature importance
    importance = model.feature_importances_
    indices = np.argsort(importance)[::-1][:top_n]
    
    # Plot
    plt.figure(figsize=(10, 8))
    plt.barh(range(len(indices)), importance[indices])
    plt.yticks(range(len(indices)), [features[i] for i in indices])
    plt.xlabel('Feature Importance')
    plt.title(f'Top {top_n} Feature Importances')
    plt.tight_layout()
    plt.savefig(output_dir / 'feature_importance.png', bbox_inches='tight')
    plt.show()
    
    # Save feature importance to file
    importance_dict = {features[i]: float(importance[i]) for i in range(len(features))}
    importance_sorted = dict(sorted(importance_dict.items(), key=lambda x: x[1], reverse=True))
    
    with open(output_dir / 'feature_importance.json', 'w') as f:
        json.dump(importance_sorted, f, indent=2)
    
    print(f"\nTop {min(top_n, len(features))} features:")
    for i, idx in enumerate(indices[:top_n]):
        print(f"{features[idx]}: {importance[idx]:.4f}")

In [None]:
def cross_validate_model(X, y, features, class_weights, best_params, n_splits=12):
    """Perform stratified cross-validation"""
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
    n_classes = len(np.unique(y))
    
    scores = {
        'accuracy': [],
        'balanced_accuracy': [],
        'f1_macro': [],
        'mcc': []
    }
    
    print(f"\n{n_splits}-Fold Cross-Validation")
    print("="*50)
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1):
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        # Sample weights
        sample_weights = np.array([class_weights[y] for y in y_train])
        
        # Train model
        if n_classes > 2:
            model = xgb.XGBClassifier(**best_params)
        else:
            model = xgb.XGBClassifier(**best_params, scale_pos_weight=sum(y_train==0)/sum(y_train==1))
        
        model.fit(
            X_train, y_train,
            sample_weight=sample_weights,
            eval_set=[(X_val, y_val)],
            #early_stopping_rounds=100,
            verbose=False
        )
        
        # Evaluate
        y_pred = model.predict(X_val)
        scores['accuracy'].append(accuracy_score(y_val, y_pred))
        scores['balanced_accuracy'].append(balanced_accuracy_score(y_val, y_pred))
        scores['f1_macro'].append(f1_score(y_val, y_pred, average='macro'))
        scores['mcc'].append(matthews_corrcoef(y_val, y_pred))
        
        print(f"Fold {fold}: Bal Acc={scores['balanced_accuracy'][-1]:.4f}, "
              f"F1={scores['f1_macro'][-1]:.4f}, MCC={scores['mcc'][-1]:.4f}")
    
    # Summary
    print("\nCross-Validation Summary:")
    for metric, values in scores.items():
        mean_score = np.mean(values)
        std_score = np.std(values)
        print(f"{metric}: {mean_score:.4f} (+/- {std_score:.4f})")
    
    return scores

In [None]:
# Prepare data
X, y, label_encoder, features, class_weights = prepare_data(df, TARGET_COLUMN, main_features)

In [None]:
# Train model
model, X_test, y_test, best_params = train_model(X, y, label_encoder, features, class_weights)

In [None]:
# Evaluate model
metrics = evaluate_model(model, X_test, y_test, label_encoder)

In [None]:
# Feature importance
plot_feature_importance(model, features, top_n=50)

In [None]:
# Cross-validation
cv_scores = cross_validate_model(X, y, features, class_weights, best_params)

In [None]:
# Save model
model.save_model(str(output_dir / 'xgboost_model.json'))
joblib.dump(label_encoder, output_dir / 'label_encoder.pkl')
print(f"\nModel saved to {output_dir}")
    
print("\n" + "="*50)
print("TRAINING COMPLETE")
print("="*50)