# MNIST Neural Network

In [1]:
using DelimitedFiles
using StatsBase
using Distributions
using LinearAlgebra

# read MNIST data
const testx = readdlm("testx.csv", ',', Int, '\n')
const testy = readdlm("testy.csv", ',', Int, '\n')
const trainx = readdlm("trainx.csv", ',', Int, '\n')
const trainy = readdlm("trainy.csv", ',', Int, '\n')

const L = 3                 # number of layers including input and output
const sizes = [784, 30, 10] # number of neurons in each layer

3-element Vector{Int64}:
 784
  30
  10

The next section contains some helper functions to abstract the necessary tasks.

In [2]:
# the activation function that will be used.
@. sigmoid(x) = 1/(1 + exp(-x))      # sigmoid activation
@. sigmoidPrime(x) = sigmoid(x) * (1-sigmoid(x))

# HELPER: convert a digit d to a 10-element vector
# e.g. 6 is converted to [0,0,0,0,0,0,1,0,0,0]
function digit2vector(d)
    vcat( repeat([0], d), 1, repeat([0], 9-d) )
end

# feedforward:
# inputs:
#    -W: matrix of weights in the NN
#    -b: biases of the NN 
#    -x: input of a single training example (a vector of length 784)
# returns the activations 
function feedforward(W, b, x)
    # note that z[1] is not used. we put it there so that the indices make sense.
    z = [ x, zeros(sizes[2]), zeros(sizes[3]) ]
    a = [ x, zeros(sizes[2]), zeros(sizes[3]) ]
    for l = 2:L
        z[l] = W[l-1]*a[l-1] + b[l-1]
        a[l] = sigmoid(z[l])
    end
    return a, z
end

# given an input vector, return the predicted digit
function classify(W, b, x)
    (a, z) = feedforward(W, b, x)
    yhat = a[3]
    convert(Int, findmax(yhat)[2] - 1)
end

# HELPER: ( backprop()).
# this function computes the error for a single training example.
# W contains the weights in the network.
# a contains the activations.
# z contains the weighted inputs.
# y is the correct digit.
# returns δ = the error. the size of δ is [ 784, 30, 10 ]
function compute_error(W, a, z, y)
    δ = [ zeros(sizes[1]), zeros(sizes[2]), zeros(sizes[3]) ]
    # note that δ[1] is junk. we put it there so that the indices make sense.

    # at the output layer L
    δ[3] = -(digit2vector(y) .- a[3]) .* sigmoidPrime(z[3])

    # for each earlier layer L-1,L-2,..,2 (for the HW, this means only layer 2)
    δ[2] = W[2]' * δ[3] .* sigmoidPrime(z[2])

    return δ
end

# helper function for backprop(). given the errors δ and the
# activations a for a single training example, this function returns
# the gradient components ∇W and ∇b.
# this function implements teh equations BP3 and BP4.
function compute_gradients(δ, a)
    ∇W = [ zeros(sizes[2], sizes[1]),
           zeros(sizes[3], sizes[2]) ]
    ∇b = [ zeros(sizes[2]), zeros(sizes[3]) ]
    ∇W[1] = δ[2] * a[1]'  # BP4
    ∇b[1] = δ[2]          # BP3
    ∇W[2] = δ[3] * a[2]'
    ∇b[2] = δ[3]
    return ∇W, ∇b
end

# backpropagation. returns ∇W and ∇b for a single training example.
function backprop(W, b, x, y)
    (a, z) = feedforward(W, b, x)
    δ = compute_error(W, a, z, y)
    (∇W, ∇b) = compute_gradients(δ, a)
    return ∇W, ∇b
end

backprop (generic function with 1 method)

In [3]:
# gradient descent algorithm.
# W = weights in the network
# b = biases in the network
# batch = the indices of the observations in the batch, i.e. the rows of trainx
# α = step size
# λ = regularization parameter
function GD(W, b, batch; α=0.01, λ=0.01)
    m = length(batch)    # batch size

    # data structure to accumulate the sum over the batch
    # in the notes and in Ng's article sumW is ΔW and sumb is Δb.
    sumW = [ zeros(sizes[2], sizes[1]),
             zeros(sizes[3], sizes[2]) ]
    sumb = [ zeros(sizes[2]), zeros(sizes[3]) ]

    # for each training example in the batch, use backprop
    # to compute the gradients and add them to the sum
    for i in batch
        x = trainx[i,:]
        y = trainy[i]
        (∇W, ∇b) = backprop(W, b, x, y)
        sumW += ∇W
        sumb += ∇b
    end

    # make the update to the weights and biases and take a step
    # of gradient descent. note that we use the average gradient.
    ∇W = (1/m)*sumW .+ λ*W
    ∇b = (1/m)*sumb
    W = W .- α * ∇W
    b = b .- α * ∇b

    # return the updated weights and biases. we also return the gradients
    return W, b, ∇W, ∇b
end

# classify the test data and compute the classification accuracy
function accuracy(W, b) 
    ntest = length(testy)
    yhat = zeros(Int, ntest)
    for i in 1:ntest
        yhat[i] = classify(W, b, testx[i,:])
    end
    sum(testy .== yhat)/ntest # hit rate
end

# train the neural network using batch gradient descent.
# this is a driver function to repeatedly call GD().
# N = number of observations in the training data.
# m = batch size
# α = learning rate / step size
# λ = regularization parameter
function BGD(N, m, epochs; α=0.01, λ=0.01) 
    # random initialization of the weights and biases
    d = Normal(0, 1)
    W = [ rand(d, sizes[2], sizes[1]),  # layer 1 to 2
          rand(d, sizes[3], sizes[2]) ] # layer 2 to 3
    b = [ rand(d, sizes[2]),   # layer 2
          rand(d, sizes[3]) ]  # layer 3
    ∇W = [ zeros(sizes[2], sizes[1]),  # layer 1 to 2
          zeros(sizes[3], sizes[2]) ] # layer 2 to 3
    ∇b = [ zeros(sizes[2]),   # layer 2
          zeros(sizes[3]) ]   # layer 3
    for j in 1:epochs
        remaining = 1:N
        while length(remaining) > 0
            batch = sample(remaining, m, replace=false)
            remaining = setdiff(remaining, batch)
            (W, b, ∇W, ∇b) = GD(W, b, batch; α=α, λ=λ)   
        end
        println("epoch ", j, ", accuracy = ", accuracy(W,b))
    end
    return W, b, ∇W, ∇b
end


BGD (generic function with 1 method)

In [None]:
# some tuning parameters
N = length(trainy)
m = 20       # batch size
epochs = 50  # number of complete passes through the training data
α = 0.01     # learning rate / step size
λ = 0.01     # regularization parameter
W, b, ∇W, ∇b = BGD(N, m, epochs, α=α, λ=λ)

epoch 1, accuracy = 0.2051
epoch 2, accuracy = 0.2815
epoch 3, accuracy = 0.3488
epoch 4, accuracy = 0.4218
epoch 5, accuracy = 0.4912
epoch 6, accuracy = 0.546
epoch 7, accuracy = 0.6325
epoch 8, accuracy = 0.7428
epoch 9, accuracy = 0.8069
epoch 10, accuracy = 0.848
epoch 11, accuracy = 0.8632
epoch 12, accuracy = 0.8776
epoch 13, accuracy = 0.888
epoch 14, accuracy = 0.8922
epoch 15, accuracy = 0.898
epoch 16, accuracy = 0.8929
epoch 17, accuracy = 0.8922
epoch 18, accuracy = 0.8952
epoch 19, accuracy = 0.8986
epoch 20, accuracy = 0.8952
epoch 21, accuracy = 0.8988
epoch 22, accuracy = 0.8924
epoch 23, accuracy = 0.9003
epoch 24, accuracy = 0.9017
epoch 25, accuracy = 0.9042
epoch 26, accuracy = 0.9014
epoch 27, accuracy = 0.8991
epoch 28, accuracy = 0.9076
epoch 29, accuracy = 0.9012
