In [None]:
import os

import numpy as np
from numpy.random import RandomState, SeedSequence, MT19937
from scipy.stats import expon

import torch
import torch.nn as nn

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

panel_w, panel_h = 2, 1.5

In [None]:
def plot_voltage_traces(mem, spk=None, dim=(4,4), spike_height=5):
    rows,cols = dim
    fig,ax = plt.subplots(rows, cols, figsize=(cols*panel_w, rows*panel_h), sharex=True, sharey=True)
    if spk is not None:
        dat = 1.0*mem
        dat[spk>0.0] = spike_height
        dat = dat.detach().cpu().numpy()
    else:
        dat = mem.detach().cpu().numpy()
    idx = np.linspace(0, len(mem), rows*cols, endpoint=False, dtype=int)
    for i in range(rows):
        for j in range(cols):
            k = i*dim[1] + j
            ax[i,j].plot(dat[idx[k]], lw=1)
            ax[i,j].grid(which='major', axis='y', lw=0.5, ls=':', color=[.6,.6,.6])
            sns.despine()
    return fig

In [None]:
def plot_spikes(spikes, dim=(4,4)):
    rows,cols = dim
    idx = np.linspace(0, spikes.shape[0], rows*cols, endpoint=False, dtype=int)
    fig,ax = plt.subplots(rows, cols, figsize=(cols*panel_w, rows*panel_h), sharex=True, sharey=True)
    for i in range(rows):
        for j in range(cols):
            k = i*cols + j
            ax[i][j].imshow(spikes[idx[k]], cmap=plt.cm.gray_r, aspect='auto')
            sns.despine()
    for a in ax[-1,:]:
        a.set_xlabel('Time (ms)')
    for a in ax[:,0]:
        a.set_ylabel('Unit')
    return fig

In [None]:
def run_leaky_nn(inputs):
    I_inp = torch.einsum('abc,bd->acd', (inputs, weights))
    I_syn_curr = torch.zeros((batch_size, n_outputs), device=device, dtype=dtype)
    Vm_curr = torch.zeros((batch_size, n_outputs), device=device, dtype=dtype)
    Vm = []
    for t in range(n_steps):
        Vm.append(Vm_curr)
        I_syn_next = alpha * I_syn_curr + I_inp[:,t,:]
        Vm_next = beta * Vm_curr + I_syn_curr
        I_syn_curr = I_syn_next
        Vm_curr = Vm_next
    Vm = torch.stack(Vm, dim=1)
    return Vm

In [None]:
def run_spiking_nn(inputs):
    heaviside = lambda x, thresh=0.: 0.5 * (1 + (x-thresh) / torch.sqrt((x-thresh)**2))
    I_inp = torch.einsum('abc,bd->acd', (inputs, weights))
    I_syn_curr = torch.zeros((batch_size, n_outputs), device=device, dtype=dtype)
    Vm_curr = torch.zeros((batch_size, n_outputs), device=device, dtype=dtype)
    Vm = []
    spikes = []
    for t in range(n_steps):
        out = heaviside(Vm_curr, 1.)
        reset = out
        I_syn_next = alpha * I_syn_curr + I_inp[:,t,:]
        Vm_next = beta * Vm_curr + I_syn_curr - reset
        Vm.append(Vm_curr)
        spikes.append(out)
        I_syn_curr = I_syn_next
        Vm_curr = Vm_next
    Vm = torch.stack(Vm, dim=1)
    spikes = torch.stack(spikes, dim=1)
    return Vm, spikes

In [None]:
# neuron parameters
tau_mem    = 10e-3
tau_syn    = 5e-3
# network parameters
input_rate = 10 # [Hz]
n_inputs   = 512
n_outputs  = 2
# simulation parameters
tend       = 0.2
dt         = 1e-3
n_steps    = int(tend / dt)
# batch size
batch_size = 256

In [None]:
dtype = torch.float
device = torch.device("cpu")
# Uncomment the line below to run on GPU
# device = torch.device("cuda:0") 

In [None]:
scale = 10
half_inputs = n_inputs // 2
ISI = torch.zeros((batch_size, half_inputs, int(np.ceil(tend * input_rate))), dtype=dtype)
ISI.exponential_(input_rate);
spike_times_fast = torch.cumsum(ISI, dim=-1)
ISI = torch.zeros((batch_size, half_inputs, int(np.ceil(tend * input_rate / scale))), dtype=dtype)
ISI.exponential_(input_rate/scale)
spike_times_slow = torch.cumsum(ISI, dim=-1)

In [None]:
inputs = torch.zeros((batch_size, n_inputs, n_steps), dtype=torch.float)
half_batches = batch_size // 2
for i in range(batch_size):
    for j in range(half_inputs):
        idx_fast = torch.floor(spike_times_fast[i,j,:] / dt).long()
        idx_fast = idx_fast[idx_fast < n_steps]
        idx_slow = torch.floor(spike_times_slow[i,j,:] / dt).long()
        idx_slow = idx_slow[idx_slow < n_steps]
        if i < half_batches:
            inputs[i, j, idx_fast] = 1
            inputs[i, j+half_inputs, idx_slow] = 1
        else:
            inputs[i, j, idx_slow] = 1
            inputs[i, j+half_inputs, idx_fast] = 1
print(f'Total number of input spikes: {inputs.sum():.0f}.')

In [None]:
fig = plot_spikes(inputs)
fig.tight_layout()

In [None]:
truth = torch.zeros(batch_size, dtype=torch.long)
truth[:half_batches] = 1
classes = truth.unique()
y = nn.functional.one_hot(truth, len(classes)).float()

In [None]:
alpha   = float(np.exp(-dt/tau_syn))
beta    = float(np.exp(-dt/tau_mem))
weight_scale = 7 * (1 - beta)
μ = 0
weights = torch.empty((n_inputs, n_outputs), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(weights[:half_inputs, 0], mean=μ, std=weight_scale/np.sqrt(n_inputs))
torch.nn.init.normal_(weights[half_inputs:, 0], mean=-μ, std=weight_scale/np.sqrt(n_inputs))
torch.nn.init.normal_(weights[:half_inputs, 1], mean=-μ, std=weight_scale/np.sqrt(n_inputs))
torch.nn.init.normal_(weights[half_inputs:, 1], mean=μ, std=weight_scale/np.sqrt(n_inputs));

In [None]:
Vm,spikes = run_spiking_nn(inputs)
fig = plot_voltage_traces(Vm, spikes)
fig.tight_layout()

In [None]:
output = torch.transpose(spikes, 1, 2).detach()
fig = plot_spikes(output)
fig.tight_layout()

In [None]:
optimizer = torch.optim.Adam([weights], lr=2e-3, betas=(0.9,0.999))
loss_fun = nn.CrossEntropyLoss()
loss_hist = []
n_epochs = 500
spiking = True
for e in tqdm(range(n_epochs)):
    if spiking:
        _,spikes = run_spiking_nn(inputs)
        output = spikes.sum(axis=1)
    else:
        Vm = run_leaky_nn(inputs) 
        output = Vm.mean(axis=1)
    # compute the loss
    y_hat = output.softmax(dim=1)
    loss = loss_fun(y_hat, y)
    # update the weights
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # store loss value
    loss_hist.append(loss.item())

In [None]:
fig,ax = plt.subplots(1, 1, figsize=(4,3))
ax.plot(loss_hist, color='k')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.grid(which='major', axis='y', lw=0.5, ls=':', color=[.6,.6,.6])
sns.despine()
fig.tight_layout()

In [None]:
Vm,spikes = run_spiking_nn(inputs)
fig = plot_voltage_traces(Vm, spikes)
fig.tight_layout()

In [None]:
output = torch.transpose(spikes, 1, 2).detach()
fig = plot_spikes(output)
fig.tight_layout()