### Roughly patterned after LeNet-5
max instead of average pooling

In [1]:
function Base.conv2(A, B, shape::String)   # extend to accept "full" or "valid" as an third argument
    if shape == "full"
        return Base.conv2(A,B)
    elseif shape == "valid"
        ranges = [ min(a,b):max(a,b) for (a,b) in zip(size(A),size(B)) ]
        return Base.conv2(A,B)[ranges...]
    else
        error("shape must be full or valid")
    end
end

In [6]:
using Plots
gr()

# function for displaying a stack of images
# imgstack is mxnxp array that contains p images, each of which is mxn 
function montage(imgstack)
    plot(
        [heatmap(imgstack[:,:,i]) for i=1:size(imgstack,3)]...,
        legend=:none, axis=nothing     # options necessary to get nice spacing of the images
    )
end|

LoadError: syntax: incomplete: premature end of input

In [3]:
function maxpool(images)
    # 2x2 max pooling for set of images
    # returns MAXIMA, WINNERS
    # MAXIMA 2x smaller image, maximum of each 2x2 patch
    # WINNERS 2x smaller image, argmax (1..4) of each 2x2 patch

    m, n, k = size(images)  # mxn images, k of them
    # reshape to 2 x m/2 x 2 x n/2 x k
    maxima, winners = findmax(reshape(images, 2, div(m,2), 2, div(n,2), k), (1,3))
    return squeeze(maxima,(1,3)), squeeze(winners,(1,3))
end

function maxpoolback(delta_out, winners)
    # backprop through 2x2 max pooling
    #
    # DELTA_IN image, deltas
    # DELTA_OUT 2x smaller image, deltas
    # WINNERS 2x smaller image, argmax (1..4) of each 2x2 patch

    m, n, k = size(delta_out);  # mxn images, k of them
    delta_in = zeros(2*m,2*n,k)

    delta_in[winners] = delta_out
    return delta_in
end

maxpoolback (generic function with 1 method)

In [4]:
using MNIST

train, trainlabels = traindata()
train = reshape(train, 28, 28, size(train,2))/255.0
trainlabels = convert(Array{Int64,1},trainlabels)
trainlabels[trainlabels .== 0] = 10;  # tenth output signals a zero
niter = size(train, 3);  # number of training examples
nepoch = 10;  # number of epochs through training set

#f(x) = max(x,0)
#df(y) = float(y.>0)

f(x) = tanh(x)
df(y) = 1. - y.*y

epsinit = 0.1;   # scale of weight initialization
eta = 0.01;  # learning rate parameter

# initialize two convolution layers 
n1 = 6; n2 = 16;   # numbers of feature maps
w1 = epsinit*randn(5,5,n1);     # n1 kernels
w2 = epsinit*randn(5,5,n2,n1);  # n2 x n1 kernels

x0 = zeros(32, 32);     # input image
x1 = zeros(28, 28, n1);   # valid convolution by w1 reduces image size by 4
x1p = zeros(14, 14, n1);  # pooling reduces image size by 2x
x2 = zeros(10, 10, n2);   # convolution by w2 reduces image size by 4
x2p = zeros(5, 5, n2);    # pooling reduces image size by 2x

# initialize three fully connected layers
n3 = 120; n4 = 84; n5 = 10              # number of neurons per layer
W3 = epsinit*randn(n3,length(x2p[:]))   # 2D organization of x2p is discarded
W4 = epsinit*randn(n4,n3)
W5 = epsinit*randn(n5,n4);

10×84 Array{Float64,2}:
 -0.107835    0.119165   -0.0424434   …  -0.142116   -0.0061786  -0.0702819
 -0.0659408  -0.0787761   0.00185246      0.056534   -0.0221143   0.238002 
  0.0669089   0.0287926   0.135284        0.0972418   0.252342   -0.04662  
  0.0104491  -0.184309   -0.0895099       0.0868073   0.0815208  -0.173352 
  0.0877485  -0.0364961  -0.0897693      -0.153143   -0.108352    0.0986245
  0.170154    0.177363   -0.235362    …  -0.190661    0.0871033  -0.0196098
 -0.0346473  -0.0736787  -0.0369671      -0.21548     0.050761    0.239259 
  0.0874269   0.0617412   0.00655789     -0.124906   -0.108902    0.0984346
  0.267911    0.0828776   0.0333724      -0.0264601   0.0259854  -0.0679191
 -0.171294    0.1056     -0.0182364       0.0388453   0.0747895   0.0190685

In [None]:
tic()
for iepoch = 1:nepoch
    errsq = zeros(niter)  # to monitor learning curve during epoch
    errcl = zeros(niter)
    for iter = 1:niter
        # zero pad 28x28 image to make it 32x32
        x0 = zeros(32,32);
        x0[3:30,3:30] = train[:,:,iter]

        for i = 1:n1
            x1[:,:,i] = conv2(x0, w1[:,:,i], "valid"); 
        end
        x1 = f(x1)
        x1p, x1w = maxpool(x1)

        x2 = zeros(size(x2));  # initialize to zero for accumulation
        for i = 1:n2
            for j = 1:n1
                x2[:,:,i] += conv2(x1p[:,:,j], w2[:,:,i,j], "valid");
            end
        end
        x2 = f(x2); 
        x2p, x2w = maxpool(x2)

        # discard 2D organization of x2p by reshaping to x2p(:)
        x3 = f(W3*x2p[:])
        x4 = f(W4*x3)
        x5 = f(W5*x4)
        prediction = indmax(x5);
        errcl[iter] = float(prediction != trainlabels[iter]);
        
        # backward pass
        d = -ones(n5,1); d[trainlabels[iter]] = 1;  # output vector
        err = d - x5; 
        delta5 = err.*df(x5);
        delta4 = (W5'*delta5).*df(x4)
        delta3 = (W4'*delta4).*df(x3)
        delta2p = W3'*delta3
        delta2p = reshape(delta2p, size(x2p))  # restore 2D organization
        delta2 = maxpoolback(delta2p, x2w).*df(x2)
        delta1p = zeros(size(x1p))
        for j = 1:n1
            for i = 1:n2
                delta1p[:,:,j] += conv2(delta2[:,:,i], w2[end:-1:1,end:-1:1,i,j], "full")
            end
        end
        delta1 = maxpoolback(delta1p, x1w).*df(x1)

        # weight updates
        W5 += eta*delta5*x4'
        W4 += eta*delta4*x3'
        W3 += eta*delta3*x2p[:]'
        for i = 1:n2
            for j = 1:n1
                w2[:,:,i,j] += eta*conv2(x1p[end:-1:1,end:-1:1,j], delta2[:,:,i], "valid")
            end
        end
        for i = 1:n1
            w1[:,:,i] += eta*conv2(x0[end:-1:1,end:-1:1], delta1[:,:,i], "valid")
        end
        if rem(iter,100) == 0
            toc()
            tic()
            IJulia.clear_output(true)
            plot(
                plot(cumsum(errsq[1:iter])./(1:iter)),
                plot(cumsum(errcl[1:iter])./(1:iter)),
                bar(x5),
                histogram(x4),
                histogram(x3),
                heatmap(x0,yflip=true),
                plot(
                    [heatmap(w1[:,:,i],
                             axis = nothing, 
                             color = :grays, 
                             yflip = true
                             ) for i = 1:n1]...), 
                layout = @layout [a b c; d e f; g{0.5h}]
            ) |> display
            sleep(0.01)
        end
    end
end