In [44]:
# KnetArray{Float64, 3}(undef, (2,3,4));
# a = [KnetArray{Float64,3}(undef,(10, 1, 1)),
#     KnetArray{Float64,3}(undef,(1, 784, 1)),
#     KnetArray{Float64,3}(undef,(1, 1, 100))
#     ];

In [2]:
using Knet, LinearAlgebra, AutoGrad
using AutoGrad: full

┌ Info: Precompiling Knet [1902f260-5fb4-5aff-8c31-6271790ab950]
└ @ Base loading.jl:1273


In [4]:
mutable struct SM3
    lr#::AbstractFloat
    eps#::AbstractFloat
    gclip#::AbstractFloat
    dims
#     momentum::AbstractArray{AbstractFloat}
    accumulators#::AbstractArray{AbstractFloat}
end

In [12]:
SM3(; lr=0.001, eps=1e-30, gclip=0.0) = SM3(lr, eps, gclip, nothing, nothing)
sm3(f,d; lr=0.001, eps=1e-30, gclip=0.0,o...) = Knet.minimize(f,d,SM3(lr, eps, gclip, nothing, nothing))
sm3!(x...;o...) = for y in sm3(x...;o...); end
Knet.clone(a::SM3) = SM3(a.lr, a.eps, a.gclip, nothing, nothing)

In [104]:
for T in (Array{Float32},Array{Float64},KnetArray{Float32},KnetArray{Float64}); @eval begin
    function Knet.update!(w::$T, g, p::SM3)
        Knet.gclip!(g, p.gclip)
        g = full(g)
        if p.accumulators==nothing; 
            p.dims=size(w);
            p.accumulators=[KnetArray(zeros(Float32, _shape_for_broadcasting(p.dims, i))) for i in 1:length(p.dims)];
        end
        accumulator = _compute_past_accumulator(p.accumulators, p.dims)
        accumulator .+= g.*g
        #TODO: Add momentum tensor for scaled gradient
        axpy!(-p.lr, g./(sqrt.(accumulator .+ p.eps)), w)
        #TODO: Add accumulator updates
    end
end;end

In [105]:
function _shape_for_broadcasting(dims, desired)
    rank = length(dims)
    return tuple([i==desired ? dims[i] : 1 for i in 1:rank]...)
end

_shape_for_broadcasting (generic function with 1 method)

In [106]:
function _compute_past_accumulator(accumulators, dims)
    rank = length(dims)
    accumulators_for_broadcasting = [
        reshape(accumulators[i], _shape_for_broadcasting(dims, i))
        for i in 1:rank]
    
    result = accumulators_for_broadcasting[1]
#     return result
    # Check if min is doing for number of elmns.
    for i in 1:rank
        result = min.(result, accumulators_for_broadcasting[i])
    end
    return result
    
end

_compute_past_accumulator (generic function with 1 method)

In [107]:
# a = [KnetArray{Float64,3}(undef,(10, 1, 1)),
#     KnetArray{Float64,3}(undef,(1, 784, 1)),
#     KnetArray{Float64,3}(undef,(1, 1, 100))
#     ];
# b = [10, 784, 100];
# _compute_past_accumulator(a, b)

In [111]:

# Define convolutional layer:
struct Conv; w; b; f; end
(c::Conv)(x) = c.f.(pool(conv4(c.w, x) .+ c.b))
Conv(w1,w2,cx,cy,f=relu) = Conv(param(w1,w2,cx,cy), param0(1,1,cy,1), f)

# Define dense layer:
struct Dense; w; b; f; end
(d::Dense)(x) = d.f.(d.w * mat(x) .+ d.b)
Dense(i::Int,o::Int,f=relu) = Dense(param(o,i), param0(o), f)

# Define a chain of layers and a loss function:
struct Chain; layers; end
(c::Chain)(x) = (for l in c.layers; x = l(x); end; x)
(c::Chain)(x,y) = nll(c(x),y)

# Load MNIST data:
include(Knet.dir("data","mnist.jl"))
dtrn, dtst = mnistdata()

# Define, train and test LeNet (about 30 secs on a gpu to reach 99% accuracy)
LeNet = Chain((Conv(5,5,1,20), Conv(5,5,20,50), Dense(800,500), Dense(500,10,identity)))
sm3!(LeNet, repeat(dtrn,10))
accuracy(LeNet, dtst)

0.9272