In [85]:
# Packages
using Knet, AutoGrad, LinearAlgebra, Base.Iterators, Statistics

In [133]:
# Constants
ARRAY=Array{Float64} # KnetArray{Float32}
BSIZE=1
XSIZE=28*28
YSIZE=10
HSIZE=64
ALPHA=100.0
GAMMA=0.0001
LAMBDA=0.999

0.999

In [3]:
# Load MNIST data:
include(Knet.dir("data","mnist.jl"))
dtrn, dtst = mnistdata(xtype=ARRAY, batchsize=BSIZE);

┌ Info: Loading MNIST...
└ @ Main /kuacc/users/dyuret/.julia/dev/Knet/data/mnist.jl:33


In [134]:
# Model definition and initialization
struct TwoLayerMLP; w1; b1; w2; b2; μ; B; g; ∇g; end

TwoLayerMLP(i,h,o;α=ALPHA)=TwoLayerMLP(
    initw(i,h),initb(h),
    initw(h,o),initb(o),
    initμ(h,o),
    initB(h,o,α=α),
    initg(h,o,α=α), init∇g(h,o)    
)

initw(i,o)=Param(ARRAY(xavier(o,i)))
initb(o)=Param(ARRAY(zeros(o,1)))
initμ(h,o)=ARRAY(zeros(h,o))
init∇g(h,o)=ARRAY(zeros(h,1))
initg(h,o;α=ALPHA)=[-h*o*log(α)]

function initB(h,o;α=ALPHA)
    B = zeros(h,h,o)
    for i in 1:o, j in 1:h
        B[j,j,i] = α
    end
    return ARRAY(B)
end

initB (generic function with 1 method)

In [5]:
# Predict and loss functions
function (m::TwoLayerMLP)(x) # predict
    m.b2 .+ m.w2 * relu.(m.b1 .+ m.w1 * mat(x))
end

function (m::TwoLayerMLP)(x,labels; γ=GAMMA) # loss
    y1 = relu.(m.b1 .+ m.w1 * mat(x))
    y2 = m.b2 .+ m.w2 * y1
    J = nll(y2,labels)  # per instance average negative log likelihood loss
    g = sumlogdet(y1,labels,m,update=true)
    return J + γ * g
end

In [125]:
# Regularization function and its derivative; assume batchsize=1 for now
function sumlogdet(y,labels,m; λ=LAMBDA, update=false)
#    global B,B2,β,μ,y0,z,ξ,g # DBG
    # TODO: handle batchsize > 1
    @assert length(labels)==1 "Batchsize > 1 not implemented yet."
    
    β = labels[1]   # β(n) class label for the nth sample
    μ = m.μ[:,β:β]  # μ[β(n)](n-1) exponentially weighted mean of class β(n) before the nth sample
    B = m.B[:,:,β]  # B[β(n)](n-1) exponentially weighted inverse covariance matrix of class β(n) before the nth sample
    
    y0 = y - μ      # ybar[L-1](n) the centralized feature vector
    z = B * y0      # unscaled gradient
    ξ = 1 / ((1/(1-λ)) + (y0' * B * y0)[1])  # gradient scaling
    B2 = (1/λ)*(B - z*z'*ξ)  # updated inverse covariance matrix
    g = m.g[1] + logdet(B) - logdet(B2)  # updated -sumlogdet(B)

    if training()  # Store gradient if differentiating
        m.∇g .= 2 * ξ * z
    end
    
    if update      # Update state if specified
        m.g[1] = g
        m.B[:,:,β] .= B2
        m.μ[:,β:β] .= λ * μ + (1-λ) * y
    end

    return g
end

function sumlogdetback(m)
    m.∇g
end

@primitive sumlogdet(y1,y,m;o...)  sumlogdetback(m)

In [60]:
# Check gradients
using AutoGrad: @gcheck
m = TwoLayerMLP(XSIZE,HSIZE,YSIZE)
x,y = first(dtrn)
@gcheck m(x,y)

true

In [61]:
# Check gradients
y1 = Param(relu.(m.b1 .+ m.w1 * mat(x)))
@gcheck sumlogdet(y1,y,m)

true

In [63]:
# Train model m1 without regularization
m1 = TwoLayerMLP(XSIZE,HSIZE,YSIZE)
pred1(x)=m1.b2 .+ m1.w2 * relu.(m1.b1 .+ m1.w1 * mat(x))
loss1(x,y)=nll(pred1(x), y)
trn100,tst100= mnistdata(xtype=ARRAY, batchsize=100)
progress!(adam(loss1,repeat(trn100,100)))
accuracy(pred1,trn100),accuracy(pred1,tst100)

1.44e-05  100.00%┣███████████████████████████████████┫ 60000/60000 [01:52/01:52, 533.53i/s]


In [126]:
# Compute mean and covariance without regularization
(xtrn,ytrn,xtst,ytst) = mnist()
x60k = ARRAY(reshape(xtrn,28*28,:))
h60k = relu.(m1.b1 .+ m1.w1 * x60k)
h = Any[ h60k[:,ytrn.==i] for i in 1:10 ]
μ = Any[ mean(h[i],dims=2) for i in 1:10 ]
C = Any[ (h0=h[i] .- μ[i]; h0 * h0' / size(h0,2)) for i in 1:10 ]
B = Any[ inv(C[i]) for i in 1:10 ]
@show norm.(μ)
@show logdet.(C)
@show logdet.(B);

norm.(μ) = [19.9217, 27.9571, 30.2238, 26.5543, 26.1994, 26.1632, 26.3726, 25.2343, 25.3175, 34.8589]
logdet.(C) = [-57.7199, 73.4593, 61.9336, 28.543, 58.8477, 25.8958, 33.7952, 55.8698, 32.4212, 31.5816]
logdet.(B) = [57.7199, -73.4593, -61.9336, -28.543, -58.8477, -25.8958, -33.7952, -55.8698, -32.4212, -31.5816]


In [135]:
# Do we converge to the right inv cov with our updates?
m1.μ .= initμ(HSIZE,YSIZE)
m1.B .= initB(HSIZE,YSIZE)
# just compute loss and update μ,B, do not change weights
progress!(m1(x,y) for (x,y) in take(dtrn,10000))

-1.55e+00  0.01%┣                                         ┫ 1/10000 [00:00/08:09, 20.44i/s]-1.47e+00  3.28%┣█▏                                    ┫ 328/10000 [00:01/00:32, 312.19i/s]-1.42e+00  6.84%┣██▌                                   ┫ 684/10000 [00:02/00:30, 333.51i/s]-1.39e+00  10.46%┣███▊                                ┫ 1046/10000 [00:03/00:29, 342.81i/s]-1.36e+00  14.17%┣█████                               ┫ 1417/10000 [00:04/00:29, 349.65i/s]-1.35e+00  17.85%┣██████▍                             ┫ 1785/10000 [00:05/00:28, 353.18i/s]-1.34e+00  21.52%┣███████▋                            ┫ 2152/10000 [00:06/00:28, 355.34i/s]-1.33e+00  25.29%┣█████████                           ┫ 2529/10000 [00:07/00:28, 358.36i/s]-1.32e+00  29.00%┣██████████▍                         ┫ 2900/10000 [00:08/00:28, 359.85i/s]-1.31e+00  32.69%┣███████████▊                        ┫ 3269/10000 [00:09/00:28, 360.85i/s]-1.30e+00  36.28%┣█████████████                       ┫ 3628/10000 [00:10/00:28

In [136]:
# Compare μ
for i in 1:10
    println((real=norm(μ[i]),pred=norm(m1.μ[:,i]),diff=norm(μ[i]-m1.μ[:,i])))
end

(real = 19.921662550809856, pred = 13.828536406381236, diff = 6.1283650774635925)
(real = 27.957112791647585, pred = 17.54194278401869, diff = 10.47183825571481)
(real = 30.223848099072185, pred = 19.857264495487815, diff = 10.40106992209713)
(real = 26.554283326683308, pred = 16.822313098196727, diff = 9.754920931114366)
(real = 26.199362447537467, pred = 14.695964602910797, diff = 11.582652584149859)
(real = 26.163237587245103, pred = 16.80837921156584, diff = 9.391836152417179)
(real = 26.37258752026844, pred = 17.431162006128584, diff = 8.966467061709775)
(real = 25.234255806222098, pred = 15.299735244052941, diff = 9.955916646167811)
(real = 25.317502304304885, pred = 15.97361696714776, diff = 9.363439789986582)
(real = 34.858885282642596, pred = 22.27957592138737, diff = 12.597850480976998)


In [137]:
# Compare B
for i in 1:10
    println((real=norm(B[i]),pred=norm(m1.B[:,:,i]),diff=norm(B[i]-m1.B[:,:,i])))
end

(real = 127.58358854028396, pred = 172.15031629690796, diff = 66.30301061572978)
(real = 7.078215930093306, pred = 12.859943950804903, diff = 6.328364212340608)
(real = 8.284792935381866, pred = 14.003642747773327, diff = 6.232317499641436)
(real = 64.91594464678964, pred = 107.57085562264054, diff = 46.079288440718734)
(real = 8.704410068496124, pred = 17.736126463600105, diff = 9.519552282459005)
(real = 20.93898661955777, pred = 38.85031054020234, diff = 19.270871029529083)
(real = 19.817326109905256, pred = 31.73697979969098, diff = 12.960315278752782)
(real = 8.76908949311271, pred = 15.740594044692417, diff = 7.559772619780569)
(real = 13.117170583006624, pred = 23.84868435150763, diff = 11.521579131182111)
(real = 24.252019372716877, pred = 51.597795434523505, diff = 29.897922747894942)


In [138]:
# Compare logdet(B)
for i in 1:10
    println((real=logdet(B[i]),pred=logdet(m1.B[:,:,i])))
end

(real = 57.71992034738368, pred = 80.72109893430219)
(real = -73.45927850158313, pred = -41.757905414556944)
(real = -61.93363597624913, pred = -34.09119813286764)
(real = -28.542967439463844, pred = 2.0328763533353187)
(real = -58.8476822640747, pred = -19.421990843425476)
(real = -25.895756131576373, pred = 5.675230355276158)
(real = -33.795214178715895, pred = -7.603199417642858)
(real = -55.869844755435146, pred = -24.121762831176305)
(real = -32.421167117789686, pred = -2.2190724590249538)
(real = -31.58162687406806, pred = -2.1537291232839673)


# JUNK below this line

In [78]:
a,b = mnistdata(xtype=ARRAY,batchsize=60000)
length(a)

1

In [43]:
det(B),det(z*z'*ξ)

(3.86915974358613e305, 0.0)

In [53]:
[logdet(mlp.B[:,:,i]) for i in 1:size(mlp.B,3)]

10-element Array{Float64,1}:
 -147.36544595161894
 -147.36544595161894
 -147.36544595161894
 -147.36544595161894
 -147.36544595161894
 -147.36544595161894
 -147.36544595161894
 -147.36544595161894
 -147.36544595161894
 -147.36544595161894

In [47]:
logdet(B),logdet(B2)

(704.8692515439695, 711.1449552901472)

In [58]:
ENV["COLUMNS"]=90
mlp.μ

64×10 Array{Float64,2}:
 0.00496261   0.0459837    0.0645267    …  0.0486146    0.000528384  0.0470241  
 0.000158066  0.0615158    0.00209705      0.00182906   0.000214186  0.00959594 
 0.0          0.0          0.0             0.0          0.0          0.0        
 0.000113193  0.00301375   2.24206e-5      6.63593e-6   2.65391e-5   0.00838022 
 0.0          0.0          0.0             0.0          0.0          2.38151e-8 
 9.23947e-5   0.0171907    0.000137434  …  0.0033648    0.00053961   0.0133446  
 8.78306e-6   0.0140433    0.0106992       0.00625449   0.00307099   0.0602241  
 0.000100281  9.24048e-5   4.49226e-5      0.000150243  0.000770539  0.00351912 
 0.000707857  0.00310398   0.0136447       0.001852     0.00436781   0.00851628 
 8.74099e-5   2.04788e-5   4.0206e-5       1.39807e-5   0.00024543   0.0569637  
 0.0868852    0.000464658  2.16367e-5   …  0.00260179   4.693e-5     3.36568e-6 
 0.290715     0.00169766   3.07977e-5      0.00347629   5.90167e-5   1.5863e-5  
 0.0

In [139]:
GAMMA=0.0
mlp = TwoLayerMLP(XSIZE,HSIZE,YSIZE)
progress!(mlp.g[1] for x in adam(mlp, take(dtrn,5000)))

-2.95e+03  0.02%┣                                          ┫ 1/5000 [00:00/02:11, 38.19i/s]-2.93e+03  4.86%┣█▉                                     ┫ 243/5000 [00:01/00:21, 236.12i/s]-2.90e+03  10.00%┣███▊                                  ┫ 500/5000 [00:02/00:20, 246.35i/s]-2.86e+03  15.14%┣█████▊                                ┫ 757/5000 [00:03/00:20, 249.63i/s]-2.83e+03  20.32%┣███████▌                             ┫ 1016/5000 [00:04/00:20, 251.82i/s]-2.80e+03  25.28%┣█████████▎                           ┫ 1264/5000 [00:05/00:20, 251.04i/s]-2.78e+03  30.44%┣███████████▎                         ┫ 1522/5000 [00:06/00:20, 252.06i/s]-2.76e+03  35.60%┣█████████████▏                       ┫ 1780/5000 [00:07/00:20, 252.83i/s]-2.75e+03  40.88%┣███████████████▏                     ┫ 2044/5000 [00:08/00:20, 254.06i/s]-2.74e+03  46.04%┣█████████████████                    ┫ 2302/5000 [00:09/00:20, 254.45i/s]-2.73e+03  51.20%┣██████████████████▉                  ┫ 2560/5000 [00:10/00:20

In [25]:
accuracy(mlp,dtrn),accuracy(mlp,dtst)

(0.68285, 0.6743)

In [None]:
mlp = TwoLayerMLP(XSIZE,HSIZE,YSIZE)
(x,y) = first(dtrn)
mlp(x,y)

In [None]:
# numerical gradient check:
(x,y) = first(dtrn)
m = TwoLayerMLP(XSIZE,HSIZE,YSIZE)
@show y1 = relu.(m.b1 .+ m.w1 * mat(x))
p1 = Param(y1)
@show J = @diff sumlogdet(p1,y,m)
@show grad(J,p1)
ϵ = 1e-4
for i in 1:length(y1)
    y1i = y1[i]
    y1[i] = y1i + ϵ
    f1 = sumlogdet(y1,y,m)
    y1[i] = y1i - ϵ
    f2 = sumlogdet(y1,y,m)
    println((i,((f1-f2)/2ϵ)))
    y1[i] = y1i
end

In [None]:
@show rand(5)

In [None]:
using AutoGrad: @gcheck
@gcheck sumlogdet(Param(y1),y,m)

In [None]:
mlp = TwoLayerMLP(XSIZE,HSIZE,YSIZE)
progress!(adam(mlp, repeat(dtrn,10)))
accuracy(mlp,dtrn),accuracy(mlp,dtst)