# Módulo de preprocesamiento 
Ammi Beltrán & Fernanda Borja

## Importado de librerias

In [None]:
# ! pip install -U scikit-learn

In [1]:
import os
import mne
import glob
#
import numpy as np
import math
import matplotlib.pyplot as plt
%matplotlib inline

import copy
import pandas as pd

# import torch 

Se crea el path de data

In [None]:
if not os.path.exists("data"):
    os.makedirs("data")
    print("Data directory created :D")

Dirección del dataset

In [2]:
# EDFDIR = "D:\\OneDrive\\OneDrive - Universidad de Chile\\Semestre X\\Inteligencia\\Proyecto\\dataset\\tuh_eeg"
EDFDIR = "c:\\Users\\TheSy\\Desktop\\tuh_eeg"
files = glob.glob(EDFDIR + '/**/*.edf', recursive=True)

## Preprocesamiento

In [None]:
# Usamos MNE
# data = mne.io.read_raw_edf(files[0])
# raw_data = data.get_data()
# info = data.info
# channels = data.ch_names

### Funciones
* Una vez funcione mover a .py y traer como librería 

Seleccion de canales

In [3]:
def channel_select(data, channels):
    '''
    Selects channels from array 
    '''
    extracted = data.pick(channels, exclude="bads")
    return extracted

Clipeo

In [4]:
def clip(data, channels,max= 500e-6):
    def cliper(array):
        for i in range(len(array)):
            if abs(array[i]) > max:
                array[i] = math.copysign(max,array[i])
        return array
    data.apply_function(cliper, picks=channels, channel_wise= True)

Filtrado

In [5]:
def eeg_filter(data, lfreq = 0.1, hfreq= 30):
    '''
    
    '''
    data_copy = copy.copy(data)
    filtered = data_copy.filter(#l_freq = lfreq,
                                l_freq = lfreq,
                                h_freq = hfreq,
                                method = "iir",
                                )
    return filtered

Corte de primero minuto y de max

In [6]:
def temporal_crop(data, tin = 60, tfin = 12*60):
    ''' 
    Cut the channels from the second "tin" to "tfin"
    '''
    data_copy = copy.copy(data)
    croped = data_copy.crop(tmin = tin, tmax = min(tfin, int(data.times[-1])))
    return croped

Marcado de épocas

In [7]:
# eventos
def get_epochs(data, channels, window = 20.0):
    ''' 
    window es la ventana de tiempo
    '''
    data_copy = copy.copy(data)
    # Create events
    events = mne.make_fixed_length_events(data_copy, duration = window, first_samp = True)
    # Divide accordingly
    picks = channels
    epochs = mne.Epochs(raw = data_copy, events = events, picks = picks, preload = True,
                        tmin = 0., tmax = window, baseline = None,
                        flat = dict(eeg = 1e-6))
    
    epochs.drop(-1,reason = "Unfixed duration")
    return epochs

Downsample

In [8]:
def downsample(epoch, freq = 200):
    ''' 
    Downsamples the data given by a factor
    En nuestro caso, down corresponde a (frecuencia que queremos)/(frecuencia actual)
    '''
    down = epoch.resample(freq, npad = "auto")
    return down

Normalización

In [9]:
def normalization(epochs):
    obj = mne.decoding.Scaler(info = epochs.info, scalings='mean')
    values = obj.fit_transform(epochs.get_data())
    return values

***

Union

## Uso de funciones

***

In [10]:
def EDFprep(edf, n_channels = 19, norm = True, random = True, long = False):
    '''
    Pipeline
    '''

    #Random channel select
    channels = edf.ch_names
    if random:
        np.random.shuffle(channels)
    channels = channels[:n_channels]

    channel_data = channel_select(edf, channels)
    clip(channel_data,channels)
    t = temporal_crop(channel_data)
    if(int(edf.times[-1]) < 100):
        return np.empty(0)
    filtered = eeg_filter(channel_data)
    trimmed_data = temporal_crop(filtered)
    down_data = downsample(trimmed_data)
    epochs = get_epochs(down_data, channels)
    if norm:
        norm_data = normalization(epochs)
        return norm_data
    return epochs.get_data()


In [None]:
def prep(path, save = False, save_dir = "data", sep = "\\"):
    ''' 
    Lectura de todos los edfs de cada paciente, guardado de ventanas temporales y csv de direcciones

    Inputs:
        -path
    Output:
        -dir_csv: csv con todos los datos de guardado de las ventanas de cada edf.    
    '''
    LEN_PAT = 8
    SESION_LEN = 15
    loc_df = pd.DataFrame(columns= ["Patient", "Session","N_Win", "Dir"], )

    if save:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
            print("Data directory created :D")
    

    patient_path = glob.glob(EDFDIR + '/**')
    for patient in patient_path:

        #Para guardar la id en el DF
        patient_id = patient[-LEN_PAT:]
        
        sessions = glob.glob(patient + '/**')
        for session in sessions:

                    
            #Para guardar la sesion correspondiente
            session_id = session[-SESION_LEN:-(SESION_LEN - 4)]

            if save:
                if not os.path.exists(save_dir + sep + session_id):
                    os.makedirs(save_dir + sep + session_id)

            edfs = glob.glob(session+ "/**/*.edf")
            for edf in edfs:

                raw = mne.io.read_raw_edf(edf,preload=True)
                data = EDFprep(raw)
                
                for i in range(len(data)):
                    sdir = f"{save_dir}{sep}{patient_id}{sep}{session_id}{sep}w{i+1}.pt"
                    loc_df.loc[len(loc_df)] = [patient_id,session_id,i+1,sdir]

                    if save:
                        torch.save(data[i],f"{save_dir}{sep}{session}{sep}{patient_id}_{session}_w{i+1}.pt")

    if save:
        pass
        #Guardar df como csv

                
    return loc_df
            
    # format ["aaaa","s001",7,"data/s001/aaaa_s001_w7.pt"]

   

In [34]:
def EDFplot(edf_path, vs_prep = False, win_s = 20, f_prep = 200,n_epoch = 0 ):
    '''
    Plot the first window after 60s of the edf 
    '''
    print(f"Ploting {edf_path}")
    names = ["Original", "Prep"]
    raw = mne.io.read_raw_edf(edf_path, preload= True)
    raw_freq = int(raw.info["sfreq"])
    raw_data = raw.get_data()

    first_channel = raw_data[0]

    n_axes = 1 + vs_prep * 1

    fig, axes = plt.subplots(nrows = n_axes, figsize=(30, 15))

    n_ep = n_epoch
    # first_win_original = first_channel[(60)*raw_freq:(60+win_s)*raw_freq]
    first_win_original = first_channel[(60 + n_ep*win_s)*raw_freq:(60 + n_ep*win_s + win_s)*raw_freq]
    
    if vs_prep:
        first_win_processed = np.squeeze(EDFprep(raw,n_channels = 1,norm = True, random = False)[n_ep])
        axes[0].plot(first_win_original, color = "black", label = names[0])
        axes[0].set_title(f"{names[0]} edf")
        axes[0].set_ylabel("uVoltage [Vs]")

        axes[1].plot(first_win_processed, color = "gray", label = names[1])
        axes[1].set_title(f"{names[1]} edf")
        axes[1].set_ylabel("uVoltage [Vs]")
        axes[1].set_xlabel("Samples [s*Hz]")
    else:
        axes.plot(first_win_original,color = "black", label = names[0])
        axes.set_title(f"{names[0]} edf")
        axes.set_ylabel("uVoltage [Vs]")
        axes.set_xlabel("Samples [s*Hz]")
    plt.show()

    


