In [1]:
import mne
import pywt
from autoreject import AutoReject

import scipy.stats
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report

In [2]:
mne.viz.set_browser_backend('matplotlib')

Using matplotlib as 2D backend.


## Data Loading

In [11]:
X = []
y = []

files = [
    "../data/edf/autism/Bader_Autism_24_11_2011S001R01.edf",
    "../data/edf/autism/Bader_Autism_24_11_2011S001R09.edf",
    "../data/edf/autism/Bader_Autism_24_11_2011S001R10.edf",
    "../data/edf/autism/Mohammed_Autism_9_11_2011S001R01.edf",
    "../data/edf/autism/Nour_Autism_2_10_2011S001R01.edf",
    "../data/edf/autism/Nour_Autism_2_10_2011S001R02.edf",
    "../data/edf/autism/Saud_Autism_1_5_2011S001R01.edf",
    "../data/edf/autism/Shahad_Autism_5_6_2011S001R01.edf",
    "../data/edf/autism/Yahia_Autism_1_5_2011S001R01.edf",
    "../data/edf/autism/Yahia_Autism_1_5_2011S001R02.edf",


    "../data/edf/normal/Amer_Normal_5_5_2011S001R01.edf",
    "../data/edf/normal/Amer_Normal_5_5_2011S001R02.edf",
    "../data/edf/normal/Amer_Normal_5_5_2011S001R03.edf",
    "../data/edf/normal/Dhelal_Normal_15_6_2011S001R01.edf",
    "../data/edf/normal/Dhelal_Normal_15_6_2011S001R02.edf",
    "../data/edf/normal/Mahmud_Normal_5_5_2011S001R01.edf",
    "../data/edf/normal/Mahmud_Normal_5_5_2011S001R02.edf",
    "../data/edf/normal/Omran_Normal_5_5_2011S001R01.edf",
]

# files = glob.glob("../data/sampled/**/*.edf")

for fpath in files:
    try:
        print(fpath)
    
        # load data
        raw = mne.io.read_raw_edf(fpath, verbose=False)
        raw.rename_channels({"FP2": "Fp2"}, verbose=False)
        raw.set_montage("standard_1020", verbose=False)

        # band-pass filter
        raw_filt = raw.copy().load_data().filter(l_freq=0.1, h_freq=60, verbose=False)

        # notch filter
        raw_notch = raw_filt.copy().load_data().notch_filter(freqs=(60), verbose=False)

        # autoreject
        raw_autoreject = raw_notch.copy()

        # epoch data
        epochs = mne.make_fixed_length_epochs(raw_autoreject, duration=10, preload=True, verbose=False)
        if len(epochs) < 5:
            print("SKIPPED:", fpath)
            continue
        
        # perform autoreject
        ar = AutoReject(n_interpolate=None, random_state=11, n_jobs=1, verbose=False)
        epochs_ar, reject_log = ar.fit_transform(epochs, return_log=True)

        # perform ICA
        ica = mne.preprocessing.ICA(random_state=99, verbose=False)
        ica.fit(epochs[~reject_log.bad_epochs], verbose=False)
        
        ica.exclude = []
        ica.apply(epochs, exclude=ica.exclude, verbose=False)

        # get data
        data = epochs[~reject_log.bad_epochs].get_data(copy=True, verbose=False) * 1e6
        target_class = 1 if "Autism" in fpath else 0

        X.append(data)
        y.extend([target_class] * data.shape[0])
    except Exception as e:
        print("ERROR:", fpath)
        print(e)
    
    print("---------------------------------------------")

../data/edf/autism/Bader_Autism_24_11_2011S001R01.edf
Reading 0 ... 47087  =      0.000 ...   183.922 secs...


  radius_init = radii.mean()
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


Dropped 5 epochs: 6, 8, 9, 10, 12
---------------------------------------------
../data/edf/autism/Bader_Autism_24_11_2011S001R09.edf
Reading 0 ... 212143  =      0.000 ...   828.631 secs...


  radius_init = radii.mean()
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


Dropped 45 epochs: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 40, 41, 46, 47, 51, 52, 56, 57, 58, 59, 60, 66, 67, 68, 69, 70, 71, 72, 75, 76, 77, 78, 79, 80, 81
---------------------------------------------
../data/edf/autism/Bader_Autism_24_11_2011S001R10.edf
Reading 0 ... 88023  =      0.000 ...   343.818 secs...


  radius_init = radii.mean()
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


Dropped 20 epochs: 0, 2, 3, 4, 5, 6, 8, 9, 10, 11, 14, 15, 23, 24, 25, 26, 27, 28, 29, 33
---------------------------------------------
../data/edf/autism/Mohammed_Autism_9_11_2011S001R01.edf
Reading 0 ... 283439  =      0.000 ...  1107.113 secs...
Dropped 11 epochs: 34, 39, 43, 53, 54, 59, 67, 84, 86, 91, 100
---------------------------------------------
../data/edf/autism/Nour_Autism_2_10_2011S001R01.edf
Reading 0 ... 196015  =      0.000 ...   765.635 secs...


  radius_init = radii.mean()
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


Dropped 33 epochs: 14, 15, 16, 17, 18, 21, 22, 27, 35, 36, 42, 43, 46, 48, 51, 54, 55, 56, 57, 60, 61, 62, 63, 64, 65, 68, 69, 70, 71, 72, 73, 74, 75
---------------------------------------------
../data/edf/autism/Nour_Autism_2_10_2011S001R02.edf
Reading 0 ... 78831  =      0.000 ...   307.914 secs...


  radius_init = radii.mean()
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


Dropped 2 epochs: 11, 13
---------------------------------------------
../data/edf/autism/Saud_Autism_1_5_2011S001R01.edf
Reading 0 ... 192095  =      0.000 ...   750.323 secs...


  radius_init = radii.mean()
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


Dropped 29 epochs: 2, 7, 8, 9, 10, 11, 15, 16, 17, 18, 26, 27, 28, 30, 33, 34, 35, 41, 42, 43, 44, 45, 52, 53, 65, 68, 69, 71, 74
---------------------------------------------
../data/edf/autism/Shahad_Autism_5_6_2011S001R01.edf
Reading 0 ... 314903  =      0.000 ...  1230.011 secs...


  radius_init = radii.mean()
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


Dropped 85 epochs: 0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 14, 15, 16, 17, 18, 19, 23, 24, 26, 27, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 48, 49, 51, 54, 55, 56, 57, 59, 60, 61, 62, 63, 67, 68, 69, 70, 76, 77, 78, 79, 80, 81, 82, 83, 84, 88, 91, 93, 97, 100, 101, 102, 103, 104, 105, 106, 110, 111, 112, 113, 114, 115, 116, 117, 119, 121, 122
---------------------------------------------
../data/edf/autism/Yahia_Autism_1_5_2011S001R01.edf
Reading 0 ... 55663  =      0.000 ...   217.420 secs...
Dropped 7 epochs: 1, 3, 6, 11, 14, 15, 18
---------------------------------------------
../data/edf/autism/Yahia_Autism_1_5_2011S001R02.edf
Reading 0 ... 185527  =      0.000 ...   724.668 secs...
Dropped 11 epochs: 1, 3, 4, 9, 14, 29, 45, 52, 61, 65, 69
---------------------------------------------
../data/edf/normal/Amer_Normal_5_5_2011S001R01.edf
Reading 0 ... 84783  =      0.000 ...   331.162 secs...


  radius_init = radii.mean()
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


Dropped 1 epoch: 5
---------------------------------------------
../data/edf/normal/Amer_Normal_5_5_2011S001R02.edf
Reading 0 ... 93927  =      0.000 ...   366.879 secs...
Dropped 10 epochs: 1, 9, 12, 18, 19, 20, 23, 29, 31, 32
---------------------------------------------
../data/edf/normal/Amer_Normal_5_5_2011S001R03.edf
Reading 0 ... 81239  =      0.000 ...   317.320 secs...
Dropped 4 epochs: 5, 8, 18, 29
---------------------------------------------
../data/edf/normal/Dhelal_Normal_15_6_2011S001R01.edf
Reading 0 ... 77183  =      0.000 ...   301.477 secs...


  radius_init = radii.mean()
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


Dropped 4 epochs: 0, 5, 9, 11
---------------------------------------------
../data/edf/normal/Dhelal_Normal_15_6_2011S001R02.edf
Reading 0 ... 230791  =      0.000 ...   901.470 secs...


  radius_init = radii.mean()
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


Dropped 32 epochs: 2, 3, 4, 5, 12, 16, 23, 24, 25, 27, 31, 32, 34, 35, 43, 46, 50, 51, 57, 61, 62, 73, 74, 76, 77, 78, 81, 82, 84, 86, 87, 88
---------------------------------------------
../data/edf/normal/Mahmud_Normal_5_5_2011S001R01.edf
Reading 0 ... 116575  =      0.000 ...   455.342 secs...


  radius_init = radii.mean()
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


Dropped 24 epochs: 13, 14, 18, 19, 20, 21, 25, 26, 27, 28, 29, 31, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44
---------------------------------------------
../data/edf/normal/Mahmud_Normal_5_5_2011S001R02.edf
Reading 0 ... 81943  =      0.000 ...   320.069 secs...


  radius_init = radii.mean()
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


Dropped 9 epochs: 4, 5, 17, 22, 23, 24, 25, 26, 27
---------------------------------------------
../data/edf/normal/Omran_Normal_5_5_2011S001R01.edf
Reading 0 ... 231143  =      0.000 ...   902.845 secs...


  radius_init = radii.mean()
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


Dropped 26 epochs: 4, 6, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 29, 30, 36, 41, 42, 52, 54, 57, 58, 61, 62, 63, 71, 85
---------------------------------------------


In [12]:
X_arr = np.concatenate(X)
y_arr = np.array(y)

X_arr.shape, y_arr.shape

((670, 16, 2560), (670,))

## Preprocessing

In [13]:
def calculate_statistics(list_values):
    return [
        np.nanpercentile(list_values, 5),
        np.nanpercentile(list_values, 25),
        np.nanpercentile(list_values, 75),
        np.nanpercentile(list_values, 95),
        np.nanpercentile(list_values, 50),
        np.nanmean(list_values),
        np.nanstd(list_values),
        np.nanvar(list_values),
        np.nanmean(np.sqrt(list_values**2)),
    ]

In [16]:
X_dwt = []

for segment_idx in range(X_arr.shape[0]):
    features = []
    decomposed = pywt.wavedec(X_arr[segment_idx, :, :], "db4", axis=-1, level=4)
    for dec in decomposed:
        features.extend(calculate_statistics(dec))
    
    X_dwt.append(features)

X_dwt = np.array(X_dwt)
X_dwt.shape

(670, 45)

In [18]:
X_train, X_test, y_train, y_test = train_test_split(X_dwt, y_arr, test_size=0.33, stratify=y)
X_train.shape, X_test.shape, y_train.shape, y_test.shape

((448, 45), (222, 45), (448,), (222,))

### DWT+Random Forest

In [19]:
clf = RandomForestClassifier()
clf.fit(X_train, y_train)

In [20]:
y_pred = clf.predict(X_test)
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       0.91      0.80      0.86        92
           1       0.87      0.95      0.91       130

    accuracy                           0.89       222
   macro avg       0.89      0.88      0.88       222
weighted avg       0.89      0.89      0.89       222



### DWT+ANN

In [21]:
import tensorflow as tf

2024-08-30 20:32:08.493741: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-30 20:32:08.530580: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-30 20:32:08.541163: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-30 20:32:08.625026: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [22]:
model = tf.keras.Sequential([
    tf.keras.layers.InputLayer((45,)),
    tf.keras.layers.Dense(1024, activation="relu"),
    tf.keras.layers.Dense(1024, activation="relu"),
    tf.keras.layers.Dense(1, activation="sigmoid"),
])

model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])
model.summary()

I0000 00:00:1725024733.941361  150872 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1725024734.193661  150872 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1725024734.193734  150872 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1725024734.195842  150872 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1725024734.195900  150872 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:0

In [23]:
model.fit(X_train, y_train, batch_size=16, epochs=100, validation_data=(X_test, y_test))

Epoch 1/100


I0000 00:00:1725024754.370644  240997 service.cc:146] XLA service 0x7f6af0005310 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1725024754.370806  240997 service.cc:154]   StreamExecutor device (0): NVIDIA GeForce RTX 3060, Compute Capability 8.6
2024-08-30 20:32:34.421815: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-08-30 20:32:34.597181: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8907


[1m 1/28[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:25[0m 3s/step - accuracy: 0.6875 - loss: 1221.1353

I0000 00:00:1725024756.752220  240997 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.



[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 69ms/step - accuracy: 0.5379 - loss: 4485.9434 - val_accuracy: 0.5856 - val_loss: 1539.8015
Epoch 2/100
[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.5393 - loss: 838.4384 - val_accuracy: 0.4775 - val_loss: 415.5085
Epoch 3/100
[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - accuracy: 0.5487 - loss: 206.7712 - val_accuracy: 0.6171 - val_loss: 39.3782
Epoch 4/100
[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.5403 - loss: 112.6592 - val_accuracy: 0.5856 - val_loss: 122.0910
Epoch 5/100
[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.5573 - loss: 96.4391 - val_accuracy: 0.5901 - val_loss: 222.6428
Epoch 6/100
[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.5729 - loss: 135.2994 - val_accuracy: 0.6216 - val_loss: 38.5201
Epoch 7/100
[1m28/28

<keras.src.callbacks.history.History at 0x7f6bbbf3d070>

### CNN1D

In [24]:
X_train, X_test, y_train, y_test = train_test_split(X_arr, y_arr, test_size=0.33, stratify=y)
X_train.shape, X_test.shape, y_train.shape, y_test.shape

((448, 16, 2560), (222, 16, 2560), (448,), (222,))

In [25]:
X_train = np.moveaxis(X_train, 1, 2)
X_test = np.moveaxis(X_test, 1, 2)

In [26]:
np.unique(y, return_counts=True)

(array([0, 1]), array([277, 393]))

In [27]:
inputs = tf.keras.layers.Input(shape=(X_train.shape[1], X_train.shape[2],))
# x = tf.keras.layers.Conv1D(32, kernel_size=8, strides=2, activation="relu", use_bias=False)(inputs)
x = tf.keras.layers.Conv1D(32, kernel_size=6, strides=2, activation="relu", use_bias=False)(inputs)

x = tf.keras.layers.GlobalAveragePooling1D()(x)
x = tf.keras.layers.Dense(1024, activation="relu")(x)
outputs = tf.keras.layers.Dense(1, activation="sigmoid")(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()

In [28]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    # loss=tf.keras.losses.BinaryFocalCrossentropy(),
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[
      tf.keras.metrics.TruePositives(name='tp'),
      tf.keras.metrics.FalsePositives(name='fp'),
      tf.keras.metrics.TrueNegatives(name='tn'),
      tf.keras.metrics.FalseNegatives(name='fn'),
      tf.keras.metrics.BinaryAccuracy(name='accuracy'),
    #   tf.keras.metrics.Precision(name='precision'),
    #   tf.keras.metrics.Recall(name='recall'),
    #   tf.keras.metrics.AUC(name='auc'),
    #   tf.keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
])

In [29]:
history = model.fit(X_train, y_train, batch_size=16, epochs=100, validation_data=(X_test, y_test))

Epoch 1/100
[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 66ms/step - accuracy: 0.5051 - fn: 52.5517 - fp: 60.6552 - loss: 1.6855 - tn: 41.3793 - tp: 84.8621 - val_accuracy: 0.4820 - val_fn: 110.0000 - val_fp: 5.0000 - val_loss: 0.8170 - val_tn: 87.0000 - val_tp: 20.0000
Epoch 2/100
[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.6067 - fn: 33.9310 - fp: 54.4138 - loss: 0.6259 - tn: 48.8276 - tp: 102.2759 - val_accuracy: 0.6802 - val_fn: 1.0000 - val_fp: 70.0000 - val_loss: 0.5937 - val_tn: 22.0000 - val_tp: 129.0000
Epoch 3/100
[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.7045 - fn: 21.4483 - fp: 46.8966 - loss: 0.5492 - tn: 57.5862 - tp: 113.5172 - val_accuracy: 0.6892 - val_fn: 0.0000e+00 - val_fp: 69.0000 - val_loss: 0.5430 - val_tn: 23.0000 - val_tp: 130.0000
Epoch 4/100
[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.7092 - fn: 13.2759 - fp: 53

In [30]:
model.evaluate(X_test, y_test)

[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 125ms/step - accuracy: 0.9271 - fn: 4.7500 - fp: 5.6250 - loss: 0.8694 - tn: 51.2500 - tp: 77.8750


[0.7833694815635681, 124.0, 9.0, 83.0, 6.0, 0.9324324131011963]