In [1]:
using MLDatasets
using Flux: onehotbatch
include("src/JW_CNN.jl")
# Load the MNIST data 
train_data = MLDatasets.MNIST(split=:train)
test_data  = MLDatasets.MNIST(split=:test)

JW_CNN module loaded
Number of threads: 8


dataset MNIST:
  metadata  =>    Dict{String, Any} with 3 entries
  split     =>    :test
  features  =>    28×28×10000 Array{Float32, 3}
  targets   =>    10000-element Vector{Int64}

In [2]:
function relu(x)
    return @.max(0, x)
end

function relu_derivative(x)
    return @. x > 0
end

function identity(x)
    return x
end

function identity_derivative(x)
    return ones(size(x))
end

identity_derivative (generic function with 1 method)

In [3]:
network = JW_CNN.NeuralNetwork(0.01, 100)
JW_CNN.add_layer!(network, JW_CNN.ConvLayer(3, 3, 6))
JW_CNN.add_layer!(network, JW_CNN.MaxPoolLayer(2, 2))
JW_CNN.add_layer!(network, JW_CNN.ConvLayer(3, 3, 16))
JW_CNN.add_layer!(network, JW_CNN.MaxPoolLayer(2, 2))
JW_CNN.add_layer!(network, JW_CNN.FlattenLayer())
JW_CNN.add_layer!(network, JW_CNN.FCLayer(400, 84, relu, relu_derivative))
JW_CNN.add_layer!(network, JW_CNN.FCLayer(84, 10, identity, identity_derivative))

inputs = reshape(train_data.features, 28, 28, 1, :);
targets = onehotbatch(train_data.targets, 0:9);
targets = reshape(targets, 10, :);

test_input = reshape(test_data.features, 28, 28, 1, :);
test_targets = onehotbatch(test_data.targets, 0:9);
test_targets = reshape(test_targets, 10, :);

In [4]:
JW_CNN.test(network, test_input, test_targets)  

┌ Info: Test Loss: 2.3019017076492307, Test Accuracy: 8.03
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:99


(2.3019017076492307, 8.03)

In [5]:
JW_CNN.train(network, inputs, targets, test_input, test_targets, 3);

┌ Info: Epoch 1
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:59


 16.745253 seconds (4.03 M allocations: 2.421 GiB, 1.23% gc time, 42.00% compilation time)


┌ Info: Test Loss: 0.34539958968758583, Test Accuracy: 89.4
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:99
┌ Info: Epoch 2
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:59


 11.700898 seconds (255.40 k allocations: 2.165 GiB, 1.57% gc time)


┌ Info: Test Loss: 0.35192971155047414, Test Accuracy: 88.34
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:99
┌ Info: Epoch 3
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:59


 11.881479 seconds (255.40 k allocations: 2.165 GiB, 1.49% gc time)


┌ Info: Test Loss: 0.2227177070081234, Test Accuracy: 93.07
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:99


In [6]:
JW_CNN.test(network, test_input, test_targets)  

┌ Info: Test Loss: 0.2227177080512047, Test Accuracy: 93.07
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:99


(0.2227177080512047, 93.07)