In [None]:
import numpy as np
import torch
import torch.nn as nn
import h5py as h5
import os
import sys
import pickle
from torchdeepretina.models import *
#import metrics
import matplotlib.pyplot as plt
from   torchdeepretina.physiology import Physio
import torchdeepretina.intracellular as intracellular
import torchdeepretina.batch_compute as bc
import torchdeepretina.retinal_phenomena as rp
import torchdeepretina.stimuli as stimuli
from   torchdeepretina.sta_ascent import STAAscent
from   torchdeepretina.deepretina_loader import loadexpt
import pyret.filtertools as ft
import scipy
import scipy.cluster as cluster
import re
import pickle
from tqdm import tqdm
import gc
import resource
import time
import math

def normalize(x):
    return (x-x.mean())/(x.std()+1e-7)

def retinal_phenomena_figs(bn_cnn):
    fig = plt.figure(figsize=(10,10))
    rp.step_response(bn_cnn)
    rp.osr(bn_cnn)
    rp.reversing_grating(bn_cnn)
    rp.contrast_adaptation(bn_cnn, .35, .05)
    rp.motion_anticipation(bn_cnn)
    
#If you want to use stimulus that isnt just boxes
def prepare_stim(stim, stim_type):
    if stim_type == 'boxes':
        return 2*stim - 1
    elif stim_type == 'flashes':
        stim = stim.reshape(stim.shape[0], 1, 1)
        return np.broadcast_to(stim, (stim.shape[0], 38, 38))
    elif stim_type == 'movingbar':
        stim = block_reduce(stim, (1,6), func=np.mean)
        stim = pyret.stimulustools.upsample(stim.reshape(stim.shape[0], stim.shape[1], 1), 5)[0]
        return np.broadcast_to(stim, (stim.shape[0], stim.shape[1], stim.shape[1]))
    elif stim_type == 'lines':
        stim_averaged = np.apply_along_axis(lambda m: np.convolve(m, 0.5*np.ones((2,)), mode='same'), 
                                            axis=1, arr=stim)
        stim = stim_averaged[:,::2]
        # now stack stimulus to convert 1d to 2d spatial stimulus
        return stim.reshape(-1,1,stim.shape[-1]).repeat(stim.shape[-1], axis=1)
    else:
        print("Invalid stim type")
        assert False
    
def index_of(arg, arr):
    for i in range(len(arr)):
        if arg == arr[i]:
            return i
    return -1

def load_model(folder, pth):
    hyps=get_hyps(folder)
    try:
        hyps['model_type'] = hyps['model_type'].split(".")[-1].split("\'")[0].strip()
        hyps['model_type'] = globals()[hyps['model_type']]
        bn_cnn = hyps['model_type'](**temp['model_hyps'])
    except Exception as e:
        model_hyps = {"n_units":5,"noise":float(hyps['noise']),"bias":bool(hyps['bias'])}
        if "chans" in hyps:
            model_hyps['chans'] = [int(x) for x in 
                                   hyps['chans'].replace("[", "").replace("]", "").strip().split(",")]
        if "adapt_gauss" in hyps:
            model_hyps['adapt_gauss'] = hyps['adapt_gauss']
        fn_args = set(hyps['model_type'].__init__.__code__.co_varnames)
        for k in model_hyps.keys():
            if k not in fn_args:
                del model_hyps[k]
        bn_cnn = hyps['model_type'](**model_hyps)
    return bn_cnn

def get_hyps(folder):
    hyps = dict()
    with open(os.path.join("../training_scripts/", folder, "hyperparams.txt")) as f:
        for line in f:
            if "(" not in line and ")" not in line:
                splt = line.strip().split(":")
                if len(splt) > 1:
                    hyps[splt[0]] = splt[1]
    return hyps

In [None]:
DEVICE = torch.device("cuda:0")
torch.cuda.empty_cache()

In [None]:
#Load data
# num_pots stores the number of cells per stimulus
# mem_pots stores the membrane potential
# psst, you can find the "data" folder in /home/grantsrb on deepretina server
# psssst, note the additional ../ added to each path in files
root_path = "~/interneuron_data/"
files = ['bipolars_late_2012.h5', 'bipolars_early_2012.h5', 'amacrines_early_2012.h5', 
         'amacrines_late_2012.h5', 'horizontals_early_2012.h5', 'horizontals_late_2012.h5']
files = [os.path.expanduser(root_path + name) for name in files]
file_ids = []
for f in files:
    file_ids.append(re.split('_|\.', f)[0])
filter_length = 40
window_size = 2
num_pots = []
stims = dict()
mem_pots = dict()
keys_to_use = {"boxes"}
for fi in files:
    with h5.File(fi, 'r') as f:
        for k in f.keys():
            if k in keys_to_use:
                if k not in stims:
                    stims[k] = []
                if k not in mem_pots:
                    mem_pots[k] = []
                try:
                    stims[k].append(prepare_stim(np.asarray(f[k+'/stimuli']), k))
                    mem_pots[k].append(np.asarray(f[k]['detrended_membrane_potential'])[:, filter_length:])
                except:
                    print("stim error at", k)
        num = np.array(f['boxes/detrended_membrane_potential'].shape[0])
        num_pots.append(num)

In [None]:
grand_folder = "absbnbncnn"
exp_folder = "../training_scripts/"+grand_folder
_, model_folders, _ = next(os.walk(exp_folder))
for i,f in enumerate(model_folders):
    model_folders[i] = grand_folder + "/" + f

In [None]:
model_folders = sorted(model_folders)
print("\n".join(model_folders))

In [None]:
file = "../training_scripts/"+model_folders[0]+"/test_epoch_0.pth"
try:
    with open(file, "rb") as fd:
        temp = torch.load(fd)
except:
    pass

try:
    temp['model']
except Exception as e:
    model = load_model(model_folders[0], temp)
print(model)

In [None]:
conv_idxs = [0, 6]
bn_idxs = [2,8]
sta_idx = 6
linear_idx = 11
linear_shape = (5,24,26,26)
bn_shapes = [(8,36,36), (24,26,26)]

In [None]:
cells = "all"
dataset = '15-10-07'
try:
    norm_stats = [temp['norm_stats']['mean'], temp['norm_stats']['std']]
except:
    norm_stats = [51.49175, 53.62663279042969]
test_data = loadexpt(dataset,cells,'naturalscene','test',40,0, norm_stats=norm_stats)
test_y = torch.FloatTensor(test_data.y)
test_x = torch.FloatTensor(test_data.X)
test_x.requires_grad=True
loss_fxn = torch.nn.PoissonNLLLoss()

## Model Visualizations

In [None]:
batch_size = 500 # For grad clustering
kmeans_k = 8
sta_ascent_obj = STAAscent()
stim_shape = [40, 50, 50]

In [None]:
for folder in model_folders:
    print("Evaluating", folder)
    for i in range(300):
        file = "../training_scripts/"+folder+"/test_epoch_{0}.pth".format(i)
        try:
            with open(file, "rb") as fd:
                temp = torch.load(fd)
            data = temp
        except Exception as e:
            pass
    try:
        model = data['model']
    except Exception as e:
        model = load_model(folder, data)
    model.load_state_dict(data['model_state_dict'])
    model = model.to(DEVICE)
    model.eval()
    # Retinal Phenomena
#     retinal_phenomena_figs(model)
    
    # Conv Filter visualizations
    for k,idx in enumerate(conv_idxs[:1]):
        module = model.sequential[idx]
        ps = list(module.parameters())
        p = ps[0].cpu().detach().numpy()
        fig=plt.figure(figsize=(18, 10), dpi= 80, facecolor='w', edgecolor='k')
        fig.suptitle("Conv " + str(k+1) + " Filters", fontsize=16)
        for i in range(p.shape[0]):
            std = np.std(p[i])
            mean = np.mean(p[i])
            vmin = mean - 4*std
            vmax = mean + 4*std
            n_slices = 8
            for j in range(n_slices):
                plt.subplot(p.shape[0],n_slices,1 + i*n_slices+j)
                plt.imshow(p[i,int(p.shape[1]*j/n_slices),:,:], vmin=vmin, vmax=vmax)
        plt.show()

        # Rank decomp
        fig = plt.figure(figsize=(18,6))
        fig.suptitle("Rank 1 Decomp of each filter", fontsize=16)
        for i in range(p.shape[0]):
            spatial_model, temporal_model = ft.decompose(p[i])
            plt.subplot(2,p.shape[0], i+1)
            plt.imshow(spatial_model, cmap = 'seismic', clim=[-np.max(abs(spatial_model)),
                                                                   np.max(abs(spatial_model))])
            plt.subplot(2,p.shape[0], i+1 + p.shape[0])
            plt.plot(temporal_model)
        plt.show()

    
    # Receptive field samples
    filter_length = 40
    conv_layers = ["sequential." + str(x) for x in conv_idxs]
    stimulus_num = 2
    stimulus = stims['boxes'][stimulus_num]
    padded_stim = intracellular.pad_to_edge(scipy.stats.zscore(stimulus))
    unit_step = 8
    model_response = bc.batch_compute_model_response(stimuli.concat(padded_stim), model, batch_size, 
                                                            insp_keys=set(conv_layers))
    # Plot the receptive fields for model cells layer 1
    cl = conv_layers[0]
    for c in range(model_response[cl].shape[1]):
        figidx = 0
        fig=plt.figure(figsize=(18, 18), dpi= 80, facecolor='w', edgecolor='k')
        fig.suptitle(cl + " Filter " + str(c), fontsize=16)
        for row in range(0,model_response[cl].shape[2],unit_step):
            for col in range(0,model_response[cl].shape[3],unit_step):
                figidx += 1
                try:
                    plt.subplot(int(np.ceil(model_response[cl].shape[2]/unit_step)), 
                                int(np.ceil(model_response[cl].shape[3]/unit_step)), 
                                figidx)
                    model_cell_response = model_response[cl][:, c, row, col]
                    rc_model, lags_model = ft.revcorr(scipy.stats.zscore(stimulus[filter_length:]), model_cell_response, 
                                                      nsamples_before=0, nsamples_after=filter_length)

                    spatial_model, temporal_model = ft.decompose(rc_model)
                    img = plt.imshow(spatial_model.squeeze(), cmap = 'seismic', clim=[-np.max(abs(spatial_model)), 
                                                                       np.max(abs(spatial_model))])
                except Exception as e:
                    print(e)
                    print("model_response[cl].shape", model_response[cl].shape)

        plt.show()
    
    # Plot the receptive fields for model cells layer 2
    cl = conv_layers[1]
    s = intracellular.pad_to_edge(scipy.stats.zscore(stimulus[filter_length:]))
    for c in range(model_response[cl].shape[1]):
        for row in range(0,model_response[cl].shape[2],unit_step):
            figidx = 0
            n_cols = int(np.ceil(model_response[cl].shape[3]/unit_step))
            fig=plt.figure(figsize=(18, 7), dpi= 80, facecolor='w', edgecolor='k')
            fig.suptitle(cl + " Filter " + str(c), fontsize=16)
            for col in range(0,model_response[cl].shape[3],unit_step):
                figidx += 1
                try:
                    plt.subplot(2, n_cols, figidx)
                    model_cell_response = model_response[cl][:, c, row, col]

                    rc_model, lags_model = ft.revcorr(s, model_cell_response, 
                                                      nsamples_before=0, nsamples_after=filter_length)
                    spatial_model, temporal_model = ft.decompose(rc_model)
                    img = plt.imshow(spatial_model.squeeze(), cmap = 'seismic', clim=[-np.max(abs(spatial_model)), 
                                                                       np.max(abs(spatial_model))])
                    plt.subplot(2, n_cols, figidx+n_cols)                    
                    sta_img = sta_ascent_obj.sta_ascent(model, cl, units=[(c,row,col)], 
                                                        n_epochs=5000, constraint=.3)
                    spatial_model, temporal_model = ft.decompose(sta_img.squeeze())
                    img = plt.imshow(spatial_model.squeeze(), cmap="seismic", clim=[-np.max(abs(spatial_model)), 
                                                                       np.max(abs(spatial_model))])
                except Exception as e:
                    print(e)
                    print("model_response[cl].shape", model_response[cl].shape)

            plt.show()

    
    ## Batchnorm vis
    for k, idx in enumerate(bn_idxs):
        module = model.sequential[idx]
        ps = list(module.named_parameters())
        p = ps[0][1].cpu().detach().numpy().reshape(bn_shapes[k])
        fig=plt.figure(figsize=(18, 4), dpi= 80, facecolor='w', edgecolor='k')
        fig.suptitle("BatchNorm " + str(k+1) + " Filters", fontsize=16)
        for i in range(p.shape[0]):
            std = np.std(p[i])
            mean = np.mean(p[i])
            vmin = mean - 1.9*std
            vmax = mean + 1.9*std
            plt.subplot(1,p.shape[0], 1+i)
            plt.imshow(p[i], vmin=vmin, vmax=vmax)
        plt.show()
    
    ## Linear Vis
    module = model.sequential[linear_idx]
    ps = list(module.named_parameters())
    p = ps[0][1].cpu().detach().numpy().reshape(linear_shape)
    fig=plt.figure(figsize=(18, 10), dpi= 80, facecolor='w', edgecolor='k')
    fig.suptitle("Linear Filters", fontsize=16)
    std = np.std(p)
    mean = np.mean(p)
    vmin = mean - 1.9*std
    vmax = mean + 1.9*std
    for j in range(linear_shape[0]):        
        for i in range(linear_shape[1]):
            plt.subplot(linear_shape[0],linear_shape[1], j*linear_shape[1]+i+1)
            plt.imshow(p[j,i], vmin=vmin, vmax=vmax)
    plt.show()
    
    ## Gradient cluster
    for i in tqdm(range(0,test_x.shape[0], batch_size)):
        batch_x = test_x[i:i+batch_size].to(DEVICE)
        batch_y = test_y[i:i+batch_size].to(DEVICE)
        preds = model(batch_x)
        loss = loss_fxn(preds, batch_y)
        loss.backward()
    test_x.grad.data /= int(test_x.shape[0]//batch_size)
    grad = test_x.grad.data
    grad = grad.view(grad.shape[0], -1).detach().cpu().numpy()
    whitened = cluster.vq.whiten(grad)
    print("Beginning Cluster")
    clust,distort = cluster.vq.kmeans(whitened, kmeans_k)
    clust = clust.reshape(tuple([-1] + stim_shape))
    fig=plt.figure(figsize=(18, 10), dpi= 80, facecolor='w', edgecolor='k')
    fig.suptitle("Gradient Cluster with " + str(kmeans_k)+" means", fontsize=16)
    for i in range(kmeans_k):
        std = np.std(clust[i])
        mean = np.mean(clust[i])
        vmin = mean - 1.9*std
        vmax = mean + 1.9*std
        n_slices = 8
        for j in range(n_slices):
            plt.subplot(kmeans_k,n_slices,1 + i*n_slices+j)
            plt.imshow(clust[i,int(stim_shape[0]*j/n_slices),:,:], vmin=vmin, vmax=vmax)
    plt.show()
    ## Rank decomp
    fig = plt.figure(figsize=(18,6))
    fig.suptitle("Rank 1 Decomp of clusters", fontsize=16)
    for i in range(clust.shape[0]):
        spatial_model, temporal_model = ft.decompose(clust[i])
        plt.subplot(2,clust.shape[0], i+1)
        plt.imshow(spatial_model, cmap = 'seismic', clim=[-np.max(abs(spatial_model)),
                                                               np.max(abs(spatial_model))])
        plt.subplot(2,clust.shape[0], i+1 + clust.shape[0])
        plt.plot(temporal_model)
    plt.show()
        
    print("\n\n\n\n\n")