In [1]:
using MLUtils: splitobs, unsqueeze
using Flux
using Flux: onehotbatch
using Flux.Data: DataLoader
using MLDatasets, Flux, JLD2, CUDA  # this will install everything if necc.

using Statistics: mean  # standard library
using ImageCore


In [31]:

folder = "runs"  # sub-directory in which to save
isdir(folder) || mkdir(folder)
filename = joinpath(folder, "lenet.jld2")
@show filename
train_data = MLDatasets.MNIST()  # i.e. split=:train
test_data = MLDatasets.MNIST(split=:test)

function loader(data::MNIST=train_data; batchsize::Int=64)
    x4dim = reshape(data.features, 28,28,1,:)   # insert trivial channel dim
    yhot = Flux.onehotbatch(data.targets, 0:9)  # make a 10×60000 OneHotMatrix
    Flux.DataLoader((x4dim, yhot); batchsize, shuffle=true) 
end

x1, y1 = first(loader()); # (28×28×1×64 Array{Float32, 3}, 10×64 OneHotMatrix(::Vector{UInt32}))


In [36]:

lenet = Chain(
    Conv((5, 5), 1=>6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6=>16, relu),
    MaxPool((2, 2)),
    Flux.flatten,
    Dense(256 => 120, relu),
    Dense(120 => 84, relu), 
    Dense(84 => 10),
)
# ) |> gpu

# sum(softmax(y1hat); dims=1)

# @show y1hat
# @show hcat(Flux.onecold(y1hat, 0:9), Flux.onecold(y1, 0:9))


10×64 Matrix{Float32}:
 -0.0039766   0.00916768  -0.0995651   -0.0273454  …   0.0706634  -0.0748312
 -0.0816891  -0.0597531   -0.08209     -0.0160386     -0.155549   -0.0893364
  0.0570726   0.0269321    0.0852055   -0.0553192      0.0806991   0.073685
 -0.021709   -0.0199056   -0.0163364   -0.0893441     -0.0567437   0.00870506
 -0.183569   -0.0962708   -0.170495    -0.151974      -0.161699   -0.168436
 -0.0349672   0.0627827   -0.00354918  -0.0508724  …  -0.0322047   0.0960027
 -0.154052   -0.0875013   -0.16829     -0.187637      -0.103768   -0.142866
 -0.0769983  -0.0138387   -0.0464541   -0.0603695     -0.0597908  -0.0429375
  0.0214283  -0.0236385   -0.0170059    0.024636       0.0277364  -0.0282913
  0.0575947   0.0729606    0.0815842    0.0470092      0.0874391   0.0230748

In [37]:

function loss_and_accuracy(model, data::MNIST=test_data)
    (x,y) = only(loader(data; batchsize=length(data)))  # make one big batch
    ŷ = model(x)
    loss = Flux.logitcrossentropy(ŷ, y)  # did not include softmax in the model
    acc = round(100 * mean(Flux.onecold(ŷ) .== Flux.onecold(y)); digits=2)
    (; loss, acc, split=data.split)  # return a NamedTuple
end

@show loss_and_accuracy(lenet);  # accuracy about 10%, before training


In [38]:

#===== TRAINING =====#
settings = (;
    eta = 3e-4,     # learning rate
    lambda = 1e-2,  # for weight decay
    batchsize = 128,
    epochs = 10,
)
train_log = []

opt_rule = OptimiserChain(WeightDecay(settings.lambda), Adam(settings.eta))
opt_state = Flux.setup(opt_rule, lenet);

for epoch in 1:settings.epochs
    # @time will show a much longer time for the first epoch, due to compilation
    @time for (x,y) in loader(batchsize=settings.batchsize)
        grads = Flux.gradient(m -> Flux.logitcrossentropy(m(x), y), lenet)
        Flux.update!(opt_state, lenet, grads[1])
    end
    # Logging & saving, but not on every epoch
    if epoch % 2 == 1
        loss, acc, _ = loss_and_accuracy(lenet)
        test_loss, test_acc, _ = loss_and_accuracy(lenet, test_data)
        @info "logging:" epoch acc test_acc
        nt = (; epoch, loss, acc, test_loss, test_acc)  # make a NamedTuple
        push!(train_log, nt)
    end
end



In [42]:
train_log

5-element Vector{Any}:
 (epoch = 1, loss = 0.21469705f0, acc = 93.69, test_loss = 0.21469706f0, test_acc = 93.69)
 (epoch = 3, loss = 0.12560661f0, acc = 96.26, test_loss = 0.12560661f0, test_acc = 96.26)
 (epoch = 5, loss = 0.105488405f0, acc = 97.11, test_loss = 0.105488405f0, test_acc = 97.11)
 (epoch = 7, loss = 0.09284543f0, acc = 97.37, test_loss = 0.09284544f0, test_acc = 97.37)
 (epoch = 9, loss = 0.08910663f0, acc = 97.32, test_loss = 0.08910663f0, test_acc = 97.32)

In [None]:
JLD2.jldsave(filename; lenet_state = Flux.state(lenet) |> cpu)
println("saved to ", filename, " after ", epoch, " epochs")

# We can re-run the quick sanity-check of predictions:
y1hat = lenet(x1)
@show hcat(Flux.onecold(y1hat, 0:9), Flux.onecold(y1, 0:9))


In [17]:

#===== INSPECTION =====#



xtest, ytest = only(loader(test_data, batchsize=length(test_data)));

# There are many ways to look at images, you won't need ImageInTerminal if working in a notebook.
# ImageCore.Gray is a special type, whick interprets numbers between 0.0 and 1.0 as shades:

xtest[:,:,1,5] .|> Gray |> transpose |> cpu


loaded_state = JLD2.load(filename, "lenet_state");
Flux.loadmodel!(lenet2, loaded_state)

@show lenet2(cpu(x1)) ≈ cpu(lenet(x1));