In [1]:
using MLDatasets
using Flux
using Random
using OneHotArrays
using StatsBase

In [2]:
trainset = MNIST.traindata()
@info "data" typeof(trainset[1]) size(trainset[1]) typeof(trainset[2]) size(trainset[2])

┌ Info: data
│   typeof(trainset[1]) = Base.ReinterpretArray{FixedPointNumbers.N0f8, 3, UInt8, Array{UInt8, 3}, false}
│   size(trainset[1]) = (28, 28, 60000)
│   typeof(trainset[2]) = Vector{Int64}
│   size(trainset[2]) = (60000,)
└ @ Main /Users/zhenchen/Julia_MNIST_VAE/mnist.ipynb:2


In [3]:
x_train, y_train =trainset[:]
len_train = length(y_train)
x_train_ = reshape(x_train, 28^2, :)
y_train_ = convert(Array{Float32}, onehotbatch(y_train, 0:9))

10×60000 Matrix{Float32}:
 0.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  1.0  0.0  0.0  1.0  0.0     0.0  1.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  1.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0     0.0  0.0  0.0  1.0  0.0  0.0  0.0
 0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  1.0  0.0  0.0  0.0  1.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  1.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0
 0.0  0.0  0.0  0.0  1.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0

In [4]:
model = Chain(
  Dense(28^2, 64, relu),
  Dense(64, 10, elu), 
  softmax
)

Chain(
  Dense(784 => 64, relu),               [90m# 50_240 parameters[39m
  Dense(64 => 10, elu),                 [90m# 650 parameters[39m
  NNlib.softmax,
) [90m                  # Total: 4 arrays, [39m50_890 parameters, 199.039 KiB.

In [5]:
lossF = (x, y) -> begin
    y_ = model(x)
    return mean(-sum(y .* log.(y_), dims=1))
end
lossF(x_train_[:, 1:100], y_train_[:, 1:100])

2.398534f0

In [11]:
accuracy = (x_, y_) -> sum(argmax(model(x_), dims=1) .== argmax(y_, dims=1)) / size(y_, 2)
accuracy(x_train_, y_train_)

0.9770166666666666

In [10]:
argmax(model(x_train_), dims=1) .== argmax(y_train_, dims=1)

1×60000 BitMatrix:
 1  1  1  1  1  1  1  1  1  1  1  1  1  …  1  1  1  1  1  1  1  1  1  1  1  1

In [6]:
opt = ADAM(0.01)
params = Flux.params(model)
grads = gradient(() -> lossF(x_train_[:, 1:100], y_train_[:, 1:100]), params)

Grads(...)

In [9]:
BATCH_SIZE = 100
for epoch in 1:10
    for i in 1:BATCH_SIZE:len_train
        x = x_train_[:, i:i+BATCH_SIZE-1]
        y = y_train_[:, i:i+BATCH_SIZE-1]
        grads = gradient(() -> lossF(x, y), params)
        Flux.update!(opt, params, grads)
    end
    @info "epoch" epoch accuracy(x_train_, y_train_)
end


┌ Info: epoch
│   epoch = 1
│   accuracy(x_train_, y_train_) = 0.9514666666666667
└ @ Main /Users/zhenchen/Julia_MNIST_VAE/mnist.ipynb:9
┌ Info: epoch
│   epoch = 2
│   accuracy(x_train_, y_train_) = 0.9566166666666667
└ @ Main /Users/zhenchen/Julia_MNIST_VAE/mnist.ipynb:9


┌ Info: epoch
│   epoch = 3
│   accuracy(x_train_, y_train_) = 0.9487833333333333
└ @ Main /Users/zhenchen/Julia_MNIST_VAE/mnist.ipynb:9
┌ Info: epoch
│   epoch = 4
│   accuracy(x_train_, y_train_) = 0.9670166666666666
└ @ Main /Users/zhenchen/Julia_MNIST_VAE/mnist.ipynb:9


┌ Info: epoch
│   epoch = 5
│   accuracy(x_train_, y_train_) = 0.9654666666666667
└ @ Main /Users/zhenchen/Julia_MNIST_VAE/mnist.ipynb:9
┌ Info: epoch
│   epoch = 6
│   accuracy(x_train_, y_train_) = 0.9729166666666667
└ @ Main /Users/zhenchen/Julia_MNIST_VAE/mnist.ipynb:9


┌ Info: epoch
│   epoch = 7
│   accuracy(x_train_, y_train_) = 0.97915
└ @ Main /Users/zhenchen/Julia_MNIST_VAE/mnist.ipynb:9
┌ Info: epoch
│   epoch = 8
│   accuracy(x_train_, y_train_) = 0.9797666666666667
└ @ Main /Users/zhenchen/Julia_MNIST_VAE/mnist.ipynb:9


┌ Info: epoch
│   epoch = 9
│   accuracy(x_train_, y_train_) = 0.9768333333333333
└ @ Main /Users/zhenchen/Julia_MNIST_VAE/mnist.ipynb:9
┌ Info: epoch
│   epoch = 10
│   accuracy(x_train_, y_train_) = 0.9770166666666666
└ @ Main /Users/zhenchen/Julia_MNIST_VAE/mnist.ipynb:9


In [13]:
testset = MNIST.testdata()
@info "data" size(testset[1]) size(testset[2])

┌ Info: data
│   size(testset[1]) = (28, 28, 10000)
│   size(testset[2]) = (10000,)
└ @ Main /Users/zhenchen/Julia_MNIST_VAE/mnist.ipynb:2


In [14]:
x_test, y_test =trainset[:]
x_test_ = reshape(x_test, 28^2, :)
y_test_ = convert(Array{Float32}, onehotbatch(y_test, 0:9))

10×60000 Matrix{Float32}:
 0.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  1.0  0.0  0.0  1.0  0.0     0.0  1.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  1.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0     0.0  0.0  0.0  1.0  0.0  0.0  0.0
 0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  1.0  0.0  0.0  0.0  1.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  1.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0
 0.0  0.0  0.0  0.0  1.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0

In [15]:
lossF(x_test_, y_test_)

0.08228423f0

In [16]:
accuracy(x_test_, y_test_)

0.9770166666666666