# Upgrading our MNIST Network

In [2]:
using MLDatasets
train_x, train_y = MNIST.traindata()
test_x,  test_y  = MNIST.testdata();

(images, labels) = (train_x[:,:,1:1000], train_y[1:1000])
images = permutedims(images, (3, 2,1))
test_x = permutedims(test_x, (3, 2,1))
one_hot_labels = zeros(length(labels), 10)
for (i,l) in enumerate(labels)
    one_hot_labels[i, l+1] = 1.0
end
labels = one_hot_labels

test_labels = zeros((length(test_y), 10))

for (i,l) in enumerate(test_y)
    test_labels[i, l+1] = 1.0
end

using Random
Random.seed!(1)

tanh2deriv(output) = 1 - output^2

function softmax(x)
    temp = exp.(x)
    return temp ./ sum(temp, dims=2)
end

alpha, iterations = (2, 300)
pixels_per_image, num_labels = (784, 10)
batch_size = 128

input_rows = 28
input_cols = 28

kernel_rows = 3
kernel_cols = 3
num_kernels = 16

hidden_size = ((input_rows - kernel_rows) * 
               (input_cols - kernel_cols)) * num_kernels

kernels = 0.02 .* rand(kernel_rows*kernel_cols, num_kernels) .- 0.01
weights_1_2 = 0.2 .* rand(hidden_size, num_labels) .- 0.1

function get_image_section(layer,row_from, row_to, col_from, col_to)
    section = layer[:, row_from:row_to,col_from:col_to]
    return reshape(section, (:, 1, row_to-row_from+1, col_to-col_from+1))
end


for j=1:iterations
    Correct_cnt = 0
    for i = 1:batch_size:size(images, 1)-batch_size
        batch_start, batch_end = i, i+batch_size-1
        layer_0 = images[batch_start:batch_end,:,:]
        
        sects = []
        for row_start=1:size(layer_0, 2)-kernel_rows
            for col_start=1:size(layer_0, 3) - kernel_cols
                sect = get_image_section(layer_0,row_start, row_start+kernel_rows-1, col_start, col_start+kernel_cols-1)
                push!(sects, sect)
            end
        end
        expanded_input = cat(sects...,dims=2)
        es = size(expanded_input)
        expanded_input = permutedims(expanded_input, (4,3,2,1))
        flattened_input = reshape(expanded_input, (:,es[1]*es[2]))'
        kernel_output = flattened_input * kernels
        kernel_output_temp = permutedims(reshape(kernel_output', (:, es[1])), (2,1))
        kernel_output_temp
        layer_1 = tanh.(kernel_output_temp)
        
        dropout_mask = bitrand(size(layer_1))
        layer_1 .*= dropout_mask .* 2
        layer_2 = softmax(layer_1 * weights_1_2)
        
        
        for k=1:batch_size
            Correct_cnt += Int(argmax(layer_2[k,:]) == argmax(labels[batch_start+k-1,:]))
        end
        
        layer_2_delta = (labels[batch_start:batch_end, :] .- layer_2) ./ (batch_size * size(layer_2, 1))
        layer_1_delta = (layer_2_delta * weights_1_2') .* tanh2deriv.(layer_1)
        layer_1_delta .*= dropout_mask
        weights_1_2 .+= alpha .* layer_1' * layer_2_delta
        
        ks = size(kernel_output)
        l1d_reshape = permutedims(reshape(layer_1_delta', (ks[2],ks[1])), (2,1))
        k_update = flattened_input' * l1d_reshape
        kernels .-= alpha .* k_update
    end
    
    test_correct_cnt = 0
    
    for i=1:size(test_x, 1)
        layer_0 = reshape(test_x[i,:,:], (1,28,28))
        
        sects = []
        for row_start=1:size(layer_0, 2)-kernel_rows
            for col_start=1:size(layer_0, 3) - kernel_cols
                sect = get_image_section(layer_0,row_start, row_start+kernel_rows-1, col_start, col_start+kernel_cols-1)
                push!(sects, sect)
            end
        end
        expanded_input = cat(sects...,dims=2)
        es = size(expanded_input)
        expanded_input = permutedims(expanded_input, (4,3,2,1))
        flattened_input = reshape(expanded_input, (:,es[1]*es[2]))'
        kernel_output = flattened_input * kernels
        kernel_output_temp = permutedims(reshape(kernel_output', (:, es[1])), (2,1))
        kernel_output_temp
        layer_1 = tanh.(kernel_output_temp)
        
        dropout_mask = bitrand(size(layer_1))
        layer_1 .*= dropout_mask .* 2
        layer_2 = softmax(layer_1 * weights_1_2)
        test_correct_cnt += Int(argmax(dropdims(layer_2;dims=1)) == argmax(test_labels[i,:]))
    end 
         
    if (j%1 == 0)
        println("I: $(j) Test accuracy: $(test_correct_cnt/size(test_x, 1)) Train accuracy: $(Correct_cnt/size(images, 1)) ")
    end             
end

I: 1 Test accuracy: 0.0574 Train accuracy: 0.058 
I: 2 Test accuracy: 0.0546 Train accuracy: 0.038 
I: 3 Test accuracy: 0.0602 Train accuracy: 0.041 
I: 4 Test accuracy: 0.0594 Train accuracy: 0.058 
I: 5 Test accuracy: 0.0666 Train accuracy: 0.051 
I: 6 Test accuracy: 0.0771 Train accuracy: 0.067 
I: 7 Test accuracy: 0.0899 Train accuracy: 0.082 
I: 8 Test accuracy: 0.1092 Train accuracy: 0.087 
I: 9 Test accuracy: 0.1299 Train accuracy: 0.127 
I: 10 Test accuracy: 0.1556 Train accuracy: 0.127 
I: 11 Test accuracy: 0.1964 Train accuracy: 0.162 
I: 12 Test accuracy: 0.2218 Train accuracy: 0.188 
I: 13 Test accuracy: 0.2399 Train accuracy: 0.227 
I: 14 Test accuracy: 0.2549 Train accuracy: 0.228 
I: 15 Test accuracy: 0.2429 Train accuracy: 0.229 
I: 16 Test accuracy: 0.1622 Train accuracy: 0.198 
I: 17 Test accuracy: 0.0963 Train accuracy: 0.12 
I: 18 Test accuracy: 0.0687 Train accuracy: 0.055 
I: 19 Test accuracy: 0.0525 Train accuracy: 0.054 
I: 20 Test accuracy: 0.0515 Train accurac