In [None]:
using Printf: @printf
using Random
using LinearAlgebra
using ChainRulesCore
using Zygote: pullback, gradient
using Flux
using ProgressMeter
using Plots

In [None]:
spike_fun(x::Number) = x > 0. ? 1. : 0.
spike_fun(x::AbstractArray{<:Number}) = spike_fun.(x)
function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:HasReverseMode},
        ::typeof(spike_fun), x::Union{Number,AbstractArray{<:Number}})
    retval = spike_fun(x)
    pullback_spike_fun(y) = NoTangent(), y ./ (1.0 .+ 100.0 * abs.(x)).^2
    return retval, pullback_spike_fun
end

In [None]:
N = 101
x = range(-0.2, 0.2, N)

use_withgradient = true
use_gradient = true

fun(x) = 2*x + spike_fun(x)

if use_withgradient
    y = zeros(size(x))
    dy = zeros(size(x))
    for i in 1:N
        y[i],tmp = Flux.withgradient(fun, x[i])
        dy[i] = tmp[1]
    end
elseif use_gradient
    y = fun.(x)
    dy = zeros(size(x))
    for i in 1:N
        dy[i] = gradient(fun, x[i])[1]
    end
else
    y = zeros(size(x))
    dy = zeros(size(x))
    for i in 1:N
        y[i],back = pullback(spike_fun, x[i])
        dy[i] = back(1.)[1]
    end
end

plot(x, y, color=:black, lw=2, label=nothing, size=(400,200))
plot!(x, dy, color=:magenta, lw=2, label=nothing)

In [None]:
rng = MersenneTwister(1983)

n_inputs = 10
n_outputs = 2
batch_size = 5
w = randn(rng, Float64, (n_inputs, n_outputs))
x = zeros(Int, (n_inputs, batch_size))
x[rand(rng, Float64, size(x)) .> 0.5] .= 1

truth = zeros(Int, batch_size)
truth[rand(rng, Float64, size(truth)) .> 0.5] .= 1
classes = unique(truth)
y = Flux.onehotbatch(truth, classes)

In [None]:
spike_fun(w'*x)

In [None]:
@printf("Initial weights:\n\n")
for i in 1 : n_outputs
    for j in 1 : n_inputs
        @printf("%6.3f ", w[j,i])
    end
    @printf("\n")
end
@printf("\n")

In [None]:
η, β = 1, (0.9, 0.8)
optimizer = Flux.Adam(η, β)
pars = Flux.Params([w])
loss = []
@showprogress for i in 1 : 1_000
    l,grad = Flux.withgradient(pars) do
        y_hat = spike_fun(w'*x)
        Flux.logitcrossentropy(y_hat, y)
    end
    Flux.Optimise.update!(optimizer, pars, grad)
    push!(loss, l)
end

In [None]:
@printf("Final weights:\n\n")
for i in 1 : n_outputs
    for j in 1 : n_inputs
        @printf("%6.3f ", w[j,i])
    end
    @printf("\n")
end
@printf("\n")

In [None]:
plot(loss, lw=3, color=:black, label=nothing, size=(400,250), xlabel="Epoch", ylabel="Loss")

In [None]:
y

In [None]:
spike_fun(w'*x)