In [None]:
import numpy as np
import torch
import torch.nn as nn
import h5py as h5
import os
import sys
import pickle
sys.path.append("../")
from models import PracticalBNCNN, NormedBNCNN, DalesBNCNN, DalesSSCNN, SSCNN, BNCNN, PracticalBNCNN, DalesHybrid, DalesSkipBNCNN, SkipBNBNCNN
#import metrics
import matplotlib.pyplot as plt
from utils.physiology import Physio
import utils.intracellular as intracellular
import utils.batch_compute as bc
import utils.retinal_phenomena as rp
import utils.stimuli as stimuli
from utils.sta_ascent import STAAscent
from utils.deepretina_loader import loadexpt
import pyret.filtertools as ft
import scipy
import scipy.cluster as cluster
import re
import pickle
from tqdm import tqdm
import gc
import resource
import time
import math

def normalize(x):
    return (x-x.mean())/(x.std()+1e-7)

def retinal_phenomena_figs(bn_cnn):
    fig = plt.figure(figsize=(10,10))
    rp.step_response(bn_cnn)
    rp.osr(bn_cnn)
    rp.reversing_grating(bn_cnn)
    rp.contrast_adaptation(bn_cnn, .35, .05)
    rp.motion_anticipation(bn_cnn)
    
#If you want to use stimulus that isnt just boxes
def prepare_stim(stimuli, stim_type):
    if stim_type == 'boxes':
        return stimuli
    elif stim_type == 'flashes':
        stim = stimuli.reshape(stimuli.shape[0], 1, 1)
        return np.broadcast_to(stim, (stim.shape[0], 38, 38))
    elif stim_type == 'movingbar':
        stim = block_reduce(stimuli, (1,6), func=np.mean)
        stim = pyret.stimulustools.upsample(stim.reshape(stim.shape[0], stim.shape[1], 1), 5)[0]
        return np.broadcast_to(stim, (stim.shape[0], stim.shape[1], stim.shape[1]))
    else:
        return None
    
def index_of(arg, arr):
    for i in range(len(arr)):
        if arg == arr[i]:
            return i
    return -1

In [None]:
DEVICE = torch.device("cuda:0")
torch.cuda.empty_cache()

In [None]:
grand_folder = "bncnn"
exp_folder = "../training_scripts/"+grand_folder
_, model_folders, _ = next(os.walk(exp_folder))
for i,f in enumerate(model_folders):
    model_folders[i] = grand_folder + "/" + f

In [None]:
model_folders = sorted(model_folders)
print("\n".join(model_folders))

In [None]:
folder = model_folders[0]
for i in range(250):
        file = "../training_scripts/"+folder+"/test_epoch_{0}.pth".format(i)
        try:
            with open(file, "rb") as fd:
                temp = torch.load(fd)
        except:
            break
temp['model']

In [None]:
conv_idxs = [0, 6]
bn_idxs = [2,8]
linear_idx = 11
linear_shape = (5,8,26,26)
bn_shapes = [(8,36,36), (8,26,26)]

In [None]:
cells = "all"
dataset = '15-10-07'
try:
    norm_stats = [temp['norm_stats']['mean'], temp['norm_stats']['std']]
except:
    norm_stats = [51.49175, 53.62663279042969]
test_data = loadexpt(dataset,cells,'naturalscene','test',40,0, norm_stats=norm_stats)
test_y = torch.FloatTensor(test_data.y)
test_x = torch.FloatTensor(test_data.X)
test_x.requires_grad=True
loss_fxn = torch.nn.PoissonNLLLoss()

## Model Visualizations

In [None]:
batch_size = 500 # For grad clustering
kmeans_k = 8
sta_ascent_obj = STAAscent()
stim_shape = [40, 50, 50]

for folder in model_folders:
    print("Evaluating", folder)
    for i in range(300):
        file = "../training_scripts/"+folder+"/test_epoch_{0}.pth".format(i)
        try:
            with open(file, "rb") as fd:
                temp = torch.load(fd)
        except:
            break
    model = temp['model']
    model = model.to(DEVICE)
    model.eval()
    # Retinal Phenomena
#     retinal_phenomena_figs(model)
    
    # Conv Filter visualizations
    for k,idx in enumerate(conv_idxs):
        module = temp['model'].sequential[idx]
        ps = list(module.parameters())
        p = ps[0].cpu().detach().numpy()
        fig=plt.figure(figsize=(18, 10), dpi= 80, facecolor='w', edgecolor='k')
        fig.suptitle("Conv " + str(k+1) + " Filters", fontsize=16)
        for i in range(p.shape[0]):
            std = np.std(p[i])
            mean = np.mean(p[i])
            vmin = mean - 4*std
            vmax = mean + 4*std
            n_slices = 8
            for j in range(n_slices):
                plt.subplot(p.shape[0],n_slices,1 + i*n_slices+j)
                plt.imshow(p[i,int(p.shape[1]*j/n_slices),:,:], vmin=vmin, vmax=vmax)
        plt.show()

        # Rank decomp
        fig = plt.figure(figsize=(18,6))
        fig.suptitle("Rank 1 Decomp of each filter", fontsize=16)
        for i in range(p.shape[0]):
            spatial_model, temporal_model = ft.decompose(p[i])
            plt.subplot(2,p.shape[0], i+1)
            plt.imshow(spatial_model, cmap = 'seismic', clim=[-np.max(abs(spatial_model)),
                                                                   np.max(abs(spatial_model))])
            plt.subplot(2,p.shape[0], i+1 + p.shape[0])
            plt.plot(temporal_model)
        plt.show()
    
    # Receptive field samples
    filter_length = 40
    conv_layers = ["sequential." + str(x) for x in conv_idxs]
    stimulus = test_data.X[:1000]
    unit_step = 4
    model_response = bc.batch_compute_model_response(stimulus, model, batch_size, 
                                                            insp_keys=set(conv_layers))
    # Plot the receptive fields for model cells
    for cl in conv_layers:
        for c in range(model_response[cl].shape[1]):
            fig=plt.figure(figsize=(18, 18), dpi= 80, facecolor='w', edgecolor='k')
            fig.suptitle(cl + " Filter " + str(c), fontsize=16)
            for row in range(0,model_response[cl].shape[2],unit_step):
                for col in range(0,model_response[cl].shape[3],unit_step):
                    plt.subplot(model_response[cl].shape[2]//unit_step, model_response[cl].shape[3]//unit_step, 
                                                row*model_response[cl].shape[3] + col +1)
                    model_cell_response = model_response[cl][:, c, row, col]
                    rc_model, lags_model = ft.revcorr(stimulus[:,0], model_cell_response, 
                                                      nsamples_before=0, nsamples_after=filter_length)
                    spatial_model, temporal_model = ft.decompose(rc_model)
                    img = plt.imshow(spatial_model.squeeze(), cmap = 'seismic', clim=[-np.max(abs(spatial_model)), 
                                                                       np.max(abs(spatial_model))])
        plt.show()
    
    ## Batchnorm vis
    for k, idx in enumerate(bn_idxs):
        module = temp['model'].sequential[idx]
        ps = list(module.named_parameters())
        p = ps[0][1].cpu().detach().numpy().reshape(bn_shapes[k])
        fig=plt.figure(figsize=(18, 4), dpi= 80, facecolor='w', edgecolor='k')
        fig.suptitle("BatchNorm " + str(k+1) + " Filters", fontsize=16)
        for i in range(p.shape[0]):
            std = np.std(p[i])
            mean = np.mean(p[i])
            vmin = mean - 1.9*std
            vmax = mean + 1.9*std
            plt.subplot(1,p.shape[0], 1+i)
            plt.imshow(p[i], vmin=vmin, vmax=vmax)
        plt.show()
    
    ## Linear Vis
    module = temp['model'].sequential[linear_idx]
    ps = list(module.named_parameters())
    p = ps[0][1].cpu().detach().numpy().reshape(linear_shape)
    fig=plt.figure(figsize=(18, 10), dpi= 80, facecolor='w', edgecolor='k')
    fig.suptitle("Linear Filters", fontsize=16)
    std = np.std(p)
    mean = np.mean(p)
    vmin = mean - 1.9*std
    vmax = mean + 1.9*std
    for j in range(linear_shape[0]):        
        for i in range(linear_shape[1]):
            plt.subplot(linear_shape[0],linear_shape[1], j*linear_shape[1]+i+1)
            plt.imshow(p[j,i], vmin=vmin, vmax=vmax)
    plt.show()
    
    ## Gradient cluster
    for i in tqdm(range(0,test_x.shape[0], batch_size)):
        batch_x = test_x[i:i+batch_size].to(DEVICE)
        batch_y = test_y[i:i+batch_size].to(DEVICE)
        preds = model(batch_x)
        loss = loss_fxn(preds, batch_y)
        loss.backward()
    test_x.grad.data /= int(test_x.shape[0]//batch_size)
    grad = test_x.grad.data
    grad = grad.view(grad.shape[0], -1).detach().cpu().numpy()
    whitened = cluster.vq.whiten(grad)
    print("Beginning Cluster")
    clust,distort = cluster.vq.kmeans(whitened, kmeans_k)
    clust = clust.reshape(tuple([-1] + stim_shape))
    fig=plt.figure(figsize=(18, 10), dpi= 80, facecolor='w', edgecolor='k')
    fig.suptitle("Gradient Cluster with " + str(kmeans_k)+" means", fontsize=16)
    for i in range(kmeans_k):
        std = np.std(clust[i])
        mean = np.mean(clust[i])
        vmin = mean - 1.9*std
        vmax = mean + 1.9*std
        n_slices = 8
        for j in range(n_slices):
            plt.subplot(kmeans_k,n_slices,1 + i*n_slices+j)
            plt.imshow(clust[i,int(stim_shape[0]*j/n_slices),:,:], vmin=vmin, vmax=vmax)
    plt.show()
    ## Rank decomp
    fig = plt.figure(figsize=(18,6))
    fig.suptitle("Rank 1 Decomp of clusters", fontsize=16)
    for i in range(clust.shape[0]):
        spatial_model, temporal_model = ft.decompose(clust[i])
        plt.subplot(2,clust.shape[0], i+1)
        plt.imshow(spatial_model, cmap = 'seismic', clim=[-np.max(abs(spatial_model)),
                                                               np.max(abs(spatial_model))])
        plt.subplot(2,clust.shape[0], i+1 + clust.shape[0])
        plt.plot(temporal_model)
    plt.show()

    ## Give STA Ascent image
    stas = []
    sta_layer = "sequential." + str(linear_idx)
    print("Ascending STAs")
    for gang_cell in range(test_y.shape[-1]):
        sta_img = sta_ascent_obj.sta_ascent(model, sta_layer, units=[gang_cell], n_epochs=10000)
        stas.append(sta_img)
        print(gang_cell, "/", test_y.shape[-1], end='\r')
    stas = np.asarray(stas).squeeze()
    fig=plt.figure(figsize=(18, 10), dpi= 80, facecolor='w', edgecolor='k')
    fig.suptitle("STA Ascent Image", fontsize=16)
    for i in range(len(stas)):
        std = np.std(stas[i])
        mean = np.mean(stas[i])
        vmin = mean - 1.9*std
        vmax = mean + 1.9*std
        n_slices = 8
        for j in range(n_slices):
            plt.subplot(len(stas),n_slices, 1 + i*n_slices+j)
            plt.imshow(stas[i,int(stim_shape[0]*j/n_slices),:,:], vmin=vmin, vmax=vmax)
    plt.show()
    # Rank decomp
    plt.figure(figsize=(18,6))
    fig.suptitle("Rank 1 Decomp of STA Ascent", fontsize=16)
    for i in range(stas.shape[0]):
        spatial_model, temporal_model = ft.decompose(stas[i])
        plt.subplot(2,stas.shape[0], i+1)
        plt.imshow(spatial_model, cmap = 'seismic', clim=[-np.max(abs(spatial_model)),
                                                               np.max(abs(spatial_model))])
        plt.subplot(2,stas.shape[0], i+1 + stas.shape[0])
        plt.plot(temporal_model)
    plt.show()
        
    print("\n\n\n\n\n")