# Imports

In [None]:
import torch
import pickle
import imageio
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.animation as animation
import numpy as np
import time
import convis
# This is module used to produce spikes. It is important to have the latest version 0.6.4 installed with:
# pip install git+https://github.com/jahuth/convis.git


%matplotlib inline
%config InlineBackend.figure_format = 'retina' #Only needed for high resolution displays

In [None]:
mpl.rcParams['axes.edgecolor'] = 'white' 
mpl.rcParams['axes.labelcolor'] = 'white' 
mpl.rcParams['xtick.color'] = 'white' 
mpl.rcParams['ytick.color'] = 'white' 
mpl.rcParams['text.color'] = 'white' 
mpl.rcParams['axes.facecolor'] = '#111111'
mpl.rcParams['figure.max_open_warning'] = 0
#mpl.rcParams['lines.linewidth'] = .2

We will group cell types according to the size of their receptive field

In [None]:
small_cells = [1,14,16,19,21,23,24,26]
large_cells = [6,7,8,9,11,12,13,22,28,34]
medium_cells = [2,3,4,5,10,15,17,18,20,25,27,29,30,31,32,33,35,36,37,38,39]

## Help functions

In [None]:
def transform_data(data, batch_size):
    transformed = np.zeros((data.shape[0], batch_size, 1))
    for i in range(data.shape[0]):
        transformed[i, :, :] = data[i]

    transformed = torch.from_numpy(transformed)
    transformed = transformed.float()
    return transformed

def save(var, name):
    file = open(name + ".pkl", 'wb')
    pickle.dump(var, file)
    file.close()

def load(file):
    file = open("./" + file + ".pkl", 'rb')
    var = pickle.load(file)
    file.close()
    return var

def timeit(fun):
    def wrapper(*args, **kwargs):
        t1 = time.time()
        params = fun(*args, **kwargs)
        t2 = time.time()
        print("Time it took to run the function: {}".format(t2 - t1))
        return params
    return wrapper

def savemodel(model, name, root='./'):
    torch.save(model.state_dict(), root + name)

def loadmodel(model, name, root='./', cuda=True, gpu=0):
    #device = torch.device("cuda:{}".format(gpu) if cuda else "cpu")
    if cuda:
        model.load_state_dict(torch.load(root + file, map_location='cuda:{}'.format(gpu)))
        model.to(device)
        model.cuda(device)
    else:
        model.load_state_dict(torch.load(root + name, map_location='cpu'))

# Model
In order to import the pytorch trained models you first need to replicate the original class

In [None]:
class Lstmcell(nn.Module):
    def __init__(self, device, hidden_size=51, batch_size=4):
        super(Lstmcell, self).__init__()
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.device = device
        self.photo = nn.LSTMCell(1, self.hidden_size)
        self.bipol = nn.LSTMCell(self.hidden_size, 1)

    def init_weights(self):
        h_photo = torch.zeros(self.batch_size, self.hidden_size)
        c_photo = torch.zeros(self.batch_size, self.hidden_size)
        h_bipol = torch.zeros(self.batch_size, 1)
        c_bipol = torch.zeros(self.batch_size, 1)

        h_photo, c_photo, h_bipol, c_bipol  =  (h_photo.to(self.device), c_photo.to(self.device),
                                                h_bipol.to(self.device), c_bipol.to(self.device))
        return h_photo, c_photo, h_bipol, c_bipol

    def forward(self, stimulus):
        h_photo, c_photo, h_bipol, c_bipol = self.init_weights()
        output = torch.empty(stimulus.size())
        for i in range(stimulus.shape[0]):
            h_photo, c_photo = self.photo(stimulus[i], (h_photo, c_photo))
            h_bipol, c_bipol = self.bipol(h_photo, (h_bipol, c_bipol))
            output[i] = h_bipol
        return output

In [None]:
class spkLstmcell(nn.Module):
    def __init__(self, device, hidden_size=51, batch_size=4):
        super(Lstmcell, self).__init__()
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.device = device
        self.photo = nn.LSTMCell(1, self.hidden_size)
        self.bipol = nn.LSTMCell(self.hidden_size, 1)

    def init_weights(self):
        h_photo = torch.zeros(self.batch_size, self.hidden_size)
        c_photo = torch.zeros(self.batch_size, self.hidden_size)
        h_bipol = torch.zeros(self.batch_size, 1)
        c_bipol = torch.zeros(self.batch_size, 1)

        h_photo, c_photo, h_bipol, c_bipol  =  (h_photo.to(self.device), c_photo.to(self.device),
                                                h_bipol.to(self.device), c_bipol.to(self.device))
        return h_photo, c_photo, h_bipol, c_bipol

    def forward(self, stimulus):
        h_photo, c_photo, h_bipol, c_bipol = self.init_weights()
        output = torch.empty(stimulus.size())
        for i in range(stimulus.shape[0]):
            h_photo, c_photo = self.photo(stimulus[i], (h_photo, c_photo))
            h_bipol, c_bipol = self.bipol(h_photo, (h_bipol, c_bipol))
            output[i] = h_bipol
        return output

Example of loading one type of cell, for CPU uncomment/comment the corresponding lines.

In [None]:
device = torch.device("cuda:0")
#device = torch.device("cpu")
root = "./models/lstm/"
file = "bipolar_type_1_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_173563030.pt"
model_type1 = Lstmcell(device, batch_size=8)
#loadmodel(model_type1, file, root=root)
loadmodel(model_type1, file, root=root, cuda=False)

### Loading all cell types

In [None]:
files = ["bipolar_type_10_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_141844336.pt",
"bipolar_type_10_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_15283127.pt",
"bipolar_type_10_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_269304424.pt",
"bipolar_type_11_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_330601724.pt",
"bipolar_type_11_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_6552616.pt",
"bipolar_type_11_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_906165115.pt",
"bipolar_type_12_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_241280767.pt",
"bipolar_type_12_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_531998175.pt",
"bipolar_type_12_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_66948984.pt",
"bipolar_type_13_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_179780013.pt",
"bipolar_type_13_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_29671914.pt",
"bipolar_type_13_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_957957141.pt",
"bipolar_type_14_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1098300983.pt",
"bipolar_type_14_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_715986001.pt",
"bipolar_type_14_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_950670156.pt",
"bipolar_type_1_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_173563030.pt",
"bipolar_type_1_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_533793902.pt",
"bipolar_type_1_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_560106992.pt",
"bipolar_type_2_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1189464144.pt",
"bipolar_type_2_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1386305591.pt",
"bipolar_type_2_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1690116631.pt",
"bipolar_type_3_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_6251400.pt",
"bipolar_type_3_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_766344198.pt",
"bipolar_type_4_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1128190960.pt",
"bipolar_type_4_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1303619795.pt",
"bipolar_type_5_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_15785680.pt",
"bipolar_type_5_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_162832312.pt",
"bipolar_type_5_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_275131426.pt",
"bipolar_type_6_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1166721910.pt",
"bipolar_type_6_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_245239075.pt",
"bipolar_type_6_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_75266929.pt",
"bipolar_type_7_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_126872303.pt",
"bipolar_type_7_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_3808957.pt",
"bipolar_type_7_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_549864573.pt",
"bipolar_type_8_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_159725155.pt",
"bipolar_type_8_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_907167403.pt",
"bipolar_type_8_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_998040308.pt",
"bipolar_type_9_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_133931557.pt",
"bipolar_type_9_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_74535517.pt",
"bipolar_type_9_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_783677.pt",
"ganglionar_type_10_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_237669104.pt",
"ganglionar_type_10_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_411326535.pt",
"ganglionar_type_10_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_905665101.pt",
"ganglionar_type_11_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1239763412.pt",
"ganglionar_type_11_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_563475798.pt",
"ganglionar_type_11_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_82775029.pt",
"ganglionar_type_12_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_172090477.pt",
"ganglionar_type_12_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_381057631.pt",
"ganglionar_type_12_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_839874706.pt",
"ganglionar_type_13_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1241648508.pt",
"ganglionar_type_13_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1494485751.pt",
"ganglionar_type_13_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_165704750.pt",
"ganglionar_type_14_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_183174642.pt",
"ganglionar_type_14_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_491223730.pt",
"ganglionar_type_14_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_94689407.pt",
"ganglionar_type_15_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1177464148.pt",
"ganglionar_type_15_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_138207895.pt",
"ganglionar_type_15_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_249169987.pt",
"ganglionar_type_16_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1122989673.pt",
"ganglionar_type_16_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1612269384.pt",
"ganglionar_type_16_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_584220802.pt",
"ganglionar_type_17_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_160596147.pt",
"ganglionar_type_17_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_705127268.pt",
"ganglionar_type_17_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_989493600.pt",
"ganglionar_type_18_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_321926665.pt",
"ganglionar_type_18_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_6292319.pt",
"ganglionar_type_18_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_78265144.pt",
"ganglionar_type_19_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1175274272.pt",
"ganglionar_type_19_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_156834873.pt",
"ganglionar_type_19_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_329472619.pt",
"ganglionar_type_1_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_198326626.pt",
"ganglionar_type_1_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_364471185.pt",
"ganglionar_type_1_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_600241632.pt",
"ganglionar_type_20_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_413860435.pt",
"ganglionar_type_20_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_44144047.pt",
"ganglionar_type_20_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_9231279.pt",
"ganglionar_type_21_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1369609959.pt",
"ganglionar_type_21_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_239711620.pt",
"ganglionar_type_21_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_619669146.pt",
"ganglionar_type_22_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_380063816.pt",
"ganglionar_type_22_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_413063837.pt",
"ganglionar_type_22_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_9730860.pt",
"ganglionar_type_23_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1211777820.pt",
"ganglionar_type_23_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1482063350.pt",
"ganglionar_type_23_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_468143651.pt",
"ganglionar_type_24_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_117458979.pt",
"ganglionar_type_24_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_2085822589.pt",
"ganglionar_type_24_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_28653427.pt",
"ganglionar_type_25_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1089518756.pt",
"ganglionar_type_25_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1599144784.pt",
"ganglionar_type_25_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_164781825.pt",
"ganglionar_type_26_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_131237893.pt",
"ganglionar_type_26_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_172340461.pt",
"ganglionar_type_26_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_350523332.pt",
"ganglionar_type_27_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_479449101.pt",
"ganglionar_type_27_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_55482899.pt",
"ganglionar_type_27_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_84281034.pt",
"ganglionar_type_28_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_264065688.pt",
"ganglionar_type_28_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_463811580.pt",
"ganglionar_type_28_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_566859297.pt",
"ganglionar_type_29_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_152841965.pt",
"ganglionar_type_29_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_530384460.pt",
"ganglionar_type_29_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_87052115.pt",
"ganglionar_type_2_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_295277855.pt",
"ganglionar_type_2_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_36236986.pt",
"ganglionar_type_2_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_43875023.pt",
"ganglionar_type_30_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_109044538.pt",
"ganglionar_type_30_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_289226367.pt",
"ganglionar_type_30_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_556781329.pt",
"ganglionar_type_31_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_178961346.pt",
"ganglionar_type_31_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_331931885.pt",
"ganglionar_type_31_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_423348922.pt",
"ganglionar_type_32_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1213524776.pt",
"ganglionar_type_32_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_267392019.pt",
"ganglionar_type_32_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_438881270.pt",
"ganglionar_type_33_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1135858999.pt",
"ganglionar_type_33_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_268015804.pt",
"ganglionar_type_33_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_465011534.pt",
"ganglionar_type_34_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1369503192.pt",
"ganglionar_type_34_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1539395724.pt",
"ganglionar_type_34_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1974221233.pt",
"ganglionar_type_35_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_131716857.pt",
"ganglionar_type_35_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_944573925.pt",
"ganglionar_type_35_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_991169481.pt",
"ganglionar_type_36_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1442639811.pt",
"ganglionar_type_36_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1731006758.pt",
"ganglionar_type_36_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_992378617.pt",
"ganglionar_type_37_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_484059010.pt",
"ganglionar_type_37_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_608691745.pt",
"ganglionar_type_37_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_744266624.pt",
"ganglionar_type_38_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_247298880.pt",
"ganglionar_type_38_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_39931331.pt",
"ganglionar_type_38_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_535897221.pt",
"ganglionar_type_39_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1069691165.pt",
"ganglionar_type_39_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1760278679.pt",
"ganglionar_type_39_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_54671615.pt",
"ganglionar_type_3_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_2062632841.pt",
"ganglionar_type_3_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_532270609.pt",
"ganglionar_type_3_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_797123245.pt",
"ganglionar_type_4_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_100781008.pt",
"ganglionar_type_4_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_244493523.pt",
"ganglionar_type_4_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_402066829.pt",
"ganglionar_type_5_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1253775427.pt",
"ganglionar_type_5_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_198058435.pt",
"ganglionar_type_5_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_702858507.pt",
"ganglionar_type_6_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1571939542.pt",
"ganglionar_type_6_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1580361806.pt",
"ganglionar_type_6_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_549525671.pt",
"ganglionar_type_7_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1186972223.pt",
"ganglionar_type_7_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_1488177546.pt",
"ganglionar_type_7_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_177417729.pt",
"ganglionar_type_8_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_130857530.pt",
"ganglionar_type_8_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_222439868.pt",
"ganglionar_type_8_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_565791624.pt",
"ganglionar_type_9_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_102671336.pt",
"ganglionar_type_9_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_666754141.pt",
"ganglionar_type_9_net_lstm_hiddensize_51_epochs_30_batches_8_lr_0.001_68539956.pt",
]


In [None]:
models = {}
root = "./models/lstm/"
device = torch.device("cuda:0")
#device = torch.device("cpu")
for file in files:
    name = file[:file.find("net") - 1]
    models[name] = Lstmcell(device, batch_size=8)
    loadmodel(models[name], file, root=root, cuda=True)
    #loadmodel(models[name], file, root=root, cuda=False)

## Example for a user-created stimulus

In [None]:
frames = np.zeros((182, 288, 384))
for i, n in enumerate(range(200,382)):
    frame = imageio.imread('dataset/images/mall/EnterExitCrossingPaths1cor0{}.jpg'.format(n))
    frames[i] = frame[:,:,1]

Sample image from the series

In [None]:
plt.imshow(frames[0], cmap = 'Greys_r')

We will crop the images and use only the green channel

In [None]:
plt.imshow(frames[0,10:-1:1,10:300:1], cmap = 'Greys_r')

And downsample to use fewer neurons. We will use three levels of downsampling to emulate neurons with different size of receptive fields

In [None]:
plt.imshow(frames[0,44:244:4,40:252:4], cmap = 'Greys_r')

In [None]:
s_frames = frames[:120,28:,40:300:1]
m_frames = frames[:120,28:-1:2,40:300:2]
l_frames = frames[:120,28:-1:4,40:300:4]
s_frames.shape, m_frames.shape, l_frames.shape

### The following is very slow because we only have one network instance, so we iterate over pixels sequentally. We need to improve this

In [None]:
n = 3
pre = 6
resp_mov_l = np.zeros((l_frames.shape[0]+pre,l_frames.shape[1],l_frames.shape[2]))
for x in range(l_frames.shape[1]):
    for y in range(l_frames.shape[2]):
        #print('processing pixel x={} y={}'.format(x,y))
        stim = np.concatenate((np.ones(pre)*128, l_frames[:,x,y]))
        stimulus_t = transform_data(stim, 8)
        out = models["ganglionar_type_{}".format(n)](stimulus_t.to(device))
        resp_mov_l[:,x,y] = out[:,0,:].cpu().detach().numpy()[:,0]
        resp_mov_l[:,x,y] = resp_mov_l[:,x,y] - resp_mov_l[:,x,y].mean()
print('done!')

In [None]:
n = 3
pre = 6
resp_mov = np.zeros((s_frames.shape[0]+pre,s_frames.shape[1],s_frames.shape[2]))
for x in range(s_frames.shape[1]):
    for y in range(s_frames.shape[2]):
        #print('processing pixel x={} y={}'.format(x,y))
        stim = np.concatenate((np.ones(pre)*128, s_frames[:,x,y]))
        stimulus_t = transform_data(stim, 8)
        out = models["ganglionar_type_{}".format(n)](stimulus_t.to(device))
        resp_mov[:,x,y] = out[:,0,:].cpu().detach().numpy()[:,0]
        resp_mov[:,x,y] = resp_mov[:,x,y] - resp_mov[:,x,y].mean()
        #print('done!')

    

In [None]:
import convis
convis.__version__

In [None]:
spk = convis.filters.spiking.Poisson()
o = spk.run(resp_mov / resp_mov.max())
plt.figure()
o.plot(mode='lines')
spk_out = o.array()

In [None]:
n = 5
pre = 6
resp_mov_m = np.zeros((m_frames.shape[0]+pre,m_frames.shape[1],m_frames.shape[2]))
for x in range(m_frames.shape[1]):
    for y in range(m_frames.shape[2]):
        #print('processing pixel x={} y={}'.format(x,y))
        stim = np.concatenate((np.ones(pre)*128, m_frames[:,x,y]))
        stimulus_t = transform_data(stim, 8)
        out = models["ganglionar_type_{}".format(n)](stimulus_t.to(device))
        resp_mov_m[:,x,y] = out[:,0,:].cpu().detach().numpy()[:,0]
resp_mov_m = resp_mov_m - resp_mov_m[:6].mean()
print('done!')

In [None]:
spk = convis.filters.spiking.Izhikevich()
o_m = spk.run(resp_mov_m[:,:,:]/resp_mov_m[:,:,:].max())
plt.figure()
o_m.plot(mode='lines')
spk_out_m = o_m.array()
spk_out_m.shape

In [None]:
spk = convis.filters.spiking.Izhikevich()
o_m_off = spk.run(-resp_mov_m[:,:,:]/resp_mov_m[:,:,:].max())
plt.figure()
o_m_off.plot(mode='lines')
spk_out_m_off = o_m_off.array()
spk_out_m_off.shape

In [None]:
spk = convis.filters.spiking.Izhikevich()
o_l = spk.run(resp_mov_l[:,:,:]/resp_mov_l[:,:,:].max())
plt.figure()
o_l.plot(mode='lines')

In [None]:
spk_out_l = o_l.array()
spk_out_l.shape

In [None]:
for i in range(s_frames.shape[2]):
    fig, ax = plt.subplots(1,3)
    fig.set_size_inches(9,3)
    ax[0].imshow(s_frames[i], cmap='gray')
    ax[0].axis('off')
    ax[1].imshow(resp_mov[i+6], vmin=resp_mov[6:,:,:].min(), vmax=resp_mov[6:,:,:].max())
    ax[1].axis('off')
    ax[2].imshow(spk_out[0,0,i+6])
    #ax[1].imshow(s_frames[i+6])
    ax[2].axis('off')
    fig.subplots_adjust(hspace=0, wspace=0)
    fig.savefig('outputs/output_gangliontype3_s_{:03d}.png'.format(i), facecolor='k')

In [None]:
for i in range(l_frames.shape[0]):
    fig, ax = plt.subplots(1,3)
    fig.set_size_inches(9,3)
    ax[0].imshow(l_frames[i], cmap='gray')
    ax[0].axis('off')
    ax[1].imshow(resp_mov_l[i+6,:,:], vmin=resp_mov_l[6:,:,:].min(), vmax=resp_mov_l[6:,:,:].max())
    ax[1].axis('off')
    ax[2].imshow(spk_out[0,0,i+6])
    ax[2].axis('off')
    fig.subplots_adjust(hspace=0, wspace=0)
    fig.savefig('outputs/output_gangliontype3_l_{:03d}.png'.format(i), facecolor='k')

In [None]:
for i in range(m_frames.shape[0]):
    fig, ax = plt.subplots(1,3)
    fig.set_size_inches(9,3)
    ax[0].imshow(m_frames[i], cmap='gray')
    ax[0].axis('off')
    ax[1].imshow(resp_mov_m[i+6,:,:], vmin=resp_mov_m[6:,:,:].min(), vmax=resp_mov_m[6:,:,:].max())
    ax[1].axis('off')
    ax[2].imshow(spk_out_m[0,0,i+6])
    ax[2].axis('off')
    fig.subplots_adjust(hspace=0, wspace=0)
    fig.savefig('outputs/output_gangliontype5_m_{:03d}.png'.format(i), facecolor='k')

In [None]:
fig, ax = plt.subplots()
fig.set_size_inches(4,4)
ims = []
for i in range(spk_out.shape[2]):
    im = ax.imshow(spk_out[0,0,i], animated=True)
    ax.set_xticks([])
    ax.set_yticks([])
    ims.append([im])
fig.tight_layout()   
ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True,
                                repeat_delay=1000)
ani.save('output_t3nat.mp4')

In [None]:
n = 15
resp15_mov = np.zeros((200,s_frames.shape[1],s_frames.shape[2]))
for x in range(s_frames.shape[1]):
    #print('processing column x={}'.format(x))
    for y in range(s_frames.shape[2]):
        #print('processing pixel x={} y={}'.format(x,y))
        stim = np.concatenate((np.ones(18)*128, s_frames[:,x,y]))
        stimulus_t = transform_data(stim, 8)
        out = models["ganglionar_type_{}".format(n)](stimulus_t.to(device))
        resp15_mov[:,x,y] = out[:,0,:].cpu().detach().numpy()[:,0]
print('done!')



In [None]:
fig = plt.figure()
fig.set_size_inches(4,4)
ims15 = []
for i in range(resp15_mov.shape[0]):
    print('processing frame {}'.format(i))
    im1 = plt.imshow(resp15_mov[i], animated=True, cmap = 'inferno', vmin=resp15_mov.min(), vmax=resp15_mov.max())
    ims15.append([im1])
    print('done!')

In [None]:
ani15 = animation.ArtistAnimation(fig, ims15, interval=50, blit=True,
                                repeat_delay=1000)
ani15.save('output_t15Bxs.mp4')

In [None]:
n = 34
resp34_mov = np.zeros((200,s_frames.shape[1],s_frames.shape[2]))
for x in range(s_frames.shape[1]):
    for y in range(s_frames.shape[2]):
        print('processing pixel x={} y={}'.format(x,y))
        stim = np.concatenate((np.ones(18)*128, s_frames[:,x,y]))
        stimulus_t = transform_data(stim, 8)
        out = models["ganglionar_type_{}".format(n)](stimulus_t.to(device))
        resp34_mov[:,x,y] = out[:,0,:].cpu().detach().numpy()[:,0]
        print('done!')




In [None]:
fig = plt.figure()
ims34 = []
for i in range(resp34_mov.shape[0]):
    print('processing frame {}'.format(i))
    im1 = plt.imshow(resp34_mov[i], animated=True, cmap = 'cividis', vmin=resp34_mov.min(), vmax=resp34_mov.max())
    ims34.append([im1])
    print('done!')

In [None]:
ani34 = animation.ArtistAnimation(fig, ims34, interval=50, blit=True,
                                repeat_delay=1000)
ani34.save('output_t34B.mp4')

In [None]:
n = 4
resp_b4_mov = np.zeros((200,m_frames.shape[1],m_frames.shape[2]))
for x in range(m_frames.shape[1]):
    for y in range(m_frames.shape[2]):
        #print('processing pixel x={} y={}'.format(x,y))
        stim = np.concatenate((np.ones(18)*128, m_frames[:,x,y]))
        stimulus_t = transform_data(stim, 8)
        out = models["bipolar_type_{}".format(n)](stimulus_t.to(device))
        resp_b4_mov[:,x,y] = out[:,0,:].cpu().detach().numpy()[:,0]
print('done!')


In [None]:
fig = plt.figure()
fig.set_size_inches(4,4)
ims_b4 = []
for i in range(resp_b4_mov.shape[0]):
    #print('processing frame {}'.format(i))
    im1 = plt.imshow(resp_b4_mov[i], animated=True, cmap = 'viridis', vmin=0, vmax=1)
    ims_b4.append([im1])
    
ani_b4 = animation.ArtistAnimation(fig, ims_b4, interval=50, blit=True,
                                repeat_delay=1000)
ani_b4.save('output_b4.mp4')
print('done!')

In [None]:
for n in small_cells:
    print('processing cell {}'.format(n))
    resp_mov = np.zeros((200,s_frames.shape[1],s_frames.shape[2]))
    for x in range(s_frames.shape[1]):
        for y in range(s_frames.shape[2]):
            stim = np.concatenate((np.ones(18)*128, s_frames[:,x,y]))
            stimulus_t = transform_data(stim, 8)
            out = models["ganglionar_type_{}".format(n)](stimulus_t.to(device))
            resp_mov[:,x,y] = out[:,0,:].cpu().detach().numpy()[:,0]
            
    fig = plt.figure()
    fig.set_size_inches(3,3)
    imsT = []
    for i in range(resp_mov.shape[0]):
        imT = plt.imshow(resp_mov[i], animated=True, cmap = 'inferno', vmin=resp_mov.min(), vmax=resp_mov.max())
        imsT.append([imT])
        
    ani_T = animation.ArtistAnimation(fig, imsT, interval=50, blit=True,
                                repeat_delay=1000)
    ani_T.save('videos/output_g{}i.mp4'.format(n))        
    plt.close(fig)
print('done!')

In [None]:
for n in medium_cells:
    print('processing cell {}'.format(n))
    resp_mov = np.zeros((200,m_frames.shape[1],m_frames.shape[2]))
    for x in range(m_frames.shape[1]):
        for y in range(m_frames.shape[2]):
            stim = np.concatenate((np.ones(18)*128, m_frames[:,x,y]))
            stimulus_t = transform_data(stim, 8)
            out = models["ganglionar_type_{}".format(n)](stimulus_t.to(device))
            resp_mov[:,x,y] = out[:,0,:].cpu().detach().numpy()[:,0]
            
    fig = plt.figure()
    fig.set_size_inches(3,3)
    imsT = []
    for i in range(resp_mov.shape[0]):
        imT = plt.imshow(resp_mov[i], animated=True, cmap = 'inferno', vmin=resp_mov.min(), vmax=resp_mov.max())
        imsT.append([imT])
        
    ani_T = animation.ArtistAnimation(fig, imsT, interval=50, blit=True,
                                repeat_delay=1000)
    ani_T.save('videos/output_g{}i.mp4'.format(n))        
    plt.close(fig)
print('done!')

In [None]:
for n in large_cells:
    print('processing cell {}'.format(n))
    resp_mov = np.zeros((200,l_frames.shape[1],l_frames.shape[2]))
    for x in range(l_frames.shape[1]):
        for y in range(l_frames.shape[2]):
            stim = np.concatenate((np.ones(18)*128, l_frames[:,x,y]))
            stimulus_t = transform_data(stim, 8)
            out = models["ganglionar_type_{}".format(n)](stimulus_t.to(device))
            resp_mov[:,x,y] = out[:,0,:].cpu().detach().numpy()[:,0]
            
    fig = plt.figure()
    fig.set_size_inches(3,3)
    imsT = []
    for i in range(resp_mov.shape[0]):
        imT = plt.imshow(resp_mov[i], animated=True, cmap = 'inferno', vmin=resp_mov.min(), vmax=resp_mov.max())
        imsT.append([imT])
        
    ani_T = animation.ArtistAnimation(fig, imsT, interval=50, blit=True,
                                repeat_delay=1000)
    ani_T.save('videos/output_g{}i.mp4'.format(n))        
    plt.close(fig)
print('done!')