In [15]:
#%matplotlib inline
import sys
sys.path.append('..')
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from scipy import signal
import warnings
from Util import FeatureUtils as util

In [2]:
"""
Class to generate features. Uses preprocessed data held in memory by FeatDataHolder class.
"""
class Feature_generator:
    
    """
    Init function.
    """
    def __init__(self,df):
        #sampling frequency and last sample taken
        self.sfreq = 500 #will always be, hopefully lol
        self.df = util.filter_common_channels(df)
        self.pca = None
        self.std = None #these parameters are used for standardization.
        self.mean = None # Use same parameter and apply to eval/test data.
        self.std_lim = None #used for artifact detection, also on eval set.
        self.std_med = None
        self.bad_indices = dict() #this needs to be passed on to the label side. Will include bad indices found during calculation and artifacts found later

    """
    Function needed for calculating the features. Central piece on the feature side. Works as follows:
    columns: channels, for each channel the 150 frequencies (0-150Hz) (hece freq*cha length), binned logarithmically
    Rows: Time steps, defined by sliding window+window size
    resulting matrix is 2D, Time Stepsx(Freq*Channels)
    In case of generating train data, this function also saves mean and stddev for standardization purpose.
    Input: Start and end time (in secs), bool for whether train data or not (for PCA), window size and sliding window in sec
    Output: Standardized, binned data.
    """
    def _calc_features(self,data, time_sta,time_stp, wsize = 100, sliding_window=False):
        bads = []
        time_it = time_sta
        mat = None
        idx = 0
        print('from {} to {}'.format(time_sta,time_stp))
        while True:
            stop = time_it + wsize
            if stop >= data.shape[1]-1:
                print('we went here, why? shape is', data.shape[1]-1)
                break
            #Note that each column is exactly one second.
            #get data in range of ALL channels
            curr_data = data[:,time_it:stop,:].reshape(data.shape[0],-1)
            
            #welch method 
            fr,psd = signal.welch(curr_data,self.sfreq,nperseg=250)
            
            #if there are nans in the psd, something's off. throw away, save index, continue
            if np.isnan(psd).any():
                bads +=[idx] #current index baad
                if (sliding_window):
                    time_it += sliding_window
                else:
                    time_it += wsize
                if time_it + wsize >= time_stp+1:
                    break
                idx+=1
                continue
                  
            fr_bin,psd_bin = util.bin_psd(fr,psd)
            idx+=1
            if mat is None:
                self.fr_bin = fr_bin
                #first time. create first column, flatten w/o argument is row major 
                mat = psd_bin.flatten()
            else:
                #after, add column for each time step
                mat = np.column_stack((mat,psd_bin.flatten()))
            #sliding window?
            if (sliding_window):
                time_it += sliding_window
            else:
                time_it += wsize
            if time_it + wsize >= time_stp+1:
                break
        return mat, bads #we do the standardization after the filtering
    
    
    def _calc_features_over_days(self,time_sta,time_stp, wsize = 100, sliding_window=False):
        #here, check how many days we need for the requested datasize
        duration = time_stp - time_sta
        time_passed = 0
        curr_data = None
        idx = 0
        while duration>time_passed+wsize: #if not a single additional window would fit, break
            print('jo schau')
            try:
                day = self.df['Day'].loc[idx]
                print('Day No {}'.format(day))
            except KeyError:
                print("Not enough data loaded into memory for this request.")
                return curr_data
            print('time start is {}, and the end of the first day is{}'.format(time_sta, self.df['End'].loc[idx]))
            if time_sta >= self.df['End'].loc[idx]-self.df['Start'].loc[idx]: #if startsample is after duration of data of first day, go to next day, change stuff
                passed_not_used = self.df['End'].loc[idx]-self.df['Start'].loc[idx]
                time_sta -= passed_not_used #how much do we have to reduce time_sta?
                time_stp -= passed_not_used
                print('jo soviel vergangen{}, so sind die nun {},{}'.format(passed_not_used,time_sta,time_stp))
                continue
            data = self.df['BinnedData'].loc[idx]
            mat, bad = self._calc_features(data,time_sta,time_sta+duration-time_passed, wsize, sliding_window)
            if idx == 0:
                self.bad_indices['NaNs'] = np.array(bad)
                curr_data = mat
                print('jo hier war ich jetzt drin.')
            else:
                self.bad_indices['NaNs'] = np.append(self.bad_indices['NaNs'],np.array(bad)+len(self.bad_indices['NaNs'])+curr_data.shape[1])
                print(self.bad_indices['NaNs'], 'this is how the nan indices look after adding some on the next day')
                curr_data = np.append(curr_data,mat,axis=1)
            idx +=1
            print('jo das ist die laenge', len(self.bad_indices['NaNs']), 'das die andere', curr_data.shape )
            #calculate how many secs have passed
            if sliding_window:
                time_passed = wsize+sliding_window*(curr_data.shape[1] + len(self.bad_indices['NaNs'])-1)
            else:
                time_passed = wsize*(curr_data.shape[1]+ len(self.bad_indices['NaNs']))
            print(time_passed, 'jo stimmt das hier mit der time passed?')
            time_sta = 0 #for the next day, in case the initial starting time wasn't zero
        return curr_data
            
            

        
    """
    Sets up PCA parameters in case of training, otherwise just transforms
    Input: Data, train bool, amount of variance one wants to be explained (automatically calculates no. of PD needed)
    Output: Data in PC space.
    """
    def _setup_PCA(self,curr_data,train,expl_variance):
        if not train and self.pca is None:
            raise ValueError('Train set has to be generated first, otherwise no principal axis available for data trafo.')
        if train:
            print('Setting up PCA on current data range...')
            no_comps=util.get_no_comps(curr_data,expl_variance)
            self.pca=PCA(n_components=no_comps)
            self.pca.fit(curr_data)
        return self.pca.transform(curr_data)
    
    
    """
    Function that actually generates the features.
    Input: Start and end time (in secs), windowsize, sliding window (in s), train bool, variance to be explained by PCA.
    Output: Features
    """
    def generate_features(self,start=0,end=None, wsize=100, sliding_window=False, train=True,expl_variance=85):
        curr_data = self._calc_features_over_days(start,end,wsize,sliding_window)
        #from here on, days don't matter anymore. We have a chunk of data, which is nice
        if train:
            self.bad_indices['Artifacts'], self.std_lim, self.std_med = util.detect_artifacts(curr_data) #JO HIER TRAIN MITGEEBEN
        else:
            self.bad_indices['Artifacts'], _,_ = util.detect_artifacts(curr_data, self.std_lim,self.std_med) #JO HIER TRAIN MITGEEBEN        
        good_data = util.remove_artifacts(curr_data, self.bad_indices['Artifacts'])
        if train: #if it's train data, then get its mean and std for standardization
            self.std = np.std(good_data,axis=1)
            self.mean = np.mean(good_data,axis=1)
        data_scal = util.standardize(good_data,self.std,self.mean)
        princ_components=self._setup_PCA(data_scal.T,train=train,expl_variance=expl_variance)
        return princ_components
    
    """
    Function to return the bad indices found by filtering. Important: First filter out the nan indices, then the artifacts!
    Order is important
    Output: Dictionary of bad data points.
    """
    
    def get_bad_indices(self):
        return self.bad_indices
        
        

In [None]:
# pecog=Feature_generator('/data2/users/stepeter/Preprocessing/processed_cb46fd46_4.h5',prefiltered=False,wsize=100)
# mecog = Feature_generator('/nas/ecog_project/derived/processed_ecog/cb46fd46/full_day_ecog/cb46fd46_fullday_4.h5', prefiltered =False, wsize =100)
# for p in range(pecog.pca.n_components):
#     plt.plot(wut[:,p])
# plt.xlabel('Time (in w_size)')
# plt.ylabel('PC Value')
# plt.title('First %d principal components' % pecog.pca.n_components)
# plt.show()

# pecog.vis_raw_data(0,30000,range(20))

# #pecog.vis_raw_data(idx[0]-5,idx[1]+5)
# #pecog.vbis_raw_data(0,idx[1]+400)
# pecog.vis_welch_data(0,30000)



# good_data=pecog.calc_data_mat(idx[1]+100,idx[1]+400)
# pecog.pca.fit(good_data)
# good_data_trafo=pecog.pca.transform(good_data)
# print(good_data_trafo.shape)
# print(good_data.shape)


# print(good_data.shape)
# print(good_data_trafo.shape)

# comps=pecog.pca.components_
# print(pecog.raw.info['ch_names'][28])
# print(comps.shape)
# comps=comps.reshape((127,-1,2))
# print(np.argmax(comps[:,:,1],axis=1))
# #plt.plot(comps[5:,5:,0].T)
# plt.plot(comps[:,:,1].T)
# plt.ylim(-0.005,0.005)

# print(len(pecog.raw.info['chs']))

# #print(data_trafo)
# plt.plot(good_data_trafo[:,0])
# plt.plot(good_data_trafo[:,1])
# #plt.xlim(-0.00001,0.00001)
# #plt.ylim(-0.00001,0.00001)

# #max(data_trafo[:,1])-min(data_trafo[:,1])

In [None]:
##### A WHOLE LOT OF PLOTTING FUNCTIONS


# wut=pecog.generate_features(0,12500,expl_variance=90)

# # print(wut.shape)


# pecog.vis_pc()

# pecog.curr_data.shape


# lel=pecog.curr_data.T
# med=np.median(lel.reshape(-1,8,lel.shape[1]),axis=0)
# men=np.mean(lel.reshape(-1,8,lel.shape[1]),axis=0)
# print(lel.shape)
# for i in range(8):
#     plt.plot(lel[8*8+i,:], label='Bin %d' %i)
# plt.legend()
# plt.xlabel('Time Window')
# plt.ylabel('PSD')
# plt.title('Welch Transformation results')
# plt.show()
# for i in range(8):
#     plt.plot(men[i,:], label='Bin %d' %i)
# plt.legend()
# plt.xlabel('Time Window')
# #plt.yscale('log')
# plt.ylabel('PSD')
# plt.title('Welch Transformation, Mean over Channels - Standardized ')
# plt.show()

# for i in range(8):
#     plt.plot(med[i,:], label='Bin %d' %i)
# plt.legend()
# plt.xlabel('Time Window')
# #plt.yscale('log')
# plt.ylabel('PSD')
# plt.title('Welch Transformation, Median over Channels - Standardized')
# plt.show()



# lel=pecog.temp_mat.T
# med=np.median(lel.reshape(-1,8,lel.shape[1]),axis=0)
# men=np.mean(lel.reshape(-1,8,lel.shape[1]),axis=0)
# print(lel.shape)
# for i in range(8):
#     plt.plot(lel[8*8+i,:], label='Bin %d' %i)
# plt.legend()
# plt.xlabel('Time Window')
# plt.ylabel('PSD')
# plt.title('Welch Transformation results')
# plt.show()
# for i in range(8):
#     plt.plot(men[i,:], label='Bin %d' %i)
# plt.legend()
# plt.xlabel('Time Window')
# #plt.yscale('log')
# plt.ylabel('PSD')
# plt.title('Welch Transformation, Mean over Channels ')
# plt.show()

# for i in range(8):
#     plt.plot(med[i,:], label='Bin %d' %i)
# plt.legend()
# plt.xlabel('Time Window')
# #plt.yscale('log')
# plt.ylabel('PSD')
# plt.title('Welch Transformation, Median over Channels')
# plt.show()

# f=h5py.File('/data2/users/stepeter/Preprocessing/processed_cb46fd46_4.h5')

# sprr=f['dataset'][()]

# for i in range(8,14):
#     plt.plot(sprr[i,125*100*500:127*100*500])
# plt.xlabel('t')
# plt.ylabel('uV')
# plt.show()
# for i in range(10,12):
#     plt.plot(pecog.data[i,125*100*500:127*100*500])
# plt.xlabel('t')
# plt.ylabel('uV')
# plt.show()

# for i in range(8,14):
#     plt.plot(sprr[i,10*100*500:12*100*500])
# plt.xlabel('t')
# plt.ylabel('uV')
# plt.title('One Patient, Sample Channels')

# for i in range(25,28):
#     plt.plot(pecog.data[i,35*100*500:40*100*500])
# plt.xlabel('t')
# plt.ylabel('uV')
# plt.show()

# for i in range(25,28):
#     plt.plot(pecog.data[i,38*100*500:40*100*500])
# plt.xlabel('t')
# plt.ylabel('uV')
# plt.show()

# for i in range(25,28):
#     plt.plot(pecog.data[i,38*100*500:50*100*500])
# plt.xlabel('t')
# plt.ylabel('uV')
# plt.show()

# np.argmax(lel.reshape(-1,8,lel.shape[1])[:,1,34])

# pecog.generate_features()

# ttest=np.abs(pecog.pca.transform(np.eye(pecog.curr_data.shape[1])))

# ttest_sum[:16]

# ttest.shape

# #how much did each bin contribute?
# ttest_sum=ttest.sum(axis=1)
# ttest_sum=ttest[:,0]
# ttest_shaped=ttest_sum.reshape(pecog.data.shape[0],-1)
# cont_bins=ttest_shaped.sum(axis=0)
# cont_elecs=ttest_shaped.sum(axis=1)
# plt.figure(figsize=(7,3))
# plt.bar(np.arange(len(cont_bins)),cont_bins)
# ticks=['[0,1]','(1,2]','(2,4]','(4,8]','(8,16]','(16,32]','(32,64]','(64,150]']
# plt.xticks(np.arange(len(cont_bins)),ticks)
# plt.title('Contributions of Bins to PD 0')
# plt.xlabel('Bins')
# plt.ylabel('Absolute Contribution')
# plt.show()

# plt.figure(figsize=(3,15))
# plt.barh(np.arange(len(cont_elecs)),cont_elecs)
# plt.yticks(np.arange(len(cont_elecs)),list(pecog.chan_labels))
# plt.ylabel('Region')
# plt.xlabel('Absolute Contribution')
# plt.title('Contributions of Chans to PD')

# mni_file=pd.read_excel('/data2/users/stepeter/mni_coords/cb46fd46/cb46fd46_MNI_atlasRegions.xlsx')
# for en,i in enumerate(pecog.chan_labels):
#     print(en)
#     if( sum(mni_file['Electrode'].isin([i]))==0):
#         print(i)
#         print(en)
        

# print(mni_file['Electrode'].loc[7])

# pecog.chan_labels.shape

# pc1=np.sum(np.abs(ttest[:,1].reshape(-1,8)),axis=1)
# pc2=np.sum(np.abs(ttest[:,2].reshape(-1,8)),axis=1)
# pc3=np.sum(np.abs(ttest[:,3].reshape(-1,8)),axis=1)
# pc4=np.sum(np.abs(ttest[:,4].reshape(-1,8)),axis=1)

# mni_coords_fullfile='/data2/users/stepeter/mni_coords/cb46fd46/cb46fd46_MNI_atlasRegions.xlsx'
# plot_ecog_electrodes_mni_from_file_and_labels(mni_coords_fullfile,pecog.chan_labels,num_grid_chans=64, colors=cont_elecs[:-1])

# plot_ecog_electrodes_mni_from_file_and_labels(mni_coords_fullfile,pecog.chan_labels,num_grid_chans=64, colors=pc1[:-1])
# plot_ecog_electrodes_mni_from_file_and_labels(mni_coords_fullfile,pecog.chan_labels,num_grid_chans=64, colors=pc2[:-1])
# plot_ecog_electrodes_mni_from_file_and_labels(mni_coords_fullfile,pecog.chan_labels,num_grid_chans=64, colors=pc3[:-1])
# plot_ecog_electrodes_mni_from_file_and_labels(mni_coords_fullfile,pecog.chan_labels,num_grid_chans=64, colors=pc4[:-1])