In [None]:
import numpy as np
import matplotlib.pyplot as plt
import keras.backend as K
%matplotlib inline

In [None]:
def deprocess_image(x):
    # normalize tensor: center on 0., ensure std is 0.1
    x -= x.mean()
    if (x.std() > 1e-5):
        x /= (x.std() + 1e-5)
    x *= 0.1

    # clip to [0, 1]
    x += 0.5
    x = np.clip(x, 0, 1)

    # convert to RGB array
    x *= 255
    if K.image_data_format() == 'channels_first':
        x = x.transpose((1, 2, 0))
    x = np.clip(x, 0, 255).astype('uint8')
    return x

In [None]:
#Outputs std of RF approximations across different random initializations
#Higher value indicates more "complex" cells (in the sense of simple vs. complex cells in V1)

brain_layers = 2
retina_width = 1

import os
noise_start = 0.0
noise_end = 0.0
retina_out_weight_reg = 0.0
retina_out_stride = 1
retina_hidden_channels = 32
task = 'classification'
filter_size = 9
retina_layers = 2
use_b = 1
actreg = 0.0
vvs_width = 32
epochs = 20
reg = 0.0
num_trials = 10
load_dir = os.path.join(os.getcwd(), 'saved_filters')
retina_out_width = 1
vvs_layers = 2

results_dict = {}
for layer_name in ['retina_2', 'vvs_1', 'vvs_2']:
  layer_result = []
  for trial in range(1, 1+num_trials):
    trial_result = {}
    for random_init in range(0, 10):
        trial_label = 'Trial'+str(trial)

        model_name = 'cifar10_type_'+trial_label+'_noise_start_'+str(noise_start)+'_noise_end_'+str(noise_end)+'_reg_'+str(reg)+'_retina_reg_'+str(retina_out_weight_reg)+'_retina_hidden_channels_'+str(retina_hidden_channels)+'_SS_'+str(retina_out_stride)+'_task_'+task+'_filter_size_'+str(filter_size)+'_retina_layers_'+str(retina_layers)+'_vvs_layers'+str(vvs_layers)+'_bias_'+str(use_b)+'_actreg_'+str(actreg)+'_retina_out_channels_'+str(retina_out_width)+'_vvs_width_'+str(vvs_width)+'_epochs_'+str(epochs)
        model_name = 'SAVED'+'_'+model_name
        filename = 'RI'+str(random_init)+'_'+model_name+'_'+str(layer_name)+'.npy'
        file_path = os.path.join(load_dir, filename)
        RFs = np.load(file_path)
        for filt in range(RFs.shape[0]):
            if np.max(RFs[filt]) - np.min(RFs[filt]) > 1e-4:
                RFs[filt] = deprocess_image(RFs[filt]) / 255.0
                if filt not in trial_result.keys():
                    trial_result[filt] = []
                trial_result[filt].append(RFs[filt].flatten())
    trial_avg_std = []
    for filt in range(RFs.shape[0]):
        if filt in trial_result.keys():
            trial_result_filt = np.array(trial_result[filt])
            trial_result_filt_std = np.mean(np.std(trial_result_filt, axis=0))
            layer_result.append(trial_result_filt_std)
  results_dict[layer_name] = np.array(layer_result)
        
for layer in ['retina_2', 'vvs_1', 'vvs_2']:
    layer_result = results_dict[layer].flatten()
    print(layer, np.mean(layer_result), 'plus or minus', 1.96 * np.std(layer_result) / np.sqrt(len(layer_result)))
