In [None]:
import os

import numpy as np
from numpy.random import RandomState, SeedSequence, MT19937
from scipy.stats import expon, norm
from sklearn.preprocessing import OneHotEncoder

import tensorflow as tf

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

panel_w, panel_h = 2, 1.5

In [None]:
def plot_voltage_traces(Vm, spikes=None, dim=(4,4), spike_height=5):
    rows,cols = dim
    fig,ax = plt.subplots(rows, cols, figsize=(cols*panel_w, rows*panel_h), sharex=True, sharey=True)
    if spikes is not None:
        data = 1.0 * Vm # make a copy of Vm
        data[spikes > 0.0] = spike_height
    else:
        data = Vm
    idx = np.linspace(0, len(Vm), rows*cols, endpoint=False, dtype=int)
    for i in range(rows):
        for j in range(cols):
            k = i*dim[1] + j
            ax[i,j].plot(data[idx[k]], lw=1)
            ax[i,j].grid(which='major', axis='y', lw=0.5, ls=':', color=[.6,.6,.6])
            sns.despine()
    return fig

In [None]:
def plot_spikes(spikes, dim=(4,4)):
    rows,cols = dim
    idx = np.linspace(0, spikes.shape[0], rows*cols, endpoint=False, dtype=int)
    fig,ax = plt.subplots(rows, cols, figsize=(cols*panel_w, rows*panel_h), sharex=True, sharey=True)
    for i in range(rows):
        for j in range(cols):
            k = i*cols + j
            ax[i][j].imshow(spikes[idx[k]], cmap=plt.cm.gray_r, aspect='auto')
            sns.despine()
    for a in ax[-1,:]:
        a.set_xlabel('Time (ms)')
    for a in ax[:,0]:
        a.set_ylabel('Unit')
    return fig

In [None]:
def run_leaky_nn(inputs, weights, dtype=tf.float32):
    batch_size, _, n_steps = inputs.shape
    n_inputs, n_outputs = weights.shape
    I_inp = tf.einsum('abc,bd->acd', inputs, weights)
    I_syn_curr = tf.zeros((batch_size, n_outputs), dtype=dtype)
    Vm_curr = tf.zeros((batch_size, n_outputs), dtype=dtype)
    Vm = []
    for t in range(n_steps):
        Vm.append(Vm_curr)
        I_syn_next = alpha * I_syn_curr + I_inp[:,t,:]
        Vm_next = beta * Vm_curr + I_syn_curr
        I_syn_curr = I_syn_next
        Vm_curr = Vm_next
    Vm = tf.stack(Vm, axis=1)
    return Vm

In [None]:
def run_spiking_nn(inputs, weigths, thresh=1., dtype=tf.float32):
    heaviside = lambda x, thresh=0.: 0.5 * (1 + (x-thresh) / tf.sqrt((x-thresh)**2))
    I_inp = tf.einsum('abc,bd->acd', inputs, weights)
    I_syn_curr = tf.zeros((batch_size, n_outputs), dtype=dtype)
    Vm_curr = tf.zeros((batch_size, n_outputs), dtype=dtype)
    Vm = []
    spikes = []
    for t in range(n_steps):
        reset = heaviside(Vm_curr, thresh)
        I_syn_next = alpha * I_syn_curr + I_inp[:,t,:]
        Vm_next = beta * Vm_curr + I_syn_curr - reset
        Vm.append(Vm_curr)
        spikes.append(reset)
        I_syn_curr = I_syn_next
        Vm_curr = Vm_next
    Vm = tf.stack(Vm, axis=1)
    spikes = tf.stack(spikes, axis=1)
    return Vm, spikes

In [None]:
rs = RandomState(MT19937(SeedSequence(100)))

In [None]:
spiking = False

In [None]:
# neuron parameters
tau_mem    = 10e-3
tau_syn    = 5e-3
# network parameters
input_rate = 10 # [Hz]
n_inputs   = 50
n_outputs  = 2
# simulation parameters
tend       = 0.2
dt         = 1e-3
n_steps    = int(tend / dt)
# batch size
batch_size = 32

In [None]:
scale = 10
half_inputs = n_inputs // 2
rv = expon()
ISI = expon.rvs(scale=1/input_rate,
                size=(batch_size, half_inputs, int(np.ceil(tend * input_rate))),
                random_state=rs)
spike_times_fast = np.cumsum(ISI, axis=-1)
ISI = expon.rvs(scale=1/(input_rate/scale),
                size=(batch_size, half_inputs, int(np.ceil(tend * input_rate / scale))),
                random_state=rs)
spike_times_slow = np.cumsum(ISI, axis=-1)

In [None]:
inputs = np.zeros((batch_size, n_inputs, n_steps), dtype=np.float32)
half_batches = batch_size // 2
for i in range(batch_size):
    for j in range(half_inputs):
        idx_fast = np.floor(spike_times_fast[i,j,:] / dt).astype(int)
        idx_fast = idx_fast[idx_fast < n_steps]
        idx_slow = np.floor(spike_times_slow[i,j,:] / dt).astype(int)
        idx_slow = idx_slow[idx_slow < n_steps]
        if i < half_batches:
            inputs[i, j, idx_fast] = 1
            inputs[i, j+half_inputs, idx_slow] = 1
        else:
            inputs[i, j, idx_slow] = 1
            inputs[i, j+half_inputs, idx_fast] = 1
print(f'Total number of input spikes: {inputs.sum():.0f}.')
# inputs = tf.Variable(inputs, dtype=dtype)

In [None]:
fig = plot_spikes(inputs)
fig.tight_layout()

In [None]:
truth = np.zeros((batch_size,1))
truth[:half_batches] = 1
enc = OneHotEncoder()
y = enc.fit(truth).transform(truth).toarray().astype(np.int8)

In [None]:
alpha   = float(np.exp(-dt/tau_syn))
beta    = float(np.exp(-dt/tau_mem))
weight_scale = 7 * (1 - beta)
μ = 0
weights = np.zeros((n_inputs, n_outputs))
weights[:half_inputs, 0] = norm.rvs(loc=μ, scale=weight_scale/np.sqrt(n_inputs),
                                    size=half_inputs, random_state=rs)
weights[half_inputs:, 0] = norm.rvs(loc=-μ, scale=weight_scale/np.sqrt(n_inputs),
                                    size=half_inputs, random_state=rs)
weights[:half_inputs, 1] = norm.rvs(loc=μ, scale=weight_scale/np.sqrt(n_inputs),
                                    size=half_inputs, random_state=rs)
weights[half_inputs:, 1] = norm.rvs(loc=-μ, scale=weight_scale/np.sqrt(n_inputs),
                                    size=half_inputs, random_state=rs)

In [None]:
dtype = tf.float32
inputs_tf = tf.constant(inputs, dtype=dtype)
weights_tf = tf.Variable(weights, dtype=dtype, trainable=True)
y_tf = tf.constant(y, dtype=dtype)

In [None]:
if spiking:
    Vm,spikes = run_spiking_nn(inputs_tf, weights_tf, dtype=dtype)
    fig = plot_voltage_traces(Vm.numpy(), spikes.numpy())
else:
    Vm = run_leaky_nn(inputs_tf, weights_tf, dtype=dtype)
    fig = plot_voltage_traces(Vm.numpy())
fig.tight_layout()

In [None]:
if spiking:
    output = tf.transpose(spikes, perm=(0, 2, 1)).numpy()
    fig = plot_spikes(output)
    fig.tight_layout()

In [None]:
loss_fun = tf.nn.softmax_cross_entropy_with_logits
with tf.GradientTape() as tape:
    if spiking:
        Vm,spikes = run_spiking_nn(inputs_tf, weights_tf, dtype=dtype)
        y_hat = tf.math.reduce_sum(spikes, axis=1)
    else:
        Vm = run_leaky_nn(inputs_tf, weights_tf, dtype=dtype)
        y_hat = tf.math.reduce_mean(Vm, axis=1)
    loss = tf.math.reduce_sum(loss_fun(labels=y_tf, logits=y_hat))
grad = tape.gradient(loss, [weights_tf])
print(f'Loss: {loss:g}.')

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=2e-3, beta_1=0.9, beta_2=0.999)
loss_fun = tf.nn.softmax_cross_entropy_with_logits
loss_hist = []
n_epochs = 500
for e in tqdm(range(n_epochs)):
    with tf.GradientTape(persistent=False) as tape:
        if spiking:
            _,spikes = run_spiking_nn(inputs_tf, weights_tf, dtype=dtype)
            y_hat = tf.math.reduce_sum(spikes, axis=1)
        else:
            Vm = run_leaky_nn(inputs_tf, weights_tf, dtype=dtype) 
            y_hat = tf.math.reduce_mean(Vm, axis=1)
        # compute the loss
        loss = tf.math.reduce_sum(loss_fun(labels=y_tf, logits=y_hat))
    # compute the gradient
    grad = tape.gradient(loss, [weights_tf])
    # update the weights
    optimizer.apply_gradients(zip(grad, [weights_tf]))
    # store loss value
    loss_hist.append(loss.numpy())

In [None]:
fig,ax = plt.subplots(1, 1, figsize=(4,3))
ax.plot(loss_hist, color='k')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.grid(which='major', axis='y', lw=0.5, ls=':', color=[.6,.6,.6])
sns.despine()
fig.tight_layout()

In [None]:
if spiking:
    Vm,spikes = run_spiking_nn(inputs_tf, weights_tf, dtype=dtype)
    fig = plot_voltage_traces(Vm.numpy(), spikes.numpy())
else:
    Vm = run_leaky_nn(inputs_tf, weights_tf, dtype=dtype)
    fig = plot_voltage_traces(Vm.numpy())
fig.tight_layout()

In [None]:
if spiking:
    output = tf.transpose(spikes, perm=(0, 2, 1)).numpy()
    fig = plot_spikes(output)
    fig.tight_layout()