In [2]:
#%matplotlib inline

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import mne
from sklearn.decomposition import PCA
from sklearn import datasets
from sklearn.model_selection import train_test_split
from scipy.fftpack import fft, fftfreq
from scipy import signal
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot
import h5py
from simple_edf_preprocessing import Preprocessor
from datetime import datetime

In [9]:
#create a ECOG PCA class for its PCA object, hyperparas and other stuff
class PCA_Ecog_preprocessed:
    def __init__(self,path=None,pca_obj=None,wsize=30):
        self.pca=pca_obj
        #hyperpara in s for how large the time window should be on whihc we calculate our fourier trafo
        self.wsize=wsize
        #sampling frequency and last sample taken
        if(path != None):
            self.df=h5py.File(path)
            #preprocess data
            self.Preprocessor=Preprocessor(self.df)
            self.data,self.bad_chan,self.bad_idx=self.Preprocessor.preprocess()
            self.data=self.data[self.bad_chan!=True]
            self.sfreq=int(self.df['f_sample'][()])
            #how many samples in this dataset?
            self.end=self.data.shape[1]-1
            self.bin_data()
            
        else:
            self.df=None
            self.sfreq=None
            self.end=None
    
    #this function restructures the data into a 3D structure, where each row presents a channel, each column one second
    #and the depth is the amount of samples per seconds (sfreq). 
    #This is to discard seconds where bad_idx are present and to be on par with the labels in the end
    #This function also creates a mask of bins to discard from the bad_idx array
    def bin_data(self):
        #where to end?
        end=self.end//self.sfreq
        self.data_bin=self.data[:,:end*self.sfreq].reshape(self.data.shape[0],end,self.sfreq)
        self.mask_bin=np.all(self.bad_idx[:end*self.sfreq].reshape(end,self.sfreq),axis=1)

        
      
    def standardize(self, data,ax=0):
        self.data_mean=np.mean(data,axis=ax)
        data_dem=data-self.data_mean
        std=np.std(data,axis=ax)
        data_stand=data_dem/std
        #self.data_scal=1000000
        #data_scal=self.data_scal*data_dem
        #return data_scal
        return data_stand
    
#this function caps at 200Hz, then bins the data in a logarithmic fashion to account for smaller psd values in higher freqs
    def bin_psd(self,fr,psd):
        fr_trun=fr[fr<=200]
        fr_total=len(fr_trun)
        fr_bins=np.arange(int(np.log2(fr_total)+1))
        #truncate everythin above 200Hz
        psd=psd[:,fr<=200]
        psd_bins=np.zeros((psd.shape[0],len(fr_bins)))
        prev=0
        max_psd_per_bin=np.exp2(fr_bins).astype('int')
        prev=0
        for b in fr_bins:
            if (b==len(fr_bins) or max_psd_per_bin[b]>=psd.shape[1]):
                psd_bins[:,b]+=np.sum(psd[:,prev:],axis=1)
            else:
                psd_bins[:,b]=np.sum(psd[:,prev:max_psd_per_bin[b]],axis=1)
            prev=max_psd_per_bin[b]
        return fr_bins, psd_bins

    
        
    #create matrix as follows:
    #columns: channels, for each channel the 200 frequencies (0-200Hz) (hece freq*cha length) BUT BINNED logarithmically
    #rows: Time steps
    #resulting matrix is 2D, Time Stepsx(Freq*Channels)
    #note that this matrix is prone to constant change. Save the current data as member variable
    def data_in_range(self,time_sta,time_stp):
        if(self.df==None):
            raise ValueError('Raw Data not set.')
        time_it=time_sta
        while True:
            stop=time_it+self.wsize
            if stop>=self.data_bin.shape[1]-1:
                print('Not enough data for set end %d. Returning all data that is available in given range.'% time_stp)
                break
            #Note that each column is exactly one second.
            #get data in range of ALL channels, applying the mask
            curr_data=self.data_bin[:,range(time_it,stop)[self.mask_bin[start:stop]],:].reshape(self.data[0],-1)
            #is this thing empty? continue
            if not curr_data.size:
                continue
            #welch method 
            fr,psd=signal.welch(curr_data,self.sfreq)
            fr_bin,psd_bin=self.bin_psd(fr,psd)
            if time_it==time_sta:
                self.fr_bin=fr_bin
                #first time. create first column, flatten w/o argument is row major 
                mat=psd_bin.flatten()
            else:
                #after, add column for each time step
                mat=np.column_stack((mat,psd_bin.flatten()))
            time_it+=self.wsize
            if time_it+self.wsize >= time_stp:
                break
        data_scal=self.standardize(mat.T)
        self.curr_data=data_scal
        return data_scal
    
    def vis_raw_data(self, start, stop, chans=None):
        if(self.df==None):
            raise ValueError('Raw Data not set (yet)')
        if chans is None:
            chans=range(self.data.shape[0])
        st=int(start*self.sfreq)
        stp=int(stop*self.sfreq)
        data=self.data[:,st:stp]
        for p in range(0,len(chans)-1):
            plt.plot(data[p])
        plt.show()
    
    def vis_welch_data(self,start,stop,no_chan=None):
        #account for wsize
        start=int(start/self.wsize)
        stop=int(stop/self.wsize)
        rem=self.curr_data[:,start:stop]
        plt.imshow(np.log(rem),cmap='viridis',aspect='auto')
        print(rem[4000,:])
        
    def vis_pc(self):
        if self.PCA is None:
            raise ValueError('PCA not set up yet. Please call setup_PCA first.')
        for p in range(self.PCA.n_components):
            plt.plot(self.PCA.transform(wut)[:,p])
        plt.xlabel('Time (in w_size)')
        plt.ylabel('PC Value')
        plt.title('First %d principal components' % self.PCA.n_components)
        plt.show()

    #get elbow curve. This also outputs the optimal n_components for the given desired explained variancce.
    def __elbow_curve(self,datapart,expl_var_lim):
        components = range(1, datapart.shape[1] + 1)
        explained_variance = []
        #till where?
        lim=min(50, datapart.shape[1])
        count=0
        for component in tqdm(components[:lim]):
            pca = PCA(n_components=component)
            pca.fit(datapart)
            expl_var=sum(pca.explained_variance_ratio_)
            explained_variance.append(expl_var)
            count+=1
            if(expl_var>(expl_var_lim/100.)):
                optimal_no_comps=count
                break
        if(explained_variance[-1:][0]<(expl_var_lim/100.)):
            print('Could not explain more than %3f percent of the variance. n_comps is therefore equal to the input dimension. Consider increasing data range or lowering demanded explained variance' % expl_var)
            optimal_no_comps=datapart.shape[1]
        sns_plot = sns.regplot(
            x=np.array(components[:count]), y=explained_variance,
            fit_reg=False).get_figure()
        return optimal_no_comps
    
    def setup_PCA(self,expl_variance):
        print('Setting up PCA on current data range...')
        if self.curr_data is None:
            raise ValueError('Data matrix not set yet. Please call data_in_range Func first.')
        no_comps=self.__elbow_curve(self.curr_data,expl_variance)
        self.pca=PCA(n_components=no_comps)
        self.pca.fit(self.curr_data)
        self.princ_components=self.pca.transform(self.curr_data)
    
        

In [10]:
pecog=PCA_Ecog_preprocessed('/data2/users/stepeter/Preprocessing/Reref/processed_cb46fd46_4.h5')

('df', (97, 43200500))


In [37]:
mnestuff=mne.io.read_raw_edf('/data1/ecog_project/ajile_data_release/cb46fd46_day_4.edf',eog=[4],preload=True)

Extracting EDF parameters from /data1/ecog_project/ajile_data_release/cb46fd46_day_4.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 86400999  =      0.000 ... 86400.999 secs...


In [40]:
mnestuff.info

<Info | 17 non-empty fields
    bads : list | 0 items
    buffer_size_sec : float | 1.0
    ch_names : list | GRID1, GRID2, GRID3, GRID4, GRID5, GRID6, GRID7, GRID8, ...
    chs : list | 98 items (EEG: 96, STIM: 1, EOG: 1)
    comps : list | 0 items
    custom_ref_applied : bool | False
    dev_head_t : Transform | 3 items
    events : list | 0 items
    highpass : float | 0.0 Hz
    hpi_meas : list | 0 items
    hpi_results : list | 0 items
    lowpass : float | 500.0 Hz
    meas_date : int | 1427987220
    nchan : int | 98
    proc_history : list | 0 items
    projs : list | 0 items
    sfreq : float | 1000.0 Hz
    acq_pars : NoneType
    acq_stim : NoneType
    ctf_head_t : NoneType
    description : NoneType
    dev_ctf_t : NoneType
    dig : NoneType
    experimenter : NoneType
    file_id : NoneType
    gantry_angle : NoneType
    hpi_subsystem : NoneType
    kit_system_id : NoneType
    line_freq : NoneType
    meas_id : NoneType
    proj_id : NoneType
    proj_name : NoneType


In [11]:
wut=pecog.data_in_range(0,3000)

NameError: global name 'start' is not defined

In [None]:
pecog.setup_PCA(94)

In [None]:
pecog.pca.fit(wut)

In [None]:
tr=pecog.pca.transform(pecog.curr_data)

In [61]:
pecog.PCA=PCA(n_components=com)
pecog.PCA.fit(wut)
pecog.PCA.n_components
pecog.PCA.transform(wut)

array([[ -0.81157505, -14.71728662,  -2.34965355, ...,   1.65272278,
         -1.98378791,  -1.03052725],
       [ -3.62768328, -16.40192335,  -2.04784517, ...,   2.41212364,
         -2.19696157,  -0.27153308],
       [-13.58255533,   3.27350952,   3.77294152, ...,   1.80580357,
         -0.62542247,   1.32854282],
       ..., 
       [  9.21359824,   0.9008957 ,   3.80799144, ...,  -0.08426016,
         -0.97584345,   1.37733634],
       [  8.33452942,   1.90464146,  -5.3162557 , ...,  -1.14187604,
          1.20489029,  -0.85552777],
       [ 11.08960914,   7.44386117,  -3.32637685, ...,  -0.93442396,
         -0.95458346,   0.13851337]])

In [None]:
for p in range(pecog.PCA.n_components):
    plt.plot(pecog.PCA.transform(wut)[:,p])
plt.xlabel('Time (in w_size)')
plt.ylabel('PC Value')
plt.title('First %d principal components' % pecog.PCA.n_components)
plt.show()

In [None]:
pecog.vis_raw_data(0,3000)

In [None]:
plt.scatter(data_trafo[:,0],data_trafo[:,1])
print(data_trafo)

In [None]:
#bin the frequencies logarithmitically (higher frequencies are much lpwer -> bin them together)
#standardize (mean and unit variance feature
#go up to 150Hz
#since we bin emotions as well, do a) regression on percentage, b) classification at cutoff
#

In [None]:
#what are the outstsanding points here?
time_idx=np.argwhere(data_trafo>2)[:,0]
#what is happening at these points?
#calculate back the time window:
idx=time_idx*pecog.wsize
print(idx)
#calculate from time windows to sampling points
start, stop = pecog.raw.time_as_index([idx[0]-5,idx[1]+5])
inter=pecog.raw.get_data(picks=range(0,40), start=start, stop=stop,reject_by_annotation=None, return_times=False)
print(inter.shape)

In [None]:
print(pecog.curr_data)

In [None]:
#pecog.vis_raw_data(idx[0]-5,idx[1]+5)
pecog.vis_raw_data(0,idx[1]+400)
pecog.vis_welch_data(50,time_idx[1]+200)

In [None]:
good_data=pecog.calc_data_mat(idx[1]+100,idx[1]+400)
pecog.pca.fit(good_data)
good_data_trafo=pecog.pca.transform(good_data)
print(good_data_trafo.shape)
print(good_data.shape)


In [None]:
print(good_data.shape)
print(good_data_trafo.shape)

In [None]:
comps=pecog.pca.components_
print(pecog.raw.info['ch_names'][28])
print(comps.shape)
comps=comps.reshape((127,-1,2))
print(np.argmax(comps[:,:,1],axis=1))
#plt.plot(comps[5:,5:,0].T)
plt.plot(comps[:,:,1].T)
plt.ylim(-0.005,0.005)

In [None]:
print(len(pecog.raw.info['chs']))

In [None]:
#print(data_trafo)
plt.plot(good_data_trafo[:,0])
plt.plot(good_data_trafo[:,1])
#plt.xlim(-0.00001,0.00001)
#plt.ylim(-0.00001,0.00001)

#max(data_trafo[:,1])-min(data_trafo[:,1])

In [None]:
##TEST STUFF
###functions for first insight. delete later
start, stop = pecog.raw.time_as_index([0, 51+pecog.wsize])
print(stop-start)
channel=pecog.raw.get_data(picks=[4,5],start=start, stop=stop,reject_by_annotation=None, return_times=False)
#data=np.squeeze(channel[0])

# Get real amplitudes of FFT (only in postive frequencies)
fft_vals = np.absolute(np.fft.rfft(channel))
# Get frequencies for amplitudes in Hz
fft_freq = np.fft.rfftfreq(len(channel[0,:]), 1.0/pecog.sfreq)
#welch method
fr,psd=signal.welch(channel,pecog.sfreq)

#throw away everything above 100Hz for now (and 0)
freqs=fft_freq[fft_freq<101][1:]
vals=fft_vals[:,fft_freq<101][:,1:]


plt.plot(fr,psd[0])
plt.plot(fr,psd[1])

#plt.plot(freqs,vals[1])
#plt.plot(freqs,vals[0])

#plt.plot(fft_freq[1:], fft_vals[1:])
#plt.xlabel('Frequency')
#plt.ylabel('Intensity')
#plt.show()
#plt.plot(freqs,vals)
#plt.xlabel('Frequency')
#plt.ylabel('Intensity')

#for 3D
# fig = pyplot.figure()
# ax = Axes3D(fig)

# ax.scatter(data_trafo[:,0],data_trafo[:,1],data_trafo[:,2],c='b')
# pyplot.show()