In [117]:
# Packages
using Knet, AutoGrad, LinearAlgebra, Base.Iterators, Statistics, Random

In [118]:
# Constants
ENV["COLUMNS"] = 64
ARRAY=Array{Float64} # KnetArray{Float32}
UPDATE=true # keep this true (false only useful for checking gradients)
BSIZE=1     # keep batchsize=1 until larger ones supported
XSIZE=28*28
YSIZE=10
HSIZE=[64]
ALPHA=100.0
GAMMA1=0.0001
GAMMA2=0.01
LAMBDA=0.995
ETA=0.1
MU0=0.0001

0.0001

In [119]:
# Load minibatched MNIST data:
include(Knet.dir("data","mnist.jl"))
dtrn, dtst = mnistdata(xtype=ARRAY, batchsize=BSIZE)
xtrn, ytrn, xtst, ytst = mnist()
xtrn = ARRAY(reshape(xtrn,(XSIZE,:)))
xtst = ARRAY(reshape(xtst,(XSIZE,:)));

In [120]:
# Model definition and initialization
struct MLP; W; b; μ; B; g1; ∇g1; g2; ∇g2;
    function MLP(dims...;α=ALPHA)
        h,o = dims[end-1:end]
        W = initw.(dims[1:end-1],dims[2:end])
        b = initb.(dims[2:end])
        μ = initμ(h,o)
        B = initB(h,o;α=α)
        g1 = initg1(B)
        ∇g1 = init∇g1(h)
        g2 = initg2(μ)
        ∇g2 = init∇g2(h)
        new(W, b, μ, B, g1, ∇g1, g2, ∇g2)
    end
end

initw(i,o)=Param(ARRAY(xavier(o,i)))
initb(o)=Param(ARRAY(zeros(o,1)))
initμ(h,o)=ARRAY(MU0*randn(h,o))
initB(h,o;α=ALPHA)=(B = zeros(h,h,o); for i in 1:o, j in 1:h; B[j,j,i] = α; end; ARRAY(B))
initg1(B)=[ -sum(logdet.(B[:,:,i] for i in 1:size(B,3))) ]
init∇g1(h)=ARRAY(zeros(h,1))
initg2(μ)=((d,n)=(0,size(μ,2));for i=1:n-1,j=i+1:n;d-=log(norm(μ[:,i]-μ[:,j])^2);end;[d])
init∇g2(h)=ARRAY(zeros(h,1))

Base.show(io::IO, m::MLP)=print(IOContext(io,:compact=>true), "MLP", (size(m.W[1],2),length.(m.b)...))

In [121]:
# Featurevec, predict and loss functions
function featurevector(m::MLP,x)
    L,y = length(m.W),mat(x)
    for l in 1:L-1
        y = relu.(m.b[l] .+ m.W[l] * y)
    end
    return y
end

function (m::MLP)(x) # predict
    m.b[end] .+ m.W[end] * featurevector(m,x)
end

function (m::MLP)(x,labels;γ1=GAMMA1,γ2=GAMMA2) # loss
    @assert length(labels)==1 "Batchsize > 1 not implemented yet."
    yfeat = featurevector(m,x)
    ypred = m.b[end] .+ m.W[end] * yfeat
    J = nll(ypred,labels)
    g1 = sumlogdet(yfeat,labels,m)
    g2 = meandist(yfeat,labels,m)
    return J + γ1 * g1 + γ2 * g2
end

In [122]:
# computes and returns g1 = ∑ logdet(Ci) = -Σ logdet(Bi)
# computes m.∇g1 if training()
# updates m.g1 and m.B if update=TRUE
function sumlogdet(y,labels,m; λ=LAMBDA, η=ETA, update=UPDATE)
    β = labels[1]   # β(n) class label for the nth sample
    μ = m.μ[:,β:β]  # μ[β(n)](n-1) exponentially weighted mean of class β(n) before the nth sample
    B = m.B[:,:,β]  # B[β(n)](n-1) exponentially weighted inverse covariance matrix of class β(n) before the nth sample
    
    y0 = y - μ      # ybar[L-1](n) the centralized feature vector
    z = B * y0      # unscaled gradient
    κ = (1-λ)*λ
    ξ = 1 / ((1/(1-λ)) + (y0' * B * y0)[1])  # gradient scaling
    A = (1/λ)*(B - z*z'*ξ)  
    B2 = A-(1-λ)*η*A*A/(1+(1-λ)*η*tr(A))  # updated inverse covariance matrix  
    g1 = m.g1[1] + logdet(B) - logdet(B2) # updated -sumlogdet(B)

    if training()  # Store gradient if differentiating
        m.∇g1 .= 2 * κ * B2 * y0
    end
    
    if update      # Update state if specified
        m.g1[1] = g1
        m.B[:,:,β] .= B2
    end

    return g1
end

@primitive sumlogdet(y,l,m;o...),dy  dy*m.∇g1

In [123]:
# computes and returns g2 = -Σ log |μi-μj|^2
# computes m.∇g2 if training()
# updates m.g2 and m.μ if update=TRUE
function meandist(y,labels,m; λ=LAMBDA, update=UPDATE)
    M = size(m.μ,2) # number of classes
    β = labels[1]   # β(n) class label for the nth sample
    μ1 = m.μ[:,β:β] # μ[β(n)](n-1) exponentially weighted mean of class β(n) before the nth sample
    μ2 = λ * μ1 + (1-λ) * y   # updated mean
    g2 = 0
    if training(); m.∇g2 .= 0; end
    for k=1:M
        if (k!=β)
            olddist = norm(m.μ[:,k:k]-μ1)^2
            newdist = norm(m.μ[:,k:k]-μ2)^2
            g2 = g2 + log(olddist) - log(newdist)
            if training()
                m.∇g2 .+= (2 * (1-λ) / newdist) * (m.μ[:,k:k]-μ2)
            end
        end
    end    
    if update
        m.g2[1] = g2
        m.μ[:,β:β] .= μ2
    end
    return g2
end

@primitive meandist(y,l,m;o...),dy  dy*m.∇g2

In [140]:
# Experiment 1: check model functions
UPDATE=false
(x,labels) = first(dtrn)
m = MLP(XSIZE,HSIZE...,YSIZE)
@show x |> summary
@show labels
@show (y = featurevector(m,x)) |> summary
@show (scores = m(x)) |> summary
@show J=nll(scores,labels)
@show g1=sumlogdet(y,labels,m)
@show g2=meandist(y,labels,m)
@show J + GAMMA1 * g1 + GAMMA2 * g2
@show m(x,labels)
UPDATE=true;

x |> summary = "28×28×1×1 Array{Float64,4}"
labels = UInt8[0x05]
(y = featurevector(m, x)) |> summary = "64×1 Array{Float64,2}"
(scores = m(x)) |> summary = "10×1 Array{Float64,2}"
J = nll(scores, labels) = 2.4778561096154585
g1 = sumlogdet(y, labels, m) = -2945.9880482939625
g2 = meandist(y, labels, m) = -35.75083284918922
J + GAMMA1 * g1 + GAMMA2 * g2 = -30.557107658243087
m(x, labels) = -30.557107658243087


In [129]:
# Experiment 2: check gradients
using AutoGrad: @gcheck, gcheck
(x,labels) = first(dtrn)
m = MLP(XSIZE,HSIZE...,YSIZE)
y = featurevector(m,x)
py = Param(y)
UPDATE=false
@show @gcheck sumlogdet(py,labels,m)
@show @gcheck meandist(py,labels,m)
@show @gcheck nll(m(x),labels)
@show @gcheck m(x,labels)
UPDATE=true;

#= In[129]:8 =# @gcheck(sumlogdet(py, labels, m)) = true
#= In[129]:9 =# @gcheck(meandist(py, labels, m)) = true
#= In[129]:10 =# @gcheck(nll(m(x), labels)) = true
#= In[129]:11 =# @gcheck(m(x, labels)) = true


In [138]:
# Experiment 3: train one epoch with regularization
Random.seed!(1)
m = MLP(XSIZE,HSIZE...,YSIZE)
GAMMA1,GAMMA2=0.01,0.1
progress!(adam(m,dtst))
(acc=accuracy(m,dtst),nll=nll(m(xtst),ytst),g1=initg1(m.B)[1],g2=initg2(m.μ)[1]) 

-5.96e+00  100.00%┣████████┫ 10000/10000 [00:28/00:28, 362.85i/s]


(acc = 0.919, nll = 0.27138646039786907, g1 = -595.6611794864707, g2 = -235.20288970750443)

In [131]:
# Experiment 4: train one epoch without regularization
Random.seed!(1)
m = MLP(XSIZE,HSIZE...,YSIZE)
GAMMA1,GAMMA2 = 0,0
progress!(adam(m,dtst))
(acc=accuracy(m,dtst),nll=nll(m(xtst),ytst),g1=initg1(m.B)[1],g2=initg2(m.μ)[1]) 

2.56e-05  100.00%┣█████████┫ 10000/10000 [00:29/00:29, 347.50i/s]


(acc = 0.9198, nll = 0.2600133985319897, g1 = -647.7278184729371, g2 = -232.03334179923695)