In [45]:
using MLDatasets
using Flux: onehotbatch
include("src/JW_CNN.jl")
# Load the MNIST data 
train_data = MLDatasets.MNIST(split=:train)
test_data  = MLDatasets.MNIST(split=:test)

JW_CNN module loaded
Number of threads: 8




dataset MNIST:
  metadata  =>    Dict{String, Any} with 3 entries
  split     =>    :test
  features  =>    28×28×10000 Array{Float32, 3}
  targets   =>    10000-element Vector{Int64}

In [46]:
function relu(x)
    return @.max(0, x)
end

function relu_derivative(x)
    return @. x > 0
end

function identity(x)
    return x
end

function identity_derivative(x)
    return ones(size(x))
end

identity_derivative (generic function with 1 method)

In [47]:
network = JW_CNN.NeuralNetwork(0.01, 100)
JW_CNN.add_layer!(network, JW_CNN.ConvLayer(3, 3, 6))
JW_CNN.add_layer!(network, JW_CNN.MaxPoolLayer(2, 2))
JW_CNN.add_layer!(network, JW_CNN.ConvLayer(3, 3, 16))
JW_CNN.add_layer!(network, JW_CNN.MaxPoolLayer(2, 2))
JW_CNN.add_layer!(network, JW_CNN.FlattenLayer())
JW_CNN.add_layer!(network, JW_CNN.FCLayer(400, 84, relu, relu_derivative))
JW_CNN.add_layer!(network, JW_CNN.FCLayer(84, 10, identity, identity_derivative))

inputs = reshape(train_data.features, 28, 28, 1, :);
targets = onehotbatch(train_data.targets, 0:9);
targets = reshape(targets, 10, :);

test_input = reshape(test_data.features, 28, 28, 1, :);
test_targets = onehotbatch(test_data.targets, 0:9);
test_targets = reshape(test_targets, 10, :);

In [48]:
JW_CNN.test(network, test_input, test_targets)  

┌ Info: Test Loss: 2.2962243509292604, Test Accuracy: 9.78
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:100


(2.2962243509292604, 9.78)

In [49]:
JW_CNN.train(network, inputs, targets, test_input, test_targets, 3);

 38.540767 seconds (669.42 k allocations: 2.487 GiB, 0.89% gc time, 7.54% compilation time)


┌ Info: Epoch 1
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:59
┌ Info: Test Loss: 0.32834180533885954, Test Accuracy: 89.33
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:100
┌ Info: Epoch 2
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:59


 35.349587 seconds (62.61 k allocations: 2.433 GiB, 0.28% gc time)


┌ Info: Test Loss: 0.21462852627038956, Test Accuracy: 93.56
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:100
┌ Info: Epoch 3
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:59


 35.460798 seconds (62.62 k allocations: 2.433 GiB, 0.27% gc time)


┌ Info: Test Loss: 0.1674371740221977, Test Accuracy: 94.98
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:100


In [50]:
JW_CNN.test(network, test_input, test_targets)  

┌ Info: Test Loss: 0.16743717635050415, Test Accuracy: 94.98
└ @ Main.JW_CNN c:\Users\jakub\Documents\JW_CNN\src\JW_CNN.jl:100


(0.16743717635050415, 94.98)