In [15]:
import os
subject = 36
if "SUBJECT" in os.environ:
    subject = int(os.environ['SUBJECT'])

In [2]:
import numpy as np
import datetime as dt
import scipy
import joblib

import convvisual.analysis.data_preparation as cdp
import convvisual.analysis.utils as cu
import convvisual.analysis.plot_new as plot
import convvisual.receptive_field.receptive_field as crf


from braindecode.analysis.plot_util import plot_head_signals_tight,plot_head_signals_tight_with_tube
from braindecode.datasets.sensor_positions import tight_cap_positions

import matplotlib.pyplot as plt
from PIL import Image

%load_ext autoreload
%autoreload 2

Using gpu device 0: GeForce GTX 780 (CNMeM is disabled, cuDNN 5005)


In [16]:
modelpath = '/home/hartmank/braindecode/data/models/'
modelname = 'paper/ours/cnt/deep4/car/%d'%subject
savepath  = '/home/hartmank/data/convvisual/RF_data_paper/'
#folder = 'FilterAnalysis_nUnits100'
folder = 'ClassAnalysis_FilterDiff_nUnits100_nFilters05'

In [17]:
datapath = os.path.join(savepath,modelname,folder)
figurepath = os.path.join(savepath,'figures',str(subject),folder)

In [18]:
misc_data = cdp.load_misc_data(datapath)
sensor_names = misc_data['sensor_names']
nFilters = 100
nFeatures=7
fig_h = 10
fig_h_big = 50
sampling_rate = misc_data['sampling_rate']
targets = misc_data['targets']

In [19]:
def save_fig_compress(fname):
    plt.savefig(fname)
    im = Image.open(fname)
    im2 = im.convert('RGB').convert('P', palette=Image.ADAPTIVE)
    im2.save( fname , format='PNG')
    
def join_and_mk_path(oldpath,newpath):
    newpath = os.path.join(oldpath,newpath)
    if not os.path.isdir(newpath):
        os.makedirs(newpath)
    return newpath

def save_data(data,path,fname):
    joblib.dump(data,os.path.join(path,fname))
    
def save_str(data,path,fname):
    f=open(os.path.join(path,fname),'w')
    print >> f,data
    f.close()

In [25]:
n1=dt.datetime.now()

KS_percentages_sum = np.zeros((nFeatures,len(misc_data['layer_indeces']),len(misc_data['classes'])))
KS_percentages_mean = np.zeros((nFeatures,len(misc_data['layer_indeces']),len(misc_data['classes'])))
KS_percentages_max = np.zeros((nFeatures,len(misc_data['layer_indeces']),len(misc_data['classes'])))

for i_lay,layer_ind in enumerate(misc_data['layer_indeces']):#misc_data['layer_indeces']:
    print 'Layer%d'%layer_ind
    tmppath_layer = join_and_mk_path(figurepath,'Layer%d'%layer_ind)
    
    X_baseline = cdp.load_baseline_data(datapath,layer_ind)
    F_baseline = cdp.load_baseline_feature_data(datapath,layer_ind)
    labels = cdp.load_labels_data(datapath,layer_ind)
    labels['FFT'] = labels
    labels_flat = cu.create_flattened_featurearr(labels['labels'])
    
    feature_entry_flat = cu.create_flattened_featurearr(F_baseline['feature_labels'])
    F_baseline_flat = cu.create_flattened_featurearr(F_baseline['features'],shape=(F_baseline['features'][0].shape[0],-1))
    
    for i_cl,cl in enumerate(misc_data['classes']):
        print 'Class%s'%str(cl)
        tmppath_cl = join_and_mk_path(tmppath_layer,'%s'%str(cl))
        
        RF_class = crf.load_ClassData(datapath,layer_ind,cl)
        
        
        KS_counts = np.zeros((len(F_baseline['feature_names']),len(RF_class.max_filters)))
        KS_means = np.zeros((len(F_baseline['feature_names']),len(RF_class.max_filters)))
        KS_max = np.zeros((len(F_baseline['feature_names']),len(RF_class.max_filters)))
        
        tmp_FFT_shape = F_baseline['features'][0].shape
        FFT_diff = np.zeros((tmp_FFT_shape[2],len(RF_class.max_filters)))
        FFT_KS = np.zeros((tmp_FFT_shape[2],len(RF_class.max_filters)))
        Phase_diff = np.zeros((tmp_FFT_shape[2],len(RF_class.max_filters)))
        Phase_KS = np.zeros((tmp_FFT_shape[2],len(RF_class.max_filters)))
        Electrode_KS = np.zeros((tmp_FFT_shape[1],len(RF_class.max_filters)))
        Electrode_KS_max = np.zeros((tmp_FFT_shape[1],len(RF_class.max_filters)))
        
        frequencies = scipy.fftpack.fftfreq(X_baseline['X_baseline'].shape[2], 1./sampling_rate)
        frequencies = frequencies[:frequencies.shape[0]/2].astype(str)[1:]
        
        
        for i_filt,filter_ind in enumerate(RF_class.max_filters):#in RF_class.max_filters:
            print 'Filter%d'%filter_ind
            tmppath_filter = join_and_mk_path(tmppath_cl,'Filter%d'%filter_ind)
            
            X_RF = cdp.load_RF_data(datapath,layer_ind,cl,filter_ind)
            F_RF = cdp.load_RF_feature_data(datapath,layer_ind,cl,filter_ind)
            KS = cdp.load_KS_score_data(datapath,layer_ind,cl,filter_ind)
            
            max_units = X_RF['max_units_in_filters'][:,0]
            tmp_targets = targets[max_units]
            tmp_target_count = np.bincount(tmp_targets)
            save_str(tmp_target_count.tolist(),tmppath_filter,'targets')
            
            n = float(len(X_baseline['X_baseline']))
            n2 = float(len(X_RF['X_RF_cropped']))
            KS_D_critical = np.sqrt(-1/2*np.log(0.05/2))*np.sqrt((n+n2)/(n*n2))
            
            all_KS = list()
            all_p = list()
            for feat_name in F_baseline['feature_names']:
                all_KS.extend(KS[feat_name]['KS_kuiper'])
                all_p.extend(KS[feat_name]['p_kuiper'])
            all_KS = np.asarray(all_KS)
            all_p = np.asarray(all_p)
            
            bonferroni_n = len(all_KS)
            all_p *= bonferroni_n
            
            F_RF_flat = cu.create_flattened_featurearr(F_RF['features'],shape=(F_RF['features'][0].shape[0],-1))
            
            #valid_indeces = np.where(np.logical_and(all_p<0.05,all_KS>KS_D_critical))[0]
            valid_indeces = np.where(all_p<0.05)[0]
            sort_mean_diff = all_KS[valid_indeces].argsort()[::-1]
            sort_mean_diff = valid_indeces[sort_mean_diff]
            
                
            plot.print_features(tmppath_filter,all_KS,all_p,labels_flat,sort_mean_diff)
            
            
            print 'Head'
            if layer_ind != 3:
                tmp_s = cu.percentile_deviation(X_RF['X_RF_cropped'],axis=0)
                plot_head_signals_tight_with_tube(np.median(X_RF['X_RF_cropped'],axis=0),tmp_s, sensor_names=sensor_names, 
                    figsize=(25, 25), sensor_map=tight_cap_positions)
                plt.ylim([-25,25])
                save_fig_compress(os.path.join(tmppath_filter,'HeadSignals.png'))
                plt.close()
            
            
            print 'FeatDist'
            tmppath_dist = join_and_mk_path(tmppath_filter,'Dist_plots')
            tmp_root = int(np.sqrt(len(sort_mean_diff[:99])))+1
            f, ax = plt.subplots(tmp_root,tmp_root,figsize=(8*tmp_root,7*tmp_root))
            if tmp_root>1:
                ax = ax.flatten()
            else:
                ax = [ax]
            for i,idx in enumerate(sort_mean_diff[:99]):
                plt.sca(ax[i])
                plot.plot_dist_comparison(F_RF_flat,F_baseline_flat,labels_flat,idx,title='KS: %f  p:%f'%(all_KS[idx],all_p[idx]))
            save_fig_compress(os.path.join(tmppath_dist,'FeatDist.png'))
            plt.close()
            
            
            print 'Ch'
            use_channels = list()
            for entry in feature_entry_flat[valid_indeces]:
                use_channels.append(entry[1][0])
            use_channels = np.unique(use_channels)
            tmppath_chan = join_and_mk_path(tmppath_filter,'Channel_plots')
            tmp_root = int(np.sqrt(len(use_channels)))+1
            f, ax = plt.subplots(tmp_root,tmp_root,figsize=(8*tmp_root,7*tmp_root))
            if tmp_root>1:
                ax = ax.flatten()
            else:
                ax = [ax]
            for i,chan in enumerate(use_channels):
                plt.sca(ax[i])
                plot.plot_channel_avg(X_RF['X_RF_cropped'],chan,title='Channel %s'%sensor_names[chan])
            save_fig_compress(os.path.join(tmppath_chan,'Ch.png'))
            plt.close()
            
            
            
            FFT_base = F_baseline['features'][0]
            FFT_filt = F_RF['features'][0]
            Phase_base = F_baseline['features'][2]
            Phase_filt = F_RF['features'][2]
            
            KS_FFT_tmp = KS['FFT']['KS_kuiper']
            #FFT_tmp_vali = np.logical_and(KS['FFT']['KS']>KS_D_critical,KS['FFT']['p']*bonferroni_n<0.05)
            FFT_tmp_vali = KS['FFT']['p_kuiper']*bonferroni_n<0.05
            KS_Phase_tmp = KS['Phase']['KS_kuiper']
            #Phase_tmp_vali = np.logical_and(KS['Phase']['KS']>KS_D_critical,KS['Phase']['p']*bonferroni_n<0.05)
            Phase_tmp_vali = KS['Phase']['p_kuiper']*bonferroni_n<0.05
            KS_FFT_tmp[FFT_tmp_vali==False] = 0
            KS_Phase_tmp[Phase_tmp_vali==False] = 0
            
            KS_FFT_tmp = KS_FFT_tmp.reshape(FFT_filt.shape[1:])
            FFT_tmp_vali = FFT_tmp_vali.reshape(FFT_filt.shape[1:])
            KS_Phase_tmp = KS_Phase_tmp.reshape(Phase_filt.shape[1:])
            Phase_tmp_vali = Phase_tmp_vali.reshape(Phase_filt.shape[1:])
            
            FFT_tmp_vali_sum = FFT_tmp_vali.sum(axis=1)
            FFT_tmp_vali_sum[FFT_tmp_vali_sum==0] = 1
            Phase_tmp_vali_sum = Phase_tmp_vali.sum(axis=1)
            Phase_tmp_vali_sum[Phase_tmp_vali_sum==0] = 1
            Electrode_KS[:,i_filt] = (KS_FFT_tmp.sum(axis=1)+KS_Phase_tmp.sum(axis=1))/(FFT_tmp_vali_sum+Phase_tmp_vali_sum)
            Electrode_KS[np.isnan(Electrode_KS[:,i_filt]),i_filt]=0
            Electrode_KS_max[:,i_filt] = np.maximum(KS_FFT_tmp.max(axis=1),KS_Phase_tmp.max(axis=1))

            
            FFT_tmp_vali_sum = FFT_tmp_vali.sum(axis=0)
            FFT_tmp_vali_sum[FFT_tmp_vali_sum==0] = 1
            Phase_tmp_vali_sum = Phase_tmp_vali.sum(axis=0)
            Phase_tmp_vali_sum[Phase_tmp_vali_sum==0] = 1
            FFT_KS[:,i_filt] = KS_FFT_tmp.sum(axis=0)/FFT_tmp_vali_sum
            FFT_KS[np.isnan(FFT_KS[:,i_filt]),i_filt]=0
            Phase_KS[:,i_filt] = KS_Phase_tmp.sum(axis=0)/Phase_tmp_vali_sum
            Phase_KS[np.isnan(Phase_KS[:,i_filt]),i_filt]=0

            FFT_diff[:,i_filt] = np.mean(np.log2(np.divide(np.median(FFT_filt,axis=0),np.median(FFT_base,axis=0))),axis=0)
            Phase_diff[:,i_filt] = -1*np.mean(np.log2(np.divide(scipy.stats.circvar(Phase_filt,axis=0),scipy.stats.circvar(Phase_base,axis=0))),axis=0)
            
            print 'KS'
            for i,feat_name in enumerate(F_baseline['feature_names']):
                tmp_KS = KS[feat_name]['KS_kuiper']
                tmp_p = KS[feat_name]['p_kuiper']*bonferroni_n
                #tmp_valid = np.logical_and(tmp_KS>KS_D_critical,tmp_p<0.05)
                tmp_valid = tmp_p<0.05
                if len(tmp_KS[tmp_valid])>0:
                    KS_counts[i,i_filt] = tmp_KS[tmp_valid].sum()
                    KS_means[i,i_filt] = tmp_KS[tmp_valid].mean()
                    KS_max[i,i_filt] = tmp_KS[tmp_valid].max()
                else:
                    KS_counts[i,i_filt] = 0
                    KS_means[i,i_filt] = 0
                    KS_max[i,i_filt] = 0
                
                
            ctrl_counts = KS_counts[:,i_filt].sum()
            ctrl_means = KS_means[:,i_filt].sum()
            ctrl_max = KS_max[:,i_filt].sum()
            
            if ctrl_counts==0:
                ctrl_counts = 1
            if ctrl_means==0:
                ctrl_means = 1
            if ctrl_max==0:
                ctrl_max = 1
                
            KS_percentages_sum[:,i_lay,i_cl] += KS_counts[:,i_filt]/ctrl_counts
            KS_percentages_mean[:,i_lay,i_cl] += KS_means[:,i_filt]/ctrl_means
            KS_percentages_max[:,i_lay,i_cl] += KS_max[:,i_filt]/ctrl_max
            plt.cla()
                
            
        KS_percentages_sum[:,i_lay,i_cl] /= len(RF_class.max_filters)
        KS_percentages_mean[:,i_lay,i_cl] /= len(RF_class.max_filters)
        KS_percentages_max[:,i_lay,i_cl] /= len(RF_class.max_filters)
            
            
        plot.KS_bar_plot(KS_counts,F_baseline['feature_names'],KS_D_critical,fig_h)
        save_fig_compress(os.path.join(tmppath_cl,'FiltKSScores.png'))
        plt.close()
        save_data(KS_counts,tmppath_cl,'KS_counts')
        
        
        plot.grid_plot(FFT_diff,sensor_names,frequencies,fig_h_big,
                       (range(FFT_diff.shape[1]),range(FFT_diff.shape[1])),
                       (range(FFT_diff.shape[0]),frequencies),
                       ('Filter','Frequency'),'Mean difference of median amplitude over all electrodes','jet')
        save_fig_compress(os.path.join(tmppath_cl,'FFT_meanDiff.png'))
        plt.close()
        save_data(FFT_diff,tmppath_cl,'FFT_diff')
        
        plot.grid_plot(FFT_KS,sensor_names,frequencies,fig_h_big,
                       (range(FFT_KS.shape[1]),range(FFT_KS.shape[1])),
                       (range(FFT_KS.shape[0]),frequencies),
                       ('Filter','Frequency'),'Mean KS Score of amplitude distributions over all electrodes','winter')
        save_fig_compress(os.path.join(tmppath_cl,'FFT_meanKS.png'))
        plt.close()
        save_data(FFT_KS,tmppath_cl,'FFT_KS')
        
        plot.grid_plot(Phase_diff,sensor_names,frequencies[1:],fig_h_big,
                       (range(Phase_diff.shape[1]),range(Phase_diff.shape[1])),
                       (range(Phase_diff.shape[0]),frequencies),
                       ('Filter','Frequency'),'Mean difference in phase variance over all electrodes','jet')
        save_fig_compress(os.path.join(tmppath_cl,'Phase_meanDiff.png'))
        plt.close()
        save_data(Phase_diff,tmppath_cl,'Phase_diff')
        
        plot.grid_plot(Phase_KS,sensor_names,frequencies[1:],fig_h_big,
                       (range(Phase_KS.shape[1]),range(Phase_KS.shape[1])),
                       (range(Phase_KS.shape[0]),frequencies),
                       ('Filter','Frequency'),'Mean KS Score of phase distribution over all electrodes','winter')
        save_fig_compress(os.path.join(tmppath_cl,'Phase_meanKS.png'))
        plt.close()
        save_data(Phase_KS,tmppath_cl,'Phase_KS')
        
        plot.grid_plot(Electrode_KS,sensor_names,frequencies,fig_h_big,
                       (range(Electrode_KS.shape[1]),range(Electrode_KS.shape[1])),
                       (range(Electrode_KS.shape[0]),sensor_names[range(Electrode_KS.shape[0])]),
                       ('Filter','Electrode'),'Mean KS Score of Electrodes','winter')
        save_fig_compress(os.path.join(tmppath_cl,'Electrode_meanKS.png'))
        plt.close()
        save_data(Electrode_KS,tmppath_cl,'Electrode_KS')
        
        plot.grid_plot(Electrode_KS_max,sensor_names,frequencies,fig_h_big,
                       (range(Electrode_KS.shape[1]),range(Electrode_KS.shape[1])),
                       (range(Electrode_KS.shape[0]),sensor_names[range(Electrode_KS.shape[0])]),
                       ('Filter','Electrode'),'Max KS Score of Electrodes','cool')
        save_fig_compress(os.path.join(tmppath_cl,'Electrode_maxKS.png'))
        plt.close()
        save_data(Electrode_KS_max,tmppath_cl,'Electrode_KS_max')
        
        
    
for i,cl in enumerate(misc_data['classes']):
    tmp_percentages = KS_percentages_sum[:,:,i]
    plt.cla()
    colors = plt.cm.jet(np.linspace(0, 0.9,len(F_baseline['feature_names'])))
    plt.figure()
    for j in range(len(F_baseline['feature_names'])):
        tmp = plt.plot(tmp_percentages[j].T,color=colors[j],label=F_baseline['feature_names'][j])
    plt.legend()
    plt.ylim([0,1])
    plt.gca().set_xticklabels(misc_data['layer_indeces'])
    save_fig_compress(os.path.join(figurepath,'KS_Evolution_sum_Class%s.png'%str(cl)))
    plt.title('(Sum) Percentual Evolution of KS scores Class%s'%str(cl))
    plt.close()
save_data(KS_percentages_sum,figurepath,'KS_percentages_sum')
    
for i,cl in enumerate(misc_data['classes']):
    tmp_percentages = KS_percentages_mean[:,:,i]
    plt.cla()
    colors = plt.cm.jet(np.linspace(0, 0.9,len(F_baseline['feature_names'])))
    plt.figure()
    for j in range(len(F_baseline['feature_names'])):
        tmp = plt.plot(tmp_percentages[j].T,color=colors[j],label=F_baseline['feature_names'][j])
    plt.legend()
    plt.ylim([0,1])
    plt.gca().set_xticklabels(misc_data['layer_indeces'])
    save_fig_compress(os.path.join(figurepath,'KS_Evolution_mean_Class%s.png'%str(cl)))
    plt.title('(Mean) Percentual Evolution of KS scores Class%s'%str(cl))
    plt.close()
save_data(KS_percentages_mean,figurepath,'KS_percentages_mean')
    
for i,cl in enumerate(misc_data['classes']):
    tmp_percentages = KS_percentages_max[:,:,i]
    plt.cla()
    colors = plt.cm.jet(np.linspace(0, 0.9,len(F_baseline['feature_names'])))
    plt.figure()
    for j in range(len(F_baseline['feature_names'])):
        tmp = plt.plot(tmp_percentages[j].T,color=colors[j],label=F_baseline['feature_names'][j])
    plt.legend()
    plt.ylim([0,1])
    plt.gca().set_xticklabels(misc_data['layer_indeces'])
    save_fig_compress(os.path.join(figurepath,'KS_Evolution_max_Class%s.png'%str(cl)))
    plt.title('(Max) Percentual Evolution of KS scores Class%s'%str(cl))
    plt.close()
save_data(KS_percentages_max,figurepath,'KS_percentages_max')
        
        
n2=dt.datetime.now()
print (n2-n1).seconds

Layer16
ClassNone
Filter61
Head
FeatDist
Ch
KS
102


In [8]:
KS_Phase_tmp.shape

(1, 4)

In [22]:
labels

{'feature_names': ('FFT', 'FFTc', 'Phase', 'Phasec', 'Mean', 'Meanc', 'Power'),
 'labels': [['FFT N/A 0.0', 'FFT N/A 25.0', 'FFT N/A 50.0', 'FFT N/A 75.0'],
  ['FFTc N/A 0.0'],
  ['Phase N/A 25.0', 'Phase N/A 50.0', 'Phase N/A 75.0', 'Phase N/A 100.0'],
  ['Phasec N/A 25.0'],
  ['Mean N/A'],
  ['Meanc N/A'],
  ['Power N/A']]}

In [23]:
F_baseline['feature_labels']

[array([[['FFT', [0, 0]], ['FFT', [0, 1]], ['FFT', [0, 2]], ['FFT', [0, 3]]]], dtype=object),
 array([[['FFTc', [0, 0]]]], dtype=object),
 array([[['Phase', [0, 0]], ['Phase', [0, 1]], ['Phase', [0, 2]],
         ['Phase', [0, 3]]]], dtype=object),
 array([[['Phasec', [0, 0]]]], dtype=object),
 array([['Mean', [0]]], dtype=object),
 array([['Meanc', [0]]], dtype=object),
 array([['Power', [0]]], dtype=object)]