# Upgrading our MNIST Network

In [None]:
using MLDatasets
train_x, train_y = MNIST.traindata()
test_x,  test_y  = MNIST.testdata();
(images, labels) = (train_x[:,:,1:1000], train_y[1:1000]); #(col,row,batch)

one_hot_labels = zeros(10, length(labels))
for (i,l) in enumerate(labels)
    one_hot_labels[l+1,i] = 1.0
end

labels = one_hot_labels

test_labels = zeros(10, length(test_y))
for (i,l) in enumerate(test_y)
    test_labels[l+1, i] = 1.0
end

using Random
Random.seed!(1)

tanh2deriv(output) = 1 - output^2

function softmax(x)
    temp = exp.(x)
    return temp ./ sum(temp, dims=1)
end

alpha, iterations = (2, 300)
pixels_per_image, num_labels = (784, 10)
batch_size = 128

input_rows = 28
input_cols = 28

kernel_rows = 3
kernel_cols = 3
num_kernels = 16

hidden_size = ((input_rows - kernel_rows) * 
               (input_cols - kernel_cols)) * num_kernels

kernels = 0.02 .* rand(num_kernels, kernel_rows*kernel_cols) .- 0.01  
weights_1_2 = 0.2 .* rand(num_labels, hidden_size) .- 0.1

function get_image_section(layer,row_from, row_to, col_from, col_to)
    section = layer[col_from:col_to, row_from:row_to, :]
    return reshape(section, (col_to-col_from+1, row_to-row_from+1, 1, :))
end


for j=1:iterations
    Correct_cnt = 0
    for i = 1:batch_size:size(images, 3)-batch_size
        batch_start, batch_end = i, i+batch_size-1
        layer_0 = images[:,:, batch_start:batch_end]
        sects = []
        for col_start=1:size(layer_0, 1) - kernel_cols
            for row_start=1:size(layer_0, 2)-kernel_rows
                sect = get_image_section(layer_0, col_start, col_start+kernel_cols-1, row_start, row_start+kernel_rows-1)
                push!(sects, sect)
            end
        end
        expanded_input = cat(sects...,dims=3) 
        es = size(expanded_input)
        flattened_input = reshape(expanded_input, (:,es[3]*es[4]))
        kernel_output = kernels * flattened_input
        layer_1 = tanh.(reshape(kernel_output, (:, es[4])))
        dropout_mask = bitrand(size(layer_1))
        layer_1 .*= dropout_mask .* 2
        layer_2 = softmax(weights_1_2 * layer_1)

        Correct_cnt += sum(argmax(layer_2;dims=1) .== argmax(labels[:,batch_start:batch_end];dims=1))
        layer_2_delta = (labels[:,batch_start:batch_end] .- layer_2) ./ (batch_size * size(layer_2, 2))
        layer_1_delta = (weights_1_2' * layer_2_delta) .* tanh2deriv.(layer_1)
        layer_1_delta .*= dropout_mask
        weights_1_2 += alpha .* layer_2_delta * layer_1'
        l1d_reshape = reshape(layer_1_delta, size(kernel_output)) 
        k_update = l1d_reshape * flattened_input'
        kernels .-= alpha .* k_update
    end
    
    test_correct_cnt = 0
    
    for i=1:size(test_x, 3)
        layer_0 = test_x[:,:, i]
        sects = []
        for col_start=1:size(layer_0, 1) - kernel_cols
            for row_start=1:size(layer_0, 2)-kernel_rows
                sect = get_image_section(layer_0, col_start, col_start+kernel_cols-1, row_start, row_start+kernel_rows-1)
                push!(sects, sect)
            end
        end
        expanded_input = cat(sects...,dims=3) ##
        es = size(expanded_input)
        flattened_input = reshape(expanded_input, (:,es[3]*es[4]))
        kernel_output = kernels * flattened_input
        layer_1 = tanh.(reshape(kernel_output, (:, es[4])))
        dropout_mask = bitrand(size(layer_1))
        layer_1 .*= dropout_mask .* 2
        layer_2 = softmax(weights_1_2 * layer_1)
        
        test_correct_cnt += Int(argmax(dropdims(layer_2;dims=2)) == argmax(test_labels[:,i]))
    end
    
    if (j%1 == 0)
        println("I: $(j) Test accuracy: $(test_correct_cnt/size(test_x, 3)) Train accuracy: $(Correct_cnt/size(images, 3)) ")
    end  
end

I: 1 Test accuracy: 0.0462 Train accuracy: 0.075 
I: 2 Test accuracy: 0.0505 Train accuracy: 0.039 
I: 3 Test accuracy: 0.0558 Train accuracy: 0.045 
I: 4 Test accuracy: 0.0656 Train accuracy: 0.047 
I: 5 Test accuracy: 0.079 Train accuracy: 0.052 
I: 6 Test accuracy: 0.1011 Train accuracy: 0.062 
I: 7 Test accuracy: 0.1236 Train accuracy: 0.08 
I: 8 Test accuracy: 0.1469 Train accuracy: 0.11 
I: 9 Test accuracy: 0.174 Train accuracy: 0.144 
I: 10 Test accuracy: 0.2044 Train accuracy: 0.167 
I: 11 Test accuracy: 0.2285 Train accuracy: 0.205 
I: 12 Test accuracy: 0.2572 Train accuracy: 0.243 
I: 14 Test accuracy: 0.2553 Train accuracy: 0.256 
I: 15 Test accuracy: 0.2114 Train accuracy: 0.214 
I: 16 Test accuracy: 0.1275 Train accuracy: 0.162 
I: 17 Test accuracy: 0.0781 Train accuracy: 0.089 
I: 18 Test accuracy: 0.052 Train accuracy: 0.047 
I: 19 Test accuracy: 0.045 Train accuracy: 0.05 
I: 20 Test accuracy: 0.043 Train accuracy: 0.031 
I: 21 Test accuracy: 0.0544 Train accuracy: 0.04