In [None]:
import os
import random
import numpy as np
import pandas as pd
import mne
import torch
from tqdm import tqdm

data_folder = './v3.0.1/edf/01_tcp_ar'
features_path = './features'
clean_eeg_path = os.path.join(features_path, 'clean_eeg')
random_seg_path = os.path.join(features_path, 'random_segment')
csv_ext = '.csv'
edf_ext = '.edf'

n_random_segments = 5000
segment_samples = 5000

channel_mapping = {
    'EEG FP1-REF': 'FP1', 'EEG FP2-REF': 'FP2',
    'EEG F3-REF': 'F3', 'EEG F4-REF': 'F4',
    'EEG C3-REF': 'C3', 'EEG C4-REF': 'C4',
    'EEG P3-REF': 'P3', 'EEG P4-REF': 'P4',
    'EEG O1-REF': 'O1', 'EEG O2-REF': 'O2',
    'EEG F7-REF': 'F7', 'EEG F8-REF': 'F8',
    'EEG T3-REF': 'T3', 'EEG T4-REF': 'T4',
    'EEG T5-REF': 'T5', 'EEG T6-REF': 'T6',
    'EEG FZ-REF': 'FZ', 'EEG CZ-REF': 'CZ',
    'EEG PZ-REF': 'PZ'
}


def get_annotations_for_eeg_artifacts(path_root, file_name):
    return pd.read_csv(os.path.join(path_root, file_name), skiprows=6)


def extract_clean_segments_from_eeg(edf_path, annotations_df,
                                   segment_duration=5.0,
                                   max_segments_per_interval=None):
    raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')

    chans = [ch for ch in raw.ch_names if ch in channel_mapping]
    raw.pick(chans)
    raw.rename_channels(channel_mapping)

    sfreq = raw.info['sfreq']
    total_dur = raw.times[-1]

    intervals = sorted([(row['start_time'], row['stop_time'])
                        for _, row in annotations_df.iterrows()])
    clean_intervals = []
    prev_end = 0.0
    for st, sp in intervals:
        if st > prev_end:
            clean_intervals.append((prev_end, st))
        prev_end = max(prev_end, sp)
    if prev_end < total_dur:
        clean_intervals.append((prev_end, total_dur))

    os.makedirs(clean_eeg_path, exist_ok=True)
    for ci_start, ci_stop in clean_intervals:
        n_segs = int((ci_stop - ci_start) // segment_duration)
        if n_segs <= 0:
            continue
        starts = [ci_start + i * segment_duration for i in range(n_segs)]
        if max_segments_per_interval:
            starts = random.sample(starts, min(max_segments_per_interval, len(starts)))
        for s in starts:
            e = s + segment_duration
            s_smpl, e_smpl = int(s*sfreq), int(e*sfreq)
            data = raw.get_data(start=s_smpl, stop=e_smpl)
            fname = f"clean_{s:.2f}_{e:.2f}.pt"
            torch.save(torch.from_numpy(data), os.path.join(clean_eeg_path, fname))


def extract_random_segments_from_eeg(edf_path,
                                     n_segments=n_random_segments,
                                     segment_samples=segment_samples):
    raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')

    chans = [ch for ch in raw.ch_names if ch in channel_mapping]
    raw.pick(chans)
    raw.rename_channels(channel_mapping)

    total_samples = raw.n_times
    if total_samples < segment_samples:
        raise ValueError(f"Archivo con {total_samples} muestras < {segment_samples} requeridas.")

    os.makedirs(random_seg_path, exist_ok=True)
    for i in range(n_segments):
        channel = random.choice(raw.ch_names)
        start = random.randint(0, total_samples - segment_samples)
        stop = start + segment_samples
        data = raw.get_data(picks=[channel], start=start, stop=stop)
        signal = data.flatten()  # vector 1D (segment_samples,)
        fname = f"random_{channel}_{start}_{stop}_{i}.pt"
        torch.save(signal, os.path.join(random_seg_path, fname))


def main():
    os.makedirs(clean_eeg_path, exist_ok=True)
    os.makedirs(random_seg_path, exist_ok=True)

    csv_files = [f for f in os.listdir(data_folder) if f.endswith(csv_ext)]
    edf_bases = {os.path.splitext(f)[0] for f in os.listdir(data_folder) if f.endswith(edf_ext)}
    common = sorted(set(os.path.splitext(f)[0] for f in csv_files) & edf_bases)

    for base in tqdm(common, desc="Procesando EEG"):
        ann_df = get_annotations_for_eeg_artifacts(data_folder, base + csv_ext)
        edf_path = os.path.join(data_folder, base + edf_ext)

        extract_clean_segments_from_eeg(edf_path, ann_df)
        extract_random_segments_from_eeg(edf_path)

if __name__ == '__main__':
    main()


Procesando EEG:   0%|▏                                                               | 1/290 [00:30<2:29:02, 30.94s/it]


KeyboardInterrupt: 