In [None]:
import os

import numpy as np
from numpy.random import RandomState, SeedSequence, MT19937
from scipy.stats import expon

import torch
import torch.nn as nn

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

In [None]:
def plot_voltage_traces(mem, spk=None, dim=(2,4), spike_height=5, w=2, h=1.5):
    fig,ax = plt.subplots(dim[0], dim[1], figsize=(dim[1]*w, dim[0]*h), sharex=True, sharey=True)
    if spk is not None:
        dat = 1.0*mem
        dat[spk>0.0] = spike_height
        dat = dat.detach().cpu().numpy()
    else:
        dat = mem.detach().cpu().numpy()
    idx = np.linspace(0, len(mem), dim[0]*dim[1], endpoint=False, dtype=int)
    for i in range(dim[0]):
        for j in range(dim[1]):
            k = i*dim[1] + j
            ax[i,j].plot(dat[idx[k]], lw=1)
            ax[i,j].grid(which='major', axis='y', lw=0.5, ls=':', color=[.6,.6,.6])
            sns.despine()
#             ax[i,j].axis('off')
    return fig

In [None]:
def run_nsnn(inputs):
    I_inp = torch.einsum('abc,bd->acd', (inputs, w))
    I_syn_curr = torch.zeros((batch_size, n_outputs), device=device, dtype=dtype)
    Vm_curr = torch.zeros((batch_size, n_outputs), device=device, dtype=dtype)
    Vm = []
    for t in range(n_steps):
        Vm.append(Vm_curr)
        I_syn_next = alpha * I_syn_curr + I_inp[:,t,:]
        Vm_next = beta * Vm_curr + I_syn_curr
        I_syn_curr = I_syn_next
        Vm_curr = Vm_next
    Vm = torch.stack(Vm, dim=1)
    return Vm

In [None]:
def run_snn(inputs):
    heaviside = lambda x, thresh=0.: 0.5 * (1 + (x-thresh) / torch.sqrt((x-thresh)**2))
    I_inp = torch.einsum('abc,bd->acd', (inputs, w))
    I_syn_curr = torch.zeros((batch_size, n_outputs), device=device, dtype=dtype)
    Vm_curr = torch.zeros((batch_size, n_outputs), device=device, dtype=dtype)
    Vm = []
    spikes = []
    for t in range(n_steps):
        out = heaviside(Vm_curr, 1.)
        reset = out
        I_syn_next = alpha * I_syn_curr + I_inp[:,t,:]
        Vm_next = beta * Vm_curr + I_syn_curr - reset
        Vm.append(Vm_curr)
        spikes.append(out)
        I_syn_curr = I_syn_next
        Vm_curr = Vm_next
    Vm = torch.stack(Vm, dim=1)
    spikes = torch.stack(spikes, dim=1)
    return Vm, spikes

In [None]:
# neuron parameters
tau_mem    = 10e-3
tau_syn    = 5e-3
# network parameters
input_rate = 10 # [Hz]
n_inputs   = 512
n_outputs  = 2
# simulation parameters
tend       = 0.2
dt         = 1e-3
n_steps    = int(tend / dt)
# batch size
batch_size = 256

In [None]:
dtype = torch.float
device = torch.device("cpu")
# Uncomment the line below to run on GPU
# device = torch.device("cuda:0") 

In [None]:
scale = 10
half_inputs = n_inputs // 2
ISI = torch.zeros((batch_size, half_inputs, int(np.ceil(tend * input_rate))), dtype=dtype)
ISI.exponential_(input_rate);
spike_times_fast = torch.cumsum(ISI, dim=-1)
ISI = torch.zeros((batch_size, half_inputs, int(np.ceil(tend * input_rate / scale))), dtype=dtype)
ISI.exponential_(input_rate/scale)
spike_times_slow = torch.cumsum(ISI, dim=-1)

In [None]:
inputs = torch.zeros((batch_size, n_inputs, n_steps), dtype=torch.float)
half_batches = batch_size // 2
for i in range(batch_size):
    for j in range(half_inputs):
        idx_fast = torch.floor(spike_times_fast[i,j,:] / dt).long()
        idx_fast = idx_fast[idx_fast < n_steps]
        idx_slow = torch.floor(spike_times_slow[i,j,:] / dt).long()
        idx_slow = idx_slow[idx_slow < n_steps]
        if i < half_batches:
            inputs[i, j, idx_fast] = 1
            inputs[i, j+half_inputs, idx_slow] = 1
        else:
            inputs[i, j, idx_slow] = 1
            inputs[i, j+half_inputs, idx_fast] = 1
print(f'Total number of input spikes: {inputs.sum():.0f}.')

In [None]:
rows,cols = 2, 3
idx = np.linspace(0, batch_size, rows*cols, endpoint=False, dtype=int)
fig,ax = plt.subplots(rows, cols, figsize=(cols*2, rows*1.5), sharex=True, sharey=True)
for i in range(rows):
    for j in range(cols):
        k = i*cols + j
        ax[i][j].imshow(inputs[idx[k]].cpu(), cmap=plt.cm.gray_r, aspect='auto')
        sns.despine()
for a in ax[-1,:]:
    a.set_xlabel('Time (ms)')
for a in ax[:,0]:
    a.set_ylabel('Unit')
fig.tight_layout()

In [None]:
truth = torch.zeros(batch_size, dtype=torch.long)
truth[:half_batches] = 1
classes = truth.unique()
y = nn.functional.one_hot(truth, len(classes)).float()

In [None]:
alpha   = float(np.exp(-dt/tau_syn))
beta    = float(np.exp(-dt/tau_mem))
weight_scale = 7 * (1 - beta)
μ = 5e-3
w = torch.empty((n_inputs, n_outputs), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(w[:half_inputs, 0], mean=μ, std=weight_scale/np.sqrt(n_inputs))
torch.nn.init.normal_(w[half_inputs:, 0], mean=-μ, std=weight_scale/np.sqrt(n_inputs))
torch.nn.init.normal_(w[:half_inputs, 1], mean=-μ, std=weight_scale/np.sqrt(n_inputs))
torch.nn.init.normal_(w[half_inputs:, 1], mean=μ, std=weight_scale/np.sqrt(n_inputs));

In [None]:
Vm = run_nsnn(inputs)
fig = plot_voltage_traces(Vm)
fig.tight_layout()

In [None]:
Vm,spikes = run_snn(inputs)
fig = plot_voltage_traces(Vm, spikes)
fig.tight_layout()

In [None]:
Vm = run_nsnn(inputs)
y_hat = Vm.mean(axis=1).softmax(dim=1)
loss_fun = nn.CrossEntropyLoss()
loss = loss_fun(y_hat, y)
print(f'Loss: {loss:g}.')
loss.backward()

In [None]:
Vm,spikes = run_snn(inputs)
y_hat = spikes.sum(dim=1).softmax(dim=1)
loss_fun = nn.CrossEntropyLoss()
loss = loss_fun(y_hat, y)
print(f'Loss: {loss:g}.')
loss.backward()