In [None]:
using Printf: @printf
using Random
using LinearAlgebra
using ChainRulesCore
using Zygote: pullback, gradient
using Flux
using ProgressMeter
using Plots

In [None]:
spike_fun(x) = x > 0. ? 1. : 0.
function ChainRulesCore.rrule(::typeof(spike_fun), x)
    retval = spike_fun(x)
    pullback_spike_fun(y) = NoTangent(), y / (1.0 + 100.0 * abs(x))^2
    return retval, pullback_spike_fun
end

In [None]:
N = 101
x = range(-0.2, 0.2, N)

use_withgradient = true
use_gradient = true

fun(x) = 2*x + spike_fun(x)

if use_withgradient
    y = zeros(size(x))
    dy = zeros(size(x))
    for i in 1:N
        y[i],tmp = Flux.withgradient(fun, x[i])
        dy[i] = tmp[1]
    end
elseif use_gradient
    y = fun.(x)
    dy = zeros(size(x))
    for i in 1:N
        dy[i] = gradient(fun, x[i])[1]
    end
else
    y = zeros(size(x))
    dy = zeros(size(x))
    for i in 1:N
        y[i],back = pullback(spike_fun, x[i])
        dy[i] = back(1.)[1]
    end
end

plot(x, y, color=:black, lw=2, label=nothing, size=(400,200))
plot!(x, dy, color=:magenta, lw=2, label=nothing)

In [None]:
rng = MersenneTwister(1983)
N = 10
w = randn(rng, Float64, N)
x = zeros(Int, N)
x[rand(rng, Float64, N) .> 0.5] .= 1
y = 1;

In [None]:
@printf("Initial weights:\n")
for ww in w
    @printf("%g ", ww)
end
@printf("\n")

In [None]:
loss_fun(w, x, y) = (y - spike_fun(w'⋅x)) ^ 2
loss_fun(w, x) = spike_fun(w'⋅x)

In [None]:
η = 10
optimizer = Flux.Descent(η)
pars = Flux.Params([w])
loss = []
@showprogress for i in 1 : 1_000
    l,grad = Flux.withgradient(pars) do
        #loss_fun(w, x, y)
        (y - loss_fun(w,x)) ^ 2
    end
    Flux.Optimise.update!(optimizer, pars, grad)
    push!(loss, l)
end

In [None]:
@printf("Final weights:\n")
for ww in w
    @printf("%g ", ww)
end
@printf("\n")

In [None]:
plot(loss, lw=3, color=:black, label=nothing, size=(400,250), xlabel="Epoch", ylabel="Loss")