In [12]:
import os
import numpy as np
import h5py
from scipy.signal import resample
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

In [13]:
import mne
from mne.preprocessing import ICA
from mne_icalabel import label_components

In [14]:
SAMPLE_RATE = 128  # fs
# SAMPLE_LEN = 1.0   # sample seconds
# OVERLAPPING = 0.8  # overlapping seconds
sub_folder_path = str(SAMPLE_RATE) + 'Hz'
sub_folder_path

'128Hz'

In [15]:
# root dir
root = 'P-ADIC/'

In [16]:
sub_id = 1
for file_name in os.listdir(root):
    if file_name.endswith(".mat"):
        file_path = os.path.join(root, file_name)
        print(f"Read file: {file_name}")
        per_class_sub_num = 0
        # open v7.3 MATLAB file
        with h5py.File(file_path, "r") as f:
            print("Keys in the file:", list(f.keys()))
            class_group = f["#refs#"]
            # Iterate through the groups and datasets
            class_group_keys = list(class_group.keys())
            # print("Keys in this group:", class_group_keys)
            for key in class_group_keys:
                data = class_group[key]
                # check the shape of last dimension to find the data (19 channels)
                if data.shape[-1] == 19:
                    print("Subject ID:", sub_id)
                    sub_id += 1
                    per_class_sub_num += 1
                    print("---------------------")
        print(f"Subject number in this class({file_name}) is {per_class_sub_num}")
        print("------------------------------------------------\n")

Read file: alz_c1_new.mat
Keys in the file: ['#refs#', 'alz_r']
Subject ID: 1
---------------------
Subject ID: 2
---------------------
Subject ID: 3
---------------------
Subject ID: 4
---------------------
Subject ID: 5
---------------------
Subject ID: 6
---------------------
Subject ID: 7
---------------------
Subject ID: 8
---------------------
Subject ID: 9
---------------------
Subject ID: 10
---------------------
Subject ID: 11
---------------------
Subject ID: 12
---------------------
Subject ID: 13
---------------------
Subject ID: 14
---------------------
Subject ID: 15
---------------------
Subject ID: 16
---------------------
Subject ID: 17
---------------------
Subject ID: 18
---------------------
Subject ID: 19
---------------------
Subject ID: 20
---------------------
Subject ID: 21
---------------------
Subject ID: 22
---------------------
Subject ID: 23
---------------------
Subject ID: 24
---------------------
Subject ID: 25
---------------------
Subject ID: 26
-----

In [17]:
labels = np.empty(shape=(sub_id-1,2), dtype='int32')  # total number of subject is 230
labels.shape

(249, 2)

In [18]:
def auto_artifact_removal_iclabel_to_numpy(
    eeg_data: np.ndarray,
    sfreq: float,
    ch_names: list,
    resample_sfreq=128,
    verbose=True
):
    """
    Clean EEG data using bandpass filtering, percentile-based bad channel detection,
    ICA + ICLabel artifact removal, resampling, re-referencing, epoching, and z-score normalization.

    Args:
        eeg_data (np.ndarray): EEG data, shape (T, C).
        sfreq (float): Original sampling frequency.
        ch_names (list): List of channel names.
        resample_sfreq (float): Target sampling frequency.
        verbose (bool): Verbose output.

    Returns:
        np.ndarray: Cleaned, normalized EEG data, shape (n_epochs, time_steps, channels).
    """
    # 1. Construct MNE Raw object
    info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=['eeg'] * len(ch_names))
    raw = mne.io.RawArray(eeg_data.T, info)

    # 2. Set Montage
    raw.set_montage(mne.channels.make_standard_montage('standard_1020'))
    if verbose:
        print("✔ Montage set: 'standard_1020'.")

    # 3. Bandpass Filter (0.5–45 Hz)
    raw.filter(l_freq=0.5, h_freq=45.0, verbose=False)
    if verbose:
        print("✔ Bandpass filter applied (0.5–45 Hz).")

    # 4. Set average reference for ICA
    raw.set_eeg_reference('average', projection=False)
    if verbose:
        print("✔ EEG re-referenced (average) before ICA.")

    # 5. ICA + ICLabel
    raw_ica = raw.copy()
    ica = ICA(n_components=0.99, random_state=97, max_iter='auto')
    ica.fit(raw_ica)
    if verbose:
        print("✔ ICA fitted.")

    try:
        ic_labels = label_components(raw_ica, ica, method='iclabel')
        labels = ic_labels['labels']
        probs = ic_labels['y_pred_proba']

        artifact_thresholds = {
            'eye blink': 0.7,
            'muscle artifact': 0.6,
            'heart beat': 0.5,
            'line noise': 0.8,
            'channel noise': 0.9
        }

        to_exclude = [
            i for i, label in enumerate(labels)
            if label in artifact_thresholds and probs[i] >= artifact_thresholds[label]
        ]
        if to_exclude:
            ica.exclude = to_exclude
            ica.apply(raw_ica)
            if verbose:
                print(f"✔ ICA applied. Excluded components: {to_exclude}")
        else:
            if verbose:
                print("No ICs exceeded artifact thresholds. No components excluded.")

    except Exception as e:
        if verbose:
            print(f"⚠ ICLabel failed: {e}. Proceeding without ICA-based removal.")

    # 6. Resample
    raw_ica.resample(resample_sfreq, npad="auto")
    if verbose:
        print(f"✔ Resampled to {resample_sfreq} Hz.")

    return raw_ica.get_data().T

In [19]:
feature_path = 'Processed/' + sub_folder_path + '/P-ADIC/Feature'
if not os.path.exists(feature_path):
    os.makedirs(feature_path)

label_map = {'controls':0, 'alz':1, 'mci':2, 'dep':3, 'schiz':4}

RAW_SAMPLING_RATE = 500  # Original sampling rate, according to the readme file
sub_id = 1
feature_list = []
for file_name in os.listdir(root):
    if file_name.endswith(".mat"):
        file_path = os.path.join(root, file_name)
        print(f"Read file: {file_name}")
        class_label = -1
        for disease_class in label_map.keys():
            if disease_class in file_name:
                print(f"Class: {disease_class}, Label: {label_map[disease_class]}")
                class_label = label_map[disease_class]
        # count the number of subjects per class
        per_class_sub_num = 0
        
        # open v7.3 MATLAB file
        with h5py.File(file_path, "r") as f:
            print("Keys in the file:", list(f.keys()))
            class_group = f["#refs#"]
            # Iterate through the groups and datasets
            class_group_keys = list(class_group.keys())
            # print("Keys in this group:", class_group_keys)
            for key in class_group_keys:
                data = class_group[key]
                # check the shape of last dimension to find the data (19 channels)
                if data.shape[-1] == 19:
                    print("Subject ID:", sub_id)
                    print("Data shape:", data.shape)
                    # Convert to numpy array
                    data = np.array(data)
                    # remove the artifacts
                    ch_names = ['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'T3', 'C3', 'Cz', 'C4', 'T4', 'T5', 'P3', 'Pz', 'P4', 'T6', 'O1', 'O2']
                    data = auto_artifact_removal_iclabel_to_numpy(data, RAW_SAMPLING_RATE, ch_names)
                    # Check the shape of the data
                    print("Downsampling data shape ", data.shape)
                    np.save(feature_path + '/feature_{:02d}.npy'.format(sub_id), data)
                    labels[sub_id-1,0] = class_label  # sub_id start from 1, need to -1 for indexing
                    labels[sub_id-1,1] = sub_id
                    per_class_sub_num += 1
                    sub_id += 1
                    print("---------------------")
        print(f"Subject number in this class is {per_class_sub_num}")
        print("------------------------------------------------\n")

Read file: alz_c1_new.mat
Class: alz, Label: 1
Keys in the file: ['#refs#', 'alz_r']
Subject ID: 1
Data shape: (680502, 19)
Creating RawArray with float64 data, n_channels=19, n_times=680502
    Range : 0 ... 680501 =      0.000 ...  1361.002 secs
Ready.
✔ Montage set: 'standard_1020'.
✔ Bandpass filter applied (0.5–45 Hz).
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
✔ EEG re-referenced (average) before ICA.
Fitting ICA to data using 19 channels (please be patient, this may take a while)
Selecting by explained variance: 15 components
Fitting ICA took 9.0s.
✔ ICA fitted.
No ICs exceeded artifact thresholds. No components excluded.
✔ Resampled to 128 Hz.
Downsampling data shape  (174209, 19)
---------------------
Subject ID: 2
Data shape: (506202, 19)
Creating RawArray with float64 data, n_channels=19, n_times=506202
    Range : 0 ... 506201 =      0.000 ...  1012.402 secs
Ready.
✔ Montage set: 'standard_1020'.
✔ Bandpass

In [20]:
label_path = 'Processed/' + sub_folder_path + '/P-ADIC/Label'
if not os.path.exists(label_path):
    os.makedirs(label_path)
np.save(label_path + '/label.npy', labels)

In [21]:
np.load(label_path + '/label.npy')

array([[  1,   1],
       [  1,   2],
       [  1,   3],
       [  1,   4],
       [  1,   5],
       [  1,   6],
       [  1,   7],
       [  1,   8],
       [  1,   9],
       [  1,  10],
       [  1,  11],
       [  1,  12],
       [  1,  13],
       [  1,  14],
       [  1,  15],
       [  1,  16],
       [  1,  17],
       [  1,  18],
       [  1,  19],
       [  1,  20],
       [  1,  21],
       [  1,  22],
       [  1,  23],
       [  1,  24],
       [  1,  25],
       [  1,  26],
       [  1,  27],
       [  1,  28],
       [  1,  29],
       [  1,  30],
       [  1,  31],
       [  1,  32],
       [  1,  33],
       [  1,  34],
       [  1,  35],
       [  1,  36],
       [  1,  37],
       [  1,  38],
       [  1,  39],
       [  1,  40],
       [  1,  41],
       [  1,  42],
       [  1,  43],
       [  1,  44],
       [  1,  45],
       [  1,  46],
       [  1,  47],
       [  1,  48],
       [  1,  49],
       [  0,  50],
       [  0,  51],
       [  0,  52],
       [  0,

In [22]:
# Test the saved npy file
# example

path = feature_path

total_length = 0
for file in os.listdir(path):
    sub_path = os.path.join(path, file)
    print(np.load(sub_path).shape)
    total_length += np.load(sub_path).shape[0]
print("\nTotal length:", total_length)

(174209, 19)
(129588, 19)
(159629, 19)
(166055, 19)
(323201, 19)
(155444, 19)
(106881, 19)
(66765, 19)
(156519, 19)
(123393, 19)
(137985, 19)
(57089, 19)
(42497, 19)
(55245, 19)
(155226, 19)
(58420, 19)
(133889, 19)
(299649, 19)
(69172, 19)
(137729, 19)
(192001, 19)
(164660, 19)
(76545, 19)
(115815, 19)
(136257, 19)
(134657, 19)
(164289, 19)
(288001, 19)
(155905, 19)
(141236, 19)
(54017, 19)
(170753, 19)
(93492, 19)
(169601, 19)
(66100, 19)
(153985, 19)
(67482, 19)
(165889, 19)
(172801, 19)
(166529, 19)
(276609, 19)
(147073, 19)
(161153, 19)
(68199, 19)
(67073, 19)
(260532, 19)
(69121, 19)
(52429, 19)
(137997, 19)
(149377, 19)
(238964, 19)
(127361, 19)
(154625, 19)
(154497, 19)
(68865, 19)
(166311, 19)
(312577, 19)
(157441, 19)
(48231, 19)
(139162, 19)
(115751, 19)
(124442, 19)
(105601, 19)
(117940, 19)
(125569, 19)
(156007, 19)
(62055, 19)
(103437, 19)
(163610, 19)
(142081, 19)
(143745, 19)
(135271, 19)
(57652, 19)
(165377, 19)
(46132, 19)
(125185, 19)
(160487, 19)
(63898, 19)
(126298