In [1]:
using Statistics: mean
using ProgressMeter: @showprogress
using Romeo
using LinearAlgebra: norm

# Flux is used only for loading the MNIST dataset
using MLDatasets, Flux

Num = Float32

Float32

In [12]:
settings = (;
    η = 15e-3,
    epochs = 5,
    batchsize = 100,
    hi = 1
)

(η = 0.015, epochs = 5, batchsize = 100, hi = 1)

In [3]:
train_data = MLDatasets.MNIST(split=:train)
test_data = MLDatasets.MNIST(split=:test)

function loader(data; batchsize::Int=-1)
    x1dim = reshape(data.features, 28 * 28, :)
    yhot = Romeo.onehot.(data.targets, 10, Num)
    return Flux.DataLoader((x1dim, yhot); batchsize, shuffle=true)
end

loader (generic function with 1 method)

In [8]:
net = Romeo.Network(
    Romeo.RNN(Romeo.RNNCell{Num}(14 * 14 => 64, activation=Romeo.tanh, init=Romeo.glorot_uniform)),
    Romeo.Dense{Num}(64 => 10, activation=Romeo.identity, init=Romeo.glorot_uniform),
    Romeo.Softmax{Num}()
)

function batch_process(model::Romeo.Network, data)
    (x, y) = data

    Romeo.reset!(model)

    ŷ = model(x[  1:196,:])
    ŷ = model(x[197:392,:])
    ŷ = model(x[393:588,:])
    ŷ = model(x[589:end,:])

    return Romeo.crossentropy(ŷ, Romeo.MatrixConstant(y)), ŷ
end

function loss_and_accuracy(model::Romeo.Network, data)
    (x, y) = only(loader(data; batchsize=length(data)))
    y = hcat(y...)

    loss_node, ŷ = batch_process(model, (x, y))

    loss = Romeo.forward!(loss_node)
    acc = round(100 * mean(Romeo.onecold.(eachcol(ŷ.value)) .== Romeo.onecold.(eachcol(y))); digits=2)

    (; loss, acc, split=data.split)
end

loss_and_accuracy (generic function with 1 method)

In [5]:
test_loss, test_acc, _ =  loss_and_accuracy(net, test_data)
@info "Before training" test_loss test_acc

┌ Info: Before training
│   test_loss = 3.2900946
│   test_acc = 14.54
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:2


In [13]:
optimizer = Romeo.Descent(settings.η)
for epoch in 1:settings.epochs
    
    batch_count = 0
    
    @time @showprogress for (x, y) in loader(train_data; batchsize=settings.batchsize)
        if length(size(x)) > 1
            y = hcat(y...)
        end
    
        loss_node, _ = batch_process(net, (x, y))

        Romeo.forward!(loss_node)

        if isnan(loss_node.value)
            @error "NaN detected"
            @show loss_node.value
            break
        end

        Romeo.backward!(loss_node)
        Romeo.clip!(net, settings.hi)
        Romeo.train!(optimizer, loss_node)

        batch_count += 1
    end

    loss, acc, _ = loss_and_accuracy(net, train_data)
    local test_loss, test_acc, _ = loss_and_accuracy(net, test_data)
    @info epoch acc loss test_acc test_loss batch_count settings.batchsize
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:07[39m[K


  7.823538 seconds (755.75 k allocations: 1.221 GiB, 21.94% gc time, 6.73% compilation time)


┌ Info: 1
│   acc = 87.25
│   loss = 0.41411597
│   test_acc = 87.42
│   test_loss = 0.41309115
│   batch_count = 600
│   settings.batchsize = 100
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:30
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:06[39m[K


  6.977171 seconds (568.37 k allocations: 1.209 GiB, 22.29% gc time)


┌ Info: 2
│   acc = 85.17
│   loss = 0.46834666
│   test_acc = 85.48
│   test_loss = 0.4621819
│   batch_count = 600
│   settings.batchsize = 100
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:30
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:06[39m[K


  6.855787 seconds (568.36 k allocations: 1.209 GiB, 20.60% gc time)


┌ Info: 3
│   acc = 84.42
│   loss = 0.49583346
│   test_acc = 84.89
│   test_loss = 0.48098692
│   batch_count = 600
│   settings.batchsize = 100
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:30
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:06[39m[K


  6.236138 seconds (567.82 k allocations: 1.209 GiB, 21.01% gc time)


┌ Info: 4
│   acc = 85.0
│   loss = 0.4740931
│   test_acc = 85.06
│   test_loss = 0.47280595
│   batch_count = 600
│   settings.batchsize = 100
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:30
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:06[39m[K


  6.359483 seconds (567.83 k allocations: 1.209 GiB, 23.07% gc time)


┌ Info: 5
│   acc = 85.38
│   loss = 0.47610918
│   test_acc = 85.83
│   test_loss = 0.4591336
│   batch_count = 600
│   settings.batchsize = 100
└ @ Main /home/sebastian/.julia/dev/Romeo/example/mnist.ipynb:30
