In [173]:
import polars as pl
import numpy as np
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
import xgboost as xgb
import joblib
import json
from pathlib import Path

In [174]:
# Load processed data set
processed_df = pl.read_parquet("processed_data/comprehensive_eeg_features.parquet")

print(f"Shape: {processed_df.shape}")
print(f"Columns: {processed_df.columns}")

Shape: (245, 747)
Columns: ['patient_id', 'sampling_rate_hz', 'seizure_number', 'file_name', 'registration_start_time', 'registration_end_time', 'seizure_start_time', 'seizure_end_time', 'age', 'gender', 'num_channels', 'json_file_path', 'eeg_channel', 'seizure_type', 'localization', 'lateralization', 'file_path', 'sampling_rate', 'duration_seconds', 'de_delta_mean', 'de_delta_std', 'de_delta_median', 'de_delta_max', 'de_delta_min', 'de_delta_asymmetry_mean', 'de_delta_asymmetry_std', 'de_theta_mean', 'de_theta_std', 'de_theta_median', 'de_theta_max', 'de_theta_min', 'de_theta_asymmetry_mean', 'de_theta_asymmetry_std', 'de_alpha_mean', 'de_alpha_std', 'de_alpha_median', 'de_alpha_max', 'de_alpha_min', 'de_alpha_asymmetry_mean', 'de_alpha_asymmetry_std', 'de_low_beta_mean', 'de_low_beta_std', 'de_low_beta_median', 'de_low_beta_max', 'de_low_beta_min', 'de_low_beta_asymmetry_mean', 'de_low_beta_asymmetry_std', 'de_high_beta_mean', 'de_high_beta_std', 'de_high_beta_median', 'de_high_beta_

In [175]:
def prepare_classification_data(df, target_column, feature_columns=None):
    """
    Prepare data for classification by separating labeled and unlabeled data
    """
    labeled_df = df.filter(pl.col(target_column) != "")
    unlabeled_df = df.filter(pl.col(target_column) == "")
    
    if feature_columns is None:
        exclude_cols = ['seizure_type', 'localization', 'lateralization', 'patient_id', 'seizure_id']
        feature_columns = [col for col in df.columns if col not in exclude_cols]
    
    X_labeled = labeled_df.select(feature_columns).to_numpy()
    y_labeled = labeled_df.select(target_column).to_numpy().ravel()
    
    X_unlabeled = unlabeled_df.select(feature_columns).to_numpy()
    
    return X_labeled, y_labeled, X_unlabeled, feature_columns

In [176]:
def train_xgboost_classifier(X_train, y_train, X_test, y_test):
    """
    Train XGBoost classifier with proper label encoding
    """
    # Fit encoder on ALL labels (train + test) to avoid unseen label errors
    label_encoder = LabelEncoder()
    all_labels = np.concatenate([y_train, y_test])
    label_encoder.fit(all_labels)
    
    # Transform train and test sets
    y_train_encoded = label_encoder.transform(y_train)
    y_test_encoded = label_encoder.transform(y_test)
    
    n_classes = len(label_encoder.classes_)
    print(f"Number of classes detected: {n_classes}")
    print(f"Classes: {label_encoder.classes_}")
    
    # Initialize and train model with explicit num_class
    if n_classes == 2:
        # Binary classification
        model = xgb.XGBClassifier(
            n_estimators=100,
            max_depth=6,
            learning_rate=0.1,
            objective='binary:logistic',
            random_state=42
        )
    else:
        # Multi-class classification
        model = xgb.XGBClassifier(
            n_estimators=100,
            max_depth=6,
            learning_rate=0.1,
            objective='multi:softprob',
            num_class=n_classes,
            random_state=42
        )
    
    print(f'X train shape: {X_train.shape}')
    print(f'Y train encoded shape: {y_train_encoded.shape}')
    
    model.fit(X_train, y_train_encoded)
    
    # Predict
    y_pred = model.predict(X_test)
    
    return model, y_pred, y_test_encoded, label_encoder

In [177]:
def evaluate_model(y_true, y_pred, label_encoder=None):
    """
    Print evaluation metrics
    """
    if label_encoder:
        y_true_labels = label_encoder.inverse_transform(y_true)
        y_pred_labels = label_encoder.inverse_transform(y_pred)
        print("Classification Report:")
        print(classification_report(y_true_labels, y_pred_labels))
        print("\nConfusion Matrix:")
        print(confusion_matrix(y_true_labels, y_pred_labels))
    else:
        print("Classification Report:")
        print(classification_report(y_true, y_pred))
        print("\nConfusion Matrix:")
        print(confusion_matrix(y_true, y_pred))

In [178]:
def cross_validate_model(X, y, cv_folds=4):
    """
    Perform cross-validation to assess model stability
    """
    # Encode labels if they're strings
    label_encoder = LabelEncoder()
    if isinstance(y[0], str):
        y_encoded = label_encoder.fit_transform(y)
    else:
        y_encoded = y
        
    n_classes = len(label_encoder.classes_)
    
    if n_classes == 2:
        # Binary classification
        model = xgb.XGBClassifier(
            n_estimators=100,
            max_depth=6,
            learning_rate=0.1,
            objective='binary:logistic',
            random_state=42
        )
    else:
        # Multi-class classification
        model = xgb.XGBClassifier(
            n_estimators=100,
            max_depth=6,
            learning_rate=0.1,
            objective='multi:softprob',
            num_class=n_classes,
            random_state=42
        )
    
    scores = cross_val_score(model, X, y_encoded, cv=cv_folds, scoring='accuracy')
    print(f"Cross-validation scores: {scores}")
    print(f"Mean accuracy: {scores.mean():.3f} (+/- {scores.std() * 2:.3f})")
    
    return scores

In [179]:
def predict_unlabeled_seizures(model, X_unlabeled, label_encoder=None):
    """
    Predict labels for unlabeled seizures
    """
    predictions = model.predict(X_unlabeled)
    probabilities = model.predict_proba(X_unlabeled)
    
    if label_encoder:
        predictions = label_encoder.inverse_transform(predictions)
    
    return predictions, probabilities

In [180]:
def get_feature_importance(model, feature_names):
    """
    Get and display feature importance
    """
    importance = model.feature_importances_
    feature_importance = list(zip(feature_names, importance))
    feature_importance.sort(key=lambda x: x[1], reverse=True)
    
    print("Top 10 Most Important Features:")
    for feat, imp in feature_importance[:10]:
        print(f"{feat}: {imp:.4f}")
    
    return feature_importance

In [181]:
def build_seizure_classifier(df, target_column):
    """
    Complete pipeline to build and evaluate a seizure classifier
    """
    # Prepare data
    X_labeled, y_labeled, X_unlabeled, feature_columns = prepare_classification_data(df, target_column)
    
    print(f'y labeled shape: {y_labeled.shape}')
    print(f'x labeled shape: {X_labeled.shape}')
    print(f'x unlabeled shape: {X_unlabeled.shape}')
    
    # Split labeled data
    X_train, X_test, y_train, y_test = train_test_split(
        X_labeled, y_labeled, test_size=0.2, random_state=42, stratify=y_labeled
    )
    
    # Train model
    model, y_pred, y_test_encoded, label_encoder = train_xgboost_classifier(X_train, y_train, X_test, y_test)
    
    print(f'y pred shape: {y_pred.shape}')
    
    # Evaluate
    evaluate_model(y_test_encoded, y_pred, label_encoder)
    
    # Get feature importance
    feature_importance = get_feature_importance(model, feature_columns)
    
    # Cross-validate
    cv_scores = cross_validate_model(X_labeled, y_labeled)
    
    # Predict unlabeled
    if len(X_unlabeled) > 0:
        predictions, probabilities = predict_unlabeled_seizures(model, X_unlabeled, label_encoder)
        print(f"\nPredicted {len(predictions)} unlabeled seizures")
    else:
        predictions, probabilities = None, None
    
    return model, label_encoder, predictions, probabilities, feature_importance

In [182]:
# Save and load functions
def save_model(model, label_encoder, filepath_prefix):
    """
    Save model and label encoder
    """
    joblib.dump(model, f"{filepath_prefix}_model.pkl")
    joblib.dump(label_encoder, f"{filepath_prefix}_encoder.pkl")

In [183]:
def load_model(filepath_prefix):
    """
    Load model and label encoder
    """
    model = joblib.load(f"{filepath_prefix}_model.pkl")
    label_encoder = joblib.load(f"{filepath_prefix}_encoder.pkl")
    return model, label_encoder

In [184]:
def train_all_classifiers(df):
    """
    Train classifiers for all three target variables
    """
    results = {}
    
    for target in ['seizure_type', 'localization', 'lateralization']:
        print(f"\n{'='*50}")
        print(f"Training classifier for: {target}")
        print('='*50)
        
        model, encoder, predictions, probabilities, feature_importance = build_seizure_classifier(df, target)
        
        results[target] = {
            'model': model,
            'encoder': encoder,
            'predictions': predictions,
            'probabilities': probabilities,
            'feature_importance': feature_importance
        }
        
        # Save model
        save_model(model, encoder, f"{target}_classifier")
    
    return results

## Data Encoding

In [185]:
targets =[
    'seizure_type', 
    'localization', 
    'lateralization'
]
processed_df = processed_df.with_columns(
    pl.col(targets).cast(pl.Categorical)
)

In [186]:
processed_df['seizure_type'].n_unique()

4

In [187]:
def encode_categoricals(df):
    # Copy dataframe to avoid modifying original
    encoded_df = df.clone()
    
    # Dictionary to store encoding mappings
    encoding_mappings = {}
    
    # Process each column
    for col in df.columns:
        dtype = df[col].dtype
        
        # Check if column is string/object type
        if dtype == pl.Utf8 or dtype == pl.Object:
            # Get unique values and create mapping (starting from 1)
            unique_vals = encoded_df[col].unique().drop_nulls().sort()
            mapping = {val: i+1 for i, val in enumerate(unique_vals)}
            encoding_mappings[col] = mapping
            
            # Apply the mapping directly
            encoded_df = encoded_df.with_columns(
                pl.col(col).replace(mapping).alias(col)
            )
    
    return encoded_df, encoding_mappings

In [188]:
encoded_df, encoding_mappings = encode_categoricals(processed_df)

In [189]:
encoded_df.columns

['patient_id',
 'sampling_rate_hz',
 'seizure_number',
 'file_name',
 'registration_start_time',
 'registration_end_time',
 'seizure_start_time',
 'seizure_end_time',
 'age',
 'gender',
 'num_channels',
 'json_file_path',
 'eeg_channel',
 'seizure_type',
 'localization',
 'lateralization',
 'file_path',
 'sampling_rate',
 'duration_seconds',
 'de_delta_mean',
 'de_delta_std',
 'de_delta_median',
 'de_delta_max',
 'de_delta_min',
 'de_delta_asymmetry_mean',
 'de_delta_asymmetry_std',
 'de_theta_mean',
 'de_theta_std',
 'de_theta_median',
 'de_theta_max',
 'de_theta_min',
 'de_theta_asymmetry_mean',
 'de_theta_asymmetry_std',
 'de_alpha_mean',
 'de_alpha_std',
 'de_alpha_median',
 'de_alpha_max',
 'de_alpha_min',
 'de_alpha_asymmetry_mean',
 'de_alpha_asymmetry_std',
 'de_low_beta_mean',
 'de_low_beta_std',
 'de_low_beta_median',
 'de_low_beta_max',
 'de_low_beta_min',
 'de_low_beta_asymmetry_mean',
 'de_low_beta_asymmetry_std',
 'de_high_beta_mean',
 'de_high_beta_std',
 'de_high_beta_m

In [190]:
features = [
    'patient_id',
    #'sampling_rate_hz',
    #'seizure_number',
    #'file_name',
    'registration_start_time',
    'registration_end_time',
    'seizure_start_time',
    'seizure_end_time',
    'age',
    'gender',
    #'num_channels',
    #'json_file_path',
    #'eeg_channel',
    'seizure_type',
    'localization',
    'lateralization',
    #'file_path',
    'sampling_rate',
    'duration_seconds',
    'de_delta_mean',
    'de_delta_std',
    'de_delta_median',
    'de_delta_max',
    'de_delta_min',
    'de_delta_asymmetry_mean',
    'de_delta_asymmetry_std',
    'de_theta_mean',
    'de_theta_std',
    'de_theta_median',
    'de_theta_max',
    'de_theta_min',
    'de_theta_asymmetry_mean',
    'de_theta_asymmetry_std',
    'de_alpha_mean',
    'de_alpha_std',
    'de_alpha_median',
    'de_alpha_max',
    'de_alpha_min',
    'de_alpha_asymmetry_mean',
    'de_alpha_asymmetry_std',
    'de_low_beta_mean',
    'de_low_beta_std',
    'de_low_beta_median',
    'de_low_beta_max',
    'de_low_beta_min',
    'de_low_beta_asymmetry_mean',
    'de_low_beta_asymmetry_std',
    'de_high_beta_mean',
    'de_high_beta_std',
    'de_high_beta_median',
    'de_high_beta_max',
    'de_high_beta_min',
    'de_high_beta_asymmetry_mean',
    'de_high_beta_asymmetry_std',
    'de_gamma_mean',
    'de_gamma_std',
    'de_gamma_median',
    'de_gamma_max',
    'de_gamma_min',
    'de_gamma_asymmetry_mean',
    'de_gamma_asymmetry_std',
    'de_high_gamma_mean',
    'de_high_gamma_std',
    'de_high_gamma_median',
    'de_high_gamma_max',
    'de_high_gamma_min',
    'de_high_gamma_asymmetry_mean',
    'de_high_gamma_asymmetry_std',
    'psd_delta_mean',
    'psd_delta_std',
    'psd_delta_cv',
    'psd_theta_mean',
    'psd_theta_std',
    'psd_theta_cv',
    'psd_alpha_mean',
    'psd_alpha_std',
    'psd_alpha_cv',
    'psd_low_beta_mean',
    'psd_low_beta_std',
    'psd_low_beta_cv',
    'psd_high_beta_mean',
    'psd_high_beta_std',
    'psd_high_beta_cv',
    'psd_gamma_mean',
    'psd_gamma_std',
    'psd_gamma_cv',
    'psd_high_gamma_mean',
    'psd_high_gamma_std',
    'psd_high_gamma_cv',
    'psd_theta_alpha_ratio',
    'psd_delta_alpha_ratio',
    'psd_beta_ratio',
    'psd_sef50',
    'psd_sef75',
    'psd_sef90',
    'psd_sef95',
    'psd_spectral_centroid',
    'psd_spectral_spread',
    'psd_spectral_skewness',
    'psd_spectral_kurtosis',
    'wt_level0_energy_mean',
    'wt_level0_energy_std',
    'wt_level0_entropy_mean',
    'wt_level0_entropy_std',
    'wt_level0_mean_mean',
    'wt_level0_mean_std',
    'wt_level0_std_mean',
    'wt_level0_std_std',
    'wt_level0_max_mean',
    'wt_level0_max_std',
    'wt_level1_energy_mean',
    'wt_level1_energy_std',
    'wt_level1_entropy_mean',
    'wt_level1_entropy_std',
    'wt_level1_mean_mean',
    'wt_level1_mean_std',
    'wt_level1_std_mean',
    'wt_level1_std_std',
    'wt_level1_max_mean',
    'wt_level1_max_std',
    'wt_level2_energy_mean',
    'wt_level2_energy_std',
    'wt_level2_entropy_mean',
    'wt_level2_entropy_std',
    'wt_level2_mean_mean',
    'wt_level2_mean_std',
    'wt_level2_std_mean',
    'wt_level2_std_std',
    'wt_level2_max_mean',
    'wt_level2_max_std',
    'wt_level3_energy_mean',
    'wt_level3_energy_std',
    'wt_level3_entropy_mean',
    'wt_level3_entropy_std',
    'wt_level3_mean_mean',
    'wt_level3_mean_std',
    'wt_level3_std_mean',
    'wt_level3_std_std',
    'wt_level3_max_mean',
    'wt_level3_max_std',
    'wt_level4_energy_mean',
    'wt_level4_energy_std',
    'wt_level4_entropy_mean',
    'wt_level4_entropy_std',
    'wt_level4_mean_mean',
    'wt_level4_mean_std',
    'wt_level4_std_mean',
    'wt_level4_std_std',
    'wt_level4_max_mean',
    'wt_level4_max_std',
    'wt_level5_energy_mean',
    'wt_level5_energy_std',
    'wt_level5_entropy_mean',
    'wt_level5_entropy_std',
    'wt_level5_mean_mean',
    'wt_level5_mean_std',
    'wt_level5_std_mean',
    'wt_level5_std_std',
    'wt_level5_max_mean',
    'wt_level5_max_std',
    'wt_packet_entropy',
    'time_mean_mean',
    'time_mean_std',
    'time_mean_max',
    'time_mean_min',
    'time_std_mean',
    'time_std_std',
    'time_std_max',
    'time_std_min',
    'time_var_mean',
    'time_var_std',
    'time_var_max',
    'time_var_min',
    'time_skewness_mean',
    'time_skewness_std',
    'time_skewness_max',
    'time_skewness_min',
    'time_kurtosis_mean',
    'time_kurtosis_std',
    'time_kurtosis_max',
    'time_kurtosis_min',
    'time_rms_mean',
    'time_rms_std',
    'time_rms_max',
    'time_rms_min',
    'time_peak_to_peak_mean',
    'time_peak_to_peak_std',
    'time_peak_to_peak_max',
    'time_peak_to_peak_min',
    'time_zero_crossings_mean',
    'time_zero_crossings_std',
    'time_zero_crossings_max',
    'time_zero_crossings_min',
    'time_hjorth_activity_mean',
    'time_hjorth_activity_std',
    'time_hjorth_activity_max',
    'time_hjorth_activity_min',
    'time_hjorth_mobility_mean',
    'time_hjorth_mobility_std',
    'time_hjorth_mobility_max',
    'time_hjorth_mobility_min',
    'time_hjorth_complexity_mean',
    'time_hjorth_complexity_std',
    'time_hjorth_complexity_max',
    'time_hjorth_complexity_min',
    'time_line_length_mean',
    'time_line_length_std',
    'time_line_length_max',
    'time_line_length_min',
    'time_nonlinear_energy_mean',
    'time_nonlinear_energy_std',
    'time_nonlinear_energy_max',
    'time_nonlinear_energy_min',
    'connectivity_mean',
    'connectivity_std',
    'connectivity_max',
    'connectivity_min',
    'global_efficiency',
    'node_strength_mean',
    'node_strength_std',
    'node_strength_max',
    'clustering_coefficient',
    'pac_theta_gamma',
    'pac_theta_high_gamma',
    'pac_alpha_gamma',
    'pac_alpha_high_gamma',
    'pac_mean',
    'pac_max',
    'sample_entropy_mean',
    'sample_entropy_std',
    'permutation_entropy_mean',
    'permutation_entropy_std',
    'approx_entropy_mean',
    'approx_entropy_std',
    'rhythmic_theta_power_mean',
    'rhythmic_theta_power_std',
    'rhythmic_delta_slow_power_mean',
    'rhythmic_delta_slow_power_std',
    'rhythmic_fast_power_mean',
    'rhythmic_fast_power_std',
    'rhythmic_spike_rate_mean',
    'rhythmic_spike_rate_std',
    'rhythmic_theta_delta_ratio',
    'rhythmic_fast_theta_ratio',
    'pre_ictal_start',
    'pre_ictal_end',
    'post_ictal_start',
    'post_ictal_end',
    'preictal_de_delta_mean',
    'preictal_de_delta_std',
    'preictal_de_delta_median',
    'preictal_de_delta_max',
    'preictal_de_delta_min',
    'preictal_de_delta_asymmetry_mean',
    'preictal_de_delta_asymmetry_std',
    'preictal_de_theta_mean',
    'preictal_de_theta_std',
    'preictal_de_theta_median',
    'preictal_de_theta_max',
    'preictal_de_theta_min',
    'preictal_de_theta_asymmetry_mean',
    'preictal_de_theta_asymmetry_std',
    'preictal_de_alpha_mean',
    'preictal_de_alpha_std',
    'preictal_de_alpha_median',
    'preictal_de_alpha_max',
    'preictal_de_alpha_min',
    'preictal_de_alpha_asymmetry_mean',
    'preictal_de_alpha_asymmetry_std',
    'preictal_de_low_beta_mean',
    'preictal_de_low_beta_std',
    'preictal_de_low_beta_median',
    'preictal_de_low_beta_max',
    'preictal_de_low_beta_min',
    'preictal_de_low_beta_asymmetry_mean',
    'preictal_de_low_beta_asymmetry_std',
    'preictal_de_high_beta_mean',
    'preictal_de_high_beta_std',
    'preictal_de_high_beta_median',
    'preictal_de_high_beta_max',
    'preictal_de_high_beta_min',
    'preictal_de_high_beta_asymmetry_mean',
    'preictal_de_high_beta_asymmetry_std',
    'preictal_de_gamma_mean',
    'preictal_de_gamma_std',
    'preictal_de_gamma_median',
    'preictal_de_gamma_max',
    'preictal_de_gamma_min',
    'preictal_de_gamma_asymmetry_mean',
    'preictal_de_gamma_asymmetry_std',
    'preictal_de_high_gamma_mean',
    'preictal_de_high_gamma_std',
    'preictal_de_high_gamma_median',
    'preictal_de_high_gamma_max',
    'preictal_de_high_gamma_min',
    'preictal_de_high_gamma_asymmetry_mean',
    'preictal_de_high_gamma_asymmetry_std',
    'preictal_wt_level0_energy_mean',
    'preictal_wt_level0_energy_std',
    'preictal_wt_level0_entropy_mean',
    'preictal_wt_level0_entropy_std',
    'preictal_wt_level0_mean_mean',
    'preictal_wt_level0_mean_std',
    'preictal_wt_level0_std_mean',
    'preictal_wt_level0_std_std',
    'preictal_wt_level0_max_mean',
    'preictal_wt_level0_max_std',
    'preictal_wt_level1_energy_mean',
    'preictal_wt_level1_energy_std',
    'preictal_wt_level1_entropy_mean',
    'preictal_wt_level1_entropy_std',
    'preictal_wt_level1_mean_mean',
    'preictal_wt_level1_mean_std',
    'preictal_wt_level1_std_mean',
    'preictal_wt_level1_std_std',
    'preictal_wt_level1_max_mean',
    'preictal_wt_level1_max_std',
    'preictal_wt_level2_energy_mean',
    'preictal_wt_level2_energy_std',
    'preictal_wt_level2_entropy_mean',
    'preictal_wt_level2_entropy_std',
    'preictal_wt_level2_mean_mean',
    'preictal_wt_level2_mean_std',
    'preictal_wt_level2_std_mean',
    'preictal_wt_level2_std_std',
    'preictal_wt_level2_max_mean',
    'preictal_wt_level2_max_std',
    'preictal_wt_level3_energy_mean',
    'preictal_wt_level3_energy_std',
    'preictal_wt_level3_entropy_mean',
    'preictal_wt_level3_entropy_std',
    'preictal_wt_level3_mean_mean',
    'preictal_wt_level3_mean_std',
    'preictal_wt_level3_std_mean',
    'preictal_wt_level3_std_std',
    'preictal_wt_level3_max_mean',
    'preictal_wt_level3_max_std',
    'preictal_wt_level4_energy_mean',
    'preictal_wt_level4_energy_std',
    'preictal_wt_level4_entropy_mean',
    'preictal_wt_level4_entropy_std',
    'preictal_wt_level4_mean_mean',
    'preictal_wt_level4_mean_std',
    'preictal_wt_level4_std_mean',
    'preictal_wt_level4_std_std',
    'preictal_wt_level4_max_mean',
    'preictal_wt_level4_max_std',
    'preictal_wt_level5_energy_mean',
    'preictal_wt_level5_energy_std',
    'preictal_wt_level5_entropy_mean',
    'preictal_wt_level5_entropy_std',
    'preictal_wt_level5_mean_mean',
    'preictal_wt_level5_mean_std',
    'preictal_wt_level5_std_mean',
    'preictal_wt_level5_std_std',
    'preictal_wt_level5_max_mean',
    'preictal_wt_level5_max_std',
    'preictal_wt_packet_entropy',
    'preictal_time_mean_mean',
    'preictal_time_mean_std',
    'preictal_time_mean_max',
    'preictal_time_mean_min',
    'preictal_time_std_mean',
    'preictal_time_std_std',
    'preictal_time_std_max',
    'preictal_time_std_min',
    'preictal_time_var_mean',
    'preictal_time_var_std',
    'preictal_time_var_max',
    'preictal_time_var_min',
    'preictal_time_skewness_mean',
    'preictal_time_skewness_std',
    'preictal_time_skewness_max',
    'preictal_time_skewness_min',
    'preictal_time_kurtosis_mean',
    'preictal_time_kurtosis_std',
    'preictal_time_kurtosis_max',
    'preictal_time_kurtosis_min',
    'preictal_time_rms_mean',
    'preictal_time_rms_std',
    'preictal_time_rms_max',
    'preictal_time_rms_min',
    'preictal_time_peak_to_peak_mean',
    'preictal_time_peak_to_peak_std',
    'preictal_time_peak_to_peak_max',
    'preictal_time_peak_to_peak_min',
    'preictal_time_zero_crossings_mean',
    'preictal_time_zero_crossings_std',
    'preictal_time_zero_crossings_max',
    'preictal_time_zero_crossings_min',
    'preictal_time_hjorth_activity_mean',
    'preictal_time_hjorth_activity_std',
    'preictal_time_hjorth_activity_max',
    'preictal_time_hjorth_activity_min',
    'preictal_time_hjorth_mobility_mean',
    'preictal_time_hjorth_mobility_std',
    'preictal_time_hjorth_mobility_max',
    'preictal_time_hjorth_mobility_min',
    'preictal_time_hjorth_complexity_mean',
    'preictal_time_hjorth_complexity_std',
    'preictal_time_hjorth_complexity_max',
    'preictal_time_hjorth_complexity_min',
    'preictal_time_line_length_mean',
    'preictal_time_line_length_std',
    'preictal_time_line_length_max',
    'preictal_time_line_length_min',
    'preictal_time_nonlinear_energy_mean',
    'preictal_time_nonlinear_energy_std',
    'preictal_time_nonlinear_energy_max',
    'preictal_time_nonlinear_energy_min',
    'ictal_de_delta_mean',
    'ictal_de_delta_std',
    'ictal_de_delta_median',
    'ictal_de_delta_max',
    'ictal_de_delta_min',
    'ictal_de_delta_asymmetry_mean',
    'ictal_de_delta_asymmetry_std',
    'ictal_de_theta_mean',
    'ictal_de_theta_std',
    'ictal_de_theta_median',
    'ictal_de_theta_max',
    'ictal_de_theta_min',
    'ictal_de_theta_asymmetry_mean',
    'ictal_de_theta_asymmetry_std',
    'ictal_de_alpha_mean',
    'ictal_de_alpha_std',
    'ictal_de_alpha_median',
    'ictal_de_alpha_max',
    'ictal_de_alpha_min',
    'ictal_de_alpha_asymmetry_mean',
    'ictal_de_alpha_asymmetry_std',
    'ictal_de_low_beta_mean',
    'ictal_de_low_beta_std',
    'ictal_de_low_beta_median',
    'ictal_de_low_beta_max',
    'ictal_de_low_beta_min',
    'ictal_de_low_beta_asymmetry_mean',
    'ictal_de_low_beta_asymmetry_std',
    'ictal_de_high_beta_mean',
    'ictal_de_high_beta_std',
    'ictal_de_high_beta_median',
    'ictal_de_high_beta_max',
    'ictal_de_high_beta_min',
    'ictal_de_high_beta_asymmetry_mean',
    'ictal_de_high_beta_asymmetry_std',
    'ictal_de_gamma_mean',
    'ictal_de_gamma_std',
    'ictal_de_gamma_median',
    'ictal_de_gamma_max',
    'ictal_de_gamma_min',
    'ictal_de_gamma_asymmetry_mean',
    'ictal_de_gamma_asymmetry_std',
    'ictal_de_high_gamma_mean',
    'ictal_de_high_gamma_std',
    'ictal_de_high_gamma_median',
    'ictal_de_high_gamma_max',
    'ictal_de_high_gamma_min',
    'ictal_de_high_gamma_asymmetry_mean',
    'ictal_de_high_gamma_asymmetry_std',
    'ictal_wt_level0_energy_mean',
    'ictal_wt_level0_energy_std',
    'ictal_wt_level0_entropy_mean',
    'ictal_wt_level0_entropy_std',
    'ictal_wt_level0_mean_mean',
    'ictal_wt_level0_mean_std',
    'ictal_wt_level0_std_mean',
    'ictal_wt_level0_std_std',
    'ictal_wt_level0_max_mean',
    'ictal_wt_level0_max_std',
    'ictal_wt_level1_energy_mean',
    'ictal_wt_level1_energy_std',
    'ictal_wt_level1_entropy_mean',
    'ictal_wt_level1_entropy_std',
    'ictal_wt_level1_mean_mean',
    'ictal_wt_level1_mean_std',
    'ictal_wt_level1_std_mean',
    'ictal_wt_level1_std_std',
    'ictal_wt_level1_max_mean',
    'ictal_wt_level1_max_std',
    'ictal_wt_level2_energy_mean',
    'ictal_wt_level2_energy_std',
    'ictal_wt_level2_entropy_mean',
    'ictal_wt_level2_entropy_std',
    'ictal_wt_level2_mean_mean',
    'ictal_wt_level2_mean_std',
    'ictal_wt_level2_std_mean',
    'ictal_wt_level2_std_std',
    'ictal_wt_level2_max_mean',
    'ictal_wt_level2_max_std',
    'ictal_wt_level3_energy_mean',
    'ictal_wt_level3_energy_std',
    'ictal_wt_level3_entropy_mean',
    'ictal_wt_level3_entropy_std',
    'ictal_wt_level3_mean_mean',
    'ictal_wt_level3_mean_std',
    'ictal_wt_level3_std_mean',
    'ictal_wt_level3_std_std',
    'ictal_wt_level3_max_mean',
    'ictal_wt_level3_max_std',
    'ictal_wt_level4_energy_mean',
    'ictal_wt_level4_energy_std',
    'ictal_wt_level4_entropy_mean',
    'ictal_wt_level4_entropy_std',
    'ictal_wt_level4_mean_mean',
    'ictal_wt_level4_mean_std',
    'ictal_wt_level4_std_mean',
    'ictal_wt_level4_std_std',
    'ictal_wt_level4_max_mean',
    'ictal_wt_level4_max_std',
    'ictal_wt_level5_energy_mean',
    'ictal_wt_level5_energy_std',
    'ictal_wt_level5_entropy_mean',
    'ictal_wt_level5_entropy_std',
    'ictal_wt_level5_mean_mean',
    'ictal_wt_level5_mean_std',
    'ictal_wt_level5_std_mean',
    'ictal_wt_level5_std_std',
    'ictal_wt_level5_max_mean',
    'ictal_wt_level5_max_std',
    'ictal_wt_packet_entropy',
    'ictal_time_mean_mean',
    'ictal_time_mean_std',
    'ictal_time_mean_max',
    'ictal_time_mean_min',
    'ictal_time_std_mean',
    'ictal_time_std_std',
    'ictal_time_std_max',
    'ictal_time_std_min',
    'ictal_time_var_mean',
    'ictal_time_var_std',
    'ictal_time_var_max',
    'ictal_time_var_min',
    'ictal_time_skewness_mean',
    'ictal_time_skewness_std',
    'ictal_time_skewness_max',
    'ictal_time_skewness_min',
    'ictal_time_kurtosis_mean',
    'ictal_time_kurtosis_std',
    'ictal_time_kurtosis_max',
    'ictal_time_kurtosis_min',
    'ictal_time_rms_mean',
    'ictal_time_rms_std',
    'ictal_time_rms_max',
    'ictal_time_rms_min',
    'ictal_time_peak_to_peak_mean',
    'ictal_time_peak_to_peak_std',
    'ictal_time_peak_to_peak_max',
    'ictal_time_peak_to_peak_min',
    'ictal_time_zero_crossings_mean',
    'ictal_time_zero_crossings_std',
    'ictal_time_zero_crossings_max',
    'ictal_time_zero_crossings_min',
    'ictal_time_hjorth_activity_mean',
    'ictal_time_hjorth_activity_std',
    'ictal_time_hjorth_activity_max',
    'ictal_time_hjorth_activity_min',
    'ictal_time_hjorth_mobility_mean',
    'ictal_time_hjorth_mobility_std',
    'ictal_time_hjorth_mobility_max',
    'ictal_time_hjorth_mobility_min',
    'ictal_time_hjorth_complexity_mean',
    'ictal_time_hjorth_complexity_std',
    'ictal_time_hjorth_complexity_max',
    'ictal_time_hjorth_complexity_min',
    'ictal_time_line_length_mean',
    'ictal_time_line_length_std',
    'ictal_time_line_length_max',
    'ictal_time_line_length_min',
    'ictal_time_nonlinear_energy_mean',
    'ictal_time_nonlinear_energy_std',
    'ictal_time_nonlinear_energy_max',
    'ictal_time_nonlinear_energy_min',
    'postictal_de_delta_mean',
    'postictal_de_delta_std',
    'postictal_de_delta_median',
    'postictal_de_delta_max',
    'postictal_de_delta_min',
    'postictal_de_delta_asymmetry_mean',
    'postictal_de_delta_asymmetry_std',
    'postictal_de_theta_mean',
    'postictal_de_theta_std',
    'postictal_de_theta_median',
    'postictal_de_theta_max',
    'postictal_de_theta_min',
    'postictal_de_theta_asymmetry_mean',
    'postictal_de_theta_asymmetry_std',
    'postictal_de_alpha_mean',
    'postictal_de_alpha_std',
    'postictal_de_alpha_median',
    'postictal_de_alpha_max',
    'postictal_de_alpha_min',
    'postictal_de_alpha_asymmetry_mean',
    'postictal_de_alpha_asymmetry_std',
    'postictal_de_low_beta_mean',
    'postictal_de_low_beta_std',
    'postictal_de_low_beta_median',
    'postictal_de_low_beta_max',
    'postictal_de_low_beta_min',
    'postictal_de_low_beta_asymmetry_mean',
    'postictal_de_low_beta_asymmetry_std',
    'postictal_de_high_beta_mean',
    'postictal_de_high_beta_std',
    'postictal_de_high_beta_median',
    'postictal_de_high_beta_max',
    'postictal_de_high_beta_min',
    'postictal_de_high_beta_asymmetry_mean',
    'postictal_de_high_beta_asymmetry_std',
    'postictal_de_gamma_mean',
    'postictal_de_gamma_std',
    'postictal_de_gamma_median',
    'postictal_de_gamma_max',
    'postictal_de_gamma_min',
    'postictal_de_gamma_asymmetry_mean',
    'postictal_de_gamma_asymmetry_std',
    'postictal_de_high_gamma_mean',
    'postictal_de_high_gamma_std',
    'postictal_de_high_gamma_median',
    'postictal_de_high_gamma_max',
    'postictal_de_high_gamma_min',
    'postictal_de_high_gamma_asymmetry_mean',
    'postictal_de_high_gamma_asymmetry_std',
    'postictal_wt_level0_energy_mean',
    'postictal_wt_level0_energy_std',
    'postictal_wt_level0_entropy_mean',
    'postictal_wt_level0_entropy_std',
    'postictal_wt_level0_mean_mean',
    'postictal_wt_level0_mean_std',
    'postictal_wt_level0_std_mean',
    'postictal_wt_level0_std_std',
    'postictal_wt_level0_max_mean',
    'postictal_wt_level0_max_std',
    'postictal_wt_level1_energy_mean',
    'postictal_wt_level1_energy_std',
    'postictal_wt_level1_entropy_mean',
    'postictal_wt_level1_entropy_std',
    'postictal_wt_level1_mean_mean',
    'postictal_wt_level1_mean_std',
    'postictal_wt_level1_std_mean',
    'postictal_wt_level1_std_std',
    'postictal_wt_level1_max_mean',
    'postictal_wt_level1_max_std',
    'postictal_wt_level2_energy_mean',
    'postictal_wt_level2_energy_std',
    'postictal_wt_level2_entropy_mean',
    'postictal_wt_level2_entropy_std',
    'postictal_wt_level2_mean_mean',
    'postictal_wt_level2_mean_std',
    'postictal_wt_level2_std_mean',
    'postictal_wt_level2_std_std',
    'postictal_wt_level2_max_mean',
    'postictal_wt_level2_max_std',
    'postictal_wt_level3_energy_mean',
    'postictal_wt_level3_energy_std',
    'postictal_wt_level3_entropy_mean',
    'postictal_wt_level3_entropy_std',
    'postictal_wt_level3_mean_mean',
    'postictal_wt_level3_mean_std',
    'postictal_wt_level3_std_mean',
    'postictal_wt_level3_std_std',
    'postictal_wt_level3_max_mean',
    'postictal_wt_level3_max_std',
    'postictal_wt_level4_energy_mean',
    'postictal_wt_level4_energy_std',
    'postictal_wt_level4_entropy_mean',
    'postictal_wt_level4_entropy_std',
    'postictal_wt_level4_mean_mean',
    'postictal_wt_level4_mean_std',
    'postictal_wt_level4_std_mean',
    'postictal_wt_level4_std_std',
    'postictal_wt_level4_max_mean',
    'postictal_wt_level4_max_std',
    'postictal_wt_level5_energy_mean',
    'postictal_wt_level5_energy_std',
    'postictal_wt_level5_entropy_mean',
    'postictal_wt_level5_entropy_std',
    'postictal_wt_level5_mean_mean',
    'postictal_wt_level5_mean_std',
    'postictal_wt_level5_std_mean',
    'postictal_wt_level5_std_std',
    'postictal_wt_level5_max_mean',
    'postictal_wt_level5_max_std',
    'postictal_wt_packet_entropy',
    'postictal_time_mean_mean',
    'postictal_time_mean_std',
    'postictal_time_mean_max',
    'postictal_time_mean_min',
    'postictal_time_std_mean',
    'postictal_time_std_std',
    'postictal_time_std_max',
    'postictal_time_std_min',
    'postictal_time_var_mean',
    'postictal_time_var_std',
    'postictal_time_var_max',
    'postictal_time_var_min',
    'postictal_time_skewness_mean',
    'postictal_time_skewness_std',
    'postictal_time_skewness_max',
    'postictal_time_skewness_min',
    'postictal_time_kurtosis_mean',
    'postictal_time_kurtosis_std',
    'postictal_time_kurtosis_max',
    'postictal_time_kurtosis_min',
    'postictal_time_rms_mean',
    'postictal_time_rms_std',
    'postictal_time_rms_max',
    'postictal_time_rms_min',
    'postictal_time_peak_to_peak_mean',
    'postictal_time_peak_to_peak_std',
    'postictal_time_peak_to_peak_max',
    'postictal_time_peak_to_peak_min',
    'postictal_time_zero_crossings_mean',
    'postictal_time_zero_crossings_std',
    'postictal_time_zero_crossings_max',
    'postictal_time_zero_crossings_min',
    'postictal_time_hjorth_activity_mean',
    'postictal_time_hjorth_activity_std',
    'postictal_time_hjorth_activity_max',
    'postictal_time_hjorth_activity_min',
    'postictal_time_hjorth_mobility_mean',
    'postictal_time_hjorth_mobility_std',
    'postictal_time_hjorth_mobility_max',
    'postictal_time_hjorth_mobility_min',
    'postictal_time_hjorth_complexity_mean',
    'postictal_time_hjorth_complexity_std',
    'postictal_time_hjorth_complexity_max',
    'postictal_time_hjorth_complexity_min',
    'postictal_time_line_length_mean',
    'postictal_time_line_length_std',
    'postictal_time_line_length_max',
    'postictal_time_line_length_min',
    'postictal_time_nonlinear_energy_mean',
    'postictal_time_nonlinear_energy_std',
    'postictal_time_nonlinear_energy_max',
    'postictal_time_nonlinear_energy_min',
    'seizure_start_seconds',
    'seizure_end_seconds',
    'seizure_duration',
    'processing_success',
    'error_message',
    'mean_propagation_speed',
    'median_propagation_speed',
    'std_propagation_speed',
    'max_propagation_speed',
    'min_propagation_speed',
    'num_propagation_events',
    'mean_onset_delay',
    'max_onset_delay'
]

In [191]:
encoded_df = encoded_df.select(features)

In [192]:
# Train all classifiers
results = train_all_classifiers(encoded_df)


Training classifier for: seizure_type
y labeled shape: (47,)
x labeled shape: (47, 736)
x unlabeled shape: (198, 736)
Number of classes detected: 3
Classes: ['FBTC' 'IAS' 'WIAS']
X train shape: (37, 736)
Y train encoded shape: (37,)
y pred shape: (10,)
Classification Report:
              precision    recall  f1-score   support

        FBTC       1.00      1.00      1.00         2
         IAS       1.00      1.00      1.00         7
        WIAS       1.00      1.00      1.00         1

    accuracy                           1.00        10
   macro avg       1.00      1.00      1.00        10
weighted avg       1.00      1.00      1.00        10


Confusion Matrix:
[[2 0 0]
 [0 7 0]
 [0 0 1]]
Top 10 Most Important Features:
time_zero_crossings_std: 0.3436
ictal_wt_level5_entropy_std: 0.1565
age: 0.1543
pac_theta_high_gamma: 0.0859
de_theta_asymmetry_mean: 0.0796
permutation_entropy_std: 0.0327
ictal_wt_level3_entropy_std: 0.0327
ictal_time_hjorth_complexity_max: 0.0308
ictal_time_ze

In [193]:
# Access individual results
seizure_type_model = results['seizure_type']['model']
seizure_type_predictions = results['seizure_type']['predictions']