In [1]:
import os
import glob
import numpy as np
import mne
from scipy.signal import stft
import re

def extractTarget(summary_file_path, edf_file_path):
    edf_file_name = os.path.basename(edf_file_path)
    seizure_intervals = []

    with open(summary_file_path, 'r') as file:
        lines = file.readlines()

    found = False
    for line in lines:
        if f"File Name: {edf_file_name}" in line:
            found = True
        elif found and line.startswith("File Name:"):
            break  

        if found:
            if "Number of Seizures in File: 0" in line:
                return []  

            start_match = re.search(r"Seizure\s*\d*\s*Start Time:\s*(\d+)", line)
            if start_match:
                seizure_start = int(start_match.group(1))
                seizure_intervals.append([seizure_start, None])  

            end_match = re.search(r"Seizure\s*\d*\s*End Time:\s*(\d+)", line)
            if end_match and seizure_intervals and seizure_intervals[-1][1] is None:
                seizure_end = int(end_match.group(1))
                seizure_intervals[-1][1] = seizure_end  


    seizure_intervals = [interval for interval in seizure_intervals if None not in interval]

    return seizure_intervals

In [2]:
from collections import Counter

def get_common_channels_for_patient(edf_files):
    channel_lists = []

    for edf_file in edf_files:
        try:
            raw = mne.io.read_raw_edf(edf_file, preload=False, verbose=False)
            valid_channels = [
                ch for ch in raw.ch_names
                if ch != '-' and not ch.startswith('--') and not ch.startswith('.-') and
                   'ECG' not in ch.upper() and
                   'EKG' not in ch.upper() and
                   'CHIN' not in ch.upper() and
                   'LOC' not in ch.upper() and
                   'ROC' not in ch.upper() and
                   'VNS' not in ch.upper() and
                   'LUE-RAE' not in ch.upper() and
                   'EKG1-EKG2' not in ch.upper() and
                   'FC1-Ref' not in ch.upper() and
                   'FC2-Ref' not in ch.upper() and
                   'FC5-Ref' not in ch.upper() and
                   'FC6-Ref' not in ch.upper() and
                   'CP1-Ref' not in ch.upper() and
                   'CP2-Ref' not in ch.upper() and
                   'CP5-Ref' not in ch.upper() and
                   'CP6-Ref' not in ch.upper() 
            ]
            channel_lists.append(valid_channels)
        except Exception as e:
            print(f"Mistake in {edf_file}: {e}")

    if not channel_lists:
        return []

    counters = [Counter(ch_list) for ch_list in channel_lists]

    common_channels_counts = counters[0].copy()
    for counter in counters[1:]:
        for ch in list(common_channels_counts):
            if ch in counter:
                common_channels_counts[ch] = min(common_channels_counts[ch], counter[ch])
            else:
                del common_channels_counts[ch]

    common_channels = []
    for ch, count in common_channels_counts.items():
        common_channels.extend([ch] * count)

    return sorted(common_channels)

In [35]:
def preprocess_spectrogram_raw(file_name,target_channels, window_length=2.0, max_freq=45):

    raw = mne.io.read_raw_edf(file_name, preload=True, verbose=False)
    raw.filter(1., 45., fir_design='firwin')
    raw.pick_channels(target_channels)

    sfreq = raw.info['sfreq']
    window_samples = int(window_length * sfreq)
    noverlap = window_samples // 2

    eeg_data = raw.get_data()
    n_channels, _ = eeg_data.shape

    spectrograms = []
    timestamps = raw.times

    for ch in range(n_channels):
        ch_data = eeg_data[ch]
        f, t, Zxx = stft(
            ch_data,
            fs=sfreq,
            nperseg=window_samples,
            noverlap=noverlap,
            scaling='spectrum'
        )
        freq_mask = f <= max_freq
        Zxx = Zxx[freq_mask, :]
        power = np.abs(Zxx) ** 2
        spectrograms.append(power)

    freqs = f[freq_mask]
    spectrograms = np.stack(spectrograms, axis=-1)
    spectrograms = np.transpose(spectrograms, (1, 0, 2))  

    timestamps = t  

    freqs = f[freq_mask]

    return spectrograms,timestamps, freqs

In [36]:
def process_patient_folder(patient_folder, summary_file_path):
    edf_files = sorted(glob.glob(os.path.join(patient_folder, "*.edf")))
    target_channels = get_common_channels_for_patient(edf_files)
    if not target_channels:
        print(f"alarm {patient_folder}")
    X_train_all, X_test_all, y_test_all = [], [], []

    for edf_file in edf_files:
        try:
            X, timestamps,freqs  = preprocess_spectrogram_raw(edf_file, target_channels)
            seizure_intervals = extractTarget(summary_file_path, edf_file)  
            print(seizure_intervals)
        except Exception as e:
            print(f"Mistake in {edf_file}: {e}")
            continue

        if not seizure_intervals:  
            X_train_all.append(X)
        else:
            y = np.zeros(len(timestamps), dtype=int)
            for start, end in seizure_intervals:
                y += (timestamps >= start) & (timestamps <= end)  
            y = (y > 0).astype(int)  
            X_test_all.append(X)
            y_test_all.append(y)

    X_train = np.concatenate(X_train_all, axis=0) if X_train_all else np.empty((0,))
    X_test = np.concatenate(X_test_all, axis=0) if X_test_all else np.empty((0,))
    y_test = np.concatenate(y_test_all, axis=0) if y_test_all else np.empty((0,))
    return X_train, X_test, y_test, freqs 

In [38]:
def average_over_bands(X, freqs):
    """
    X: ndarray of shape (n_windows, n_freqs, n_channels)
    freqs: ndarray of frequencies (n_freqs,)
    Returns: ndarray of shape (n_windows, n_bands, n_channels)
    """
    bands = {
        "delta": (1, 4),
        "theta": (4, 8),
        "alpha": (8, 13),
        "beta": (13, 30),
        "gamma": (30, 45)
    }

    band_features = []

    for name, (low, high) in bands.items():
        mask = (freqs >= low) & (freqs < high)
        band_power = np.mean(X[:, mask, :], axis=1) 
        band_features.append(band_power)

    return np.stack(band_features, axis=1)

In [37]:
from sklearn.preprocessing import StandardScaler

def prepare_concept(freqs, name, X_train, X_test, y_test):
    X_train_bands = average_over_bands(X_train, freqs)
    X_test_bands = average_over_bands(X_test, freqs)

    X_train_features = np.mean(X_train_bands, axis=2)  
    X_test_features = np.mean(X_test_bands, axis=2)
    
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train_features)
    X_test_scaled = scaler.transform(X_test_features)

    return X_train_scaled, X_test_scaled, y_test

In [None]:
base_dir = "/net/afscra/people/plgayahorava/physionet.org/files/chbmit/1.0.0"  
patients = ["chb01","chb02","chb03", "chb04","chb05","chb06","chb07","chb08","chb09","chb10",
"chb11","chb14",
"chb20","chb21","chb22", "chb23"]

concepts_for_npy = []

for patient in patients:
    folder = os.path.join(base_dir, patient)
    summary_file = os.path.join(folder, f"{patient}-summary.txt")

    X_train, X_test, y_test, freqs  = process_patient_folder(folder, summary_file)
    X_train_scaled, X_test_scaled,y_test = prepare_concept(freqs = freqs,name=patient, X_train=X_train, X_test=X_test, y_test=y_test)

    concept_dict = {
        "name": patient,
        "train_data": X_train_scaled,
        "test_data": X_test_scaled,
        "test_labels": y_test
    }
    concepts_for_npy.append(concept_dict)



In [None]:
concepts_array = np.array(concepts_for_npy, dtype=object)
np.save(os.path.join(base_dir, "concepts_freq_ch_mean.npy"), concepts_array)