In [None]:
import numpy as np
import torch
import torch.nn as nn
import h5py as h5
import os
import sys
import pickle
sys.path.append("../")
from models import PracticalBNCNN, NormedBNCNN, DalesBNCNN, DalesSSCNN, SSCNN, BNCNN, PracticalBNCNN, DalesHybrid, DalesSkipBNCNN, SkipBNBNCNN
#import metrics
import matplotlib.pyplot as plt
from utils.deepretina_loader import loadexpt
from utils.physiology import Physio
import utils.intracellular as intracellular
import utils.batch_compute as bc
import utils.retinal_phenomena as rp
import utils.stimuli as stimuli
import pyret.filtertools as ft
import scipy
import re
import pickle
from tqdm import tqdm
import gc
import resource
import time
import math

def normalize(x):
    return (x-x.mean())/(x.std()+1e-7)

def retinal_phenomena_figs(bn_cnn):
    rp.step_response(bn_cnn)
    rp.osr(bn_cnn)
    rp.reversing_grating(bn_cnn)
    rp.contrast_adaptation(bn_cnn, .35, .05)
    rp.motion_anticipation(bn_cnn)
    
#If you want to use stimulus that isnt just boxes
def prepare_stim(stim, stim_type):
    if stim_type == 'boxes':
        return 2*stim - 1
    elif stim_type == 'flashes':
        stim = stim.reshape(stim.shape[0], 1, 1)
        return np.broadcast_to(stim, (stim.shape[0], 38, 38))
    elif stim_type == 'movingbar':
        stim = block_reduce(stim, (1,6), func=np.mean)
        stim = pyret.stimulustools.upsample(stim.reshape(stim.shape[0], stim.shape[1], 1), 5)[0]
        return np.broadcast_to(stim, (stim.shape[0], stim.shape[1], stim.shape[1]))
    elif stim_type == 'lines':
        stim_averaged = np.apply_along_axis(lambda m: np.convolve(m, 0.5*np.ones((2,)), mode='same'), 
                                            axis=1, arr=stim)
        stim = stim_averaged[:,::2]
        # now stack stimulus to convert 1d to 2d spatial stimulus
        return stim.reshape(-1,1,stim.shape[-1]).repeat(stim.shape[-1], axis=1)
    else:
        print("Invalid stim type")
        assert False
    
def index_of(arg, arr):
    for i in range(len(arr)):
        if arg == arr[i]:
            return i
    return -1

In [None]:
DEVICE = torch.device("cuda:0")
torch.cuda.empty_cache()

In [None]:
#Load data
# num_pots stores the number of cells per stimulus
# mem_pots stores the membrane potential
# psst, you can find the "data" folder in /home/grantsrb on deepretina server
# psssst, note the additional ../ added to each path in files

files = ['../data/bipolars_late_2012.h5', '../data/bipolars_early_2012.h5', '../data/amacrines_early_2012.h5', '../data/amacrines_late_2012.h5', '../data/horizontals_early_2012.h5', '../data/horizontals_late_2012.h5']
files = ["../" + name for name in files]
file_ids = []
for f in files:
    file_ids.append(re.split('_|\.', f)[0])
filter_length = 40
window_size = 2
num_pots = []
stims = dict()
mem_pots = dict()
keys_to_use = {"boxes"}
for fi in files:
    with h5.File(fi, 'r') as f:
        for k in f.keys():
            if k in keys_to_use:
                if k not in stims:
                    stims[k] = []
                if k not in mem_pots:
                    mem_pots[k] = []
                try:
                    stims[k].append(prepare_stim(np.asarray(f[k+'/stimuli']), k))
                    mem_pots[k].append(np.asarray(f[k]['detrended_membrane_potential'])[:, filter_length:])
                except:
                    print("stim error at", k)
        num = np.array(f['boxes/detrended_membrane_potential'].shape[0])
        num_pots.append(num)

In [None]:
grand_folder = "bncnnMultiData"
exp_folder = "../training_scripts/"+grand_folder
_, model_folders, _ = next(os.walk(exp_folder))
for i,f in enumerate(model_folders):
    model_folders[i] = grand_folder + "/" + f

In [None]:
model_folders = sorted(model_folders, key=lambda x: (x.split("dataset")[-1].split("_")[0], x.split("stim_type")[-1]))
print("\n".join(model_folders))

In [None]:
file = "../training_scripts/"+model_folders[0]+"/test_epoch_0.pth"
try:
    with open(file, "rb") as fd:
        temp = torch.load(fd)
except:
    pass

temp['model']

In [None]:
conv_layers = ['sequential.2', 'sequential.8']

In [None]:
gc.collect()
max_mem_used = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
print("Memory Used: {:.2f} mb".format(max_mem_used / 1024))

## Look at model performances

In [None]:
n_epochs = 250
best_folder_by_loss = ""
best_loss = 100
best_folder_by_val_loss = ""
best_val_loss = 100
best_folder_by_val_acc = ""
best_val_acc = -100
best_folder_by_test_acc = ""
best_test_acc = -100
best_folder_by_intr_cor = ""
best_intr_cor = -1

results_file_name = grand_folder + "_analysis_results.txt"
results_file = open(results_file_name, 'a')
batch_compute_size = 500

model_stats = dict()
intraneuron_cors = dict()
test_accs = dict()
# load the losses
print("Using layers:", " and ".join(conv_layers))
for folder in model_folders:
    stats = dict()
    starttime = time.time()
    losses = []
    val_losses = []
    val_accs = []
    for i in range(n_epochs):
        file = "../training_scripts/"+folder+"/test_epoch_{0}.pth".format(i)
        try:
            with open(file, "rb") as fd:
                temp = torch.load(fd)
            losses.append(temp['loss'])
            val_losses.append(temp['val_loss'])
            val_accs.append(temp['val_acc'])
        except:
            break
    bn_cnn = temp['model']
    bn_cnn = bn_cnn.to(DEVICE)
    bn_cnn.eval()
    print("Folder:", folder)
    results_file.write(folder + "\n")
    print("Train Loss:", losses[-1])
    stats['TrainLoss'] = losses[-1]
    results_file.write("Train Loss:"+ str(losses[-1]) + "\n")
    print("Val Loss:", val_losses[-1])
    stats['ValLoss'] = val_losses[-1]
    results_file.write("Val Loss:"+ str(val_losses[-1]) + "\n")
    print("Val Acc:", val_accs[-1])
    stats['ValAcc'] = val_accs[-1]
    results_file.write("Val Acc:"+ str(val_accs[-1]) + "\n")
    if(math.isnan(losses[-1]) or math.isnan(val_losses[-1]) or math.isnan(val_accs[-1])):
        print("NaN results, continuing...\n\n\n\n")
        results_file.write("NaN results, continuing...\n\n\n\n")
        continue
    
    cells = "all"
    dataset = folder.split("dataset")[-1].split("_")[0]
    stats['dataset'] = dataset
    stim_type = folder.split("stim_type")[-1].split("_")[0]
    stats['stim_type'] = stim_type
    try:
        norm_stats = [temp['norm_stats']['mean'], temp['norm_stats']['std']]
    except:
        norm_stats = [temp['data_norm_stats']['mean'], temp['data_norm_stats']['std']]
    test_data = loadexpt(dataset,cells,stim_type,'test',40,0, norm_stats=norm_stats)
    test_x = torch.from_numpy(test_data.X)
    
    model_response = bc.batch_compute_model_response(test_data.X, bn_cnn, batch_compute_size, insp_keys=set(conv_layers))    
    avg_test_acc = np.mean([scipy.stats.pearsonr(model_response['output'][:, i], test_data.y[:, i])[0] for i in range(len(cells))])
    test_accs[folder] = avg_test_acc
    stats['TestAcc'] = avg_test_acc
    if math.isnan(avg_test_acc):
        print("NaN results, continuing...\n\n\n\n")
        results_file.write("NaN results, continuing...\n\n\n\n")
        continue
    print("\nFinal Test Acc:", avg_test_acc)
    results_file.write("Final Test Acc:"+ str(avg_test_acc) + "\n")
    with open("../training_scripts/"+folder+"/hyperparams.txt", 'a') as f:
        f.write("\nTest Ganglion Cell Correlation: " + str(avg_test_acc))
    
    # Plot loss curve and response snippet
    plt.plot(normalize(model_response['output'][:400, 0]))
    plt.plot(normalize(test_data.y[:400,0]), alpha=.7)
    plt.legend(["model", "data"])
    plt.title("Firing Rate")
    plt.show()
    plt.plot(losses)
    plt.plot(val_losses)
    plt.legend(["TrainLoss", "ValLoss"])
    plt.title("Loss Curves")
    plt.show()
    retinal_phenomena_figs(bn_cnn)
    plt.show()
    
    if avg_test_acc < .5 or losses[-1] > 1:
        print("Skipping further analysis due to poor results...\n\n\n\n")
        results_file.write("Skipping further analysis due to poor results...\n\n\n\n")
        continue
    print("Calculating model responses...\n")
    # Computes the model responses for each stimulus 
    # and interneuron type labels y_true (0 for bipolar, 1 for amacrine, 2 for horizontal)
    y_true = []
    filter_length = 40
    model_responses = dict()
    cell_type_keys = ['bipolar', 'amacrine', 'horizontal']
    for i in tqdm(range(len(files))):
        file_name = files[i]
        if cell_type_keys[0] in file_name:
            for j in range(num_pots[i]):
                y_true.append(0)
        elif cell_type_keys[1] in file_name:
            for j in range(num_pots[i]):
                y_true.append(1)
        elif cell_type_keys[2] in file_name:
            for j in range(num_pots[i]):
                y_true.append(2)
        for k in stims.keys():
            stim = stims[k][i]
            padded_stim = intracellular.pad_to_edge(scipy.stats.zscore(stim))
            if k not in model_responses:
                model_responses[k] = []
            model_responses[k].append(bc.batch_compute_model_response(stimuli.concat(padded_stim),
                                                                      bn_cnn,batch_compute_size, 
                                                                      insp_keys=set(conv_layers)))
            # Reshape potentially flat layers
            for j,cl in enumerate(conv_layers):
                if len(model_responses[k][-1][cl].shape) <= 2:
                    try:
                        model_responses[k][-1][cl] = model_responses[k][-1][cl].reshape((-1,8,36,36))
                    except:
                        model_responses[k][-1][cl] = model_responses[k][-1][cl].reshape((-1,8,26,26))
    
    # uses classify to get the most correlated cell/layer/subtype for each interneuron recording. 
    # Stored in all_cell_info. y_pred does a baseline "classification": record the convolutional 
    # layer that the most correlated cell is in.
    # See intracellular.py for more info
    # This takes a really long time to run. 
    print("Calculating intercellular correlations...\n")
    all_cell_info = dict()
    info_by_type = dict()
    y_pred = dict()
    for i in tqdm(range(len(files))):
        interneuron = cell_type_keys[i//2]
        for k in stims.keys():
            model_response = model_responses[k][i]
            stim = stims[k][i]
            for j in range(mem_pots[k][i].shape[0]):
                potential = mem_pots[k][i][j]
                cell_info = intracellular.classify(potential, model_response, stim.shape[0], 
                                                   layer_keys=conv_layers)
                #layer, channel,(row, col), cor_coef = cell_info
                if k not in all_cell_info:
                    info_by_type[k] = dict()
                    all_cell_info[k] = []
                    y_pred[k] = []
                all_cell_info[k].append(cell_info)
                y_pred[k].append(index_of(cell_info[0], conv_layers))
    
    stats['all_cell_info'] = all_cell_info
    intraneuron_cors[folder] = all_cell_info
            
    # Average intracellular correlation. RIP.
    avg_intr_cor = np.mean(np.asarray([[all_cell_info[k][i][-1] for i in range(len(all_cell_info[k]))] for k in all_cell_info.keys()]))
    print("Mean intracellular:", avg_intr_cor)
    stats['IntrCor'] = avg_intr_cor
    results_file.write("Mean intracellular:" + str(avg_intr_cor) + "\n")
    std = np.std(np.asarray([[all_cell_info[k][i][-1] for i in range(len(all_cell_info[k]))] for k in all_cell_info.keys()]))
    print("Std intracellular:", std)
    m = np.min(np.asarray([[all_cell_info[k][i][-1] for i in range(len(all_cell_info[k]))] for k in all_cell_info.keys()]))
    print("Min intracellular:", m)
    results_file.write("Min intracellular:" + str(m) + "\n")
    m = np.max(np.asarray([[all_cell_info[k][i][-1] for i in range(len(all_cell_info[k]))] for k in all_cell_info.keys()]))
    print("Max intracellular:", m)
    results_file.write("Max intracellular:" + str(m) + "\n")
    
    if avg_intr_cor > best_intr_cor:
        best_intr_cor = avg_intr_cor
        best_folder_by_intr_cor = folder
    
    stim_type = 'boxes'
    # Make example correlation map
    model_response = model_responses[stim_type][-1]
    potential = mem_pots[stim_type][-1][-1]
    layer, k, (i,j), r = all_cell_info[stim_type][-1]
    print("Layer", layer, "correlation map")
    plt.imshow(intracellular.correlation_map(potential, model_response[layer][:, k]))
    plt.show()

    layer_dict = {} # Tracks the number of cells that were classified to each layer in conv_layers
    # Tally layers for maximally correlated cell
    for i in range(len(y_true)):
        if y_true[i] not in layer_dict:
            layer_dict[y_true[i]] = [0 for i in range(len(conv_layers))]
        for k in y_pred.keys():
            layer_dict[y_true[i]][y_pred[k][i]] += 1

    width = 0.5
    lkeys = list(layer_dict.keys())
    ind = np.arange(0,len(conv_layers))
    for i,k in enumerate(lkeys):
        plt.bar(ind, [count for count in layer_dict[k]], width)
        plt.xticks(ind,conv_layers)
        print(cell_type_keys[i])
        plt.title("Layer of unit with max correlation")
        plt.show()
    
    stimulus_num = 3
    filter_length = 40
    for type_key in stims.keys():
        if type_key == "flashes":
            continue
        stimulus = stims[type_key][stimulus_num]
        # Plot the receptive field for a model cell
        for i,cl in enumerate(conv_layers):
            model_cell_response = model_responses[type_key][stimulus_num][cl][:, 1, 15, 15]
            print("Receptive field of", type_key,"model cell in Layer", i)
            rc_model, lags_model = ft.revcorr(scipy.stats.zscore(stimulus)[filter_length:], model_cell_response, 
                                              nsamples_before=0, nsamples_after=filter_length)
            spatial_model, temporal_model = ft.decompose(rc_model)
            img = plt.imshow(spatial_model, cmap = 'seismic', clim=[-np.max(abs(spatial_model)), 
                                                                   np.max(abs(spatial_model))])
            plt.show()
    
    gc.collect()
    max_mem_used = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
    print("Memory Used: {:.2f} memory".format(max_mem_used / 1024))
    print("Completed in", time.time()-starttime, "seconds")
    print("\n\n\n\n")
    results_file.write("\n\n\n\n")
    model_stats[folder] = stats
                    
print("Best by validation loss:", best_folder_by_val_loss)
print("Best by training loss:", best_folder_by_loss)
print("Best by val accuracy:", best_folder_by_val_acc)
print("Best by test accuracy:", best_folder_by_test_acc)
print("Best by intracellular correlation:", best_folder_by_intr_cor)
results_file.close()


In [None]:
def get_best_models(model_stats, metric):
    best_models = dict()
    for folder in model_stats.keys():
        stats = model_stats[folder]
        if stats['dataset'] not in best_models:
            best_models[stats['dataset']] = dict()
        if stats['stim_type'] not in best_models[stats['dataset']]:
            best_models[stats['dataset']][stats['stim_type']] = folder
        best_folder = best_models[stats['dataset']][stats['stim_type']]
        if model_stats[best_folder][metric] < stats[metric]:
            best_models[stats['dataset']][stats['stim_type']] = folder
    return best_models

## Performance on models best by test acc

In [None]:
best_folders = get_best_models(model_stats, "TestAcc")

In [None]:
folders = []
for k in best_folders.keys():
    folders += [best_folders[k][kk] for kk in best_folders[k].keys()]

In [None]:
folders

In [None]:
print("Avg Test Acc:", np.mean([model_stats[folder]['TestAcc'] for folder in folders]))

In [None]:
keys = list(best_folders.keys())
print("Nat Scenes:", np.mean([model_stats[best_folders[key]['naturalscene']]['TestAcc'] for key in keys]))

In [None]:
keys = list(best_folders.keys())
print("Whitenoise:", np.mean([model_stats[best_folders[key]['whitenoise']]['TestAcc'] for key in keys]))

In [None]:
avg_cors = {k:[] for k in cell_type_keys} # cell_type_keys are defined above as ['bipolar', 'amacrine', 'horizontal']
for dataset in best_folders.keys():
    for stim_type in best_folders[dataset].keys():
        all_cell_info = model_stats[best_folders[dataset][stim_type]]['all_cell_info']
        for intr_stim_type in all_cell_info.keys(): # each stim type for the intraneuron recordings
            for j in range(len(all_cell_info[intr_stim_type])):
                max_cor = all_cell_info[intr_stim_type][j][-1]
                avg_cors[cell_type_keys[y_true[j]]].append(max_cor)
for k in avg_cors:
    print("Mean", k, ":", np.mean(avg_cors[k]))
grand_avg = []
for k in avg_cors:
    grand_avg += [x for x in avg_cors[k]]
print("Mean across all:", np.mean(grand_avg))