In [1]:
from modules.constants import *

In [None]:
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

create_dir(train_dir)
create_dir(val_dir)

for i in range(1, num_sub+1):
    for j in range(1, num_sess+1):
        data_path = f'../ds003774/sub-0{i//10}{i%10}/ses-{j//10}{j%10}/eeg/sub-0{i//10}{i%10}_ses-{j//10}{j%10}_task-MusicListening_run-{j}_eeg.set'
        raw = read_raw_eeglab(data_path, preload=True)

        # High-pass filter at 0.2 Hz
        raw.filter(l_freq=0.2, h_freq=None)

        # Remove 50 Hz line noise
        raw.notch_filter(freqs=[50])

        # Downsample the data to 256 Hz
        raw.resample(256)

        # Extract EEG data and calculate PSD using Welch's method
        picks = pick_types(raw.info, eeg=True, exclude=[])
        data, times = raw.get_data(picks=picks, return_times=True)
        psds, freqs = psd_array_welch(data, sfreq=raw.info['sfreq'], fmin=2, fmax=40)

        # Calculate the mean and threshold for PSD
        psd_mean = psds.mean(axis=-1)
        psd_threshold = 3 * np.std(psds, axis=-1)  # Calculate the standard deviation along the frequency axis

        # Identify bad channels based on spectral criteria
        bad_channels = [raw.ch_names[p] for p in picks if psd_mean[p] > psd_threshold[p]]
        raw.info['bads'] += bad_channels
        raw.interpolate_bads()

        # Artifact rejection using ICA
        ica = ICA(n_components=20, random_state=99, method='fastica')
        ica.fit(raw)
        ica.apply(raw)

        # Re-reference the data to the average
        raw.set_eeg_reference('average', projection=True)

        # Save preprocessed data
        pre_path = f'pre_eeg_sub-0{i//10}{i%10}_ses-{j//10}{j%10}_eeg.fif'
        if i <= 16:
            pre_path = os.path.join(train_dir, pre_path)
        else:
            pre_path = os.path.join(val_dir, pre_path)
        raw.save(pre_path, overwrite=True)