In [None]:
using Printf: @printf
using Random
using Tullio
using Statistics
using LinearAlgebra
using Distributions: Exponential
using Flux
using Plots
using ProgressMeter

In [None]:
sigmoid(x::Number; thresh::Number=0., α::Number=1.) = 1. / (1 + exp(-α*(x-thresh)));
heaviside(x::Number; thresh::Number=0.) = 0.5 * (1 + (x-thresh) / sqrt((x-thresh)^2));

In [None]:
function simulate_SNN(τm::Number, τs::Number, Δt::Number, spikes_in::Array, w::Matrix)
    α = exp(-Δt/τs)
    β = exp(-Δt/τm)
    n_batches, n_spikes_in, n_steps = size(spikes_in)
    _, n_outputs = size(w)
    @tullio I_inp[a,c,d] := spikes_in[a,b,d] * w[b,c]
    I_syn  = zeros(batch_size, n_outputs, n_steps)
    Vm     = zeros(size(I_syn))
    for t in 1 : n_steps-1
        I_syn[:, :, t+1] = α * I_syn[:, :, t] + I_inp[:, :, t]
        Vm[:, :, t+1]    = β * Vm[:, :, t] + I_syn[:, :, t]
    end

    Vm
end;

In [None]:
function SNN_sim(tau_mem::Number, tau_syn::Number, dt::Number, spikes_in::Array, w::Matrix)
    alpha = exp(-dt / tau_syn)
    beta  = exp(-dt / tau_mem)
    n_batches, n_spikes_in, n_steps = size(spikes_in)
    n_outputs = size(w, 2)
    I_syn_curr  = zeros(batch_size, n_outputs);
    Vm_curr     = zeros(size(I_syn_curr));
    Vm_acc      = zeros(size(Vm_curr));
    @tullio I_inp[a,c,d] := spikes_in[a,b,d] * w[b,c]
    for t in 1 : n_steps
        Vm_acc = Vm_acc + Vm_curr
        I_syn_next = alpha * I_syn_curr + I_inp[:, :, t]
        Vm_next = beta * Vm_curr + I_syn_curr
        I_syn_curr = I_syn_next
        Vm_curr = Vm_next
    end
    Vm_acc ./ n_steps
end;

In [None]:
function plot_traces(Vm; spikes=nothing, layout::Tuple=(3,5), spike_height::Number=5,
        w::Number=150, h::Number=100)
    n = prod(layout)
    data = 1 .* Vm[1:n, :, :] # make a copy
    if ! isnothing(spikes)
        data[spikes[1:n,:,:] .== 1] .= spike_height
    end
    p = [plot(data[k,:,:]', axis=false) for k in 1:n]
    plot!(p..., layout=layout, lw=1.5, axis=nothing, legend=nothing, size=(w*layout[2], h*layout[1]));
end;

In [None]:
# neuron parameters
tau_mem    = 10e-3
tau_syn    = 5e-3
# network parameters
input_rate = 10 # [Hz]
n_inputs   = 100
n_outputs  = 2
# simulation parameters
tend       = 0.2
dt         = 1e-3
n_steps    = Int(tend / dt)
# batch size
batch_size = 32;

In [None]:
# for reproducibility
rng = MersenneTwister(100);

In [None]:
scale = 10
half_inputs = Int(n_inputs / 2)
ISI = rand(Exponential(1/input_rate), (batch_size, half_inputs, Int(ceil(tend * input_rate))));
spike_times_fast = cumsum(ISI, dims=3);
ISI = rand(Exponential(scale/input_rate), (batch_size, half_inputs, Int(ceil(tend * input_rate / scale))));
spike_times_slow = cumsum(ISI, dims=3);

In [None]:
inputs = zeros(batch_size, n_inputs, n_steps)
half_batches = Int(batch_size / 2)
for i in 1 : batch_size
    for j in 1 : half_inputs
        idx_fast = Int.(ceil.(spike_times_fast[i,j,:] / dt))
        idx_fast = idx_fast[idx_fast .<= n_steps]
        idx_slow = Int.(ceil.(spike_times_slow[i,j,:] / dt))
        idx_slow = idx_slow[idx_slow .<= n_steps]
        if i <= half_batches
            inputs[i, j, idx_fast] .= 1
            inputs[i, j+half_inputs, idx_slow] .= 1
        else
            inputs[i, j, idx_slow] .= 1
            inputs[i, j+half_inputs, idx_fast] .= 1
        end
    end
end
@printf("Total number of input spikes: %d.", sum(inputs))

In [None]:
rows, cols = 4,3;
idx = Int.(round.(range(1, batch_size, rows*cols)));
p = [heatmap(1 .- inputs[i, :, :], color=:grays, title="Batch #"*string(i)) for i in idx];
for i in 1:cols:rows*cols
    plot!(p[i], ylabel="Input #");
end
for i in (rows-1)*cols+1:rows*cols
    plot!(p[i], xlabel="Time")
end
plot(p..., layout=(rows, cols), colorbar=nothing, size=(250*cols, 200*rows))

In [None]:
truth = ones(Int, batch_size)
truth[1:half_batches] .= 2
classes = unique(truth)
y = Flux.onehotbatch(truth, classes)';

In [None]:
weight_scale = 7 * (1 - exp(-dt/tau_mem));
w = weight_scale / sqrt(n_inputs) * randn(rng, Float64, (n_inputs, n_outputs));

In [None]:
Vm = simulate_SNN(tau_mem, tau_syn, dt, inputs, w)
y_hat = maximum(Vm, dims=3)[:,:,1]
loss = Flux.logitcrossentropy(y_hat, y)
@printf("Loss: %g.", loss)

In [None]:
plot_traces(Vm)

In [None]:
pars = Flux.Params([w])
optimizer = Adam(2e-3, (0.9, 0.999)) # simple gradient descent
loss = []
@showprogress for epoch in 1:1_000
    l,grads = Flux.withgradient(pars) do
        y_hat = SNN_sim(tau_mem, tau_syn, dt, inputs, w)
        Flux.logitcrossentropy(y_hat, y)
    end
    Flux.Optimise.update!(optimizer, pars, grads)
    push!(loss, l)
end

In [None]:
plot(loss, lw=2, color=:black, xlabel="Epoch", ylabel="Loss", label=false, size=(400,250))

In [None]:
Vm = simulate_SNN(tau_mem, tau_syn, dt, inputs, w)
y_hat = maximum(Vm, dims=3)[:,:,1]
loss = Flux.logitcrossentropy(y_hat, y)
@printf("Loss: %g.", loss)

In [None]:
plot_traces(Vm)