Goal: compare established perceptron learning rules for with a novel algorithm that aims to associate low generalization error with large fluctuations. The model and approach loosely follow Goldt, S., & Seifert, U. (2017). Thermodynamic efficiency of learning a rule in neural networks. New Journal of Physics, 19(11), 113001. https://doi.org/10.1088/1367-2630/aa89ff

In [1]:
using DifferentialEquations
using Plots
plotly()

Plots.PlotlyBackend()

In [2]:
function convert32(x)
    convert(Float32, x)
end

function force(dweights, weights, (k, learning_rate, learning_signal), t)
    dweights .= -k * weights + learning_rate * learning_signal(weights)
end


function noise(dweights, weights, (k, learning_rate, learning_signal), t)
    dweights .= 2f0
end


function hebbian_batch_signal(weights, teacher)
    n = length(weights)
    return sign.(teacher) * (1f0 / sqrt(convert32(n)))
end


function learning_signal(weights, teacher, learning_rule, batch_size)
    n = length(weights)
    examples = random_examples(batch_size, n)
    labels = sign.(examples * teacher)
    activations = (examples * weights) * (1f0 / sqrt(convert32(n)))
    return (
        examples'
        * (labels .* (learning_rule(activations, labels) * (1f0 / batch_size)))
    )
end


function hebbian_rule(activations, labels)
    return 1f0
end


function perceptron_rule(activations, labels)
    return (labels .* activations .< 0f0)
end


function adatron_rule(activations, labels)
    return abs.(activations) .* perceptron_rule(activations, labels)
end


function my_rule(activations, labels, exponent)
    prule = perceptron_rule(activations, labels)
    return prule .+ ((abs.(activations) .^ exponent) .* (.!prule))
end


function random_examples(batch_size, n)
    return rand((-1f0, 1f0), batch_size, n)
end


random_examples (generic function with 1 method)

In [3]:
const pi32 = convert32(pi)


function run(
        teacher, initial_weights, k, learning_rate, learning_rule, tspan, dt, batch_size
    )
    prob = SDEProblem(
        force,
        noise,
        initial_weights,
        tspan,
        (
            k,
            learning_rate,
            weights -> learning_signal(
                weights,
                teacher,
                learning_rule,
                batch_size,
            ),
        ),
    )
    return solve(prob, EM(); dt=dt, adaptive=false)
end


function results(sol, skip)
    return (sol.t[1:skip:end], hcat(sol.u[1:skip:end]...))
end


function generalization_error(weights, teacher)
    return (1f0 / pi32) * acos.(
        weights'
        * teacher
        ./ sqrt.((teacher' * teacher) * sum(weights .* weights; dims=1)[1, :])
    )
end

generalization_error (generic function with 1 method)

In [4]:
const n = 10000

const k = 1f0
const learning_rate = 3f0 * sqrt(convert32(n))

const tspan = (0f0, 12f0 / k)
const dt = 1f-3

const batch_size = 30

const teacher = randn(Float32, n)
const initial_weights = randn(Float32, n) * (1f0 / sqrt(k))

const learning_rules = [
    hebbian_rule,
    perceptron_rule,
    adatron_rule, 
    (
        (activations, labels) -> my_rule(activations, labels, exponent)
        for exponent in [1f0 / 4f0, 2f0 / 4f0, 3f0 / 4f0, 1f0]
    )...,
]
const learning_rule_names = [
    "hebbian",
    "perceptron",
    "adatron",
    "α = 0.25",
    "α = 0.50",
    "α = 0.75",
    "α = 1.00",
]

7-element Array{String,1}:
 "hebbian"   
 "perceptron"
 "adatron"   
 "α = 0.25"  
 "α = 0.50"  
 "α = 0.75"  
 "α = 1.00"  

In [5]:
learning_rule_results = [
    results(
        run(
            teacher,
            initial_weights,
            k,
            learning_rate,
            learning_rule,
            tspan,
            dt,
            batch_size,
        ),
        100,
    )
    for learning_rule in learning_rules
]

7-element Array{Tuple{Array{Float32,1},Array{Float32,2}},1}:
 ([0.0, 0.1, 0.2, 0.3, 0.399998, 0.499997, 0.599996, 0.699995, 0.799993, 0.899992  …  11.0008, 11.1009, 11.2009, 11.3009, 11.401, 11.501, 11.6011, 11.7011, 11.8011, 11.9012], [0.377467 0.277766 … -2.00195 -1.76512; -2.19013 -3.70453 … -6.60533 -6.52913; … ; 0.817706 0.0536202 … -0.443698 -2.41512; -1.72295 -0.127687 … 1.31788 0.482723])                 
 ([0.0, 0.1, 0.2, 0.3, 0.399998, 0.499997, 0.599996, 0.699995, 0.799993, 0.899992  …  11.0008, 11.1009, 11.2009, 11.3009, 11.401, 11.501, 11.6011, 11.7011, 11.8011, 11.9012], [0.377467 -1.06181 … -1.36179 -1.61509; -2.19013 -1.9739 … -0.730052 -0.082684; … ; 0.817706 0.936066 … -2.19775 -3.50032; -1.72295 -1.59391 … -0.528703 0.0625194])                
 ([0.0, 0.1, 0.2, 0.3, 0.399998, 0.499997, 0.599996, 0.699995, 0.799993, 0.899992  …  11.0008, 11.1009, 11.2009, 11.3009, 11.401, 11.501, 11.6011, 11.7011, 11.8011, 11.9012], [0.377467 0.294753 … -1.12796 -0.837542; -2.19013 -2

In [6]:
plot(
    learning_rule_results[1][1],
    generalization_error(learning_rule_results[1][2], teacher),
    line=(3, :dash),
    label=learning_rule_names[1],
)
for ((times, weights), names) in zip(
        learning_rule_results[2:3], learning_rule_names[2:3]
    )
    plot!(
        times,
        generalization_error(weights, teacher),
        line=(3, :dash),
        label=names,
    )
end
for ((times, weights), names) in zip(
        learning_rule_results[4:end], learning_rule_names[4:end]
    )
    plot!(
        times,
        generalization_error(weights, teacher),
        line=(3,),
        label=names,
    )
end
xaxis!("time")
yaxis!("generalization error")