In [1]:
using MLDatasets

MNIST()

dataset MNIST:
  metadata  =>    Dict{String, Any} with 3 entries
  split     =>    :train
  features  =>    28×28×60000 Array{Float32, 3}
  targets   =>    60000-element Vector{Int64}

In [2]:
digits = MNIST()
X,y =  digits.features, digits.targets
size(X), extrema(X)
# we can see that the unit range has been normalized
# also known as min-max scaling which scales features
# to lie in the interval [0; 1]

((28, 28, 60000), (0.0f0, 1.0f0))

In [3]:
trainset = MNIST(:train)
# or call digits.split
Xtrain, ytrain = trainset[:] # return all observations
summary(Xtrain), summary(ytrain)

("28×28×60000 Array{Float32, 3}", "60000-element Vector{Int64}")

In [4]:
testset = MNIST(:test)
Xtest, ytest = testset[:]
summary(Xtest), summary(ytest)

("28×28×10000 Array{Float32, 3}", "10000-element Vector{Int64}")

In [33]:
using Flux, OneHotArrays

In [34]:
fluxonehoty = onehotbatch(ytrain, 0:9)

10×60000 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
 ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  …  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  1  ⋅  ⋅  1  ⋅  1  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  1  ⋅  1     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅
 ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  …  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  1  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  1
 ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  1  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅

In [35]:

W = rand(Float32, 3, 4)
b = [0.0f0, 0.0f0, 0.0f0]
m(W, b, X) = W * X .+ b

softmax(X) = exp.(X) ./ sum(exp.(X), dims=1)
model(W, b, X) = softmax(m(W, b, X))

model (generic function with 1 method)

In [36]:
using MLUtils

Xtrainflat = flatten(Xtrain)
# 28 x 28 = 284
fluxmodel = Chain(Dense(784 => 10), softmax)

Chain(
  Dense(784 => 10),                     [90m# 7_850 parameters[39m
  Main.softmax,
) 

In [37]:
using MLUtils

function lossfunction(fluxmodel, features, onehotlabels)
    ŷ = fluxmodel(features)
    Flux.logitcrossentropy(ŷ, onehotlabels)
end
lossfunction(fluxmodel, Xtrainflat, fluxonehoty)

2.30521f0

In [41]:
using Statistics
function trainmodel!(loss, model, features, labels)
    dLdm, _, _ = gradient(loss, model, features, labels)
    @. model[1].weight = model[1].weight - 0.000001 * dLdm[:layers][1][:weight]
    @. model[1].bias = model[1].bias - 0.000001 * dLdm[:layers][1][:bias]
end


trainmodel!(lossfunction, fluxmodel, Xtrainflat, fluxonehoty)
mean(Flux.onecold(fluxmodel(Xtrainflat), 0:9) .== ytrain )

0.1182

In [42]:
function trainmodel!(loss, model, features, labels)
   dLdm, _, _ = gradient(loss, model, features, labels)
   @. model.weight = model.weight - 0.000001 * dLdm.weight
   @. model.bias = model.bias - 0.000001 * dLdm.bias
end

for ((xtrain, ytrain),(xval, yval)) in MLUtils.kfolds(shuffleobs((X,onehotbatch(y, 0:9))), k=5)
    # MLUtils.flatten = reshape(x, :, size(x)[end])
    xnorm = flatten(xtrain)

    xnval = flatten(xval)
    
    model = Dense(784 => 10)

    # run an infinite loop that breaks once change in loss is < δ
    loss_init = Inf;
    while true
        trainmodel!(lossfunction, model, xnorm, ytrain)
        # intialize loss value
        if loss_init == Inf
            loss_init = lossfunction(model, xnorm, ytrain)
            println("loss initialized at ", loss_init)
            continue
        end
        # convergence check: break if change in loss is <  (1 / 10³)
        if abs(loss_init - lossfunction(model, xnorm, ytrain)) < 1e-4
            break
        else
            loss_init = lossfunction(model, xnorm, ytrain)
        end
    end
    println(lossfunction(model, xnorm, ytrain))
    println("validation score: ", lossfunction(model, xnval, yval))
end




loss initialized at 2.3488715
2.34887
validation score: 2.3505385
loss initialized at 2.3256462
2.3256445
validation score: 2.3309221
loss initialized at 2.4060252
2.4060235
validation score: 2.4090607
loss initialized at 2.3913388
2.3913374
validation score: 2.3935323
loss initialized at 2.3357458
2.3357444
validation score: 2.3286624
