In [7]:
using Statistics: mean
using ProgressMeter: @showprogress
using Romeo

# This is needed to avoid interactive prompts
ENV["DATADEPS_ALWAYS_ACCEPT"] = true

# Flux is used only for loading the MNIST dataset
using MLDatasets, Flux

Num = Float32

Float32

In [8]:
train_data = MLDatasets.MNIST(split=:train)
test_data = MLDatasets.MNIST(split=:test)

function loader(data; batchsize::Int=-1)
    x1dim = reshape(Num.(data.features), 28 * 28, :)
    yhot = Romeo.onehot.(data.targets, 10, Num)
    return Flux.DataLoader((x1dim, yhot); batchsize, shuffle=true)
end

loader (generic function with 1 method)

In [9]:
net = Romeo.Network(
    Romeo.RNN(Romeo.RNNCell{Num}(14 * 14 => 64, activation=Romeo.tanh, init=Romeo.glorot_uniform)),
    Romeo.Dense{Num}(64 => 10, activation=Romeo.identity, init=Romeo.glorot_uniform),
    Romeo.Softmax{Num}()
)

function batch_process(model::Romeo.Network, data)
    (x, y) = data

    Romeo.reset!(model)

    ŷ = model(x[  1:196,:])
    ŷ = model(x[197:392,:])
    ŷ = model(x[393:588,:])
    ŷ = model(x[589:end,:])

    return Romeo.crossentropy(ŷ, Romeo.MatrixConstant(y)), ŷ
end

function loss_and_accuracy(model::Romeo.Network, data)
    (x, y) = only(loader(data; batchsize=length(data)))
    y = hcat(y...)

    loss_node, ŷ = batch_process(model, (x, y))

    loss = Romeo.forward!(loss_node)
    acc = round(100 * mean(Romeo.onecold.(eachcol(ŷ.value)) .== Romeo.onecold.(eachcol(y))); digits=2)

    (; loss, acc, split=data.split)
end

loss_and_accuracy (generic function with 1 method)

In [10]:
test_loss, test_acc, _ =  loss_and_accuracy(net, test_data)
@info "Before training" test_loss test_acc

┌ Info: Before training
│   test_loss = 3.2658546
│   test_acc = 11.62
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:2


In [11]:
settings = (;
    η = 15e-3,
    epochs = 5,
    batchsize = 100,
    threshold = 0.9f0
)

(η = 0.015, epochs = 5, batchsize = 100, threshold = 0.9f0)

In [12]:
optimizer = Romeo.Descent(settings.η)

for epoch in 1:settings.epochs
    batch_count = 0
    @time @showprogress for (x, y) in loader(train_data; batchsize=settings.batchsize)
        if length(size(x)) > 1
            y = hcat(y...)
        end
    
        loss_node, _ = batch_process(net, (x, y))

        Romeo.forward!(loss_node)

        if isnan(loss_node.value)
            @error "NaN detected"
            @show loss_node.value
            break
        end

        Romeo.backward!(loss_node, threshold=settings.threshold)
        Romeo.train!(optimizer, loss_node)

        batch_count += 1
    end

    loss, acc, _ = loss_and_accuracy(net, train_data)
    local test_loss, test_acc, _ = loss_and_accuracy(net, test_data)
    @info epoch acc loss test_acc test_loss batch_count settings.batchsize
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:26[39m[K


 28.598613 seconds (5.03 M allocations: 1.678 GiB, 4.82% gc time, 69.75% compilation time: 2% of which was recompilation)


┌ Info: 1
│   acc = 91.22
│   loss = 0.4065568
│   test_acc = 91.27
│   test_loss = 0.40810195
│   batch_count = 600
│   settings.batchsize = 100
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:28
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:07[39m[K


  7.539057 seconds (617.35 k allocations: 1.385 GiB, 1.88% gc time)


┌ Info: 2
│   acc = 93.57
│   loss = 0.3281799
│   test_acc = 93.84
│   test_loss = 0.29298812
│   batch_count = 600
│   settings.batchsize = 100
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:28
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:08[39m[K


  9.054645 seconds (617.77 k allocations: 1.385 GiB, 9.46% gc time)


┌ Info: 3
│   acc = 93.45
│   loss = 0.360613
│   test_acc = 92.94
│   test_loss = 0.36234134
│   batch_count = 600
│   settings.batchsize = 100
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:28
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:08[39m[K


  8.643232 seconds (617.14 k allocations: 1.385 GiB, 11.22% gc time)


┌ Info: 4
│   acc = 94.02
│   loss = 0.33712512
│   test_acc = 93.94
│   test_loss = 0.32813686
│   batch_count = 600
│   settings.batchsize = 100
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:28
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:07[39m[K


  7.902901 seconds (617.44 k allocations: 1.385 GiB, 1.57% gc time)


┌ Info: 5
│   acc = 93.75
│   loss = 0.34842512
│   test_acc = 93.6
│   test_loss = 0.3562206
│   batch_count = 600
│   settings.batchsize = 100
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:28
