# Code for calculating and plotting linearity and linear separability


### Note-We have recently changed the layer convention such that  (Image, bipolar, ganglion, VVS 1-4) are (0, 2, 3, 4-7). I believe that I have updated this code in alignment with convention, but I haven't yet been able to test it. 


In [None]:
# These are all the modules we'll be using later. 
import time
import os
import sys
#from __future__ import print_function
import tensorflow as tf
import numpy as np
import scipy.linalg
import scipy.stats
import scipy as sp
import matplotlib.pyplot as plt
from six.moves import cPickle as pickle

import sklearn
import sklearn.linear_model

sys.stdout.flush()

#my_marker_size = 20;
font = {'family' : 'normal',
        'weight' : 'normal',
        'size'   : 18}

plt.rc('font', **font)
plt.rcParams['lines.markersize'] = 8;
#plt.tight_layout()


params = {'legend.fontsize': 'x-large',
          'figure.figsize': (15, 5),
         'axes.labelsize': 'x-large',
         'axes.titlesize':'x-large',
         'xtick.labelsize':'x-large',
         'ytick.labelsize':'x-large'}
plt.rcParams.update(params)




#Helper method to ensure that a directory exists
def ensure_dir(file_path):
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)

#Another helper method so I don't need to write if statements everywhere
def verbose_print(verbose, str_to_print):
    if(verbose):
        print(str_to_print)
        sys.stdout.flush()



# Subroutines for running and computing activations
------------

#### We use a scheme where we pickle activations once and then load them for analysis

In [None]:
desktop_path = os.path.join(os.path.join(os.path.expanduser('~')), 'Desktop/') 
pickle_dir = '/hdd/ActivationPickleHoard/' 

print('Pickle Dir is ' + pickle_dir)
model_input_path = 'saved_models/'


def big_pickle_dump(data, file_path, should_replace = True): #Workaround for pickle bug that I found on stackexchange
    if(should_replace):
        if(os.path.exists(file_path)):
            os.remove(file_path);
    
    if(1):
#        pickle.dump(data, open(file_path, 'wb'), encoding = 'latin1')        
        pickle.dump(data, open(file_path, 'wb'))                
#        pickle.dump(data, open(file_path, 'wb'), encoding = 'latin1')        
#        pickle.dump(data, open(file_path, 'wb'), protocol=4)
    else:
        n_bytes = 2**31
        max_bytes = 2**31 - 1
    #    data = bytearray(n_bytes)
        ## write
        bytes_out = pickle.dumps(data)
        with open(file_path, 'wb') as f_out:
            for idx in range(0, len(bytes_out), max_bytes):
                f_out.write(bytes_out[idx:idx+max_bytes])
    
def big_pickle_load(file_path): #Workaround for pickle bug that I found on stackexchange
    return pickle.load(open(file_path, 'rb'), encoding= 'latin1')



def ensure_batch_files_exist(model_name, layer, path = pickle_dir, batch_size = 200, full_size = 10000, verbose = True):
    cur_layer_activations_list = [];
    batch_start_list = list(range(0, full_size, batch_size))    
    
    for batch_start in batch_start_list:   
        batch_pickle_path = batch_pickle_file_name(model_name, layer, batch_start, path)     
        if(not os.path.exists(batch_pickle_path)):
            print('Warning, ' + batch_pickle_path + ' doesnt exist!')





def load_layer_activations_from_batch_files(model_name, layer, path = pickle_dir, batch_size = 200, full_size = 10000, verbose = True):
    cur_layer_activations_list = [];
    batch_start_list = list(range(0, full_size, batch_size))
    
    for batch_start in batch_start_list:   
        batch_pickle_path = batch_pickle_file_name(model_name, layer, batch_start, path)     
        if(verbose):
            print('Loading ' + batch_pickle_path)
        cur_layer_batch_activations = big_pickle_load(batch_pickle_path);
#            cur_layer_batch_activations = pickle.load(open(batch_pickle_path, 'rb'))

        if(verbose):
            print('Appending Shape of ' + str(np.shape(cur_layer_batch_activations)))
        cur_layer_activations_list.append(cur_layer_batch_activations)

    cur_layer_activations_array = np.array(cur_layer_activations_list)
    act_array_shape = np.shape(cur_layer_activations_array);
    if(verbose):    
        print('Old Shape is ' + str(act_array_shape))
    new_shape = tuple([full_size] +  list(act_array_shape[2:]));
    if(verbose):    
        #Reshapes so the (nbatch, batchsize, layer_shape ) goes to (nbatch* batchsize, layer_shape )
        print('Reshaping ' + str(act_array_shape) + 'to' + str(new_shape))
    reshaped_layer_activations = cur_layer_activations_array.reshape(new_shape)
    return reshaped_layer_activations



def load_layer_activations(model_name, layer, path = pickle_dir):
    file_name_to_load = layer_pickle_file_name(model_name, layer, path);
    return big_pickle_load(file_name_to_load);


def batch_pickle_file_name(model_name, layer, batch, path):
    return path + "BatchActs_"+ model_name + "_layer" + "{:02}".format(layer) + '_batch' + "{:05}".format(batch) + '.p'    


def layer_pickle_file_name(model_name, layer,  path):
    return path + "FullActs_"+ model_name + "_layer" + "{:02}".format(layer) +'.p'    



def output_activations_to_pickle(model, input_data, layers_to_keep, model_name, output_dir = pickle_dir, verbose = True, batch_size = 200, full_size = 10000, make_batches = False, compile_batches = False, delete_batches_when_done = False, conditional_make_batches = True):
    inp = model.input # input placeholder

    out = [model.layers[layer_nb].output for layer_nb in layers_to_keep] # layer output
    func = K.function([inp, K.learning_phase()], out) # function relating input to output

    
    batch_start_list = list(range(0, full_size, batch_size))
    
    if(make_batches):
        #We make a bunch of individual pickle fils. Every (model, layer, batch) combination will have its own pickle file.     
        for batch_start in batch_start_list:   
            if(verbose):
                print('Doing batch', batch_start)
            output = func([input_data[batch_start:(batch_start+batch_size)], 0])  #  apply function to a particular input    
            for (i_layer, layer) in enumerate(layers_to_keep):
                batch_pickle_path = batch_pickle_file_name(model_name, layer, batch_start, output_dir)   
                if(os.path.exists(batch_pickle_path) and conditional_make_batches):
                    print('Path Already exists, deleting before making again');
                    os.remove(batch_pickle_path)
#                    return;
                if(verbose):
                    print('Dumping to ' + batch_pickle_path)
                big_pickle_dump(output[i_layer], batch_pickle_path);

        out = [model.layers[layer_nb].output for layer_nb in layers_to_keep] # layer output
        func = K.function([inp, K.learning_phase()], out) # function relating input to output

    if(compile_batches):
        batch_start_list = list(range(0, full_size, batch_size))
        #We make a bunch of individual pickle fils. Every (model, layer, batch) combination will have its own pickle file.     
        for (i_layer, layer) in enumerate(layers_to_keep):
            print('Assembling Layer', layer) 
            cur_layer_activations_list = [];
            for batch_start in batch_start_list:   
                batch_pickle_path = batch_pickle_file_name(model_name, layer, batch_start, output_dir)      
                cur_layer_batch_activations = big_pickle_load(batch_pickle_path);
    #            cur_layer_batch_activations = pickle.load(open(batch_pickle_path, 'rb'))
                print('Appending Shape of ' + str(np.shape(cur_layer_batch_activations)))
                cur_layer_activations_list.append(cur_layer_batch_activations)

                #print(np.shape(cur_layer_batch_activations))
                if(delete_batches_when_done):
                    os.remove(batch_pickle_path)

            cur_layer_activations_array = np.array(cur_layer_activations_list)
            act_array_shape = np.shape(cur_layer_activations_array);
            print('Old Shape is ' + str(act_array_shape))
            new_shape = tuple([full_size] +  list(act_array_shape[2:]));
            #Reshapes so the (nbatch, batchsize, layer_shape ) goes to (nbatch* batchsize, layer_shape )
            print('Reshaping ' + str(act_array_shape) + 'to' + str(new_shape))
            reshaped_layer_activations = cur_layer_activations_array.reshape(new_shape)
    #        print(reshaped_layer_activations)
            layer_pickle_path = layer_pickle_file_name(model_name, layer, output_dir)
            if(verbose):
                print('Old shape' + str(act_array_shape))
                print('New Shape' + str(new_shape))
                print('Dumping to ' + layer_pickle_path)
                print('Reshaped_layer_activations shape' + str(np.shape(reshaped_layer_activations)))

            big_pickle_dump(reshaped_layer_activations, layer_pickle_path);
    #        pickle.dump(reshaped_layer_activations, open(layer_pickle_path, 'wb'))


 

# Subroutines for calculating linear fits and separabilty

In [None]:
   
def isconv_channels_pixels(input_array):
    if(len(np.shape(input_array)) == 4):
        return (True, np.shape(input_array)[3], np.shape(input_array)[1] * np.shape(input_array)[2] )
    else:
        return (False, np.shape(input_array)[1], 1)


        
def possibly_ravel(data_to_fit):
    # A helper method-The multitaskcv is slower than lassocv when the dimesnion = 1,
    # so I'm adding this to speed things along when we just have to fit a 1D value
    if(np.ndim(data_to_fit) == 1 or (np.ndim(data_to_fit) ==2 and np.shape(data_to_fit)[1] == 1)):
        return data_to_fit.ravel();
    else:
        return data_to_fit;

    
    
def possibly_multitask_lasso_cv(data_to_fit):
    # A helper method-The multitaskcv is slower than lassocv when the dimesnion = 1,
    # so I'm adding this to speed things along when we just have to fit a 1D value
    if(np.ndim(data_to_fit) == 1 or (np.ndim(data_to_fit) ==2 and np.shape(data_to_fit)[1] == 1)):
        return sklearn.linear_model.LassoCV();
    else:
        return sklearn.linear_model.MultiTaskLassoCV()

        
def one_way_multi_channel_ridge_var_exp(act_1, act_2, pixel_loc = 15, verbose = True, frac_for_training = .8):
    verbose_print(verbose, 'Reshaping')
    
    n_data = np.shape(act_2)[0];
    n_training = np.int(n_data * frac_for_training)

    #Slices X(Act1) into training and test
    reshaped_act1_x = act_1.reshape((np.shape(act_1)[0], np.prod(np.shape(act_1)[1:])))    
    act1_x_train = reshaped_act1_x[:n_training, :]
    act1_x_test = reshaped_act1_x[n_training:, :]

    #Slices Y(Act2) into training and test
    verbose_print(verbose, 'Original shapes are '+ str((np.shape(act_1), np.shape(act_2)) ) )
    if(np.ndim(act_2) > 2):
        cropped_act2 = act_2[:, pixel_loc, pixel_loc, :];        
        notes_to_return = 'Took center pixel'
    else:
        pix_val_to_use = np.int(np.shape(act_2)[1]/2)
        cropped_act2 = act_2[:, [pix_val_to_use, pix_val_to_use+1]]
        notes_to_return = 'Warning, trying to predict a nonconvolutional layer, just took the medium two entries to predict'
        verbose_print(verbose, notes_to_return)
    reshaped_act2_y = cropped_act2.reshape((np.shape(cropped_act2)[0], np.prod(np.shape(cropped_act2)[1:])))
    reshaped_act2_y_train = reshaped_act2_y[:n_training, :]    
    act2_y_train = reshaped_act2_y[:n_training, :]
    act2_y_test = reshaped_act2_y[n_training:, :]

    est_12 = sklearn.linear_model.RidgeCV();
    verbose_print(verbose, 'Training Est12 ' + str(type(est_12)) +  ' with shapes of ' + str(np.shape(reshaped_act1_x)) + ', ' + str(np.shape(reshaped_act2_y)));
    est_12.fit(act1_x_train,  act2_y_train); 
    est_12_frac_var_exp = est_12.score(act1_x_test,  act2_y_test)
    predict_act2_y_test = est_12.predict(act1_x_test)
    verbose_print(verbose, 'Finished!')
    
    return_dict =  {'frac_var_exp': est_12_frac_var_exp, 'estimator': est_12, 'Y_real_vs_fit': (act2_y_test, predict_act2_y_test)}    
    return_dict['Notes'] = notes_to_return
    
    
    chan_wise_scores = [];
    for chan in range(np.shape(act2_y_test)[1]):     
        cur_chan_wise_score = np.corrcoef(act2_y_test[:, chan], predict_act2_y_test[:, chan])[1,0]
        chan_wise_scores.append(cur_chan_wise_score)
        
        #This is true as N-> infinity. Figure if its true for small biases
    return_dict['chan_wise_scores'] = chan_wise_scores
    
    return return_dict  
#    return (est_21_frac_var_exp * act1_y_var, act1_y_var, est_12_frac_var_exp * act2_y_var, act2_y_var, est_21, est_12)
    
    
def multi_channel_lasso_var_exp(act1, act2, pixel_loc = 15, verbose = True):
    verbose_print(verbose, 'Reshaping')
    reshaped_act1_x = act_1.reshape((np.shape(act_1)[0], np.prod(np.shape(act_1)[1:])))
    reshaped_act2_x = act_2.reshape((np.shape(act_2)[0], np.prod(np.shape(act_2)[1:])))

    cropped_act1 = act1[:, pixel_loc, pixel_loc, :];
    cropped_act2 = act2[:, pixel_loc, pixel_loc, :];
    
    reshaped_act1_y = cropped_act1.reshape((np.shape(cropped_act1)[0], np.prod(np.shape(cropped_act1)[1:])))
    act1_y_var = np.sum(np.var(reshaped_act1_y, axis = 0));    
    verbose_print(verbose, 'Var1 is ' + str(act1_y_var))
    reshaped_act1_y = possibly_ravel(reshaped_act1_y)
    
    reshaped_act2_y = cropped_act2.reshape((np.shape(cropped_act2)[0], np.prod(np.shape(cropped_act2)[1:])))
    act2_y_var = np.sum(np.var(reshaped_act2_y, axis = 0));
    reshaped_act2_y = possibly_ravel(reshaped_act2_y)
    verbose_print(verbose, 'Var2 is ' + str(act2_y_var))
    
    
    
    verbose_print(verbose, 'Training Est12 ' + ' with shapes of ' + str(np.shape(reshaped_act1_x)) + ', ' + str(np.shape(reshaped_act2_y)));
    
#    est_12 =  sklearn.linear_model.MultiTaskLassoCV();
    est_12 = possibly_multitask_lasso_cv(reshaped_act2_y)
    verbose_print(verbose, 'Training Est12 ' + str(type(est_12)) +  ' with shapes of ' + str(np.shape(reshaped_act1_x)) + ', ' + str(np.shape(reshaped_act2_y)));
    
#    verbose_print(verbose, 'Est12 has type ' + str(est_12))
    est_12.fit(reshaped_act1_x, reshaped_act2_y); 
    est_12_frac_var_exp = est_12.score(reshaped_act1_x, reshaped_act2_y)
#    est_12_tot_var = np.sum(np.var(reshaped_act2_y, axis = 1)); This is fudged to be 1 at the moment

    
    est_21 = possibly_multitask_lasso_cv(reshaped_act1_y)
    verbose_print(verbose, 'Training Est21' + str(type(est_21)) + ' with shapes of ' + str(np.shape(reshaped_act2_x)) + ', ' + str(np.shape(reshaped_act1_y)));
    
    verbose_print(verbose, 'Est21 has type ' + str(est_21))
    est_21.fit(reshaped_act2_x, reshaped_act1_y); 
    est_21_frac_var_exp = est_21.score(reshaped_act2_x, reshaped_act1_y)
#    est_21_tot_var = np.sum(np.var(reshaped_act1_y, axis = 1));

    verbose_print(verbose, 'Finished!')
    
    
    return (est_21_frac_var_exp * act1_y_var, act1_y_var, est_12_frac_var_exp * act2_y_var, act2_y_var, est_21, est_12)

    
def lasso_var_exp(act1, act2, ind_to_use):
    est_12 =  sklearn.linear_model.LassoCV();
    est_12.fit(act1, act2[:, ind_to_use])
    est_12_frac_var_exp = est_12.score(act1, act2[:, ind_to_use])
    est_12_tot_var = np.var(act2[:, ind_to_use]);
    
    est_21 =  sklearn.linear_model.LassoCV();
    est_21.fit(act2, act1[:, ind_to_use])
    est_21_frac_var_exp = est_21.score(act2, act1[:, ind_to_use])
    est_21_tot_var = np.var(act1[:, ind_to_use]);
    
    return (est_21_frac_var_exp * est_21_tot_var, est_21_tot_var, est_12_frac_var_exp * est_12_tot_var, est_12_tot_var)

# Subroutines for filling and querying datastructures for calculated quantities


####  The basic pattern is a fill up a dictionary with values which depend on the model, channel, layer, and possibly the class to be linearly separated. After doing so, the data can be sliced whichever way by querying the dictionary. 

In [None]:


def model_name_string(bottleneck_width, brain_layers, instance, noise_start = 0.0, noise_end = 0.0,
                      retina_out_weight_reg = 0.0, retina_out_stride = 1, retina_hidden_channels = 32,
                      task = 'classification', filter_size = 9, retina_layers = 2, use_b = 1, actreg = 0.0,
                      vvs_width = 32, epochs = 20, reg = 0.0):

    trial_label = instance
    vvs_layers = brain_layers
    retina_out_width = bottleneck_width


    gen_model_name = 'cifar10_type_'+str(trial_label)+'_noise_start_'+str(noise_start)+'_noise_end_'+str(noise_end)+'_reg_'+str(reg)+'_retina_reg_'+str(retina_out_weight_reg)+'_retina_hidden_channels_'+str(retina_hidden_channels)+'_SS_'+str(retina_out_stride)+'_task_'+task+'_filter_size_'+str(filter_size)+'_retina_layers_'+str(retina_layers)+'_vvs_layers'+str(vvs_layers)+'_bias_'+str(use_b)+'_actreg_'+str(actreg)+'_retina_out_channels_'+str(retina_out_width)+'_vvs_width_'+str(vvs_width)+'_epochs_'+str(epochs)
    gen_model_name = 'saved_models/SAVED'+'_'+gen_model_name

    return gen_model_name


print(model_name_string(brain_layers=2, instance=4, bottleneck_width=4))

def create_model_name_list(bottleneck_width_list, brain_layers_list, instance_list):
    model_name_list = [];
    for cur_instance in instance_list:
        for cur_brain_layers in brain_layers_list:    
            for cur_bottleneck_width in bottleneck_width_list:
                cur_model_name = model_name_string(bottleneck_width = cur_bottleneck_width, brain_layers = cur_brain_layers, instance = cur_instance)
                model_name_list.append(cur_model_name)
    return model_name_list


def query_mean_sep_score(lin_sep_dict, model_name, n_classes = 10, layer = 3):
    score_list = [];
    for class1 in range(n_classes):
        for class2 in range(class1+1,n_classes):
            cur_score = lin_sep_dict[(model_name, layer, class1, class2)];
            score_list.append(cur_score)
    return np.mean(score_list)




def query_lin_sep_retina_plot_data(lin_sep_dict, bottleneck_width_list,brain_layer_list = range(5),  instance_list = range(1, 11, 1)):
    #Make a list of means, stderrs, and legends to plot
    all_plot_mean_stderr = [];
    all_plot_legends = []

    for (i_bottleneck_width, bottleneckwidth) in  enumerate(bottleneck_width_list):

        cur_plot_mean_list  = [];
        cur_plot_stderr_list = [];

        for (i_brain_layers, brain_layers) in enumerate(brain_layer_list):
            cur_val_list = []

            for cur_instance in instance_list:
                gen_model_name = model_name_string(bottleneck_width=bottleneckwidth, brain_layers=brain_layers, instance=cur_instance);
              #  print(gen_model_name)
                mean_score = query_mean_sep_score(my_lin_sep_dict, gen_model_name, layer=4, n_classes=10)
                cur_val_list.append(2 * mean_score)
            cur_plot_mean_list.append(np.mean(cur_val_list))
            cur_plot_stderr_list.append(np.std(cur_val_list)/np.sqrt(len(cur_val_list)-1))


        all_plot_mean_stderr.append((cur_plot_mean_list, cur_plot_stderr_list) )
#        all_plot_legends.append(str(bottleneckwidth) + ' channels')

        if(0):
            print('PLOT MEAN LIST:')
            print(cur_plot_mean_list)

            print('PLOT STDERR LIST:')
            print(cur_plot_stderr_list)
        fig_legend_list.append(str(bottleneckwidth) + ' channels')

    return( all_plot_mean_stderr, fig_legend_list)

def query_lin_sep_vs_layer_plot_data(lin_sep_dict, bottleneck_width_list, layer_list = [0, 4,6, 8, 10, 12],
                                     instance_list = range(1, 9, 1), brain_layers = 4):
    #Make a list of means, stderrs, and legends to plot
    all_plot_mean_stderr = [];
    all_plot_legends = []

    for (i_bottleneck_width, bottleneckwidth) in  enumerate(bottleneck_width_list):

        cur_plot_mean_list  = [];
        cur_plot_stderr_list = [];

        for (i_layer, layer) in enumerate(layer_list):
            cur_val_list = []

            for cur_instance in instance_list:
                gen_model_name = model_name_string(bottleneck_width=bottleneckwidth, brain_layers=brain_layers, instance=cur_instance);
              #  print(gen_model_name)
#                print('Doing layer' + str(layer))
                mean_score = query_mean_sep_score(my_lin_sep_dict, gen_model_name, layer=layer, n_classes=10)
                cur_val_list.append(2 * mean_score)
            cur_plot_mean_list.append(np.mean(cur_val_list))
            cur_plot_stderr_list.append(np.std(cur_val_list)/np.sqrt(len(cur_val_list)-1))


        all_plot_mean_stderr.append((cur_plot_mean_list, cur_plot_stderr_list) )
        all_plot_legends.append(str(bottleneckwidth) + ' channels')

        if(0):
            print('PLOT MEAN LIST:')
            print(cur_plot_mean_list)

            print('PLOT STDERR LIST:')
            print(cur_plot_stderr_list)
        fig_legend_list.append(str(bottleneckwidth) + ' channels')

    return( all_plot_mean_stderr, all_plot_legends)


def query_linearity_retina_plot_data(linearity_dict, bottleneck_width_list,brain_layer_list = range(5), instance_list = range(1, 9, 1), layer = 3):
    #Make a list of means, stderrs, and legends to plot
    all_plot_mean_stderr = [];
    all_plot_legends = []

    for (i_bottleneck_width, bottleneckwidth) in  enumerate(bottleneck_width_list):

        cur_plot_mean_list  = [];
        cur_plot_stderr_list = [];

        for (i_brain_layers, brain_layers) in enumerate(brain_layer_list):
            cur_val_list = []

            for cur_instance in instance_list:
                gen_model_name = model_name_string(bottleneck_width=bottleneckwidth, brain_layers=brain_layers, instance=cur_instance);
                cur_return_dict = linearity_dict[(gen_model_name, layer)]
                cur_val_list.append(cur_return_dict['frac_var_exp']);
            cur_plot_mean_list.append(np.mean(cur_val_list))
            cur_plot_stderr_list.append(np.std(cur_val_list)/np.sqrt(len(cur_val_list)-1))


        all_plot_mean_stderr.append((cur_plot_mean_list, cur_plot_stderr_list) )
#        all_plot_legends.append(str(bottleneckwidth) + ' channels')

        if(1):
            print('PLOT MEAN LIST:')
            print(cur_plot_mean_list)

            print('PLOT STDERR LIST:')
            print(cur_plot_stderr_list)
        fig_legend_list.append(str(bottleneckwidth) + ' channels')

    return(all_plot_mean_stderr, all_plot_legends)

        
def query_chanwise_sep_score(lin_sep_dict, model_name, n_classes = 10, layer = 3, n_channels = 1):
    all_chan_score_list =  []
    for chan in range(n_channels):
        cur_chan_score_list = [];
        for class1 in range(n_classes):
            for class2 in range(class1+1,n_classes):
                cur_score = lin_sep_dict[(model_name, layer, class1, class2, chan)];
                cur_chan_score_list.append(cur_score)
        all_chan_score_list.append(np.mean(cur_chan_score_list))
    return all_chan_score_list
#    return np.mean(score_list)


# Calculate Linear Separability and linearity of channels
--------------------------------------
#### This fills up several large dictionary structures, which will be queried to make plots

In [None]:
## measure linear separability in pixel space with SVM
#for each layer, load activations

from keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_test2 = y_test.copy()


model_SVM = sklearn.svm.LinearSVC(penalty='l2', loss='squared_hinge', dual=True, tol=0.0001, C=1.0, 
                      multi_class='ovr', fit_intercept=True, intercept_scaling=1, class_weight=None, 
                      verbose=0, random_state=None, max_iter=1000)
if(0):
    my_lin_fit_dict = dict();
    my_im_fit_dict = dict();    
    my_lin_sep_dict = dict();
    my_chanwise_lin_sep_dict = dict();
else:
    # We can either make new dictionary or load ones we have to calculate more values 
    my_lin_sep_dict = big_pickle_load('AnalysisOutputs/LinSeparationDict.p')
    my_lin_fit_dict = big_pickle_load('AnalysisOutputs/LinFitDict.p')
    my_chanwise_lin_sep_dict = big_pickle_load('AnalysisOutputs/ChanwiseLinSeparationDict.p')
    my_im_fit_dict = big_pickle_load('AnalysisOutputs/ImFitdict.p')
    
    

    
    
    
    
bottleneck_width_list = [1,32]
#bottleneck_width_list = [1, 2, 4, 8, 16, 32]
instance_list = list(range(1, 11, 1)) #Instances MUST start from 1
#brain_layers_list = [0, 1, 2, 3, 4]
brain_layers_list = [4]


n_classes_to_do = 10; #We linearly separate all possible pairs of classes. 
#By setting this number to be  less than 10, we can make things run faster
#layers_to_do = [3];
layers_to_do = [0, 2, 4, 5, 6, 7]

#Layer 0 is the image, layer 2 is the first layer of Retina_Net, layer 3 is the bottleneck, and layers 4, 5, 6,7 are the activations in VVS-Nett

#all_model_names = os.listdir(model_input_path);
#all_model_names = [all_model_names[0]]
all_model_names = create_model_name_list(bottleneck_width_list, brain_layers_list, instance_list)
print('N models is ' + str(len(all_model_names)))
#all_model_names = all_model_names[0:]; 
#print('N Clipped models is ' + str(len(all_model_names)))



#all_model_names = all_model_names[43:]
for (i_model_name, model_name) in enumerate(all_model_names):
    print('\n\nDoing model '   +  str(i_model_name) + ' /'  + str(len(all_model_names)) )
    print('Name:' + model_name)
    print(time.asctime())
    
    model_nlayers_filename = pickle_dir +  'NLayersOf_' + model_name + '.p';
    model_nlayers = big_pickle_load(model_nlayers_filename);
    pixels_act = load_layer_activations_from_batch_files(model_name, layer=0, verbose=False)
    

    
    cur_layers_to_keep = layers_to_do
#    cur_layers_to_keep = [0, 3] + list(range(4, model_nlayers, 2));    
    for cur_layer in cur_layers_to_keep:
        print('Calculating Properties of Layer', cur_layer)
        # load activation
        
        act = load_layer_activations_from_batch_files(model_name, cur_layer, verbose=False)
        print('Loaded activations of shape ' + str(np.shape(act)))
        
        
        if(1):
            #Calculates linear fits between image and activations
            print('Calculating linear fit from image to layer')
            cur_return_dict = one_way_multi_channel_ridge_var_exp(act_1 = pixels_act, act_2 = act, verbose=True);
            my_lin_fit_dict[(model_name, cur_layer)] = cur_return_dict
        if(1):
            print('Calculating linear fit from layer to image')            
            cur_return_dict = one_way_multi_channel_ridge_var_exp(act_1 = act, act_2 = pixels_act, verbose=True);
            my_im_fit_dict[(model_name, cur_layer)] = cur_return_dict
            
        if(0): 
            #Calcualates Linear Class Separabililty Channel By channel            
            #for each pair of class, measure linear separability on testing set
            i = 0
            for chan in range(np.shape(act)[3]):
                print('Doing Channel-Wise Separability for Channel ' + str(chan))
                
                cur_chan_acts = act[:, :, :, chan]
                cur_chan_acts = cur_chan_acts.reshape(cur_chan_acts.shape[0],cur_chan_acts.shape[1]*cur_chan_acts.shape[2])

                for class1 in range(n_classes_to_do):
                    for class2 in range(class1+1,n_classes_to_do):
 #                       print('Comparing classes ' +str((class1, class2)))
 #                       print(time.asctime())

                        #print(class2)
                        keep = np.squeeze((y_test2==class1) + (y_test2==class2))
                        actk = cur_chan_acts[keep]
                        yk = y_test[keep]
                        yk = yk==class1
                        #y = y_test2==class1
                        #y = y_test2==class2
                        model_SVM.fit(actk[:1000], yk[:1000])
                        pred = np.squeeze(model_SVM.predict(actk[1000:]))
                        real = np.squeeze(yk[1000:])
                        score = np.mean(pred==real)- .5
                        my_chanwise_lin_sep_dict[(model_name, cur_layer, class1, class2, chan)] = score;


                        i = i+1

#                        print('Score betwen classes ' +str((class1, class2))+' is' + str(score))
        if(1): #Calculates Linear Class Separability as a whole
            if (len(np.shape(act))>2):
                act = act.reshape(act.shape[0],act.shape[1]*act.shape[2]*act.shape[3])
            
            #for each pair of class, measure linear separability on testing set
            i = 0
            print('Calculating All-Channel Linear Separability')
            for class1 in range(n_classes_to_do):
                for class2 in range(class1+1,n_classes_to_do):
                    print('All-Channel Linear Separability, Comparing classes ' +str((class1, class2)))
#                    print(time.asctime())

                    #print(class2)
                    keep = np.squeeze((y_test2==class1) + (y_test2==class2))
                    actk = act[keep]
                    yk = y_test[keep]
                    yk = yk==class1
                    #y = y_test2==class1
                    #y = y_test2==class2
                    model_SVM.fit(actk[:1000], yk[:1000])
                    pred = np.squeeze(model_SVM.predict(actk[1000:]))
                    real = np.squeeze(yk[1000:])
                    score = np.mean(pred==real)- .5
                    my_lin_sep_dict[(model_name, cur_layer, class1, class2)] = score;
                    
                    
                    i = i+1
#                    print('Score betwen classes ' +str((class1, class2))+' is' + str(score))
    #We save the dictionary once we're done with each model.
    #If there were too many models, this would be inefficient, but in practice this doesn't take too much time
    ensure_dir('AnalysisOutputs/ImFitdict.p')
    big_pickle_dump(my_im_fit_dict, 'AnalysisOutputs/ImFitdict.p')
    big_pickle_dump(my_chanwise_lin_sep_dict, 'AnalysisOutputs/ChanwiseLinSeparationDict.p') 
    big_pickle_dump(my_lin_sep_dict, 'AnalysisOutputs/LinSeparationDict.p')
    big_pickle_dump(my_lin_fit_dict, 'AnalysisOutputs/LinFitDict.p') 
    print('Finished with model ' + model_name)
print('Finished making dictionaries!')

# Plots variance explained of linear fits or linear separability for various architectures


### Organized as a bunch of nearly identical sections which plot different quantities


##

In [None]:
my_lin_sep_dict = big_pickle_load('AnalysisOutputs/LinSeparationDict.p')
my_lin_fit_dict = big_pickle_load('AnalysisOutputs/LinFitDict.p')
my_im_fit_dict = big_pickle_load('AnalysisOutputs/ImFitdict.p')
new_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
              '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
              '#bcbd22', '#17becf'];

all_lines  =[]
brain_layer_list = list(range(5));
fig_legend_list = [];
plot_line_format_list = [];
if(0):
    plot_line_format_list.append(('k', '-'))
    plot_line_format_list.append(([.65, .65, .65], '-'))
else:
    
    bottleneck_width_list  =[1, 2, 4, 8, 16, 32]
    plot_line_format_list.append(('k', '-'))
    
    for color in new_colors:
        plot_line_format_list.append((color, '-'))
    
    plot_line_format_list.append(('r', ':'))
    plot_line_format_list.append(('g', ':'))
    plot_line_format_list.append(('b', ':'))
    plot_line_format_list.append(('purple', ':'))
    plot_line_format_list.append(([.65, .65, .65], '-'))
    
#    plot_line_format_list.append(('purple', '--'))
    
plotted_model_names = []
x_pad_r = .045
x_pad_l = .06

x_pad_r = .1
x_pad_l = .1
if(1):
    bottleneck_width_list =  [1, 2, 4]
               
    #instance_list = list(range(1, 11, 1))
    instance_list = list(range(1, 11, 1))

    (mean_stderr_list, legend_list) = query_linearity_retina_plot_data(my_im_fit_dict, bottleneck_width_list)
    plt.clf();
    all_lines = []

    for (i_bottleneck_width, (cur_means, cur_stderrs) ) in enumerate (mean_stderr_list):
        cur_plot_line_format = plot_line_format_list[i_bottleneck_width];
        cur_line  = plt.errorbar(brain_layer_list, cur_means, yerr = 2 *  np.array(cur_stderrs), c = cur_plot_line_format[0], linestyle = cur_plot_line_format[1], marker = 'o')
        print('At ' + str(i_bottleneck_width))
        all_lines.append(cur_line)
        
    #    plt.show()

    plt.xlabel('N Brain Layers')
    plt.ylabel('Linearity (Retina ->Image)')

    plt.legend(tuple(all_lines), tuple(fig_legend_list))   
    plt.ylim([.4, 1])
    plt.ylim([.0, 1])
    
    plt.xlim([-x_pad_l + np.min(brain_layer_list), np.max(brain_layer_list)+x_pad_r])
    plt.xticks(brain_layer_list)
    fig_out_path = 'Outputs/ImageLinearityVsNLayers.pdf'    
    ensure_dir(fig_out_path)
    plt.savefig(fig_out_path)
    plt.show()

    
if(1): ## Plot the linear separability of the retinal layer as a function the the number of layers and bottleneck width
    #This is figure 3D
    bottleneck_width_list = [1, 4]
    instance_list = [1, 2]
    brain_layers_list = [0, 1, 2, 3, 4]    
#    bottleneck_width_list =  [1, 2, 4]    
#    instance_list = list(range(1, 11, 1))

    (mean_stderr_list, legend_list) = query_lin_sep_retina_plot_data(my_lin_sep_dict, bottleneck_width_list, instance_list=instance_list)
    plt.clf();
    all_lines = [];

    for (i_bottleneck_width, (cur_means, cur_stderrs) ) in enumerate (mean_stderr_list):
        cur_plot_line_format = plot_line_format_list[i_bottleneck_width];
        cur_line  = plt.errorbar(brain_layer_list, cur_means, yerr = 2 * np.array(cur_stderrs), c = cur_plot_line_format[0], linestyle = cur_plot_line_format[1], marker = 'o')
        print('At Bottleneck Width' + str(i_bottleneck_width))
        all_lines.append(cur_line)
        
    if(1):
        #Hardcoded in the linear separability of the image itself
        cur_line, = plt.plot([0, 4], 1 * np.array([0.2738, 0.2738]), c = [.1, .1, .1], linestyle = '--');
        all_lines.append(cur_line)
        fig_legend_list.append('Raw Pixels')

    plt.legend(tuple(all_lines), tuple(fig_legend_list))   
    plt.ylim([.0, 1])
    
    plt.xlim([-x_pad_l + np.min(brain_layer_list), np.max(brain_layer_list)+x_pad_r])
    plt.xticks(brain_layer_list)
    fig_out_path = 'AnalysisOutputs/RetinaLinSepVsNLayers.pdf'    
    ensure_dir(fig_out_path)
    plt.savefig(fig_out_path)
    plt.show()

    

if(1): #Plot the linearity of the bottleneck layer as a function of the number of layers and bottleneck width
    #This is figure 3A
    bottleneck_width_list = [1, 4]
    #instance_list = list(range(1, 3, 1)) #Instances MUST start from 1
    instance_list = [1, 2]
    brain_layers_list = [0, 1, 2, 3, 4]
    
#    bottleneck_width_list =  [1, 2, 4]    
    #instance_list = list(range(1, 11, 1))
#    instance_list = list(range(1, 11, 1))

    (mean_stderr_list, legend_list) = query_linearity_retina_plot_data(my_lin_fit_dict, bottleneck_width_list, instance_list=instance_list)
    plt.clf();
    all_lines = [];
    for (i_bottleneck_width, (cur_means, cur_stderrs) ) in enumerate (mean_stderr_list):
        cur_plot_line_format = plot_line_format_list[i_bottleneck_width];
        cur_line  = plt.errorbar(brain_layer_list, cur_means, yerr =  2 * np.array(cur_stderrs), c = cur_plot_line_format[0], linestyle = cur_plot_line_format[1], marker = 'o')
        print('At ' + str(i_bottleneck_width))
        all_lines.append(cur_line)

#    plt.xlabel('N Brain Layers')
#    plt.ylabel('Linearity (Image ->Retina)')

#    plt.legend(tuple(all_lines), tuple(fig_legend_list))   
 #   plt.ylim([.4, 1])
    plt.xlim([-x_pad_l + np.min(brain_layer_list), np.max(brain_layer_list)+x_pad_r])
    plt.xticks(brain_layer_list)
    fig_out_path = 'AnalysisOutputs/RetinaLinearityVsNLayers.pdf'    
    ensure_dir(fig_out_path)
    plt.savefig(fig_out_path)
    plt.show()
    
if(1): #Plots the linear separaiblity for each layer
    print('Doing linear separability for each layer!')
    bottleneck_width_list =  [1]
               
    #I currently only have these for the extreme bottlenecks
    bottleneck_width_list =  [1, 32]
    plot_line_format_list = [];
    plot_line_format_list.append(('k', '-'))
    plot_line_format_list.append(([.5, .5, .5], '-'))
    
    #instance_list = list(range(1, 11, 1))
    instance_list = list(range(1, 9, 1))
#    instance_list = [1]

    lin_sep_layer_list = [0,2, 3, 4, 5, 6, 7]
    (mean_stderr_list, legend_list) = query_lin_sep_vs_layer_plot_data(my_lin_sep_dict, bottleneck_width_list, layer_list=lin_sep_layer_list, instance_list=instance_list)
    plt.clf();
    all_lines = []

    for (i_bottleneck_width, (cur_means, cur_stderrs) ) in enumerate (mean_stderr_list):
        cur_plot_line_format = plot_line_format_list[i_bottleneck_width];
        cur_line  = plt.errorbar(np.array(lin_sep_layer_list)/2, cur_means, yerr = 2 *  np.array(cur_stderrs), c = cur_plot_line_format[0], linestyle = cur_plot_line_format[1], marker = 'o')
        print('At ' + str(i_bottleneck_width))
        all_lines.append(cur_line)
        
    plt.xlim([-x_pad_l + np.min(np.array(lin_sep_layer_list)/2), np.max(np.array(lin_sep_layer_list)/2)+x_pad_r])
    plt.xticks(np.array(lin_sep_layer_list)/2)
    fig_out_path = 'AnalysisOutputs/LinSepVsLayerDepth.pdf'    
    ensure_dir(fig_out_path)
    plt.savefig(fig_out_path)
    plt.show()    
print('Finished Making All Plots!')    

# Plot linear separability vs linearity for invidiual channels

In [None]:
if(0):
    my_lin_fit_dict = dict();
    my_lin_sep_dict = dict();
    my_chanwise_lin_sep_dict = dict();
else:
    my_lin_sep_dict = big_pickle_load('AnalysisOutputs/LinSeparationDict.p')
    my_lin_fit_dict = big_pickle_load('AnalysisOutputs/LinFitDict.p')
    my_chanwise_lin_sep_dict = big_pickle_load('AnalysisOutputs/ChanwiseLinSeparationDict.p')
#    bottleneck_width_list = [1, 2, 4, 8, 16, 32]
    bottleneck_width_list = [4]
    
    #instance_list = [1,2,3]
    instance_list = list(range(1, 10, 1)) #Instances MUST start from 1
    brain_layers_list = list(range(5))
#    brain_layers_list = [4]
    n_classes_to_do = 10;

    

instance_list = range(1, 11, 1)

#print(my_lin_fit_dict.keys())
#print(my_chanwise_lin_sep_dict.keys())
#filled_markers = ('A', 'B', 'C', 'D', 'E')
#p
filled_markers = ('o', '|', '^', 's', 'p') #Shapes up to pentagon
#instance_list = range(1,9, 1)
#filled_markers = ('x', 'v', 'o', '^', '8')
my_other_colors = ('k', 'r', 'g', 'b', 'c', 'm', 'y', '#000066', [.3, .3, .3], [.75, .75, .75])
my_colors = ('#000066','#3333cc','#cc33ff','#ff3399','#ff9966','#FFD700', '#99ff66')
#my_colors = ('#000066', #3333cc, #cc33ff, #ff3399, #ff9966, #ff9966, #ffff99)
if(1):
    plt.clf();
#    my_chanwise_linsep_list =  []
#    my_chanwise_lin_fit_list = []
    
    cur_layer = 3
    #for (iblayers, blayers) in enumerate(brain_layers_list):
    for (ibneckwidth, bneckwidth) in enumerate(bottleneck_width_list):
        
        mega_x_list = [];
        mega_y_list = [];
        mega_color_list = []
        mega_marker_list = []
        mega_fmt_list = []

        plt.clf()
        plt.figure(figsize=(20,10))   

        my_chanwise_linsep_list =  []
        my_chanwise_lin_fit_list = []
#        for (ibneckwidth, bneckwidth) in enumerate(bottleneck_width_list):
        for (iblayers, blayers) in enumerate(brain_layers_list):
                    
            all_cur_model_names = create_model_name_list(bottleneck_width_list = [bneckwidth], brain_layers_list = [blayers], instance_list =  instance_list)
            cur_arch_y_list = [];
            cur_arch_x_list = [];
            
            cur_arch_color_list = []
            cur_arch_marker_list = []
            cur_arch_fmt_list = []
            
            for (imname, model_name) in enumerate(all_cur_model_names):
                print('IModelName is '+ str(imname))
                cur_lin_sep_list = query_chanwise_sep_score(my_chanwise_lin_sep_dict, model_name = model_name, n_classes=n_classes_to_do, layer=cur_layer, n_channels= bneckwidth);
#                cur_lin_sep_list = query_chanwise_sep_score(my_chanwise_lin_sep_dict, model_name = model_name, n_classes=n_classes_to_do, layer=cur_layer, n_channels= 32);
                
                my_chanwise_linsep_list.extend(cur_lin_sep_list);
                cur_lin_fit_element = my_lin_fit_dict[(model_name, cur_layer)]
                cur_lin_fit_list = cur_lin_fit_element['chan_wise_scores'];
                my_chanwise_lin_fit_list.extend(cur_lin_fit_list)
                
                
                cur_n_points = len(cur_lin_fit_list)
                print('Lin fit, lin sep have ' + str((len(cur_lin_fit_list), len(cur_lin_sep_list))))
                mega_x_list.extend(cur_lin_fit_list)
                mega_y_list.extend(cur_lin_sep_list)
                cur_color = my_colors[iblayers]
                cur_marker = 'o'
                
#                cur_marker = filled_markers[ibneckwidth]
                
                mega_color_list.extend([cur_color] * cur_n_points)
                mega_marker_list.extend([cur_marker] * cur_n_points)
                mega_fmt_list.extend([cur_color + cur_marker] * cur_n_points)

                cur_arch_y_list.extend(cur_lin_sep_list)
                cur_arch_x_list.extend(cur_lin_fit_list)
                cur_arch_color_list.extend([my_other_colors[imname]] * len(cur_lin_sep_list))
            if(1):
                plt.figure(figsize=(20,10))   

                for i in list(reversed(range(len(cur_arch_y_list)))):
    #                plt.plot(mega_x_list[i], mega_y_list[i], color = mega_color_list[i], marker =  mega_marker_list[i], linestyle = '')
                    plt.plot(cur_arch_x_list[i], cur_arch_y_list[i], color = cur_arch_color_list[i], marker =  'o', linestyle = '', alpha = .5, markersize = 40)

            #    plt.scatter(mega_x_list, mega_y_list, color = mega_color_list)
            #    plt.scatter(mega_x_list, mega_y_list, mega_fmt_list[0])

                plt.ylabel('Separability')
                plt.xlabel('Linearity')
    #            plt.title('Brain Layers ' + str(blayers))
                plt.title('Instances of Bneck, Depth ' + str((bneckwidth, blayers)) )

                plt.show()
                plt.clf()                
                
             #   plt.scatter(cur_lin_fit_list, cur_lin_sep_list)


#            print(np.shape(my_chanwise_linsep_list))
#            print(np.shape(my_chanwise_lin_fit_list))

#            print(my_chanwise_linsep_list)
#            print(my_chanwise_lin_fit_list)
            cur_color = my_colors[ibneckwidth]
#            cur_color = 'k'
    
            cur_marker = filled_markers[iblayers]
#            cur_marker = '.'
        
#            cur_marker = filled_markers[iblayers]
            
            print((cur_color, cur_marker))
            cur_fmt = cur_color+cur_marker;
            cur_n = len(my_chanwise_lin_fit_list)
            print(cur_n)
#            plt.scatter(my_chanwise_lin_fit_list, my_chanwise_linsep_list, color = [cur_color] * cur_n, marker = [cur_marker] * cur_n)
#            plt.scatter(my_chanwise_lin_fit_list, my_chanwise_linsep_list, color = [cur_color] * cur_n,  marker = cur_marker)
#            plt.scatter(my_chanwise_lin_fit_list, my_chanwise_linsep_list, color = cur_color,  marker = cur_marker)
            #plt.plot(my_chanwise_lin_fit_list, my_chanwise_linsep_list, linestyle='', color = cur_color, marker = cur_marker)
            
#            plt.scatter(my_chanwise_lin_fit_list, my_chanwise_linsep_list, color = cur_color, marker = cur_marker)
            if(1):
                plt.figure(figsize=(20,10))   

                for i in list(reversed(range(len(mega_x_list)))):
    #                plt.plot(mega_x_list[i], mega_y_list[i], color = mega_color_list[i], marker =  mega_marker_list[i], linestyle = '')
                    plt.plot(mega_x_list[i], mega_y_list[i], color = mega_color_list[i], marker =  mega_marker_list[i], linestyle = '', alpha = .5, markersize = 15)

            #    plt.scatter(mega_x_list, mega_y_list, color = mega_color_list)
            #    plt.scatter(mega_x_list, mega_y_list, mega_fmt_list[0])

                plt.ylabel('Separability')
                plt.xlabel('Linearity')
    #            plt.title('Brain Layers ' + str(blayers))
                plt.title('Bneck ' + str(bneckwidth))

                plt.show()
                plt.clf()

    if(0):
        plt.ylabel('Separability')
        plt.xlabel('Linearity')
        plt.show()
        plt.clf();
        print('Here!')
    

#plt.show()
if(0):
    for model_name in all_model_names:
        cur_lin_sep_list = query_chanwise_sep_score(my_chanwise_lin_sep_dict, model_name = model_name, n_classes=n_classes_to_do, layer=cur_layer, n_channels= 32);
        my_chanwise_linsep_list.extend(cur_lin_sep_list);
        cur_lin_fit_element = my_lin_fit_dict[(model_name, cur_layer)]
        cur_lin_fit_list = cur_lin_fit_element['chan_wise_scores'];
        my_chanwise_lin_fit_list.extend(cur_lin_fit_list)
        plt.scatter(cur_lin_fit_list, cur_lin_sep_list)


    print(np.shape(my_chanwise_linsep_list))
    print(np.shape(my_chanwise_lin_fit_list))

    print(my_chanwise_linsep_list)
    print(my_chanwise_lin_fit_list)
    #plt.scatter(my_chanwise_lin_fit_list, my_chanwise_linsep_list)
    #plt.ylabel('Separability')
    #plt.xlabel('Linearity')
    plt.show()


# Pickle All models

### Here we pickle all model activations to load later. While this must be run first, we only need to do this once, so it's at the bottom 

In [None]:

import keras
from keras.datasets import cifar10

import keras.backend as K
K.set_image_data_format('channels_last')
from keras.models import Sequential, Model
from keras.layers import Conv2D, ZeroPadding2D, BatchNormalization, Input
from keras.layers import Conv2DTranspose, Reshape, Activation, Cropping2D, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.activations import relu
from keras.initializers import RandomNormal
from keras.optimizers import RMSprop, SGD, Adam
import keras
import tensorflow as tf

import math
import os
import sys  



should_load_models = False
should_pickle_activations = False;
should_pickle_n_layers = False;

should_try_to_load_all_pickles = True


(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = np.mean(x_train, 3, keepdims=True)
x_test = np.mean(x_test, 3, keepdims=True) 
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

num_classes = 10

y_train2 = y_train.copy()
y_test2 = y_test.copy()

# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
keras.__version__

##############^^^IMPORT DATASET^^^####################


#model_input_path = '../../COMMON/PickleJar/'

import os
untrimmed_model_names = os.listdir(model_input_path);
#all_model_names = untrimmed_model_names[0:4]

all_model_names = untrimmed_model_names[0:300]
from keras.models import load_model
full_size_to_use = 10000;
batch_size_to_use = 200;

if(should_load_models):
    for (i_model, cur_model_name) in enumerate(all_model_names):
        keras.backend.clear_session()

        print(cur_model_name + '\n')
        print('Loading model ' + str(i_model) + '/' +  str(len(all_model_names)) + '...' )
        cur_model = load_model(model_input_path + '/' + cur_model_name)
        cur_layers_to_keep = [0,2,  3] + list(range(4, len(cur_model.layers)));
#        cur_layers_to_keep = [0, 4, 6];

        print('Finished Loading model')    
        if(should_pickle_activations):
            output_activations_to_pickle(cur_model, x_test, cur_layers_to_keep, cur_model_name, verbose = False, batch_size = batch_size_to_use, full_size = full_size_to_use, make_batches = True, compile_batches = False, delete_batches_when_done = False)   
            print('Finished pickling acts of size' + str(full_size_to_use))
        if(should_pickle_n_layers):
            cur_model_n_layers =  len(cur_model.layers);
            cur_path_name = pickle_dir + '/' + 'NLayersOf_' + cur_model_name + '.p';
            print('Dumped n_layers ' + str(cur_model_n_layers))
            print('\n\n' + cur_path_name + '\n\n')
            big_pickle_dump(cur_model_n_layers, cur_path_name)
        

    print('Finished Making Batches!')
        
if(should_try_to_load_all_pickles):    
    print('At last bit of conditional, just loading and viewing')
    print(time.asctime());
    for (i_model, cur_model_name) in enumerate(all_model_names):
        keras.backend.clear_session()
    
        print(cur_model_name + '\n')
        if(0):
            #This is what we do if we don't have a model name
            print('Loading model ' + str(i_model) + '/' +  str(len(all_model_names)) + ' ...' )
            cur_model = load_model(model_input_path + '/' + cur_model_name)
            cur_layers_to_keep = [0,2,  3] + list(range(4, len(cur_model.layers)));
            
#            cur_layers_to_keep = [0, 3] + list(range(4, len(cur_model.layers), 2));
        else:
            
            cur_path_name = pickle_dir + '/' + 'NLayersOf_' + cur_model_name + '.p';
            cur_model_n_layers = pickle.load(open(cur_path_name, 'rb'))
#            cur_layers_to_keep = [0, 3] + list(range(4, cur_model_n_layers, 2));
#            cur_layers_to_keep = [0, 4, 6];            
            cur_layers_to_keep = [0,2,  3] + list(range(4, len(cur_model.layers)));

            print('Model ' + str(i_model) + '/' +  str(len(all_model_names)) + 'Has ' + str(cur_model_n_layers) + ' Layers' )
            print(time.asctime())


        for cur_layer in cur_layers_to_keep:
            print('Loading Layer ' + str(cur_layer))
            ensure_batch_files_exist(cur_model_name, cur_layer, pickle_dir, batch_size=batch_size_to_use, full_size=full_size_to_use)
            cur_layer_acts = load_layer_activations_from_batch_files(cur_model_name, cur_layer, pickle_dir, batch_size=batch_size_to_use, full_size=full_size_to_use, verbose= False)
            print('Loaded Shape is ' + str(np.shape(cur_layer_acts)))

            
print('\n\n*************Done with everything in this cell!*********')



# Here we scatter bottleneck activations predicted from a linear fit of the image against actual activations (Figure 3B)

In [None]:
all_model_names = create_model_name_list(bottleneck_width_list=[1], instance_list=[1], brain_layers_list=[0, 4])

for model_name in all_model_names:
    layer_num = 3;
    cur_return_dict = my_lin_fit_dict[(model_name,layer_num)] #Layer is 4
    
    plt.clf()
    plt.figure(figsize=(8,3))    
    max_of_each = 1
    min_of_each = -.5

#    max_of_each = np.max([np.max(y_real), np.max(y_fit)])
#    min_of_each = np.min([np.min(y_real), np.min(y_fit)])
    
#    plt.plot([min_of_each, max_of_each], [min_of_each, max_of_each], '--r')
    
    (y_real, y_fit) =  cur_return_dict['Y_real_vs_fit'];
    scatter_scaling = 1.5;
    if(np.ndim(y_real)>1):
        for chan in range(np.shape(y_real)[1]):
            cur_y_real = y_real[:, chan];
            cur_y_fit = y_fit[:, chan];        
#            plt.scatter(scatter_scaling * cur_y_fit/np.max(cur_y_real), scatter_scaling * cur_y_real/np.max(cur_y_real), c = 'k', marker='.', alpha=1, s = 10)

            plt.scatter(scatter_scaling * cur_y_fit/np.max(cur_y_real), scatter_scaling * cur_y_real/np.max(cur_y_real), c = 'k', marker='.', alpha=.03, s = 60)
    else:
            plt.scatter(scatter_scaling * y_fit/np.max(y_real), scatter_scaling *  y_real/np.max(y_real), c = 'k', marker='.', alpha=.03, s = 60)
#            plt.scatter(scatter_scaling * cur_y_fit/np.max(cur_y_real), scatter_scaling * cur_y_real/np.max(cur_y_real), c = 'k', marker='.', alpha=1, s = 60)
    

#    plt.xlabel('Predicted Activity')
#    plt.ylabel('Activity')
    plt.xlim([min_of_each, max_of_each])
    plt.ylim([-.1, max_of_each])
    plt.xticks([-.5, 0, .5, 1])
    plt.yticks([0, .5, 1])
    
#    plt.title(cur_return_dict['Notes'])
    print(model_name)
    output_path = 'PlotOutputs/' + model_name + 'Layer' + str(layer_num) + '.png'
    ensure_dir(output_path)
    plt.savefig(output_path)
    plt.show()



## Visualize bottleneck activations


#### When the bottleneck is a single channel, it's nice to visualize it as an image and see how it varies with network architecture. 

In [None]:
should_do_full_matrix = False

model_pair_list = [];
#model_pair_list.append(['nobottleneck_1', 'nobottleneck_2'])
crop_amount = 2800; #can be 2400 when I don't store pandas objects
full_size_to_use = 2800;
batch_size_to_use = 200;


my_varexp_matr_dict = dict();


batch_size_to_use = 200;
full_size_to_use= 10000; 
ims_to_make = 30; 
var_exp_1_list = [];
var_exp_2_list = [];
fft_var_exp_1_list = []
fft_var_exp_2_list = [];

all_model_names = create_model_name_list([1], range(5), [1])

mega_im_bottleneck_dict = dict();

for (imodel_name, model_name) in enumerate(all_model_names):
    print(model_name)
    if(model_name.find('bnC_1_')>=0):
        image_act = load_layer_activations_from_batch_files(model_name, 0, pickle_dir, batch_size=batch_size_to_use, full_size=full_size_to_use, verbose= False)
        bottleneck_act = load_layer_activations_from_batch_files(model_name, 3, pickle_dir, batch_size=batch_size_to_use, full_size=full_size_to_use, verbose= False)
        (fft_act1_exp, fft_act1_var, fft_act2_exp, fft_act2_var) = dft_ccas.fourier_var_exp(image_act, bottleneck_act); 
        
        reshaped_im_act = image_act.reshape((full_size_to_use, 1024))
        reshaped_bn_act = bottleneck_act.reshape((full_size_to_use, 1024))
        (act1_exp, act1_var, act2_exp, act2_var) = cca_core.frac_variance_explained(reshaped_im_act.T, reshaped_bn_act.T); 
        
        var_exp_1_list.append(act1_exp/act1_var)
        fft_var_exp_1_list.append(fft_act1_exp/fft_act1_var)
        var_exp_2_list.append(act2_exp/act2_var)
        fft_var_exp_2_list.append(fft_act2_exp/fft_act2_var)
        
        
        
        
        for i_image in range(ims_to_make):
            cur_image = image_act[i_image, ::].reshape((32,32))
            cur_act = bottleneck_act[i_image, ::].reshape((32,32))
            cur_image = cur_image/np.max(cur_image);
            cur_act = cur_act/np.max(cur_act)
            cat_im = np.concatenate([cur_image, cur_act], axis = 1)
            cat_im_path = 'BottleneckVisualizations/' + model_name + '/Im' + str(i_image) + '.png' 
            ensure_dir(cat_im_path)
            plt.imsave(cat_im_path, cat_im, cmap = 'gray')
            mega_im_bottleneck_dict[(i_image, 0)] = cur_image;
            mega_im_bottleneck_dict[(i_image, 5-imodel_name)] = cur_act;
            

vert_cat_im_list = []
for i_image in range(ims_to_make):
    cat_im_list = []
    for i_grid in range(6):
        cat_im_list.append(mega_im_bottleneck_dict[(i_image, i_grid)])
    cat_im = np.concatenate(cat_im_list, axis = 0)
    plt.imshow(cat_im)
    plt.show()
    vert_cat_im_list.append(cat_im)
vert_cat_im = np.concatenate(vert_cat_im_list, axis = 1)
vert_cat_im_path = 'BottleneckVisualizations/' + 'VertCat.png' 
plt.imsave(vert_cat_im_path, vert_cat_im, cmap = 'gray')

            
            
print('Finished!')