### Roughly patterned after LeNet-5
max instead of average pooling

The `FastConv` package provides a faster alternative to the built-in `conv` and `conv2` functions.  Install with `Pkg.clone("https://github.com/aamini/FastConv.jl.git")` before running the following code.

In [None]:
using FastConv

In [None]:
function FastConv.fastconv(A, B, shape::String)   # extend to accept "full" or "valid" as an third argument
    if shape == "full"
        return FastConv.fastconv(A,B)
    elseif shape == "valid"
        ranges = [ min(a,b):max(a,b) for (a,b) in zip(size(A),size(B)) ]
        return FastConv.fastconv(A,B)[ranges...]
    else
        error("shape must be full or valid")
    end
end

In [None]:
using Plots
gr(size=(600,600),legend=:none)

# function for displaying a stack of images
# imgstack is mxnxp array that contains p images, each of which is mxn 
function montage(imgstack)
    plot(
        [heatmap(imgstack[:,:,i]) for i=1:size(imgstack,3)]...,
        legend=:none, axis=nothing     # options necessary to get nice spacing of the images
    )
end

In [None]:
function maxpool(images)
    # 2x2 max pooling for set of images
    # returns MAXIMA, WINNERS
    # MAXIMA 2x smaller image, maximum of each 2x2 patch
    # WINNERS 2x smaller image, argmax (1..4) of each 2x2 patch

    m, n, k = size(images)  # mxn images, k of them
    # reshape to 2 x m/2 x 2 x n/2 x k
    maxima, winners = findmax(reshape(images, 2, div(m,2), 2, div(n,2), k), (1,3))
    return squeeze(maxima,(1,3)), squeeze(winners,(1,3))
end

function maxpoolback(delta_out, winners)
    # backprop through 2x2 max pooling
    #
    # DELTA_IN image, deltas
    # DELTA_OUT 2x smaller image, deltas
    # WINNERS 2x smaller image, argmax (1..4) of each 2x2 patch

    m, n, k = size(delta_out);  # mxn images, k of them
    delta_in = zeros(Float32,2*m,2*n,k)

    delta_in[winners] = delta_out
    return delta_in
end

In [None]:
using MNIST

train, trainlabels = traindata()
train = reshape(train, 28, 28, size(train,2))/255.0
train = convert(Array{Float32}, train)
trainlabels = convert(Array{Int64},trainlabels)
trainlabels[trainlabels .== 0] = 10;  # tenth output signals a zero
niter = size(train, 3);  # number of training examples
nepoch = 10;  # number of epochs through training set

#f(x) = max(x,0)
#df(y) = float(y.>0)

f(x) = tanh(x)
df(y) = 1. - y.*y

epsinit = 0.1;   # scale of weight initialization
eta = 0.01;  # learning rate parameter

# initialize two convolution layers 
n1 = 6; n2 = 16;   # numbers of feature maps
w1 = epsinit*randn(Float32,5,5,n1);     # n1 kernels
w2 = epsinit*randn(Float32,5,5,n2,n1);  # n2 x n1 kernels

x0 = zeros(Float32, 32, 32);     # input image
x1 = zeros(Float32, 28, 28, n1);   # valid convolution by w1 reduces image size by 4
x1p = zeros(Float32, 14, 14, n1);  # pooling reduces image size by 2x
x2 = zeros(Float32, 10, 10, n2);   # convolution by w2 reduces image size by 4
x2p = zeros(Float32, 5, 5, n2);    # pooling reduces image size by 2x

# initialize three fully connected layers
n3 = 120; n4 = 84; n5 = 10              # number of neurons per layer
W3 = epsinit*randn(Float32,n3,length(x2p[:]))   # 2D organization of x2p is discarded
W4 = epsinit*randn(Float32,n4,n3)
W5 = epsinit*randn(Float32,n5,n4);

In [None]:
tic()
for iepoch = 1:nepoch
    errsq = zeros(niter)  # to monitor learning curve during epoch
    errcl = zeros(niter)
    for iter = 1:niter
        # zero pad 28x28 image to make it 32x32
        x0 = zeros(Float32, 32,32);
        x0[3:30,3:30] = train[:,:,iter]

        for i = 1:n1
            x1[:,:,i] = fastconv(x0, w1[:,:,i], "valid"); 
        end
        x1 = f(x1)
        x1p, x1w = maxpool(x1)
 
        x2 = zeros(Float32,size(x2));  # initialize to zero for accumulation
        for i = 1:n2
            for j = 1:n1
                x2[:,:,i] += fastconv(x1p[:,:,j], w2[:,:,i,j], "valid");
            end
        end
        x2 = f(x2); 
        x2p, x2w = maxpool(x2)
 
        # discard 2D organization of x2p by reshaping to x2p(:)
        x3 = f(W3*x2p[:])
        x4 = f(W4*x3)
        x5 = f(W5*x4)

        prediction = indmax(x5);
        errcl[iter] = float(prediction != trainlabels[iter]);
        # backward pass
        d = -ones(Float32, n5,1); d[trainlabels[iter]] = 1;  # output vector
        err = d - x5; 
        errsq[iter] = sum(err.*err)
        delta5 = err.*df(x5);
        delta4 = (W5'*delta5).*df(x4)
        delta3 = (W4'*delta4).*df(x3)
        delta2p = W3'*delta3
        delta2p = reshape(delta2p, size(x2p))  # restore 2D organization
        delta2 = maxpoolback(delta2p, x2w).*df(x2)
        delta1p = zeros(Float32,size(x1p))
        for j = 1:n1
            for i = 1:n2
                delta1p[:,:,j] += fastconv(delta2[:,:,i], w2[end:-1:1,end:-1:1,i,j], "full")
            end
        end
        delta1 = maxpoolback(delta1p, x1w).*df(x1)

        # weight updates
        W5 += eta*delta5*x4'
        W4 += eta*delta4*x3'
        W3 += eta*delta3*x2p[:]'
        for i = 1:n2
            for j = 1:n1
                w2[:,:,i,j] += eta*fastconv(x1p[end:-1:1,end:-1:1,j], delta2[:,:,i], "valid")
            end
        end
        for i = 1:n1
            w1[:,:,i] += eta*fastconv(x0[end:-1:1,end:-1:1], delta1[:,:,i], "valid")
        end
        if rem(iter,500) == 0
            toc()
            tic()
            IJulia.clear_output(true)
            plot(
                plot(cumsum(errsq[1:iter])./(1:iter),
                    ylabel="sq err"
                ),
                plot(cumsum(errcl[1:iter])./(1:iter),
                    ylabel = "cl err", 
                    title = @sprintf("epoch=%d, iter=%d",iepoch,iter)
                ),
                bar(x5, xlabel="x5"),
                histogram(x4, xlabel="x4"),
                histogram(x3, xlabel="x3"),
                heatmap(x0, yflip=true),
                plot(
                    [heatmap(w1[:,:,i],
                             axis = nothing, 
                             color = :grays, 
                             yflip = true
                             ) for i = 1:n1]...), 
                layout = @layout [a b c; d e f; g{0.5h}]
            ) |> display
            sleep(0.01)
        end
    end
end