In [None]:
# Packages
using Knet, AutoGrad, LinearAlgebra, Base.Iterators, Statistics, Random

In [None]:
# Constants
ARRAY=Array{Float64} # KnetArray{Float32}
BSIZE=1
XSIZE=28*28
YSIZE=10
HSIZE=[64]
ALPHA=100.0
GAMMA=0.0001
LAMBDA=0.995

In [None]:
# Load minibatched MNIST data:
include(Knet.dir("data","mnist.jl"))
dtrn, dtst = mnistdata(xtype=ARRAY, batchsize=BSIZE);

In [None]:
# Model definition and initialization
struct MLP; W; b; μ; B; g; ∇g;
    function MLP(dims...;α=ALPHA)
        new(initw.(dims[1:end-1],dims[2:end]),
            initb.(dims[2:end]),
            initμ(dims[end-1],dims[end]),
            initB(dims[end-1],dims[end],α=α),
            initg(dims[end-1],dims[end],α=α), 
            init∇g(dims[end-1],dims[end]))
    end
end

initw(i,o)=Param(ARRAY(xavier(o,i)))
initb(o)=Param(ARRAY(zeros(o,1)))
initμ(h,o)=ARRAY(zeros(h,o))
initB(h,o;α=ALPHA)=(B = zeros(h,h,o); for i in 1:o, j in 1:h; B[j,j,i] = α; end; ARRAY(B))
initg(h,o;α=ALPHA)=ARRAY([-h*o*log(α)])
init∇g(h,o)=ARRAY(zeros(h,1))

Base.show(io::IO, m::MLP)=print(IOContext(io,:compact=>true), "MLP", (size(m.W[1],2),length.(m.b)...))

In [None]:
# Predict and loss functions

function featurevector(m::MLP,x)
    L,y = length(m.W),mat(x)
    for l in 1:L-1
        y = relu.(m.b[l] .+ m.W[l] * y)
    end
    return y
end

function (m::MLP)(x) # predict
    m.b[end] .+ m.W[end] * featurevector(m,x)
end

function (m::MLP)(x,labels; γ=GAMMA, update=true) # loss
    yfeat = featurevector(m,x)
    ypred = m.b[end] .+ m.W[end] * yfeat
    J = nll(ypred,labels)  # per instance average negative log likelihood loss
    g = sumlogdet(yfeat,labels,m; update=update)
    return J + γ * g
end

In [None]:
# Regularization function and its derivative; assume batchsize=1 for now
function sumlogdet(y,labels,m; λ=LAMBDA, update=false)
    @assert length(labels)==1 "Batchsize > 1 not implemented yet."

    λ = convert(eltype(y),λ)
    β = labels[1]   # β(n) class label for the nth sample
    μ = m.μ[:,β:β]  # μ[β(n)](n-1) exponentially weighted mean of class β(n) before the nth sample
    B = m.B[:,:,β]  # B[β(n)](n-1) exponentially weighted inverse covariance matrix of class β(n) before the nth sample
    
    y0 = y - μ      # ybar[L-1](n) the centralized feature vector
    z = B * y0      # unscaled gradient
    ξ = 1 / ((1/(1-λ)) + (y0' * B * y0)[1])  # gradient scaling
    B2 = (1/λ)*(B - z*z'*ξ)  # updated inverse covariance matrix
    g = m.g[1] + logdet(B) - logdet(B2)  # updated -sumlogdet(B)

    if training()  # Store gradient if differentiating
        m.∇g .= 2 * ξ * z
    end
    
    if update      # Update state if specified
        m.g[1] = g
        m.B[:,:,β] .= B2
        m.μ[:,β:β] .= λ * μ + (1-λ) * y
    end

    return g
end

function sumlogdetback(m)
    m.∇g
end

@primitive sumlogdet(y,labels,model;o...),dy  dy*sumlogdetback(m)

In [None]:
# Run experiments with different hyperparameters (using dtst because it is small)
HSIZE=[64]
GAMMA=0.0001
LAMBDA=0.995
ALPHA=100.0
Random.seed!(1)
m = MLP(XSIZE,HSIZE...,YSIZE)
progress!(adam(m,dtst))
accuracy(m,dtst)