In [1]:
import mne
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import scipy.signal
import torch
from tqdm import tqdm
%matplotlib qt

In [2]:
data_folder = './v3.0.1/edf/02_tcp_le/'
csv_extension = '.csv'
edf_extension = ".edf"

In [3]:
csv_file_names = [f for f in os.listdir(data_folder) if '.ipynb_checkpoints' not in f and f.endswith('.csv')]
edf_file_names = [f for f in os.listdir(data_folder) if '.ipynb_checkpoints' not in f and f.endswith('.edf')]

In [4]:
csv_base_names = {os.path.splitext(f)[0] for f in csv_file_names}
edf_base_names = {os.path.splitext(f)[0] for f in edf_file_names}

if csv_base_names == edf_base_names:
    print("All .csv files have corresponding .edf files and vice versa.")
else:
    print("There are mismatches between .csv and .edf files.")
    missing_in_csv = edf_base_names - csv_base_names
    missing_in_edf = csv_base_names - edf_base_names
    if missing_in_csv:
        print(f"Files missing in .csv: {missing_in_csv}")
    if missing_in_edf:
        print(f"Files missing in .edf: {missing_in_edf}")

All .csv files have corresponding .edf files and vice versa.


In [5]:
def get_annotations_for_eeg_artifacts(path_root, file_name):
    annotations_file = os.path.join(path_root, file_name)
    df = pd.read_csv(annotations_file, skiprows=6)
    df["channel_original"] = df["channel"]
    df["channel_anode"] = df["channel"].apply(
        lambda x: x.split('-')[0] if isinstance(x, str) and '-' in x else x
    )
    df["channel_cathode"] = df["channel"].apply(
        lambda x: x.split('-')[1] if isinstance(x, str) and '-' in x else None
    )
    return df

In [6]:
channel_mapping = {
    'EEG FP1-LE': 'FP1', 'EEG FP2-LE': 'FP2',
    'EEG F3-LE': 'F3', 'EEG F4-LE': 'F4',
    'EEG C3-LE': 'C3', 'EEG C4-LE': 'C4',
    'EEG P3-LE': 'P3', 'EEG P4-LE': 'P4',
    'EEG O1-LE': 'O1', 'EEG O2-LE': 'O2',
    'EEG F7-LE': 'F7', 'EEG F8-LE': 'F8',
    'EEG T3-LE': 'T3', 'EEG T4-LE': 'T4',
    'EEG T5-LE': 'T5', 'EEG T6-LE': 'T6',
    'EEG FZ-LE': 'FZ', 'EEG CZ-LE': 'CZ',
    'EEG PZ-LE': 'PZ'
}

In [None]:
features_path = './features'
time_domain_features_path = os.path.join(features_path, 'time_domain')

def extract_artifact_from_eeg(edf_path, df):
    raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
    raw.filter(1., 50., fir_design='firwin')
    channels_to_keep = [ch for ch in raw.ch_names if ch in channel_mapping]
    raw.pick(channels_to_keep)
    raw.rename_channels(channel_mapping)

    for i, row in df.iterrows():
        anode = row["channel_anode"]
        cathode = row["channel_cathode"]

        if anode not in raw.ch_names or cathode not in raw.ch_names:
            continue

        sfreq = raw.info["sfreq"]
        start_sample = int(row["start_time"] * sfreq)
        stop_sample = int(row["stop_time"] * sfreq)

        data_anode = raw.get_data(picks=anode, start=start_sample, stop=stop_sample)
        data_cathode = raw.get_data(picks=cathode, start=start_sample, stop=stop_sample)

        data_diff = data_anode - data_cathode

        label_folder = row["label"]
        folder_path = os.path.join(time_domain_features_path, label_folder)
        os.makedirs(folder_path, exist_ok=True)

        filename = f"{row['channel_original']}_{row['start_time']}_{row['stop_time']}.pt"
        filepath = os.path.join(folder_path, filename)

        torch.save(data_diff.flatten(), filepath)

In [8]:
for file in tqdm(csv_file_names, desc="Processing EEG files"):
    eeg_file = file.split('.')
    df = get_annotations_for_eeg_artifacts(data_folder, eeg_file[0] + '.csv')
    extract_artifact_from_eeg(os.path.join(data_folder, eeg_file[0] + '.edf'), df)

Processing EEG files: 100%|████████████████████████████████████████████████████████████| 13/13 [00:14<00:00,  1.11s/it]
