In [21]:
from mne.io import read_raw_edf
import numpy as np
import os
import pandas as pd
import scipy.signal

In [22]:
''' Create dataset '''
main_path = "/Users/bryanmcelvy/Documents/physionet.org/files/chbmit/1.0.0"
data = pd.DataFrame()
channels = ["T7-P7", "T8-P8-0"]
file = read_raw_edf(input_fname = f"{main_path}/chb01/chb01_03.edf", preload=False, verbose='ERROR')
fs = int(file.info['sfreq'])

seizure_idx = pd.read_csv("seizure_idx.csv")
patient = "chb01"
filename_list = [name for name in seizure_idx["filename"].unique()]


In [23]:
''' Filtering '''
fl = 30 # fc for LPF
fh = 1 # fc for HPF

b_lpf, a_lpf = scipy.signal.butter(N=15, Wn=fl, fs=fs, btype='low', analog=False) # low-pass filter coefficients
b_hpf, a_hpf = scipy.signal.butter(N=1, Wn=fh, fs=fs, btype='high', analog=False) # high-pass filter coefficients

In [24]:
''' Import data '''
for fname in filename_list:
    start_end_idx = seizure_idx.loc[seizure_idx["filename"] == fname, ["start", "end"]].astype(int).multiply(fs).values
    
    # Find row indices for start of 1-second windows for class 1
    window_1_idx = np.array([], dtype=int) # Starting idx of each 1-second window for class 1
    for [n_start, n_end] in start_end_idx:
        window_1_idx = np.concatenate([window_1_idx, np.arange(start=n_start, stop=n_end, step=fs, dtype=int)])
        
    # Load data
    file = read_raw_edf(input_fname=f"{main_path}/{fname[:5]}/{fname}", preload=True, verbose='ERROR')
    
    # Omit files without selected channels
    channel_idx = [] # Idx of selected channels
    for channel in channels:
        if (channel not in file.ch_names):
            hasChannels = False
            break
        else:
            hasChannels = True
            channel_idx.append(file.ch_names.index(channel)) 
    if not hasChannels: 
        print(f"Skipping {fname}")
        continue
    
    # Find row indexes for start of 1-second windows for class 1
    rng = np.random.default_rng(seed=42)
    window_0_idx = np.array([idx for idx in np.arange(file.n_times, step=fs, dtype=int) if idx not in window_1_idx], dtype=int)
    
    # Balance dataset via random sampling
    if len(window_0_idx) > len(window_1_idx):
        window_0_idx = np.sort(rng.choice(a=window_0_idx, size=len(window_1_idx), replace=False))
    elif len(window_0_idx) < len(window_1_idx):
         window_1_idx = np.sort(rng.choice(a=window_1_idx,size=len(window_0_idx), replace=False))
    
    temp_state = np.concatenate([[0] * len(window_0_idx) * fs, [1] * len(window_1_idx) * fs], dtype=int)
    temp_fname = [fname] * len(temp_state)
    temp_df = pd.DataFrame(data={"filename":temp_fname,
                                    "state":temp_state})
    
    # Load labeled raw dataset, filter, and add to DataFrame
    for ch_idx in channel_idx:
        raw_data = file.get_data()[ch_idx]
        data_filt = scipy.signal.filtfilt(b_lpf, a_lpf, raw_data) # Apply low-pass filter
        data_filt = scipy.signal.filtfilt(b_hpf, a_hpf, data_filt) # Apply high-pass filter
        temp_v = np.concatenate(
            [np.array([data_filt[i:i+fs] for i in window_0_idx]), np.array([data_filt[i:i+fs] for i in window_1_idx])], 
            axis=None, dtype=np.float32) * 1e6
        temp_df[file.ch_names[ch_idx]] = temp_v
    
    data = pd.concat([data, temp_df])
    

Skipping chb12_27.edf
Skipping chb12_28.edf
Skipping chb12_29.edf
Skipping chb13_40.edf
Skipping chb16_18.edf


In [None]:
''' Save to CSV'''
data.to_csv(path_or_buf=f"datasets/dataset_all_filt.csv", index=False)