In [1]:
import os
import glob
#
import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

In [2]:
class CustomEEGDataset(Dataset):
    def __init__(self, loc_df , root_dir , transform = None, multi = False):
        # self.loc_df = pd.read_csv(csv_file)
        self.loc_df = loc_df
        self.transform = transform
        self.root_dir = root_dir
        self.multi = multi
    def __len__(self,):
        return len(self.loc_df)
        
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        batch = []   
        for i in idx:
            eeg_file = os.path.join(self.root_dir,
                                self.loc_df.iloc[i, 3])
            eeg = torch.from_numpy(torch.load(eeg_file)) # [0][0]

            if self.multi:
                eeg= eeg.unsqueeze(0)
            batch.append(eeg) 
        bat = torch.vstack(batch)
            
        if self.transform is not None:
            bat = self.transform(bat)

        return bat



In [3]:
class DFSpliter():
    def __init__(self, train_size= 0.8, val_size = 0.2, save = False, seed = 69) -> None:
        self.train_size = train_size
        self.val_size = val_size
        self.save = save
        self.seed = seed

    def __call__(self, csv_file):
        # loc_df = pd.read_csv(csv_file)
        loc_df = csv_file
        patients = loc_df["Patient"].unique()
        np.random.seed(self.seed)
        np.random.shuffle(patients)
        end_idx = round(len(patients)*self.train_size)

        train_patients = patients[:end_idx]
        val_patients = patients[end_idx:]
        
        train_df = pd.DataFrame()
        for patient in train_patients:
            train_df = pd.concat([train_df,loc_df[loc_df["Patient"] == patient]])

        val_df = pd.DataFrame()
        for patient in val_patients:
            val_df = pd.concat([val_df,loc_df[loc_df["Patient"] == patient]])
        
        val_df.reset_index(inplace=True, drop= True)
        train_df.reset_index(inplace=True, drop=True)

        if self.save:
            train_df.to_csv("train_feats.csv", encoding= "utf-8")
            val_df.to_csv("val_feats.csv", encoding="utf-8")
        print("CSVs creados")
        return train_df,val_df
        


In [113]:
def Masking(channel: np.array, window: int= 150):
    '''
    Set to zero 
    Input:  -channel = Numpy array
            -window = Number of samples to set to zero
    Output: Numpy array masked
    '''
    channel_size = len(channel)
    first = np.random.randint(0,channel_size- window)
    masked = channel.copy()
    masked[first:first+window] = 0

    return masked

def DCVoltage(channel : np.array, max_magnitude: float = 0.5):
    ''' 
    Add a DC component between [-max_mangitude, max_magnitude]
    Input:  -channel = Numpy array
            -max_magnitude = max value to be added
    Output: Numpy array 
    '''
    dc_comp = (np.random.random(1)*2 - 1)*max_magnitude
    dispaced_channel = channel + dc_comp
    return dispaced_channel    

def GaussianNoise(channel: np.array, std: float = 0.2):
    '''
    Add Gaussian Noise with zero mean and std deviation
    Input:  -channel = Numpy array
            -std = Gaussian std
    Output: Channel with additive gaussian noise added
    '''
    channel_size = len(channel)
    noise = np.random.normal(loc = 0, scale= std, size= channel_size)
    noisy_channel = channel + noise
    return noisy_channel

def Time_Shift(channel: np.array, min_shift: int = 0, max_shift: int = 50 ):
    ''' 
    Shifts the channel n samples between min_shift and max_shift using reflection pad
    Input:  -channel = Numpy array
            -min_shift = Min number of samples to shift
            -max_shhift = Max number of samples to shift  
    Output: Shifted channel
    '''
    n_shift = np.random.randint(min_shift,max_shift)
    channel_size = len(channel)
    padded_array = np.pad(channel,pad_width= n_shift, mode = "reflect")
    right_left = np.random.choice((0,2))
    shifted_array = padded_array[n_shift*right_left:channel_size + n_shift*right_left]
    return shifted_array
def Amplitude(channel :np.array, max_amplitude: float = 1.5):
    '''
    Modifies the ampliude of the channel values between [1+max_amplitude,1-max_amplitude]
    Input:  -channel = Numpy array
            -max_amplitude = Max aplitude to add
    Output: Boosted channel
    '''
    amplitude = 1 + ((np.random.random(1)*2 -1) * max_amplitude)
    boosted_channel = channel*amplitude
    return boosted_channel

def Permutation(channel: np.array, win_samples: int = 4):
    '''
    Permutates the arrays by secuences of win_samples len
    Ensure its divisible by the total len of the array or the len of the output secuence will be wrong
    Input:  -channel = Numpu array
            -win_samples = Number of samples per secuences (N_sec = len(channel)// win_samples)
    Output: Permutated secuence
    '''

    n_seqs = len(channel)// win_samples
    random_idx = np.random.choice(np.arange(0,n_seqs, 1), n_seqs, replace=False ) 
    permutated = np.concatenate([channel[win_samples*i: win_samples*(i+1)] for i in random_idx])
    return permutated
def Temporal_Invertion(channel: np.array):
    ''' 
    Return the array reversed
    Input:  -channel = Numpy array
    Output: Reversed array
    '''
    reversed = channel[::-1]
    return reversed

def Negation(channel: np.array):
    '''
    Inverts the full array
    Input: -channel = Numpy array
    Output: Inverted array
    '''
    negated = channel * (-1)
    return negated


In [114]:
#Augmentation set
AUGMENTATIONS = [Negation,
                 Time_Shift,
                 Amplitude,
                 DCVoltage,
                 GaussianNoise,
                 Temporal_Invertion,
                 Permutation,
                 Masking]


In [135]:
class Augmentations(nn.Module):
    def __init__(self, n_aug, multi = False, augmentations = None) -> None:
        self.n_aug = n_aug
        self.multi = multi
        self.augmentations = augmentations

    def __call__(self,batch):
        xbar_batch = []
        xhat_batch = []

        rbar_idxs = np.random.choice(np.arange(0,len(self.augmentations),1),size=self.n_aug, replace= False)
        rhat_idxs = np.random.choice(np.arange(0,len(self.augmentations),1),size=self.n_aug, replace= False)

        batch = batch.numpy()

        for channel in batch:
            xbar = channel
            xhat = channel
            for i in rbar_idxs:
                xbar = self.augmentations[i](xbar)
                print(self.augmentations[i])
            for j in rhat_idxs:
                xhat = self.augmentations[j](xhat)
                print(self.augmentations[j])
            xbar_batch.append(torch.from_numpy(xbar[:4000].copy()))
            xhat_batch.append(torch.from_numpy(xhat[0:4000].copy()))
        
        xbar_batch = torch.vstack(xbar_batch)
        xhat_batch = torch.vstack(xhat_batch)

        return xbar_batch,xhat_batch

        

In [13]:
loc_df = pd.DataFrame(columns= ["Patient", "Session","N_Win", "Dir"], )

In [14]:
root_path = "C:\\Users\\TheSy\\Desktop\\FinalEL7006"

In [32]:
#Esta cosa agrega elementos al dataframe vacio, prueba cambiando el primer elemento de la lista 
#Para agregar distintos pacientes c:
loc_df.loc[len(loc_df)] = ["aaal","session_id",1,"LSTMData-0.001.pt"]


In [33]:
loc_df

Unnamed: 0,Patient,Session,N_Win,Dir
0,aaap,session_id,1,LSTMData-0.001.pt
1,aaap,session_id,1,LSTMData-0.001.pt
2,aaap,session_id,1,LSTMData-0.001.pt
3,aaax,session_id,1,LSTMData-0.001.pt
4,aaax,session_id,1,LSTMData-0.001.pt
5,aaax,session_id,1,LSTMData-0.001.pt
6,aaaa,session_id,1,LSTMData-0.001.pt
7,aaaa,session_id,1,LSTMData-0.001.pt
8,aaaa,session_id,1,LSTMData-0.001.pt
9,aaag,session_id,1,LSTMData-0.001.pt


In [34]:
spliter = DFSpliter()
train ,val = spliter(loc_df)

CSVs creados


In [116]:
train,val

(   Patient     Session  N_Win                Dir
 0     aaap  session_id      1  LSTMData-0.001.pt
 1     aaap  session_id      1  LSTMData-0.001.pt
 2     aaap  session_id      1  LSTMData-0.001.pt
 3     aaai  session_id      1  LSTMData-0.001.pt
 4     aaai  session_id      1  LSTMData-0.001.pt
 5     aaai  session_id      1  LSTMData-0.001.pt
 6     aaaa  session_id      1  LSTMData-0.001.pt
 7     aaaa  session_id      1  LSTMData-0.001.pt
 8     aaaa  session_id      1  LSTMData-0.001.pt
 9     aaal  session_id      1  LSTMData-0.001.pt
 10    aaal  session_id      1  LSTMData-0.001.pt
 11    aaal  session_id      1  LSTMData-0.001.pt
 12    aaax  session_id      1  LSTMData-0.001.pt
 13    aaax  session_id      1  LSTMData-0.001.pt
 14    aaax  session_id      1  LSTMData-0.001.pt,
   Patient     Session  N_Win                Dir
 0    aaag  session_id      1  LSTMData-0.001.pt
 1    aaag  session_id      1  LSTMData-0.001.pt
 2    aaag  session_id      1  LSTMData-0.001.pt)

In [117]:
dataset = CustomEEGDataset(train,root_path, multi=False)

In [137]:
batch = dataset.__getitem__([1])[0]


In [138]:
augment = Augmentations(3,augmentations=AUGMENTATIONS)

In [165]:
aug_batch = augment(batch)
print(aug_batch[0].shape,aug_batch[1].shape)
aug_batch

<function DCVoltage at 0x00000175A437A3B0>
<function Temporal_Invertion at 0x00000175A44C4280>
<function Negation at 0x00000175A44C4C10>
<function Masking at 0x00000175A2F7D000>
<function Time_Shift at 0x00000175A4445240>
<function GaussianNoise at 0x00000175A44440D0>
<function DCVoltage at 0x00000175A437A3B0>
<function Temporal_Invertion at 0x00000175A44C4280>
<function Negation at 0x00000175A44C4C10>
<function Masking at 0x00000175A2F7D000>
<function Time_Shift at 0x00000175A4445240>
<function GaussianNoise at 0x00000175A44440D0>
<function DCVoltage at 0x00000175A437A3B0>
<function Temporal_Invertion at 0x00000175A44C4280>
<function Negation at 0x00000175A44C4C10>
<function Masking at 0x00000175A2F7D000>
<function Time_Shift at 0x00000175A4445240>
<function GaussianNoise at 0x00000175A44440D0>
<function DCVoltage at 0x00000175A437A3B0>
<function Temporal_Invertion at 0x00000175A44C4280>
<function Negation at 0x00000175A44C4C10>
<function Masking at 0x00000175A2F7D000>
<function Time_

(tensor([[-0.6372, -0.6618, -0.6613,  ..., -1.2341, -1.1787, -1.1013],
         [ 1.1144,  1.0168,  0.8019,  ...,  0.8002,  0.8955,  0.9877],
         [ 0.0414, -0.0200, -0.0478,  ...,  0.0779,  0.0763,  0.0540],
         ...,
         [ 0.9374,  0.9537,  0.9894,  ...,  1.4440,  1.4937,  1.4515],
         [ 0.5113,  0.5073,  0.4386,  ...,  0.6054,  0.6391,  0.6898],
         [ 0.0595,  0.0286, -0.0156,  ..., -0.3285, -0.3090, -0.3620]],
        dtype=torch.float64),
 tensor([[ 1.4536,  1.5054,  1.4562,  ...,  0.9768,  0.6682,  0.6493],
         [-0.1579,  0.1763,  0.0108,  ..., -0.2050, -0.3751, -0.3038],
         [-0.4084, -0.7187, -0.8615,  ..., -1.3342, -1.0952, -1.3659],
         ...,
         [-0.4248, -0.7077, -0.1583,  ..., -1.0437, -1.1164, -1.3308],
         [-0.4506, -0.1834, -0.2816,  ..., -0.0245, -0.0792, -0.5733],
         [-0.0879, -0.2644, -0.1827,  ..., -0.5085, -0.9852, -0.8978]],
        dtype=torch.float64))