In [25]:
using Pkg

In [1]:
using Knet
using LinearAlgebra
using AutoGrad: full

In [11]:
mutable struct SM3
    lr::AbstractFloat
    eps::AbstractFloat
    gclip::AbstractFloat
    momentum::AbstractFloat
    velocity
    dims
    accumulators
end

for T in (Array{Float32},Array{Float64},KnetArray{Float32},KnetArray{Float64}); @eval begin
    function Knet.update!(w::$T, g, p::SM3)
        Knet.gclip!(g, p.gclip)
        g = full(g)
        if p.accumulators==nothing; 
            p.dims=size(w);
            p.accumulators=[KnetArray(zeros(Float32, _shape_for_broadcasting(p.dims, i))) for i in 1:length(p.dims)];
        end
        accumulator = _compute_past_accumulator(p.accumulators, p.dims)
        accumulator .+= g.*g
#         #TODO: Add momentum tensor for scaled gradient
# #         momentum_tensor = KnetArray(Float32(p.momentum) .+ zeros(Float32, size(w)))
#         if p.velocity == nothing; p.velocity = zero(w); end
#         lmul!(p.momentum, p.velocity);
        
#         scaled_g = (1.0 .- p.velocity) .* (g ./(sqrt.(accumulator .+ p.eps)))

#         if 0 < p.momentum
#             p.velocity += p.velocity .* (p.velocity .- 1.0) .+ scaled_g
#             update = p.velocity
#         else
#             update = scaled_g            
#         end
#         axpy!(p.lr, update, w)
        axpy!(-p.lr, g./(sqrt.(accumulator .+ p.eps)), w)
        p.accumulators = _accumulator_updater(p.accumulators, p.dims, accumulator)
        
    end
end;end

SM3(; lr=0.001, eps=1e-30, gclip=0.0, momentum=0.9) = SM3(lr, eps, gclip, momentum, nothing, nothing, nothing)
sm3(f,d; lr=0.001, eps=1e-30, gclip=0.0, momentum=0.9, o...) = Knet.minimize(f,d,SM3(lr, eps, gclip, momentum, nothing, nothing, nothing);o...)
sm3!(x...;o...) = for y in sm3(x...;o...); end
Knet.clone(a::SM3) = SM3(a.lr, a.eps, a.gclip, a.momentum, nothing, nothing, nothing)

In [3]:
function _shape_for_broadcasting(dims, desired)
    rank = length(dims)
    return tuple([i==desired ? dims[i] : 1 for i in 1:rank]...)
end

function _compute_past_accumulator(accumulators, dims)
    rank = length(dims)
    accumulators_for_broadcasting = [
        reshape(accumulators[i], _shape_for_broadcasting(dims, i))
        for i in 1:rank]
    
    result = accumulators_for_broadcasting[1]
#     return result
    # Check if min is doing for number of elmns.
    for i in 1:rank
        result = min.(result, accumulators_for_broadcasting[i])
    end
    return result
    
end

function _accumulator_updater(accumulators, dims, update_tensor)
    rank = length(dims)
    for i in 1:rank
        max_dims = []
#       max_dims = [i!=j ? j : for j in 1:rank]
#       TODO: Make this by array comprehension.
        for j in 1:rank
            if i!=j
                append!(max_dims, j)
            end
        end
        accumulators[i] = max.(accumulators[i], maximum(update_tensor, dims=tuple(max_dims...)))
    end
    return accumulators
end

_accumulator_updater (generic function with 1 method)

In [23]:
using Knet

# Define convolutional layer:
struct Conv; w; b; f; end
(c::Conv)(x) = c.f.(pool(conv4(c.w, x) .+ c.b))
Conv(w1,w2,cx,cy,f=relu) = Conv(param(w1,w2,cx,cy), param0(1,1,cy,1), f)

# Define dense layer:
struct Dense; w; b; f; end
(d::Dense)(x) = d.f.(d.w * mat(x) .+ d.b)
Dense(i::Int,o::Int,f=relu) = Dense(param(o,i), param0(o), f)

# Define a chain of layers and a loss function:
struct Chain; layers; end
(c::Chain)(x) = (for l in c.layers; x = l(x); end; x)
(c::Chain)(x,y) = nll(c(x),y)

# Load MNIST data:
include(Knet.dir("data","mnist.jl"))
dtrn, dtst = mnistdata(batchsize=100)

# Define, train and test LeNet (about 30 secs on a gpu to reach 99% accuracy)
LeNet = Chain((Conv(5,5,1,20), Conv(5,5,20,50), Dense(800,500), Dense(500,10,identity)))
sm3!(LeNet, repeat(dtrn,10))
accuracy(LeNet, dtst)

0.953