In [None]:
using Printf: @printf
using Random
using Plots
using Einsum
using Statistics
using Flux

In [None]:
sigmoid(x::Number; thresh::Number=0., α::Number=1.) = 1. / (1 + exp(-α*(x-thresh)));
heaviside(x::Number; thresh::Number=0.) = 0.5 * (1 + (x-thresh) / sqrt((x-thresh)^2));

In [None]:
function plot_traces(Vm; spikes=nothing, layout::Tuple=(3,5), spike_height::Number=5,
        w::Number=150, h::Number=100)
    n = prod(layout)
    data = 1 .* Vm[1:n, :, :] # make a copy
    if ! isnothing(spikes)
        data[spikes[1:n,:,:] .== 1] .= spike_height
    end
    p = [plot(data[k,:,:]', axis=false) for k in 1:n]
    plot!(p..., layout=layout, lw=1.5, axis=nothing, legend=nothing, size=(w*layout[2], h*layout[1]));
end;

In [None]:
function simulate_SNN(τm::Number, τs::Number, Δt::Number, spikes_in::Array,
        w_ih::Matrix, w_ho::Matrix; θ::Number=1.)
    α = exp(-Δt/τs)
    β = exp(-Δt/τm)
    n_batches, n_spikes_in, n_steps = size(spikes_in)
    n_hidden, n_outputs = size(w_ho)
    
    # hidden layer
    @einsum I_ih[a,c,d] := spikes_in[a,b,d] * w_ih[b,c]
    I_syn_h  = zeros(batch_size, n_hidden, n_steps)
    Vm_h     = zeros(size(I_syn_h))
    spikes_h = zeros(size(I_syn_h))
    for t in 1 : n_steps-1
        spikes_h[:, :, t]  = heaviside.(Vm_h[:, :, t], thresh=θ)
        I_syn_h[:, :, t+1] = α * I_syn_h[:, :, t] + I_ih[:, :, t]
        Vm_h[:, :, t+1]    = β * Vm_h[:, :, t] + I_syn_h[:, :, t] - spikes_h[:, :, t]
    end

    # readout layer
    @einsum I_ho[a,c,d] := spikes_h[a,b,d] * w_ho[b,c]
    I_syn_o  = zeros(batch_size, n_outputs, n_steps)
    Vm_o     = zeros(size(I_syn_o))
    for t in 1 : n_steps-1
        I_syn_o[:, :, t+1] = α * I_syn_o[:, :, t] + I_ho[:, :, t]
        Vm_o[:, :, t+1]    = β * Vm_o[:, :, t] + I_syn_o[:, :, t]
    end

    Vm_o, Vm_h, spikes_h, I_syn_h, I_syn_o
end;

In [None]:
# neuron parameters
tau_mem    = 10e-3
tau_syn    = 5e-3
# network parameters
input_rate = 5 # [Hz]
n_inputs   = 50
n_hidden   = 4
n_outputs  = 2
# simulation parameters
dt         = 1e-3
n_steps    = 100
# batch size
batch_size = 64;

In [None]:
# for reproducibility
rng = MersenneTwister(1983);

In [None]:
prob = input_rate * dt
inputs = zeros(batch_size, n_inputs, n_steps)
inputs[rand(rng, Float64, size(inputs)) .< prob] .= 1;
@printf("Total number of input spikes: %d.", sum(inputs))

In [None]:
truth = rand(big.(1:n_outputs), batch_size)
classes = unique(truth)
y = Flux.onehotbatch(truth, classes)';

In [None]:
weight_scale = 7 * (1 - exp(-dt/tau_mem));
w_ih = weight_scale / sqrt(n_inputs) * randn(rng, Float64, (n_inputs, n_hidden));
w_ho = weight_scale / sqrt(n_inputs) * randn(rng, Float64, (n_hidden, n_outputs));

In [None]:
Vm_o, Vm_h, spikes_h, I_syn_h, I_syn_o = simulate_SNN(tau_mem, tau_syn, dt, inputs, w_ih, w_ho; θ=1.);

In [None]:
plot_traces(Vm_h, spikes=spikes_h, spike_height=5)

In [None]:
plot_traces(Vm_o)

In [None]:
y_hat = maximum(Vm_o, dims=3)[:,:,1]
loss = Flux.logitcrossentropy(y_hat, y)
@printf("Loss: %g.", loss)

In [None]:
alpha     = exp(-dt / tau_mem)
beta      = exp(-dt / tau_syn)
I_syn_h   = zeros(batch_size, n_hidden, n_steps)
Vm_h      = zeros(size(I_syn_h))
spikes_h  = zeros(size(I_syn_h))
I_syn_o   = zeros(batch_size, n_outputs, n_steps)
Vm_o      = zeros(size(I_syn_o))

pars = Flux.Params([w_ih, w_ho])

loss,grads = Flux.withgradient(pars) do

    # hidden layer
    @einsum I_ih[a,c,d] := inputs[a,b,d] * w_ih[b,c]
    for t in 1 : n_steps-1
        spikes_h[:, :, t]  = heaviside.(Vm_h[:, :, t], thresh=1.)
        I_syn_h[:, :, t+1] = alpha * I_syn_h[:, :, t] + I_ih[:, :, t]
        Vm_h[:, :, t+1]    = beta * Vm_h[:, :, t] + I_syn_h[:, :, t] - spikes_h[:, :, t]
    end

    # readout layer
    @einsum I_ho[a,c,d] := spikes_h[a,b,d] * w_ho[b,c]
    for t in 1 : n_steps-1
        I_syn_o[:, :, t+1] = alpha * I_syn_o[:, :, t] + I_ho[:, :, t]
        Vm_o[:, :, t+1]    = beta * Vm_o[:, :, t] + I_syn_o[:, :, t]
    end

    y_hat = maximum(Vm_o, dims=3)[:,:,1]
    Flux.logitcrossentropy(y_hat, y)

end