In [1]:
using Flux, Flux.Data.MNIST
using Flux: onehotbatch, argmax, crossentropy, throttle
using Base.Iterators: repeated
using FluxExtensions

push!(LOAD_PATH, "../")
using KNNmem

include("../train_and_track.jl");



In [2]:
# Prepare data

imgs = MNIST.images()
X = hcat(float.(reshape.(imgs, :))...)

labels = MNIST.labels()
Y = labels
oneHotY = Flux.onehotbatch(Y, 0:9) # for softmax

tX = hcat(float.(reshape.(MNIST.images(:test), :))...)
tY = MNIST.labels(:test)
oneHotTY = Flux.onehotbatch(tY, 0:9)

10×10000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 false  false  false   true  false  …  false  false  false  false  false
 false  false   true  false  false     false  false  false  false  false
 false   true  false  false  false      true  false  false  false  false
 false  false  false  false  false     false   true  false  false  false
 false  false  false  false   true     false  false   true  false  false
 false  false  false  false  false  …  false  false  false   true  false
 false  false  false  false  false     false  false  false  false   true
  true  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false

In [3]:
# Model with memory

memoryModel = Chain(
  FluxExtensions.ResDense(28^2, 32, relu),
  FluxExtensions.ResDense(32, 10, relu))

memory = KNNmemory(1000, 10, 128, 10)

# Model without memory

classicModel = Chain(
  FluxExtensions.ResDense(28^2, 32, relu),
  FluxExtensions.ResDense(32, 10, relu),
  softmax)

Chain(ResDense(Dense(32, 32, NNlib.relu)), ResDense(Dense(10, 10, NNlib.relu)), NNlib.softmax)

In [4]:
# Training setup

memLoss(x, y) = trainQuery!(memory, memoryModel(x), y)
memAccuracy(x, y) = mean(query(memory, memoryModel(x)) .== y)
memOpt = ADAM(params(memoryModel))

classicLoss(x, y) = crossentropy(classicModel(x), y)
classicAccuracy(x, y) = mean(argmax(classicModel(x)) .== argmax(y))
classicOpt = ADAM(params(classicModel))

iterations = 1000
batchSize = 1000
printInterationCount = 100



100

In [None]:
# Training

memHistory = trainAndTrack!(memLoss, memOpt, iterations, batchSize, X, Y, tX, tY, printInterationCount)
classicHistory = trainAndTrack!(classicLoss, classicOpt, iterations, batchSize, X, oneHotY, tX, oneHotTY, printInterationCount)

In [None]:
# Accuracy comparison

memAccuracy(tX, tY)

In [None]:
classicAccuracy(tX, oneHotTY)

In [None]:
# Plot training

using Plots
pyplot()

plot(memHistory)
plot!(title = "Training with memory")

In [None]:
plot(classicHistory)
plot!(title = "Training without memory")