In [2]:
using DifferentialEquations, DiffEqFlux, Lux, Random, Optimization, OptimizationOptimisers, ComponentArrays, Zygote, Statistics
using MLDatasets: MNIST

imgs, nums = MNIST().features, MNIST().targets

function onehot(nums::AbstractVector)
    n = length(nums)
    ret = zeros(Int, 10, n)
    for j = 1:n
        ret[nums[j]+1, j] = 1
    end
    ret
end

down = Chain(FlattenLayer(), Dense(784, 20, tanh))
rng = MersenneTwister()
down_p, down_st = Lux.setup(rng, down)

nn = Chain(Dense(20, 10, tanh), Dense(10, 10, tanh), Dense(10, 20, tanh))
nn_ode = NeuralODE(nn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false, reltol = 1e-3, abstol = 1e-3, save_start = false)
fc = Dense(20, 10)
m = Chain(; down, nn_ode, convert = WrappedFunction(last), fc)
ps, st = Lux.setup(rng, m)
ps = ComponentArray(ps)

logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims = 1); dims = 1))
function loss_function(ps, x, y)
    pred, st_ = m(x, ps, st)
    return logitcrossentropy(pred, y), pred
end

opt_func = OptimizationFunction((ps, _, x, y) -> loss_function(ps, x, y), AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps)

const N = 100
M = length(nums) ÷ N - 1
data = [(imgs[:, :, N*(b-1)+1:N*b], onehot(nums[N*(b-1)+1:N*b])) for b = 1:M]

@time res = solve(opt_prob, Adam(0.003), data)

classify(x) = [argmax(x[:,j]) - 1 for j = 1:size(x,2)]
sum(classify(m(imgs[:,:,end-N+1:end], res.u, st)[1]) .== nums[end-N+1:end])/N

  8.105775 seconds (7.58 M allocations: 7.300 GiB, 13.02% gc time, 18.41% compilation time: 7% of which was recompilation)


0.98