In [4]:
import os
import pywt
import torch
import pickle
import numpy as np
from scipy.signal import resample

### Checking the path to the directories

In [5]:
train_out_dir = '/Volumes/PHILIPS/train_files_TUEV'
eval_out_dir = '/Volumes/PHILIPS/test_files_TUEV'

data_path = "/media/public/Datasets/TUEV/tuev/edf/processed_banana_half"

train_files = os.listdir(data_path + '/processed_train_banana')
val_files = os.listdir(data_path + '/processed_eval_banana')
test_files = os.listdir(data_path + '/processed_test_banana')

print(f'length of train files: {len(train_files)}')
print(f'length of eval files: {len(val_files)}')
print(f'length of test files: {len(test_files)}')

length of train files: 65290
length of eval files: 18642
length of test files: 28305


### Wavelet tranformation for a single file

In [6]:
test_file = data_path + '/processed_train_banana/aaaaablw_00000001-0.pkl' # path to a pickle file
sample = pickle.load(open(os.path.join(test_file), "rb"))

X = sample["signal"]
coeffs = pywt.dwt(X, 'haar')  # Perform discrete Haar wavelet transform
X = coeffs[0]

Y = int(sample["label"][0] - 1)

In [7]:
print(X.shape)
print(type(X))

(8, 500)
<class 'numpy.ndarray'>


In [10]:
class TUEVLoader(torch.utils.data.Dataset):
    def __init__(self, root, files, sampling_rate=200):
        self.root = root
        self.files = files
        self.default_rate = 200
        self.sampling_rate = sampling_rate

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        sample = pickle.load(open(os.path.join(self.root, self.files[index]), "rb"))
        X = sample["signal"]
        if self.sampling_rate != self.default_rate:
            X = resample(X, 5 * self.sampling_rate, axis=-1)

        coefficients = pywt.dwt(X, 'haar')  # Perform discrete Haar wavelet transform
        X = coefficients[0]
        Y = int(sample["label"][0] - 1)
        
        return X, Y
    

def prepare_TUEV_dataset():
    # set random seed
    seed = 4523
    np.random.seed(seed)

    # path to train, val, test files. Might need to be changed depending on your file organisation
    # train_files = os.listdir("/Volumes/PHILIPS/train_files_TUEV/train_files")
    # val_files = os.listdir("/Volumes/PHILIPS/train_files_TUEV/eval_files")
    # test_files = os.listdir("/Volumes/PHILIPS/test_files_TUEV/test_files")
    data_path = "/media/public/Datasets/TUEV/tuev/edf/processed_banana_half"
    
    train_files = os.listdir(data_path + '/processed_train_banana')
    val_files = os.listdir(data_path + '/processed_eval_banana')
    test_files = os.listdir(data_path + '/processed_test_banana')

    # prepare training and test data loader
    train_dataset = TUEVLoader(
        os.path.join(data_path + '/processed_train_banana'), train_files
    )
    test_dataset = TUEVLoader(
        os.path.join(data_path + '/processed_test_banana'), test_files
    )
    val_dataset = TUEVLoader(
        os.path.join(data_path + '/processed_eval_banana'), val_files
    )
    print(len(train_files), len(val_files), len(test_files))
    return train_dataset, test_dataset, val_dataset

In [11]:
def get_TUEV_dataset():
    train_dataset, test_dataset, val_dataset = prepare_TUEV_dataset()
    ch_names = ['EEG FP1-REF', 'EEG FP2-REF', 'EEG F3-REF', 'EEG F4-REF', 'EEG C3-REF', 'EEG C4-REF', 'EEG P3-REF', 'EEG P4-REF', 'EEG O1-REF', 'EEG O2-REF', 'EEG F7-REF', \
                'EEG F8-REF', 'EEG T3-REF', 'EEG T4-REF', 'EEG T5-REF', 'EEG T6-REF', 'EEG A1-REF', 'EEG A2-REF', 'EEG FZ-REF', 'EEG CZ-REF', 'EEG PZ-REF', 'EEG T1-REF', 'EEG T2-REF']
    ch_names_after_convert = ['FP1-F7', 'F7-T3', 'T3-T5', 'T5-O1',
                              'FP2-F8', 'F8-T4', 'T4-T6', 'T6-O2',
                              'FP1-F3', 'F3-C3', 'C3-P3', 'P3-O1',
                              'FP2-F4', 'F4-C4', 'C4-P4', 'P4-O2']

    new_ch_names = ["FP1-F7", "F7-T7", "T7-P7", "P7-O1",
                    "FP2-F8", "F8-T8", "T8-P8", "P8-O2",
                    "FP1-F3", "F3-C3", "C3-P3", "P3-O1",
                    "FP2-F4", "F4-C4", "C4-P4", "P4-O2"]

    new_ch_names_to_128 = ["FP1-F7", "F7-T7", "T7-P7", "P7-O1",
                    "FP2-F8", "F8-T8", "T8-P8", "P8-O2"]


    ch_names = [name.split(' ')[-1].split('-')[0] for name in ch_names_after_convert]
    # args.nb_classes = 6
    metrics = ["accuracy", "balanced_accuracy", "cohen_kappa"]
    return train_dataset, test_dataset, val_dataset, new_ch_names_to_128, metrics

In [12]:
dataset_train, dataset_test, dataset_val, ch_names, metrics = get_TUEV_dataset()

65290 18642 28305


In [13]:
X_list, y_list = [], []
for X_batch, y_batch in dataset_train:
    X_list.append(X_batch)
    y_list.append(y_batch)

In [14]:
X_list_test, y_list_test = [], []
for X_batch, y_batch in dataset_test:
    X_list_test.append(X_batch)
    y_list_test.append(y_batch)
    
X_list_eval, y_list_eval = [], []
for X_batch, y_batch in dataset_val:
    X_list_eval.append(X_batch)
    y_list_eval.append(y_batch)

In [15]:
print(len(X_list))
print(len(X_list[0]))
print(len(X_list_test))
print(len(X_list_test[0]))
print(len(X_list_eval))
print(len(X_list_eval[0]))

65290
8
28305
8
18642
8


In [19]:
tmp = np.array(X_list)
print(tmp.shape)
X = tmp.reshape(65290, 4000)
print(X.shape)
y_list = np.array(y_list)
print(y_list.shape)

(65290, 8, 500)
(65290, 4000)
(65290,)


In [22]:
tmp = np.array(X_list_test)
print(tmp.shape)
X_test = tmp.reshape(28305, 4000)
print(X_test.shape)
y_list_test = np.array(y_list_test)
print(y_list_test.shape)

(28305, 8, 500)
(28305, 4000)
(28305,)


In [23]:
tmp = np.array(X_list_eval)
X_eval = tmp.reshape(18642, 4000)
print(X_eval.shape)
y_list_eval = np.array(y_list_eval)
print(y_list_eval.shape)

(18642, 4000)
(18642,)


In [30]:
data_path = "/media/public/Datasets/TUEV/tuev/edf/wavelet_preprocess_half_banana/"
np.save(data_path + "X_train_values_DWT.npy", X)
np.save(data_path + "y_train_values_DWT.npy", y_list)

np.save(data_path + "/X_test_values_DWT.npy", X_test)
np.save(data_path + "/y_test_values_DWT.npy", y_list_test)

np.save(data_path + "/X_val_values_DWT.npy", X_eval)
np.save(data_path + "/y_val_values_DWT.npy", y_list_eval)

### XGBoost

In [26]:
import pickle
import numpy as np
from sklearn.ensemble import GradientBoostingClassifier

In [31]:
data_path = "/media/public/Datasets/TUEV/tuev/edf/wavelet_preprocess_half_banana/"

X_train = np.load(data_path + "X_train_values_DWT.npy")
y_train = np.load(data_path + "y_train_values_DWT.npy")

X_test = np.load(data_path + "X_test_values_DWT.npy")
y_test = np.load(data_path + "y_test_values_DWT.npy")

X_eval = np.load(data_path + "X_val_values_DWT.npy")
y_eval = np.load(data_path + "y_val_values_DWT.npy")

In [None]:
xgb_clf = GradientBoostingClassifier()
xgb_clf.fit(X_train, y_train)

with open("xgb_model.pkl", "wb") as file:
    pickle.dump(xgb_clf, file)

### Test metrics


In [15]:
from sklearn.metrics import hamming_loss, accuracy_score, classification_report
from sklearn.metrics import precision_recall_fscore_support

y_pred = xgb_clf.predict(X_test)

# Hamming Loss
print("Hamming Loss:", hamming_loss(y_test, y_pred))

# Accuracy
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

# Classification Report
# report = classification_report(y_test, y_pred, target_names=[f"Class {i}" for i in range(y_test.shape[1])])
# print("Classification Report:\n", report)

# Precision, Recall, F1
precision, recall, f1, _ = precision_recall_fscore_support(y_test, y_pred, average='macro')
print(f"Precision: {precision}, Recall: {recall}, F1-Score: {f1}")


Hamming Loss: 0.2576929871047518
Accuracy: 0.7423070128952481
Precision: 0.4446869639219509, Recall: 0.4067162525821333, F1-Score: 0.4165019914809527


In [18]:
# Classification Report
report = classification_report(y_test, y_pred, target_names=['spsw', 'gped', 'pled', 'eyem', 'artf', 'backg'])
print("Classification Report:\n", report)

Classification Report:
               precision    recall  f1-score   support

        spsw       0.05      0.03      0.04       567
        gped       0.66      0.46      0.54      3561
        pled       0.29      0.15      0.20      1998
        eyem       0.38      0.48      0.42       329
        artf       0.47      0.40      0.43      2204
       backg       0.81      0.92      0.86     19646

    accuracy                           0.74     28305
   macro avg       0.44      0.41      0.42     28305
weighted avg       0.71      0.74      0.72     28305



### Train metrics

In [21]:
train_predict = xgb_clf.predict(X_train)
train_report = classification_report(y_train, train_predict, target_names=['spsw', 'gped', 'pled', 'eyem', 'artf', 'backg'])
print("Classification Report:\n", train_report)

Classification Report:
               precision    recall  f1-score   support

        spsw       0.97      0.98      0.97       475
        gped       0.97      0.87      0.92     10654
        pled       0.97      0.73      0.83      4683
        eyem       0.92      0.87      0.89       977
        artf       0.98      0.80      0.88      9870
       backg       0.91      1.00      0.95     43187

    accuracy                           0.93     69846
   macro avg       0.95      0.87      0.91     69846
weighted avg       0.93      0.93      0.93     69846



### Eval metrics

In [23]:
eval_predict = xgb_clf.predict(X_eval)
eval_report = classification_report(y_eval, eval_predict, target_names=['spsw', 'gped', 'pled', 'eyem', 'artf', 'backg'])
print("Classification Report:\n", eval_report)

Classification Report:
               precision    recall  f1-score   support

        spsw       0.04      0.04      0.04       170
        gped       0.38      0.45      0.41       600
        pled       0.89      0.37      0.52      1501
        eyem       0.22      0.42      0.29        93
        artf       0.55      0.33      0.42      1183
       backg       0.86      0.95      0.90     10539

    accuracy                           0.80     14086
   macro avg       0.49      0.43      0.43     14086
weighted avg       0.80      0.80      0.79     14086



### Binary classification 

In [27]:
from sklearn.metrics import balanced_accuracy_score

group_1 = {1, 2, 3}  # Group 1 (mapped to 0)
group_2 = {4, 5, 6}  # Group 2 (mapped to 1)

true_labels = [0 if cls in group_1 else 1 for cls in y_test]
predicted_labels = [0 if cls in group_1 else 1 for cls in y_pred]

balanced_acc = balanced_accuracy_score(true_labels, predicted_labels)
print("Balanced Accuracy:", balanced_acc)

Balanced Accuracy: 0.7182684758243453
