In [1]:
# Opracowane na podstawie https://minpy.readthedocs.io/en/latest/tutorial/rnn_mnist.html
using MLDatasets, Flux
train_data = MLDatasets.MNIST(split=:train)
test_data  = MLDatasets.MNIST(split=:test)

function loader(data; batchsize::Int=1)
    x1dim = reshape(data.features, 28 * 28, :) # reshape 28×28 pixels into a vector of pixels
    yhot  = Flux.onehotbatch(data.targets, 0:9) # make a 10×60000 OneHotMatrix
    Flux.DataLoader((x1dim, yhot); batchsize, shuffle=true)
end

net = Chain(
    RNN((14 * 14) => 64, tanh),
    Dense(64 => 10, identity),
)

Chain(
  Recur(
    RNNCell(196 => 64, tanh),           [90m# 16_768 parameters[39m
  ),
  Dense(64 => 10),                      [90m# 650 parameters[39m
) [90m        # Total: 6 trainable arrays, [39m17_418 parameters,
[90m          # plus 1 non-trainable, 64 parameters, summarysize [39m68.406 KiB.

In [2]:
using Statistics: mean  # standard library
function loss_and_accuracy(model, data)
    (x,y) = only(loader(data; batchsize=length(data)))
    Flux.reset!(model)
    ŷ = model(x[  1:196,:])
    ŷ = model(x[197:392,:])
    ŷ = model(x[393:588,:])
    ŷ = model(x[589:end,:])
    loss = Flux.logitcrossentropy(ŷ, y)  # did not include softmax in the model
    acc = round(100 * mean(Flux.onecold(ŷ) .== Flux.onecold(y)); digits=2)
    (; loss, acc, split=data.split)  # return a NamedTuple
end

@show loss_and_accuracy(net, test_data);  # accuracy about 10%, before training

train_log = []
settings = (;
    eta = 15e-3,
    epochs = 5,
    batchsize = 100,
)

opt_state = Flux.setup(Descent(settings.eta), net);

loss_and_accuracy(net, test_data) = (loss = 2.3263302f0, acc = 15.33, split = :test)


In [None]:
using ProgressMeter

for epoch in 1:settings.epochs
    @time for (x,y) in loader(train_data, batchsize=settings.batchsize)
        Flux.reset!(net)
        grads = Flux.gradient(model -> let
                ŷ = model(x[  1:196,:])
                println(size(ŷ) )
                ŷ = model(x[197:392,:])
                ŷ = model(x[393:588,:])
                ŷ = model(x[589:end,:])
                Flux.logitcrossentropy(ŷ, y)
            end, net)
        Flux.update!(opt_state, net, grads[1])
    end
    
    loss, acc, _ = loss_and_accuracy(net, train_data)
    test_loss, test_acc, _ = loss_and_accuracy(net, test_data)
    @info epoch acc test_acc
    nt = (; epoch, loss, acc, test_loss, test_acc) 
    push!(train_log, nt)
end

Float32[-1.8128965 -0.6963422 -0.2815399 -0.64857745 -1.5193468 -0.78652096 -0.6379102 -0.94469887 -1.3035861 -0.95997006 -1.7799201 -0.7145916 -2.1201038 -1.713208 -0.75215465 -0.61507326 -1.6559812 0.8166775 -0.2806108 -1.1479326 -0.6879301 -0.7478409 -1.1139553 -0.8089196 -0.37684116 -0.7145916 -1.2220459 -0.46047443 -0.85042894 -0.6298803 -0.7145916 -1.253432 -1.3867404 -0.97103924 -1.1347003 -1.445074 -0.46844143 -1.6494486 -1.1907831 0.26640418 -1.0692838 -1.1206815 -0.7145916 -1.1865703 -0.5147885 0.034934737 -0.670249 -0.75325984 -0.7145916 -0.7288383 -0.7145916 -1.3005023 0.43205497 -0.7145916 -0.88257354 0.6159475 -0.7145916 -0.62280107 -0.3555229 -1.0630317 -0.59652716 -0.7419934 -0.7145916 -0.99159855 -0.44784033 -0.9384313 -0.7145916 0.22231674 -0.33599824 -0.4313132 -1.379068 0.39344725 -0.9719773 -0.5308203 -0.48112333 -0.7145916 -1.2115333 -0.7145916 1.8281102 -0.8109349 -0.9128069 -0.74696463 -1.3255205 -0.7338511 -1.2407115 -0.6915446 -1.119634 -0.7857814 -0.67193496 

Excessive output truncated after 524290 bytes.

 -0.36808205 -0.97809434 -1.2625178 -0.42282796 -0.6909693 -1.2810824 -0.6558448; 0.18884183 -0.22873464 0.23643382 0.2936058 0.41972578 0.07915567 0.08811772 0.6324926 -0.51325625 0.6204978 0.28426385 0.6204978 -0.051555574 0.36879 0.26077867 0.023550922 0.28036648 -0.42921093 0.6204978 0.56831217 

In [16]:
Flux.reset!(net)
x1, y1 = first(loader(train_data)); # (28×28×1×1 Array{Float32, 3}, 10×1 OneHotMatrix(::Vector{UInt32}))
y1hat = net(x1[  1:196,:])
y2hat = net(x1[197:392,:])
y3hat = net(x1[393:588,:])
y4hat = net(x1[589:end,:])
@show hcat(Flux.onecold(y1hat, 0:9), Flux.onecold(y1, 0:9))

@show loss_and_accuracy(net, train_data);

hcat(Flux.onecold(y1hat, 0:9), Flux.onecold(y1, 0:9)) = [7 5]
loss_and_accuracy(net, train_data) = (loss = 0.14942178f0, acc = 95.7, split = :train)


In [19]:
ŷ

LoadError: UndefVarError: `ŷ` not defined

In [18]:
y2hat

10×1 Matrix{Float32}:
 -1.5439503
  0.32302755
  2.0229397
 -0.14970115
 -2.5560298
  4.495308
 -2.9564798
  2.5527892
 -0.7681055
  1.46721

In [9]:
y1hat

10×1 Matrix{Float32}:
  7.5319905
 -2.5360425
  1.3435442
  1.7613591
 -3.7410812
  3.890618
 -2.5021336
 -2.069571
  1.9236335
 -0.4906675