# MNIST with a simple *MLP*

In [1]:
using Knet

In [2]:
include(Knet.dir("data","mnist.jl"))
xtrn, ytrn, xtst, ytst = mnist();
dtrn = minibatch(xtrn, ytrn, 100);
dtst = minibatch(xtst, ytst, 100);

[1m[36mINFO: [39m[22m[36mLoading MNIST...
[39m

In [3]:
features = prod(size(xtrn)[1:2]);
classifiers = 10;
weights = [128, 64];

In [4]:
function predict(w, x)
    for i=1:2:length(w)
        x = w[i]*mat(x) .+ w[i+1]
        if i < length(w)-1
            x = relu.(x)
        end
    end
    return x
end

predict (generic function with 1 method)

In [5]:
loss(w,x,y) = nll(predict(w,x),y)
lossgradient = grad(loss);

In [6]:
function train(w, dtrn; lr=.1, epochs=20)
    for epoch= 1:epochs
        for (x,y) in dtrn
            g = lossgradient(w, x, y)
            update!(w, g; lr=lr)
        end
    end
    return w
end

train (generic function with 1 method)

In [7]:
function make_weights(h...; atype=Array{Float32}, winit=0.1)
    w = Any[]
    x = features
    for y in [h..., classifiers]
        push!(w, convert(atype, winit*randn(y,x)))
        push!(w, convert(atype, zeros(y, 1)))
        x=y
    end
    return w
end

make_weights (generic function with 1 method)

In [13]:
epochs = 10;
learning_rate = 0.1;

In [14]:
@printf("x training size: %s\ny training size: %s\n",size(xtrn), size(xtst))

x training size: (28, 28, 1, 60000)
y training size: (28, 28, 1, 10000)


In [15]:
w = make_weights(weights...);

In [16]:
println((:epoch, 0, :trn, accuracy(w,dtrn,predict), :test, accuracy(w,dtst,predict)))
@time for epoch=1:epochs
   train(w, dtrn; lr=learning_rate, epochs=1)
   println((:epoch, epoch, :trn, accuracy(w,dtrn,predict), :test, accuracy(w,dtst,predict)))
end

(:epoch, 0, :trn, 0.09523333333333334, :test, 0.0941)
(:epoch, 1, :trn, 0.9195166666666666, :test, 0.9218)
(:epoch, 2, :trn, 0.9452, :test, 0.9434)
(:epoch, 3, :trn, 0.9581, :test, 0.9554)
(:epoch, 4, :trn, 0.9651833333333333, :test, 0.9598)
(:epoch, 5, :trn, 0.9699, :test, 0.9644)
(:epoch, 6, :trn, 0.9737, :test, 0.9657)
(:epoch, 7, :trn, 0.9764333333333334, :test, 0.9675)
(:epoch, 8, :trn, 0.9793333333333333, :test, 0.9693)
(:epoch, 9, :trn, 0.9817833333333333, :test, 0.9702)
(:epoch, 10, :trn, 0.98395, :test, 0.9703)
 21.859452 seconds (5.66 M allocations: 10.432 GiB, 3.27% gc time)
