# Módulo de preprocesamiento 
Ammi Beltrán & Fernanda Borja

## Importado de librerias

In [1]:
# ! pip install -U scikit-learn

In [2]:
import os
import mne
import glob
#
import numpy as np
import math
import matplotlib.pyplot as plt
%matplotlib inline

import copy
import pandas as pd

import torch 

Se crea el path de data

In [3]:
if not os.path.exists("data"):
    os.makedirs("data")
    print("Data directory created :D")

Dirección del dataset

In [4]:
# EDFDIR = "D:\\OneDrive\\OneDrive - Universidad de Chile\\Semestre X\\Inteligencia\\Proyecto\\dataset\\tuh_eeg"
EDFDIR = "c:\\Users\\TheSy\\Desktop\\tuh_eeg"
files = glob.glob(EDFDIR + '/**/*.edf', recursive=True)

## Preprocesamiento

In [5]:
# Usamos MNE
# data = mne.io.read_raw_edf(files[0])
# raw_data = data.get_data()
# info = data.info
# channels = data.ch_names

### Funciones
* Una vez funcione mover a .py y traer como librería 

Seleccion de canales

In [6]:
def channel_select(data, channels):
    '''
    Selects channels from array 
    '''
    extracted = data.pick(channels, exclude="bads")
    return extracted

Clipeo

In [7]:
def clip(data, channels,max= 500e-6):
    def cliper(array):
        for i in range(len(array)):
            if abs(array[i]) > max:
                array[i] = math.copysign(max,array[i])
        return array
    data.apply_function(cliper, picks=channels, channel_wise= True)

Filtrado

In [8]:
def eeg_filter(data, lfreq = 0.1, hfreq= 30):
    '''
    
    '''
    data_copy = copy.copy(data)
    filtered = data_copy.filter(#l_freq = lfreq,
                                l_freq = lfreq,
                                h_freq = hfreq,
                                method = "iir",
                                )
    return filtered

Corte de primero minuto y de max

In [9]:
def temporal_crop(data, tin = 60, tfin = 12*60):
    ''' 
    Cut the channels from the second "tin" to "tfin"
    '''
    data_copy = copy.copy(data)
    croped = data_copy.crop(tmin = tin, tmax = min(tfin, int(data.times[-1])))
    return croped

Marcado de épocas

In [10]:
# eventos
def get_epochs(data, channels, window = 20):
    ''' 
    window es la ventana de tiempo
    '''
    data_copy = copy.copy(data)
    # Create events
    events = mne.make_fixed_length_events(data_copy, duration = window, first_samp = True)
    # Divide accordingly
    picks = channels
    epochs = mne.Epochs(raw = data_copy, events = events, picks = picks, preload = True,
                        tmin = 0., tmax = window, baseline = None,
                        flat = dict(eeg = 1e-6))
    
    epochs.drop(-1,reason = "Unfixed duration")
    return epochs

Downsample

In [11]:
def downsample(epoch, freq = 200):
    ''' 
    Downsamples the data given by a factor
    En nuestro caso, down corresponde a (frecuencia que queremos)/(frecuencia actual)
    '''
    down = epoch.resample(freq, npad = "auto")
    return down

Normalización

In [12]:
def normalization(epochs):
    obj = mne.decoding.Scaler(info = epochs.info, scalings='mean')
    values = obj.fit_transform(epochs.get_data())
    return values

***

Union

## Uso de funciones

***

In [13]:
def EDFprep(edf, n_channels = 19, norm = True, random = True, ):
    '''
    Pipeline
    '''

    #Random channel select
    channels = edf.ch_names
    if random:
        ch = np.random.choice(channels[:-3], size = n_channels, replace=False)
    else:
        ch = channels[:n_channels]
    
    channel_data = channel_select(edf, ch)
    clip(channel_data,ch)
    if(int(edf.times[-1]) < 100):
        return np.empty(0)
    filtered = eeg_filter(channel_data)
    trimmed_data = temporal_crop(filtered)
    down_data = downsample(trimmed_data)
    epochs = get_epochs(down_data, ch)
    if norm:
        norm_data = normalization(epochs)
        norm_data = np.delete(norm_data,-1,2)
        return norm_data
    epochs = np.delete(epochs,-1,2)    
    return epochs


In [14]:
raw = mne.io.read_raw_edf(files[0], preload=True)
procesed = EDFprep(raw, random = True,)
# raw.ch_names

Extracting EDF parameters from c:\Users\TheSy\Desktop\tuh_eeg\aaaaaaaa\s001_2015_12_30\01_tcp_ar\aaaaaaaa_s001_t000.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 323839  =      0.000 ...  1264.996 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.1 - 30 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 0.10, 30.00 Hz: -6.02, -6.02 dB

Not setting metadata
33 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 33 events and 4001 original time points ...
0 bad epochs dropped
Dropped 1 epoch: 32


In [15]:
# torch.save(procesed,"LSTMData-0.001.pt")
# train_loader(batch_size= 2*19, multi = False)

In [16]:
def Save_win(data,loc_df, save_dir, patient_id,session_id, save = False):

    for i, win in enumerate(data):
        sdir = f"{save_dir}\\{patient_id}\\{patient_id}_{session_id}_w{i+1}.pt"
        loc_df.loc[len(loc_df)] = [patient_id,session_id,i+1,sdir]

        if save:
            torch.save(win,sdir)
    return loc_df

def Save_ch(data,loc_df, save_dir, patient_id,session_id, save = False):
    
    for i, win in enumerate(data):
        for j, ch in enumerate(win):
            sdir = f"{save_dir}\\{patient_id}\\{patient_id}_{session_id}_w{i+1}_ch{j+1}.pt"
            loc_df.loc[len(loc_df)] = [patient_id,session_id,i+1,sdir]
            
            if save:
                torch.save(ch,sdir)
    return loc_df

In [19]:
def prep(path, save = False,mode = "per_win", save_dir = "data", sep = "\\"):
    ''' 
    Lectura de todos los edfs de cada paciente, guardado de ventanas temporales y csv de direcciones

    Inputs:
        -path
    Output:
        -dir_csv: csv con todos los datos de guardado de las ventanas de cada edf.    
    '''
    LEN_PAT = 8
    SESION_LEN = 15
    loc_df = pd.DataFrame(columns= ["Patient", "Session","N_Win", "Dir"], )
    save_dir = save_dir + mode

    if save:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
            print("Data directory created :D")
    

    patient_path = glob.glob(path + '/**')
    for patient in patient_path:

        #Para guardar la id en el DF
        patient_id = patient[-LEN_PAT:]

        if save:
            if not os.path.exists(os.path.join(save_dir, patient_id)):
                os.makedirs(os.path.join(save_dir, patient_id))
        
        sessions = glob.glob(patient + '/**')
        for session in sessions:

            #Para guardar la sesion correspondiente
            session_id = session[-SESION_LEN:-(SESION_LEN - 4)]

            edfs = glob.glob(session+ "/**/*.edf")
            for edf in edfs:

                raw = mne.io.read_raw_edf(edf,preload=True)
                try:
                    data = EDFprep(raw,random = False)
                except:
                    print(f"{patient_id}_{session_id} failed")
                    continue
                if mode == "per_win":
                    loc_df = Save_win(data,loc_df,save_dir,patient_id,session_id, save)
                elif mode == "per_channel":
                    loc_df = Save_ch(data,loc_df,save_dir,patient_id,session_id,save)

    if save:
        if mode == "per_channel":
            loc_df.to_csv("prep_channels.csv", encoding= "utf-8" )
        elif mode == "per_win":
            loc_df.to_csv("prep_windows.csv", encoding= "utf-8")
                
    return loc_df

In [20]:
df = prep(EDFDIR,save = False, mode = "per_channel", save_dir = "data", )

Extracting EDF parameters from c:\Users\TheSy\Desktop\tuh_eeg\aaaaaaaa\s001_2015_12_30\01_tcp_ar\aaaaaaaa_s001_t000.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 323839  =      0.000 ...  1264.996 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.1 - 30 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 0.10, 30.00 Hz: -6.02, -6.02 dB

Not setting metadata
33 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 33 events and 4001 original time points ...
0 bad epochs dropped
Dropped 1 epoch: 32
Extracting EDF parameters from c:\Users\TheSy\Desktop\tuh_eeg\aaaaaaab\s001_2002_12_30\02_tcp_le\aaaaaaab_s001_t000.edf...
EDF file detected
Setting channel info structure...
Creating raw.info stru

In [21]:
df

Unnamed: 0,Patient,Session,N_Win,Dir
0,aaaaaaaa,s001,1,dataper_channel\aaaaaaaa\aaaaaaaa_s001_w1_ch1.pt
1,aaaaaaaa,s001,1,dataper_channel\aaaaaaaa\aaaaaaaa_s001_w1_ch2.pt
2,aaaaaaaa,s001,1,dataper_channel\aaaaaaaa\aaaaaaaa_s001_w1_ch3.pt
3,aaaaaaaa,s001,1,dataper_channel\aaaaaaaa\aaaaaaaa_s001_w1_ch4.pt
4,aaaaaaaa,s001,1,dataper_channel\aaaaaaaa\aaaaaaaa_s001_w1_ch5.pt
...,...,...,...,...
162825,aaaaaadv,s001,32,dataper_channel\aaaaaadv\aaaaaadv_s001_w32_ch1...
162826,aaaaaadv,s001,32,dataper_channel\aaaaaadv\aaaaaadv_s001_w32_ch1...
162827,aaaaaadv,s001,32,dataper_channel\aaaaaadv\aaaaaadv_s001_w32_ch1...
162828,aaaaaadv,s001,32,dataper_channel\aaaaaadv\aaaaaadv_s001_w32_ch1...


In [None]:
def EDFplot(edf_path, vs_prep = False, win_s = 20, f_prep = 200,n_epoch = 0 ):
    '''
    Plot the first window after 60s of the edf 
    '''
    print(f"Ploting {edf_path}")
    names = ["Original", "Prep"]
    raw = mne.io.read_raw_edf(edf_path, preload= True)
    raw_freq = int(raw.info["sfreq"])
    raw_data = raw.get_data()

    first_channel = raw_data[0]

    n_axes = 1 + vs_prep * 1

    fig, axes = plt.subplots(nrows = n_axes, figsize=(30, 15))

    n_ep = n_epoch
    # first_win_original = first_channel[(60)*raw_freq:(60+win_s)*raw_freq]
    first_win_original = first_channel[(60 + n_ep*win_s)*raw_freq:(60 + n_ep*win_s + win_s)*raw_freq]
    
    if vs_prep:
        first_win_processed = np.squeeze(EDFprep(raw,n_channels = 1,norm = True, random = False)[n_ep])
        axes[0].plot(first_win_original, color = "black", label = names[0])
        axes[0].set_title(f"{names[0]} edf")
        axes[0].set_ylabel("uVoltage [Vs]")

        axes[1].plot(first_win_processed, color = "gray", label = names[1])
        axes[1].set_title(f"{names[1]} edf")
        axes[1].set_ylabel("uVoltage [Vs]")
        axes[1].set_xlabel("Samples [s*Hz]")
    else:
        axes.plot(first_win_original,color = "black", label = names[0])
        axes.set_title(f"{names[0]} edf")
        axes.set_ylabel("uVoltage [Vs]")
        axes.set_xlabel("Samples [s*Hz]")
    plt.show()

    


