In [None]:
using Printf: @printf
using LinearAlgebra
using Random
using Statistics
using Distributions: Exponential
using Tullio
using ChainRulesCore
using Flux
using Plots
using ProgressMeter

panel_width, panel_height = 150, 100;

In [None]:
spiking = true;

In [None]:
spike_fun(x::Number) = x > 0. ? 1. : 0.
spike_fun(x::AbstractArray{<:Number}) = spike_fun.(x)
function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:HasReverseMode},
        ::typeof(spike_fun), x::Union{Number,AbstractArray{<:Number}})
    retval = spike_fun(x)
    pullback_spike_fun(y) = NoTangent(), y ./ (1.0 .+ 100.0 * abs.(x)).^2
    return retval, pullback_spike_fun
end

In [None]:
function simulate_leaky_NN(τm::Number, τs::Number, Δt::Number, spikes_in::Array, w::Matrix)
    α = exp(-Δt/τs)
    β = exp(-Δt/τm)
    n_batches, n_spikes_in, n_steps = size(spikes_in)
    _, n_outputs = size(w)
    @tullio I_inp[a,c,d] := spikes_in[a,b,d] * w[b,c]
    I_syn  = zeros(batch_size, n_outputs, n_steps)
    Vm     = zeros(size(I_syn))
    for t in 1 : n_steps-1
        I_syn[:, :, t+1] = α * I_syn[:, :, t] + I_inp[:, :, t]
        Vm[:, :, t+1]    = β * Vm[:, :, t] + I_syn[:, :, t]
    end
    Vm
end;

In [None]:
function simulate_spiking_NN(τm::Number, τs::Number, Δt::Number, spikes_in::Array, w::Matrix, θ::Number=1.)
    α = exp(-Δt/τs)
    β = exp(-Δt/τm)
    n_batches, n_spikes_in, n_steps = size(spikes_in)
    _, n_outputs = size(w)
    @tullio I_inp[a,c,d] := spikes_in[a,b,d] * w[b,c];
    I_syn  = zeros(batch_size, n_outputs, n_steps);
    Vm     = zeros(size(I_syn));
    spikes = zeros(size(I_syn));
    for t in 1 : n_steps-1
        reset = spike_fun(Vm[:, :, t] .- θ)
        I_syn[:, :, t+1]  = α * I_syn[:, :, t] + I_inp[:, :, t]
        Vm[:, :, t+1]     = β * Vm[:, :, t] + I_syn[:, :, t] - reset
        spikes[:, :, t+1] = reset
    end
    Vm, spikes
end;

In [None]:
function run_leaky_NN(tau_mem::Number, tau_syn::Number, dt::Number, spikes_in::Array, w::Matrix)
    alpha = exp(-dt / tau_syn)
    beta  = exp(-dt / tau_mem)
    n_batches, n_spikes_in, n_steps = size(spikes_in)
    n_outputs = size(w, 2)
    I_syn_curr  = zeros(batch_size, n_outputs);
    Vm_curr     = zeros(size(I_syn_curr));
    Vm_acc      = zeros(size(Vm_curr));
    @tullio I_inp[a,c,d] := spikes_in[a,b,d] * w[b,c]
    for t in 1 : n_steps
        I_syn_next = alpha * I_syn_curr + I_inp[:, :, t]
        Vm_next = beta * Vm_curr + I_syn_curr
        Vm_acc = Vm_acc + Vm_curr
        I_syn_curr = I_syn_next
        Vm_curr = Vm_next
    end
    Vm_acc ./ n_steps
end;

In [None]:
function run_spiking_NN(tau_mem::Number, tau_syn::Number, dt::Number, spikes_in::Array, w::Matrix, θ::Number=1.)
    alpha = exp(-dt / tau_syn)
    beta  = exp(-dt / tau_mem)
    n_batches, n_spikes_in, n_steps = size(spikes_in)
    n_outputs = size(w, 2)
    I_syn_curr  = zeros(batch_size, n_outputs);
    Vm_curr     = zeros(size(I_syn_curr));
    spikes_out  = zeros(size(Vm_curr));
    @tullio I_inp[a,c,d] := spikes_in[a,b,d] * w[b,c]
    for t in 1 : n_steps
        reset = spike_fun(Vm_curr .- θ)
        I_syn_next = alpha * I_syn_curr + I_inp[:, :, t]
        Vm_next = beta * Vm_curr + I_syn_curr - reset
        spikes_out += reset
        I_syn_curr = I_syn_next
        Vm_curr = Vm_next
    end
    spikes_out
end;

In [None]:
function plot_voltage_traces(Vm; spikes=nothing, layout::Tuple=(3,4), spike_height::Number=5,
        w::Number=panel_width, h::Number=panel_height)
    rows,cols = layout;
    n = rows * cols;
    idx = Int.(floor.(range(1, size(Vm, 1), n)))
    data = 1 .* Vm[idx, :, :] # make a copy
    if ! isnothing(spikes)
        data[spikes[idx,:,:] .== 1] .= spike_height
    end
    p = [plot(data[k,:,:]') for k in 1:n]
    plot!(p..., layout=layout, lw=1.5, axis=nothing, legend=nothing, size=(w*cols, h*rows));
end;

In [None]:
function plot_spikes(spikes, layout::Tuple=(3,4),
        w::Number=panel_width, h::Number=panel_height)
    rows,cols = layout;
    n = rows * cols;
    idx = Int.(floor.(range(1, size(spikes, 1), n)));
    p = [heatmap(1 .- spikes[i, :, :], color=:grays) for i in idx];
    plot(p..., layout=(rows, cols), axis=nothing, colorbar=nothing, size=(w*cols, h*rows));
end;

In [None]:
# neuron parameters
tau_mem    = 10e-3
tau_syn    = 5e-3
# network parameters
input_rate = 10 # [Hz]
n_inputs   = 100
n_outputs  = 2
# simulation parameters
tend       = 0.2
dt         = 1e-3
n_steps    = Int(tend / dt)
# batch size
batch_size = 128;

In [None]:
# for reproducibility
rng = MersenneTwister(100);

In [None]:
scale = 10
half_inputs = Int(n_inputs / 2)
ISI = rand(Exponential(1/input_rate), (batch_size, half_inputs, Int(ceil(tend * input_rate))));
spike_times_fast = cumsum(ISI, dims=3);
ISI = rand(Exponential(scale/input_rate), (batch_size, half_inputs, Int(ceil(tend * input_rate / scale))));
spike_times_slow = cumsum(ISI, dims=3);

In [None]:
inputs = zeros(batch_size, n_inputs, n_steps)
half_batches = Int(batch_size / 2)
for i in 1 : batch_size
    for j in 1 : half_inputs
        idx_fast = Int.(ceil.(spike_times_fast[i,j,:] / dt))
        idx_fast = idx_fast[idx_fast .<= n_steps]
        idx_slow = Int.(ceil.(spike_times_slow[i,j,:] / dt))
        idx_slow = idx_slow[idx_slow .<= n_steps]
        if i <= half_batches
            inputs[i, j, idx_fast] .= 1
            inputs[i, j+half_inputs, idx_slow] .= 1
        else
            inputs[i, j, idx_slow] .= 1
            inputs[i, j+half_inputs, idx_fast] .= 1
        end
    end
end
@printf("Total number of input spikes: %d.", sum(inputs))

In [None]:
plot_spikes(inputs)

In [None]:
truth = ones(Int, batch_size)
truth[1:half_batches] .= 2
classes = unique(truth)
y = Flux.onehotbatch(truth, classes)';

In [None]:
weight_scale = 7 * (1 - exp(-dt/tau_mem));
weights = weight_scale / sqrt(n_inputs) * randn(rng, Float64, (n_inputs, n_outputs));

In [None]:
if spiking
    Vm,spikes = simulate_spiking_NN(tau_mem, tau_syn, dt, inputs, weights)
    y_hat = sum(spikes, dims=3)[:,:,1]
else
    Vm = simulate_leaky_NN(tau_mem, tau_syn, dt, inputs, weights)
    y_hat = mean(Vm, dims=3)[:,:,1]
    spikes = nothing
end
loss = Flux.logitcrossentropy(y_hat, y)
@printf("Loss: %g.", loss)

In [None]:
plot_voltage_traces(Vm, spikes=spikes)

In [None]:
if spiking
    plot_spikes(spikes)
end

In [None]:
if spiking
    fun = run_spiking_NN
else
    fun = run_leaky_NN
end
pars = Flux.Params([weights])
optimizer = Adam(2e-3, (0.9, 0.999))
loss = []
@showprogress for epoch in 1:1_000
    l,grads = Flux.withgradient(pars) do
        y_hat = fun(tau_mem, tau_syn, dt, inputs, weights)
        Flux.logitcrossentropy(y_hat, y)
    end
    Flux.Optimise.update!(optimizer, pars, grads)
    push!(loss, l)
end

In [None]:
plot(loss, lw=2, color=:black, xlabel="Epoch", ylabel="Loss", label=false, size=(400,250))

In [None]:
if spiking
    Vm,spikes = simulate_spiking_NN(tau_mem, tau_syn, dt, inputs, weights)
    y_hat = sum(spikes, dims=3)[:,:,1]
else
    Vm = simulate_leaky_NN(tau_mem, tau_syn, dt, inputs, weights)
    y_hat = mean(Vm, dims=3)[:,:,1]
    spikes = nothing
end
loss = Flux.logitcrossentropy(y_hat, y)
@printf("Loss: %g.", loss)

In [None]:
plot_voltage_traces(Vm, spikes=spikes)

In [None]:
if spiking
    plot_spikes(spikes)
end