In [84]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report
from sklearn.metrics import classification_report, confusion_matrix
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder, StandardScaler
import seaborn as sns

In [24]:
import matplotlib as mpl
import pylab
mpl.rcParams['lines.linewidth'] = 2
mpl.rcParams['lines.color'] = 'r'
mpl.rcParams['font.weight'] = 200
plt.style.use('seaborn-whitegrid')
plt.rc('figure',figsize=(8,8))
mpl.axes.Axes.annotate
mpl.rcParams['font.family'] = "serif"
pylab.rcParams['ytick.major.pad']='15'
pylab.rcParams['xtick.major.pad']='15'
mpl.rcParams['font.weight'] = "semibold"
mpl.rcParams['axes.labelsize'] = 20
mpl.rcParams['axes.linewidth'] = 4
mpl.rcParams['xtick.labelsize'] = 20
mpl.rcParams['ytick.labelsize'] = 20
mpl.rcParams['axes.edgecolor'] = 'black'
mpl.rcParams['axes.titlesize'] = 20
mpl.rcParams['legend.fontsize'] = 15

In [58]:
def plot_cm(y_true, y_pred, labels, class_names=None, normalize='true', png_path=None, show=False):
    cm = confusion_matrix(y_true, y_pred, labels=labels, normalize=normalize)
    fig, ax = plt.subplots(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt=".2f", cmap="Blues",
                xticklabels=class_names if class_names else labels,
                yticklabels=class_names if class_names else labels)
    ax.set_xlabel("Predicted Label")
    ax.set_ylabel("True Label")
    if png_path:
        plt.savefig(png_path, bbox_inches='tight')
    if show:
        plt.show()
    plt.close(fig)
    return cm

In [59]:
def load_data(file_path):
    data = pd.read_csv(file_path)
    return data

In [60]:
def preprocess_data(data):

    #X = data.drop(columns=["Cognitive_Load_Label"])
    X = data.drop(columns=["label"])
    y = data["label"]
    #y = data["Cognitive_Load_Label"]
    
    
    label_encoder = LabelEncoder()
    y = label_encoder.fit_transform(y)
    
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    
    return X, y, label_encoder

In [61]:
def split_data(X, y, test_size=0.20, random_state=42):

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
    return X_train, X_test, y_train, y_test

In [62]:
def analyze_labels(data, label_column):
    unique_classes = data[label_column].unique()  # Find unique classes
    class_counts = data[label_column].value_counts()  # Count samples per class
    
    print(f"Number of unique classes: {len(unique_classes)}")
    print("\nClasses and their sample counts:")
    print(class_counts)

In [63]:
data = load_data("EEGfeatures_with_labels.csv")
print("Data loaded successfully.")

Data loaded successfully.


In [65]:
#data = load_data("PPGfeatures_with_labels.csv")
#print("Data loaded successfully.")

In [66]:
data.head(5)

Unnamed: 0,ShannonEntropy_0,ShannonEntropy_1,ShannonEntropy_2,ShannonEntropy_3,MedianFreq_0,MedianFreq_1,MedianFreq_2,MedianFreq_3,Std_0,Std_1,...,BandPower_beta_1,BandPower_beta_2,BandPower_beta_3,BandPower_gamma_0,BandPower_gamma_1,BandPower_gamma_2,BandPower_gamma_3,MI_0,PLI_0,label
0,6.360649,6.306755,10.002474,7.191911,85.8,114.6,56.6,84.4,32.061341,27.603852,...,1.444642,120.873901,0.503801,0.625572,3.308251,938.98337,0.957668,0.838047,0.117057,1
1,6.71883,5.959228,9.984402,7.338912,83.0,31.1,76.1,38.9,32.026274,23.0107,...,1.063105,126.451029,0.416003,0.650411,2.372382,943.085065,0.943409,3.412117,0.027865,1
2,6.614333,5.946551,9.995242,7.332993,90.6,95.3,100.4,54.3,30.748537,22.257319,...,0.93713,108.797352,0.566946,0.683323,2.175736,954.340933,1.002169,1.640939,0.027344,1
3,6.181034,5.897674,10.036926,7.025405,82.4,51.2,20.4,58.9,24.771528,24.504576,...,1.245104,135.301583,0.321927,0.458879,3.001322,948.288957,0.813986,2.610624,0.126172,1
4,6.65162,6.004883,10.006555,7.496894,101.0,84.4,27.5,18.3,32.95474,23.439127,...,1.182783,115.742823,0.548304,0.625329,2.436957,961.504731,1.194271,1.051656,0.120182,1


In [67]:
print(data.columns)

Index(['ShannonEntropy_0', 'ShannonEntropy_1', 'ShannonEntropy_2',
       'ShannonEntropy_3', 'MedianFreq_0', 'MedianFreq_1', 'MedianFreq_2',
       'MedianFreq_3', 'Std_0', 'Std_1', 'Std_2', 'Std_3',
       'ShannonEntropy_delta_0', 'ShannonEntropy_delta_1',
       'ShannonEntropy_delta_2', 'ShannonEntropy_delta_3',
       'ShannonEntropy_theta_0', 'ShannonEntropy_theta_1',
       'ShannonEntropy_theta_2', 'ShannonEntropy_theta_3',
       'ShannonEntropy_alpha_0', 'ShannonEntropy_alpha_1',
       'ShannonEntropy_alpha_2', 'ShannonEntropy_alpha_3',
       'ShannonEntropy_beta_0', 'ShannonEntropy_beta_1',
       'ShannonEntropy_beta_2', 'ShannonEntropy_beta_3',
       'ShannonEntropy_gamma_0', 'ShannonEntropy_gamma_1',
       'ShannonEntropy_gamma_2', 'ShannonEntropy_gamma_3', 'HjorthMob_0',
       'HjorthMob_1', 'HjorthMob_2', 'HjorthMob_3', 'HjorthComp_0',
       'HjorthComp_1', 'HjorthComp_2', 'HjorthComp_3', 'BandPower_alpha_0',
       'BandPower_alpha_1', 'BandPower_alpha_2', 'Band

In [68]:
X, y, label_encoder = preprocess_data(data)
print("Data preprocessed successfully.")

Data preprocessed successfully.


In [69]:
y

array([1, 1, 1, ..., 1, 1, 1])

In [70]:
X.shape

(2520, 54)

In [71]:
y.shape

(2520,)

In [72]:
#analyze_labels(data, 'Cognitive_Load_Label') 
analyze_labels(data, 'label') 

Number of unique classes: 3

Classes and their sample counts:
1    1317
0     728
2     475
Name: label, dtype: int64


In [73]:
X.shape

(2520, 54)

In [74]:
X_train, X_test, y_train, y_test = split_data(X, y)
print("Data split successfully.")

Data split successfully.


In [75]:
print(X_train.shape, X_test.shape)

(2016, 54) (504, 54)


In [76]:
print(y_train.shape, y_test.shape)

(2016,) (504,)


In [77]:
print(type(X_train), type(y_train), type(X_test), type(y_test))

<class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'>


In [80]:
def train_models(X_train, X_test, y_train, y_test, class_map, save_dir="cm_outputs"):
    os.makedirs(save_dir, exist_ok=True)
    random_state = 0

    models = {
        "NB": GaussianNB(),
        "RF": RandomForestClassifier(n_estimators=400, n_jobs=-1, random_state=random_state),
        "MLP": Pipeline([
            ("scaler", StandardScaler()),
            ("clf", MLPClassifier(hidden_layer_sizes=(128, 64), activation="relu", max_iter=400, random_state=random_state))
        ]),
        "DT": DecisionTreeClassifier(random_state=random_state),
        "SVM": Pipeline([
            ("scaler", StandardScaler()),
            ("clf", SVC(kernel="linear", probability=True, random_state=random_state))
        ])
    }

    results = {}
    for name, model in models.items():
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)

        print(f"\n=== {name} ===")
        print(classification_report(y_test, y_pred, target_names=[class_map[i] for i in sorted(class_map)]))

        labels = sorted(class_map)
        class_names = [class_map[i] for i in labels]

        cm_path = os.path.join(save_dir, f"{name.lower()}_cm_norm.png")
        cm = plot_cm(y_test, y_pred, labels=labels, class_names=class_names, normalize='true', png_path=cm_path)

        results[name] = {
            "report": classification_report(y_test, y_pred, target_names=class_names, output_dict=True),
            "cm": cm,
            "cm_path": cm_path
        }

    return results

In [None]:
class_map = {0: "Low", 1: "Medium", 2: "High"}
results = train_models(X_train, X_test, y_train, y_test, class_map)

# Optional: extract and display F1 scores
f1_scores = {
    model: {
        str(k): v["f1-score"] for k, v in result["report"].items() if isinstance(v, dict)
    }
    for model, result in results.items()
}
f1_df = pd.DataFrame(f1_scores).T
print("\nMacro F1 Scores:")
print(f1_df)


=== NB ===
              precision    recall  f1-score   support

         Low       0.40      0.78      0.53       152
      Medium       0.68      0.41      0.51       254
        High       0.54      0.29      0.37        98

    accuracy                           0.50       504
   macro avg       0.54      0.49      0.47       504
weighted avg       0.57      0.50      0.49       504


=== RF ===
              precision    recall  f1-score   support

         Low       0.87      0.64      0.73       152
      Medium       0.70      0.92      0.79       254
        High       0.85      0.51      0.64        98

    accuracy                           0.75       504
   macro avg       0.80      0.69      0.72       504
weighted avg       0.78      0.75      0.75       504

