In [26]:
import os
import mne
import glob
#
import numpy as np
import math
import matplotlib.pyplot as plt
%matplotlib inline

import copy
import pandas as pd
import torch 

from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F


In [5]:
DownPath = "C:\\Users\\TheSy\\Desktop\\dataverse_files"

# Pre-processing

In [24]:
SANE =[ "h01.edf",
        "h02.edf",
        "h03.edf",
        "h04.edf",
        "h05.edf",
        "h06.edf",
        "h07.edf",
        "h08.edf",
        "h09.edf",
        "h10.edf",
        "h11.edf",
        "h12.edf",
        "h13.edf",
        "h14.edf",
        ]

ABNORMAL = ["s01.edf",
            "s02.edf",
            "s03.edf",
            "s04.edf",
            "s05.edf",
            "s06.edf",
            "s07.edf",
            "s08.edf",
            "s09.edf",
            "s10.edf",
            "s11.edf",
            "s12.edf",
            "s13.edf",
            "s14.edf",
]


In [8]:
def channel_select(data, channels):
    '''
    Selects channels from array 
    '''
    extracted = data.pick(channels, exclude="bads")
    return extracted

def clip(data, channels,max= 500e-6):
    def cliper(array):
        for i in range(len(array)):
            if abs(array[i]) > max:
                array[i] = math.copysign(max,array[i])
        return array
    data.apply_function(cliper, picks=channels, channel_wise= True)

def eeg_filter(data, lfreq = 1, hfreq= 70):
    '''
    
    '''
    data_copy = copy.copy(data)
    filtered = data_copy.filter(#l_freq = lfreq,
                                l_freq = lfreq,
                                h_freq = hfreq,
                                method = "iir",
                                )
    return filtered

def temporal_crop(data, tin = 60, tfin = 12*60):
    ''' 
    Cut the channels from the second "tin" to "tfin"
    '''
    data_copy = copy.copy(data)
    croped = data_copy.crop(tmin = tin, tmax = min(tfin, int(data.times[-1])))
    return croped

def get_epochs(data, channels, window = 10):
    ''' 
    window es la ventana de tiempo
    '''
    data_copy = copy.copy(data)
    # Create events
    events = mne.make_fixed_length_events(data_copy, duration = window, first_samp = True)
    # Divide accordingly
    picks = channels
    epochs = mne.Epochs(raw = data_copy, events = events, picks = picks, preload = True,
                        tmin = 0., tmax = window, baseline = None,
                        flat = dict(eeg = 1e-6))
    
    epochs.drop(-1,reason = "Unfixed duration")
    return epochs

def downsample(epoch, freq = 100): # original 200
    ''' 
    Downsamples the data given by a factor
    En nuestro caso, down corresponde a (frecuencia que queremos)/(frecuencia actual)
    '''
    down = epoch.resample(freq, npad = "auto")
    return down

def normalization(epochs):
    obj = mne.decoding.Scaler(info = epochs.info, scalings='mean')
    values = obj.fit_transform(epochs.get_data())
    return values

def EDFprep(edf, n_channels = 19, norm = True, random = True, ):
    '''
    Pipeline
    '''

    #Random channel select
    channels = edf.ch_names
    if random:
        ch = np.random.choice(channels[:-3], size = n_channels, replace=False)
    else:
        ch = channels[:n_channels]
    
    channel_data = channel_select(edf, ch)
    filtered = eeg_filter(channel_data)
    clip(filtered, ch)
    trimmed_data = temporal_crop(filtered)
    re_ref = trimmed_data.copy().set_eeg_reference(ref_channels="average")
    down_data = downsample(re_ref)
    epochs = get_epochs(down_data,down_data.ch_names)

    if norm:
        norm_data = normalization(epochs)
        norm_data = np.delete(norm_data,-1,2)
        return norm_data 
    return epochs

In [9]:
def Save_win(data,loc_df, final_dir, patient_id,label, save = False):

    for i, win in enumerate(data):
        sdir = os.path.join(final_dir,f"{patient_id}_w{i+1}.pt")
        loc_df.loc[len(loc_df)] = [patient_id,i+1,sdir]

        win_save = torch.from_numpy(win).type(torch.FloatTensor)
        if save:
            torch.save((win_save, label),sdir)
    return loc_df

def Save_ch(data,loc_df, final_dir, patient_id,label, save = False):
    
    for i, win in enumerate(data):
        for j, ch in enumerate(win):
            # sdir = f"{save_dir}/{patient_id}/{patient_id}_{session_id}_w{i+1}_ch{j+1}.pt"
            sdir = os.path.join(final_dir, f"{patient_id}_w{i+1}_ch{j+1}.pt")
            loc_df.loc[len(loc_df)] = [patient_id,i+1,sdir]
            
            if save:
                ch_save = torch.from_numpy(ch).type(torch.FloatTensor)
                ch_save = ch_save.unsqueeze(dim = 0)
                torch.save((ch_save, label), sdir)
    return loc_df

In [15]:

def prep(path, save = False, mode = "per_win", save_dir = "data"):
    ''' 
    Lectura de todos los edfs de cada paciente, guardado de ventanas temporales y csv de direcciones

    Inputs:
        -path
    Output:
        -dir_csv: csv con todos los datos de guardado de las ventanas de cada edf.    
    
    '''
    folders = ["sane", "abnormal"]
    LEN_PAT = 7
    loc_df = pd.DataFrame(columns= ["Patient","N_Win", "Dir"], )
    save_dir = os.path.join(save_dir, mode)

    if save:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
            print("Data directory created :D")
    

    patient_path = glob.glob(path + "/*.edf" )
    for patient in patient_path:

        #Para guardar la id en el DF
        patient_id = patient[-LEN_PAT:]

        if patient_id in SANE:
            label = torch.zeros(1)
        if patient_id in ABNORMAL:
            label = torch.ones(1)
        
        patient_id = patient_id[:-4]
        
        folder = folders[int(label)]
        final_dir = os.path.join(save_dir, folder)

        if save:
            if not os.path.exists(final_dir):
                os.makedirs(final_dir)

        raw = mne.io.read_raw_edf(patient,preload=True)
        try:
            data = EDFprep(raw,random = False)
        except:
            print(f"{patient_id}_failed")
            continue

        if mode == "per_win":
            loc_df = Save_win(data,loc_df,final_dir,patient_id,label, save)
        elif mode == "per_channel":
            loc_df = Save_ch(data,loc_df,final_dir,patient_id,label, save)

    if save:
        if mode == "per_channel":
            loc_df.to_csv("down_prep_channels.csv", encoding= "utf-8" ,index = False)
        elif mode == "per_win":
            loc_df.to_csv("down_prep_windows.csv", encoding= "utf-8", index = False)
                
    return loc_df

In [16]:
prep(DownPath, True, "per_win", "test")

Data directory created :D
Extracting EDF parameters from C:\Users\TheSy\Desktop\dataverse_files\h01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 231249  =      0.000 ...   924.996 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 70 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 1.00, 70.00 Hz: -6.02, -6.02 dB

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Not setting metadata
66 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 66 events and 1001 original time points ...
1 bad epochs dropped
Dropped 1 epoch: 64
Extracting EDF parameters from C:\Users\TheSy\Desktop\dataverse_files\h02.edf...
EDF file detected
Setti

Unnamed: 0,Patient,N_Win,Dir
0,h01,1,test\per_win\sane\h01_w1.pt
1,h01,2,test\per_win\sane\h01_w2.pt
2,h01,3,test\per_win\sane\h01_w3.pt
3,h01,4,test\per_win\sane\h01_w4.pt
4,h01,5,test\per_win\sane\h01_w5.pt
...,...,...,...
1787,s14,60,test\per_win\abnormal\s14_w60.pt
1788,s14,61,test\per_win\abnormal\s14_w61.pt
1789,s14,62,test\per_win\abnormal\s14_w62.pt
1790,s14,63,test\per_win\abnormal\s14_w63.pt


# Datasets 

In [42]:
SANE_P =[ "h01",
        "h02",
        "h03",
        "h04",
        "h05",
        "h06",
        "h07",
        "h08",
        "h09",
        "h10",
        "h11",
        "h12",
        "h13",
        "h14",
        ]

ABNORMAL_P = ["s01",
            "s02",
            "s03",
            "s04",
            "s05",
            "s06",
            "s07",
            "s08",
            "s09",
            "s10",
            "s11",
            "s12",
            "s13",
            "s14",
]

In [44]:
class CustomEEGDataset(Dataset):
    def __init__(self, csv_file , root_dir , transform = None, ):

        try:
            self.loc_df = pd.read_csv(os.path.join(root_dir,csv_file)).drop(labels="Unnamed: 0", axis = 1)
        except:
            self.loc_df = pd.read_csv(os.path.join(root_dir,csv_file))

        self.transform = transform
        self.root_dir = root_dir

    def __len__(self,):
        return len(self.loc_df)
        
    def __getitem__(self, idx):

        eeg_file = os.path.join(self.root_dir,
                                self.loc_df.iloc[idx, 2])
        eeg = torch.load(eeg_file)
            
        if self.transform is not None:
            ch = self.transform(eeg[0])
            return ch, eeg[1]

        return eeg
    
class DFSpliter():
    def __init__(self, train_size= 0.8, val_size = 0.2, save = False, seed = 69, mode = "down_ch") -> None:
        self.train_size = train_size
        self.val_size = val_size
        self.save = save
        self.seed = seed
        self.mode = mode

    def __call__(self, csv_file, root_path):
        try:
            loc_df = pd.read_csv(os.path.join(root_path,csv_file)).drop(labels="Unnamed: 0", axis = 1)
        except:
            loc_df = pd.read_csv(os.path.join(root_path,csv_file))
        # loc_df = csv_file
        patients = loc_df["Patient"].unique()

        sanes = [pat for pat in patients if pat in SANE]
        abnormals = [pat for pat in patients if pat in ABNORMAL]


        np.random.seed(self.seed)

        np.random.shuffle(sanes)
        np.random.shuffle(abnormals)

        s_end_idx = round(len(sanes)*self.train_size)
        a_end_idx = round(len(abnormals)*self.train_size)

        train_patients = [*sanes[:s_end_idx], *abnormals[:a_end_idx]]
        val_patients = [*sanes[s_end_idx:], *abnormals[a_end_idx:]]
        
        train_df = pd.DataFrame()
        for patient in train_patients:
            train_df = pd.concat([train_df,loc_df[loc_df["Patient"] == patient]])

        val_df = pd.DataFrame()
        for patient in val_patients:
            val_df = pd.concat([val_df,loc_df[loc_df["Patient"] == patient]])
        
        val_df.reset_index(inplace=True, drop= True)
        train_df.reset_index(inplace=True, drop=True)

        if self.save:
            train_df.to_csv(f"{self.mode}_train_feats.csv", encoding= "utf-8", index = False)
            val_df.to_csv(f"{self.mode}_val_feats.csv", encoding="utf-8", index=False)
        print("CSVs creados")
        return train_df,val_df

In [45]:
spliter = DFSpliter(save = True, mode= "down_win")

In [46]:
spliter("down_prep_windows.csv",".")

CSVs creados


(     Patient  N_Win                               Dir
 0        h04      1       test\per_win\sane\h04_w1.pt
 1        h04      2       test\per_win\sane\h04_w2.pt
 2        h04      3       test\per_win\sane\h04_w3.pt
 3        h04      4       test\per_win\sane\h04_w4.pt
 4        h04      5       test\per_win\sane\h04_w5.pt
 ...      ...    ...                               ...
 1403     s11     60  test\per_win\abnormal\s11_w60.pt
 1404     s11     61  test\per_win\abnormal\s11_w61.pt
 1405     s11     62  test\per_win\abnormal\s11_w62.pt
 1406     s11     63  test\per_win\abnormal\s11_w63.pt
 1407     s11     64  test\per_win\abnormal\s11_w64.pt
 
 [1408 rows x 3 columns],
     Patient  N_Win                               Dir
 0       h10      1       test\per_win\sane\h10_w1.pt
 1       h10      2       test\per_win\sane\h10_w2.pt
 2       h10      3       test\per_win\sane\h10_w3.pt
 3       h10      4       test\per_win\sane\h10_w4.pt
 4       h10      5       test\per_win\san

In [17]:
xd = [*SANE, *ABNORMAL]