In [None]:
using Knet, Plots, Images, ImageMagick, Statistics
Atype = gpu() >= 0 ? KnetArray{Float32} : Array{Float32}
#ENV["COLUMNS"]=80             # column width for array printing
#Plots.plotlyjs()              # for interactive plots
#Plots.scalefontsizes(1.5)     # for presentations

## Define linear model

In [None]:
# A linear model has two components: w=weightMatrix, b=biasVector.
struct Linear <: Model; w; b; end     # 1. type declaration
(f::Linear)() = (f.w, f.b)            # 2. parameter iterator
(f::Linear)(x) = (f.w*mat(x) .+ f.b)  # 3. predict function
(f::Linear)(x,y) = nll(f(x),y)        # 4. loss function

In [None]:
Linear(i::Int,o::Int;std=0.01) = Linear(pa(std*randn(o,i)), pa(zeros(o)))  # 6. constructor
pa(x) = Param(Atype(x))
f = Linear(784,10)

## Accuracy and zero-one loss

In [None]:
include(Knet.dir("data","mnist.jl"))        # Load data
xtrn,ytrn,xtst,ytst = mnist()
dtst = minibatch(xtst,ytst,100;xtype=Atype) # [ (x1,y1), (x2,y2), ... ] where xi,yi are minibatches of 100
dtrn = minibatch(xtrn,ytrn,100;xtype=Atype) # [ (x1,y1), (x2,y2), ... ] where xi,yi are minibatches of 100

In [None]:
x,y = first(dtst)
summary.((x,y))       # Take a look at the first  minibatch

In [None]:
ypred = f(x)          # predictions are given as a 10xN score matrix         

In [None]:
y'                    # correct answers are given as an array of integers

In [None]:
accuracy(ypred,y)     # accuracy gives percentage of correct answers on this minibatch

In [None]:
accuracy(f,dtst)      # or the whole dataset

In [None]:
zeroone(f,dtst)       # zeroone loss (error) defined as 1 - accuracy

## Softmax loss function

In [None]:
# Calculate softmax (cross entropy, negative log likelihood) loss of a model for one minibatch (x,y)
f(x,y)

In [None]:
# Manual loss calculation
using SparseArrays
ypred=f(x)
yp1 = exp.(ypred)
yp2 = yp1 ./ sum(yp1,dims=1)
yp3 = -log.(yp2)
yc1 = Array(sparse(y,1:100,1f0))
sum(Array(yp3).*yc1) / 100

In [None]:
(f::Linear)(d::Data) = mean(f(x,y) for (x,y) in d)   # 5. optional dataset loss function
f(dtst)             # per-instance average softmax loss for the whole test set

## Calculating the gradient using Knet

In [None]:
Knet.seed!(9)
f = Linear(784,10,std=0.1)  # use a larger std to get a larger gradient for this example

In [None]:
f(x,y)

In [None]:
J = differentiate(f,x,y)

In [None]:
value(J)

In [None]:
∇w = gradient(J,f.w)

In [None]:
∇b = gradient(J,f.b) # gradients have the same size and shape as the corresponding parameters

## Checking the gradient using numerical approximation

In [None]:
@show ∇b;
# Meaning of gradient:
# If I move the last entry of w[2] by epsilon, loss will go up by 0.345075 epsilon!

In [None]:
@show f.b;

In [None]:
f(x,y)

In [None]:
f.b[10] = 0.1   # to numerically check the gradient let's move the last entry by +0.1.
@show f.b;

In [None]:
f(x,y)
#softmax(w1,x,y,linear)  
# We see that the loss moves by +0.03 as expected.
# You should check all/most entries in your gradients this way to make sure they are correct.

In [None]:
f.b[10] = 0

## Checking the gradient using manual implementation

In [None]:
# Manually defined gradient for softloss
function softgrad_manual(w,b,x,y)
    x = mat(x)
    p = w * x .+ b
    p = p .- maximum(p,dims=1) # for numerical stability
    expp = exp.(p)
    p = expp ./ sum(expp,dims=1)
    q = oftype(p, sparse(convert(Vector{Int},y),1:length(y),1,size(p,1),length(y)))
    dJdy = (p - q) / size(x,2)
    dJdw = dJdy * x'
    dJdb = vec(sum(dJdy,dims=2))
    dJdw,dJdb
end;

In [None]:
∇w2,∇b2 = softgrad_manual(f.w,f.b,x,y)

In [None]:
∇w2 ≈ ∇w

In [None]:
∇b2 ≈ ∇b

## Training (SGD) loop

In [None]:
# Train model(w) with SGD and return a list containing w for every epoch
function train!(f,data; epochs=100,lr=0.1,record=[])
    rec()=push!(record,deepcopy(f()),f(dtrn),f(dtst),zeroone(f,dtrn),zeroone(f,dtst))
    for epoch in 1:epochs
        rec()
        for (x,y) in data
            J = differentiate(f,x,y)
            update!(f, J; lr=lr)
        end
    end
    rec()
end;

## Training the linear model and underfitting

In [None]:
if !isfile("lin.jld2")
    r = []
    @time train!(Linear(784,10),dtrn,record=r)
    r = reshape(r,5,:)
    Knet.save("lin.jld2","record",r)
else
    r = Knet.load("lin.jld2","record")
end
minimum(r[3,:]),minimum(r[5,:])  # 0.2667, 0.0744

In [None]:
plot([r[2,:], r[3,:]],ylim=(.0,.4),labels=[:trnloss :tstloss],xlabel="Epochs",ylabel="Loss") 
# Demonstrates underfitting: training loss not close to 0
# Also slight overfitting: test loss higher than train

In [None]:
plot([r[4,:], r[5,:]],ylim=(.0,.12),labels=[:trnerr :tsterr],xlabel="Epochs",ylabel="Error")  
# this is the error plot, we get to about 7.5% test error, i.e. 92.5% accuracy

## Visualizing the learned weights

In [None]:
for t in 10 .^ range(0,stop=log10(size(r,2)),length=10) #logspace(0,2,20)
    i = floor(Int,t)
    w = r[1,i]
    w1 = reshape(Array(value(w[1]))', (28,28,1,10))
    w2 = clamp.(w1.+0.5,0,1)
    IJulia.clear_output(true)
    display(hcat([mnistview(w2,i) for i=1:10]...))
    display("Epoch $i")
    sleep(1) # (0.96^i)
end