## MLP in flux

In [None]:
using Flux, Flux.Data.MNIST
using Flux: onehotbatch, argmax, crossentropy, throttle
using Base.Iterators: repeated

In [None]:
# Classify MNIST digits with a simple multi-layer-perceptron
imgs = MNIST.images()
# Stack images into one large batch
X = hcat(float.(reshape.(imgs, :))...)

labels = MNIST.labels()
# One-hot-encode the labels
Y = onehotbatch(labels, 0:9)

In [None]:
m = Chain(Dense(28^2, 32, relu),
          Dense(32, 10),
          softmax)

# using CuArrays
# x, y = cu(x), cu(y)
# m = mapleaves(cu, m)

loss(x, y) = crossentropy(m(x), y)
accuracy(x, y) = mean(argsmax(m(x)) .== argmax(y))

In [None]:
dataset = repeated((X, Y), 200)

In [None]:
?dataset

In [None]:
# Dataset
evalcb = () -> @show(loss(X, Y))
opt = SGD(params(m))

Flux.train!(loss, (X,Y) , opt, cb = throttle(evalcb, 1))
accuracy(X, Y)

# Test set accuracy
tX = hcat(float.(reshape.(MNIST.images(:test), :))...)
tY = onehotbatch(MNIST.labels(:test), 0:9)

# If CuArrays
# tX, tY = cu(tX), cu(tY)
accuracy(tX, tY)

In [None]:
Flux.train!(loss, dataset, opt, cb = throttle(evalcb, 10))
accuracy(X, Y)

In [None]:
teX = hcat(float.(reshape.(MNIST.images(:test), :))...)
teY = onehotbatch(MNIST.labels(:test), 0:9)
accuracy(teX, teY)

In [None]:
trX = hcat(float.(reshape.(MNIST.images(:train), :))...)
trY = onehotbatch(MNIST.labels(:train), 0:9)
accuracy(trX, trY)

#### make iterations

In [None]:
# Flux.train!(loss, dataset, opt, cb = throttle(evalcb, 10))