In [3]:
using Statistics: mean
using ProgressMeter: @showprogress
using Romeo
using LinearAlgebra: norm

# Flux is used only for loading the MNIST dataset
using MLDatasets, Flux

In [4]:
Num = Float32

train_data = MLDatasets.MNIST(split=:train)
test_data = MLDatasets.MNIST(split=:test)

function loader(data; batchsize::Int=-1)
    x1dim = reshape(data.features, 28 * 28, :)
    yhot = Romeo.onehot.(data.targets, 10, Num)
    return Flux.DataLoader((x1dim, yhot); batchsize, shuffle=true)
end

loader (generic function with 1 method)

In [5]:
net = Romeo.Network(
    Romeo.RNN(Romeo.RNNCell{Num}(14 * 14 => 64, activation=Romeo.tanh, init=Romeo.glorot_uniform)),
    Romeo.Dense{Num}(64 => 10, activation=Romeo.identity, init=Romeo.glorot_uniform),
    Romeo.Softmax{Num}()
)

settings = (;
    η = 1e-2,
    epochs = 5,
    batchsize = 100,
    hi = 2,
    lo = 0.0001
)

(η = 0.01, epochs = 5, batchsize = 100, hi = 2, lo = 0.0001)

In [6]:
function batch_process(model::Romeo.Network, data)
    (x, y) = data

    Romeo.reset!(model)

    ŷ = model(x[  1:196,:])
    ŷ = model(x[197:392,:])
    ŷ = model(x[393:588,:])
    ŷ = model(x[589:end,:])

    return Romeo.crossentropy(ŷ, Romeo.MatrixConstant(y)), ŷ
end

# test the network on entire test data and return accuracy
function loss_and_accuracy(model::Romeo.Network, data)
    (x, y) = only(loader(data; batchsize=length(data)))
    y = hcat(y...)

    loss_node, ŷ = batch_process(model, (x, y))

    loss = Romeo.forward!(loss_node)
    acc = round(100 * mean(Romeo.onecold.(eachcol(ŷ.value)) .== Romeo.onecold.(eachcol(y))); digits=2)

    (; loss, acc, split=data.split)
end

loss_and_accuracy (generic function with 1 method)

In [7]:
test_loss, test_acc, _ = loss_and_accuracy(net, test_data)
@info "Before training" test_loss test_acc

┌ Info: Before training
│   test_loss = 3.5006883
│   test_acc = 9.49
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:2


In [8]:
optimizer = Romeo.Descent(settings.η)
for epoch in 1:settings.epochs
    @time @showprogress for (x, y) in loader(train_data)
        if length(size(x)) > 1
            y = hcat(y...)
        end
    
        loss_node, _ = batch_process(net, (x, y))

        Romeo.forward!(loss_node)

        if isnan(loss_node.value)
            @error "NaN detected"
            break
        end

        Romeo.backward!(loss_node)
        Romeo.train!(optimizer, loss_node)
    end

    loss, acc, _ = loss_and_accuracy(net, train_data)
    test_loss, test_acc, _ = loss_and_accuracy(net, test_data)
    @info epoch acc loss test_acc test_loss
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:01:12[39m[K8[39m[K


 73.147304 seconds (54.72 M allocations: 59.743 GiB, 1.74% gc time, 20.55% compilation time: 1% of which was recompilation)


┌ Info: 1
│   acc = 89.5
│   loss = 0.33622417
│   test_acc = 90.01
│   test_loss = 0.32715535
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:23
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:01:00[39m[K


 60.894325 seconds (48.33 M allocations: 59.327 GiB, 3.53% gc time)


┌ Info: 2
│   acc = 89.85
│   loss = 0.33075008
│   test_acc = 90.0
│   test_loss = 0.32828578
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:23
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:01:03[39m[K


 63.914541 seconds (48.33 M allocations: 59.327 GiB, 4.23% gc time)


┌ Info: 3
│   acc = 89.27
│   loss = 0.34383067
│   test_acc = 89.33
│   test_loss = 0.33818325
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:23
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:01:01[39m[K


 61.579762 seconds (48.33 M allocations: 59.327 GiB, 3.36% gc time)


┌ Info: 4
│   acc = 89.16
│   loss = 0.35573757
│   test_acc = 89.83
│   test_loss = 0.33765224
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:23
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:01:02[39m[K


 62.111161 seconds (48.33 M allocations: 59.327 GiB, 3.46% gc time)


┌ Info: 5
│   acc = 88.76
│   loss = 0.36127394
│   test_acc = 89.13
│   test_loss = 0.34825683
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:23
