# Backpropagation learning for multilayer perceptron
# with minibatches

In [None]:
using MNIST, Plots
gr(                        # GR backend for Plots
    size = (600,600),    # you may need to change the numbers to fit your screen
    legend = :none
)

In [None]:
f(x) = tanh(x)
df(y) = 1.0 - y.*y

#f(x) = 1.0./(1.0+exp(-x))       # logistic function
#df(y) = y.*(1.0-y)

n0 = 784       # widths of layers
n1 = 200
n2 = 100
n3 = 10

eta = 0.01       # learning rate parameter
epsinit = 0.01   # magnitude of initial conditions for synaptic weights

W1 = epsinit*randn(Float32,n1,n0)
W2 = epsinit*randn(Float32,n2,n1)
W3 = epsinit*randn(Float32,n3,n2)

b1 = epsinit*randn(Float32,n1,1)
b2 = epsinit*randn(Float32,n2,1)
b3 = epsinit*randn(Float32,n3,1)

tmax = 600000       # maximum number of minibatch updates
tshow = 1000         # how often to pause (# of minibatches) for visualization

errsq = zeros(tmax)
errcl = zeros(tmax);
errclvalidate = zeros(div(tmax,tshow))

# preprocess training set
train, trainlabels = traindata()
train = convert(Array{Float32,2},train)
train = train/255.0
trainlabels[trainlabels .==0] = 10
trainlabels = convert(Array{Int64,1},trainlabels)

# separate out validation set
mtotal = size(train,2)    # total number of examples in original training set
mvalidate = 10000         # desired size of validation set
mtrain = mtotal - mvalidate    # remaining examples will be new training set
srand(2017)    # seed the random number generator so that validation set is reproducible
ind = randperm(mtotal)
validate = train[:,ind[end-mvalidate+1:end]]
validatelabels = trainlabels[ind[end-mvalidate+1:end]]
train = train[:,ind[1:mtrain]]
trainlabels = trainlabels[ind[1:mtrain]]

batchsize = 32     # minibatch size

In [None]:
for t = 1:tmax     # time in number of minibatches
    batchindices = ceil(Int,mtrain*rand(batchsize))   # random minibatch
    x0 = train[:,batchindices]
    batchlabels = trainlabels[batchindices]
    y = -ones(Float32,10,batchsize)              # appropriate for tanh
#    y = zeros(Float32,10,batchsize)               # appropriate for logistic function
    for j in zip(batchlabels,1:batchsize)
        y[j...] = 1
    end

    # forward pass   
    x1 = f(W1*x0 .+ b1)
    x2 = f(W2*x1 .+ b2)
    x3 = f(W3*x2 .+ b3)
    # error computation
    errsq[t] = sum((y-x3).^2)/batchsize
    errcl[t] = mean(float(map(indmax,[x3[:,i] for i=1:batchsize]) .!= batchlabels))
    delta3 = (y-x3).*df(x3)
    # backward pass
    delta2 = (W3'*delta3).*df(x2)
    delta1 = (W2'*delta2).*df(x1)

#    W3 += (eta/batchsize)*delta3*x2'
#    W2 += (eta/batchsize)*delta2*x1'
#    W1 += (eta/batchsize)*delta1*x0'
    # following is equivalent to but faster than the above weight updates
    BLAS.gemm!('N','T',Float32(eta/batchsize),delta3,x2,1.0f0,W3)
    BLAS.gemm!('N','T',Float32(eta/batchsize),delta2,x1,1.0f0,W2)
    BLAS.gemm!('N','T',Float32(eta/batchsize),delta1,x0,1.0f0,W1)

    b3 += (eta/batchsize)*sum(delta3,2)
    b2 += (eta/batchsize)*sum(delta2,2)
    b1 += (eta/batchsize)*sum(delta1,2)
    
    if rem(t,tshow) == 0    # visualization every tshow steps
        x0 = validate                  # compute error on validation set
        x1 = f(W1*x0 .+ b1)
        x2 = f(W2*x1 .+ b2)
        x3 = f(W3*x2 .+ b3)
        errclvalidate[div(t,tshow)] = mean(float(map(indmax,[x3[:,i] for i=1:mvalidate]) .!= validatelabels))       
        avgerrsq = mean(reshape(errsq[1:t],tshow,div(t,tshow)),1)'
        avgerrcl = mean(reshape(errcl[1:t],tshow,div(t,tshow)),1)'
        IJulia.clear_output(true)
        plot(
            plot(avgerrsq,
                ylabel = "sq err", 
                ylim = (0.001,4),
                yscale = :log10
                ), 
            plot([avgerrcl, errclvalidate[1:div(t,tshow)]],
                ylabel = "class err", 
                ylim = (0.001,1),
                yscale = :log10,
                title = @sprintf("t=%d",t),
                xlabel = @sprintf("x%d minibatches",tshow)
                ), 
            plot([avgerrcl, errclvalidate[1:div(t,tshow)]],
                ylabel = "class err", 
                ylim = (0,0.1),
                title = @sprintf("t=%d",t),
                xlabel = @sprintf("x%d minibatches",tshow)
                ), 
            histogram(x1[:], xlabel = "x1"),
            histogram(x2[:], xlabel = "x2"),
            histogram(x3[:], xlabel = "x3")
            ) |> display
        sleep(0.01)
    end
end