In [3]:
# Opracowane na podstawie https://minpy.readthedocs.io/en/latest/tutorial/rnn_mnist.html
using MLDatasets, Flux
train_data = MLDatasets.MNIST(split=:train)
test_data  = MLDatasets.MNIST(split=:test)

function loader(data; batchsize::Int=1)
    x1dim = reshape(data.features, 28 * 28, :) # reshape 28×28 pixels into a vector of pixels
    yhot  = Flux.onehotbatch(data.targets, 0:9) # make a 10×60000 OneHotMatrix
    Flux.DataLoader((x1dim, yhot); batchsize, shuffle=true)
end

net = Chain(
    RNN((14 * 14) => 64, tanh),
    Dense(64 => 10, identity),
)

Chain(
  Recur(
    RNNCell(196 => 64, tanh),           [90m# 16_768 parameters[39m
  ),
  Dense(64 => 10),                      [90m# 650 parameters[39m
) [90m        # Total: 6 trainable arrays, [39m17_418 parameters,
[90m          # plus 1 non-trainable, 64 parameters, summarysize [39m68.406 KiB.

In [9]:
using Statistics: mean  # standard library
function loss_and_accuracy(model, data)
    (x,y) = only(loader(data; batchsize=length(data)))
    Flux.reset!(model)
    ŷ = model(x[  1:196,:])
    display(x[  1:196,:])
    ŷ = model(x[197:392,:])
    ŷ = model(x[393:588,:])
    ŷ = model(x[589:end,:])
    loss = Flux.logitcrossentropy(ŷ, y)  # did not include softmax in the model
    acc = round(100 * mean(Flux.onecold(ŷ) .== Flux.onecold(y)); digits=2)
    (; loss, acc, split=data.split)  # return a NamedTuple
end

@show loss_and_accuracy(net, test_data);  # accuracy about 10%, before training

train_log = []
settings = (;
    eta = 15e-3,
    epochs = 5,
    batchsize = 100,
)

opt_state = Flux.setup(Descent(settings.eta), net);

196×10000 Matrix{Float32}:
 0.0        0.0  0.0       0.0  …  0.0        0.0       0.0       0.0
 0.0        0.0  0.0       0.0     0.0        0.0       0.0       0.0
 0.0        0.0  0.0       0.0     0.0        0.0       0.0       0.0
 0.0        0.0  0.0       0.0     0.0        0.0       0.0       0.0
 0.0        0.0  0.0       0.0     0.0        0.0       0.0       0.0
 0.0        0.0  0.0       0.0  …  0.0        0.0       0.0       0.0
 0.0        0.0  0.0       0.0     0.0        0.0       0.0       0.0
 0.0        0.0  0.0       0.0     0.0        0.0       0.0       0.0
 0.0        0.0  0.0       0.0     0.0        0.0       0.0       0.0
 0.0        0.0  0.0       0.0     0.0        0.0       0.0       0.0
 ⋮                              ⋱                                 
 0.0117647  0.0  0.960784  0.0     0.0        0.992157  0.960784  0.0
 0.0        0.0  0.988235  0.0     0.0196078  0.992157  0.427451  0.0
 0.0        0.0  0.290196  0.0     0.403922   0.988235  0.0       

loss_and_accuracy(net, test_data) = nothing


In [11]:
using ProgressMeter

for epoch in 1:settings.epochs
    @time for (x,y) in loader(train_data, batchsize=settings.batchsize)
        Flux.reset!(net)
        grads = Flux.gradient(model -> let
                ŷ = model(x[  1:196,:])
                ŷ = model(x[197:392,:])
                ŷ = model(x[393:588,:])
                ŷ = model(x[589:end,:])
                Flux.logitcrossentropy(ŷ, y)
            end, net)
        Flux.update!(opt_state, net, grads[1])
    end
    
    loss, acc, _ = loss_and_accuracy(net, train_data)
    test_loss, test_acc, _ = loss_and_accuracy(net, test_data)
    @info epoch acc test_acc
    nt = (; epoch, loss, acc, test_loss, test_acc) 
    push!(train_log, nt)
end

10×60000 Matrix{Float32}:
 -0.213655   -0.332257    0.0995641  …  -0.73018     0.208137    0.0638069
  0.133647   -0.361153    0.302384      -0.36615     0.176838   -0.478121
 -0.199316   -0.822529   -0.426453      -0.580882    0.0509286  -0.96359
 -0.273503   -0.0808809  -0.43193        0.0125504  -0.0694752   0.471783
  0.52688    -0.301246    0.214178      -0.720722    0.28851    -0.296031
 -0.123806   -0.0633075   0.131075   …  -0.712444    0.198246    0.161361
 -0.262311   -0.384467   -0.313968      -0.0554412  -0.13532    -0.799805
  0.250082    1.53114    -0.0790297      2.16761     0.028474    1.38873
 -0.330798   -0.0832345  -0.467489      -0.406275   -0.36744     0.434755
  0.0197034   0.570053   -0.0917165      1.00348    -0.109199    0.362981

10×10000 Matrix{Float32}:
 -0.291492    -0.312734   -0.33603    …  -0.740249  -0.126825  -0.348973
 -0.00408123  -0.0189533   0.0530193     -0.452925   0.282625  -0.0890535
 -0.651393    -0.168041   -0.207269      -0.460895  -0.140149   0.166949
 -0.20202     -0.0866933  -0.194527       0.509621  -0.06125    0.330517
 -0.430193    -0.806245    0.660921      -0.726146   0.241251  -0.109103
 -0.151222    -0.146755   -0.0665411  …  -0.553347   0.145752  -0.120537
 -0.324633    -0.662058   -0.274155      -0.680688  -0.129487  -0.19254
  1.76352      1.11444     0.421602       0.994752   0.126187   0.144873
 -0.307364    -0.311092   -0.261874       0.334594  -0.282712  -0.101142
  0.671077     0.966145    0.141297       0.944174  -0.108451   0.0743774

 11.171415 seconds (22.20 M allocations: 4.006 GiB, 1.55% gc time, 93.73% compilation time)


┌ Info: 1
│   acc = 89.74
│   test_acc = 90.27
└ @ Main /home/maciek/Templates/AutomaticDifferention/addons/rnn.ipynb:18


10×60000 Matrix{Float32}:
 -0.200779   -0.553401   -0.225164   …  -0.500487   -0.219185   -0.624833
 -0.154705   -0.0121608   0.351074      -0.214506    0.247815   -0.314294
 -0.345252   -0.403799   -0.0847212     -0.854151   -0.130951   -0.898983
 -0.280145   -0.343904   -0.0643481     -0.316027    0.044967   -0.316845
 -0.0429397  -0.079158    0.244391      -0.14866     0.446781   -0.546441
  0.0860017   0.0244702   0.256749   …  -0.0225209   0.431617    0.0588507
 -0.993717   -0.365934   -0.173278      -0.683438   -0.28239    -1.20909
  1.1833      1.62554     0.190716       1.57744     0.490149    1.95773
 -0.161488   -0.527498   -0.401482      -0.434034   -0.274072   -0.591619
  0.899614    0.889598   -0.13617        0.688616    0.0765667   1.83453

10×10000 Matrix{Float32}:
 -0.326752    -0.728735  -0.598657  …  -0.493208  -0.209173   -1.72137
  0.209063    -0.371232  -0.478126     -0.243721   0.277037   -0.93502
  0.181186    -0.832342  -1.03505      -0.26147   -0.105785   -0.63891
  0.00680358  -0.313581  -0.371656     -0.153781   0.0417817   0.950983
  0.176378    -0.256782  -0.140796     -0.136459   0.234216   -0.349626
  0.169193    -0.253577  -0.239673  …   0.156352   0.204077   -0.998898
 -0.115184    -0.46146   -1.05378      -1.19511   -0.212033   -0.629795
  0.127235     2.36       1.03322       1.10211    0.187509    2.94148
 -0.301738    -0.709066  -0.530417      0.174638  -0.339448   -0.774505
  0.0113742    1.0227     1.76178       0.92415   -0.0668638   2.59625

  0.601947 seconds (541.34 k allocations: 2.630 GiB, 18.56% gc time)


┌ Info: 2
│   acc = 91.87
│   test_acc = 92.35
└ @ Main /home/maciek/Templates/AutomaticDifferention/addons/rnn.ipynb:18


10×60000 Matrix{Float32}:
 -0.791523   -1.09391    -0.400168   …  -0.603425   -1.47814    -0.287702
 -0.539374   -0.914684   -0.163786      -0.0208185  -0.0301288  -0.375645
 -0.272132   -0.0951207  -0.14889       -1.0495     -0.302379   -0.8125
 -0.0571119   0.106665    0.158798      -0.553878    0.933078    1.34394
 -0.210254    0.0907017   0.380034      -0.137197    0.407035   -0.88684
 -0.0505518  -0.758102    0.0834419  …   0.388627   -0.251212    0.209671
 -0.973693   -0.506666   -0.596113      -1.01909    -0.231229   -1.02338
  1.78963     2.00512     0.269586       0.848024    1.1465      1.59169
  0.0581551  -0.238166   -0.267497      -0.375554   -0.556782    0.37214
  1.48198     2.58684     0.609732       0.692398    1.95898     0.904081

10×10000 Matrix{Float32}:
 -0.881124   -0.323084   -0.267111   …  -1.67933   -0.806752  -0.280699
 -0.646654    0.235787    0.302173      -0.977954  -0.617877  -1.24712
 -0.833811   -0.12343    -0.430103      -0.512515  -1.05076   -1.65652
 -0.0438811  -0.395326   -0.206539       0.624714  -0.180314   0.115356
 -0.481023    0.275634    0.0562282     -0.560826  -0.557521  -0.792409
 -0.0171473  -0.0827337  -0.0381311  …  -0.926414   0.116595   0.432635
 -0.878833   -0.351964   -0.619013      -0.423391  -1.18423   -1.23753
  3.13178     0.97081     0.955328       3.0837     3.23488    3.04162
 -0.318903   -0.455663   -0.257382      -0.944428  -0.414176   0.320164
  1.91899     0.131579    0.521838       2.82635    2.02911    1.34232

  0.693958 seconds (541.34 k allocations: 2.630 GiB, 32.35% gc time)


┌ Info: 3
│   acc = 93.1
│   test_acc = 93.5
└ @ Main /home/maciek/Templates/AutomaticDifferention/addons/rnn.ipynb:18


10×60000 Matrix{Float32}:
 -1.28521   -0.537523  -0.173786  …  -0.862117   -0.611252   -1.40596
 -0.870656   0.413839   0.298526     -0.580864   -0.342229   -0.904176
 -1.16746   -0.424038  -0.36892      -1.02765    -0.801219   -0.803036
 -0.304882  -0.304341  -0.339997     -0.653834   -0.627162    0.408093
 -0.42949   -0.196915   0.473593      0.0949538   0.0116838  -1.07041
 -0.300876   0.120563   0.42761   …   0.106235    0.0837578  -0.404439
 -0.805428  -0.362169  -0.646465     -0.26861    -0.602566   -0.618098
  2.78755    1.17654    0.146687      2.62382     2.02927     3.60371
 -0.405765  -0.5689    -0.519167     -0.693966   -0.197536   -0.846186
  2.36855    0.543934   0.11836       1.24618     1.38361     2.33999

10×10000 Matrix{Float32}:
 -0.374351   -0.625781   -1.2332     …  -0.487508   -0.485577   -1.28841
  0.437545    0.0842115  -0.0587161     -0.391084    0.201528   -0.909035
  0.0310331  -0.0704714   0.44149       -1.65051    -0.112272   -0.788349
 -0.0782153  -0.10479     0.489724      -0.260045   -0.444428   -0.0873738
  0.185604    0.692803    0.772594      -0.91813     0.590256   -0.461712
  0.411619    0.377017   -0.089136   …   0.546782    0.0960975  -0.308214
 -0.225612   -0.359729   -0.0951772     -1.6593     -0.258183   -0.212443
  0.284547    0.87434     0.302544       2.59344     0.695252    3.1472
 -0.557289   -0.413551   -0.56032        0.0353868  -0.429349   -0.791719
 -0.189694    0.292076    1.23124        1.64902    -0.0138773   1.7838

  0.472516 seconds (541.34 k allocations: 2.630 GiB, 2.38% gc time)


┌ Info: 4
│   acc = 93.9
│   test_acc = 93.88
└ @ Main /home/maciek/Templates/AutomaticDifferention/addons/rnn.ipynb:18


10×60000 Matrix{Float32}:
 -0.432462   -0.620048   -0.474095   …  -0.305789   -0.521863   -1.34476
  0.480387    0.245895   -0.115739       0.276305    0.408333   -1.13653
  0.0804422  -0.0407934  -0.887538      -0.0156336   0.0826199  -0.835716
 -0.0700365   0.272757    0.0370414      0.186325   -0.017948    0.75463
  0.160817    0.126342   -0.478252       0.20585     0.267944   -0.990837
  0.473615    0.137719    0.315655   …   0.569842    0.49055    -0.377923
 -0.255719   -0.520423   -1.26891       -0.496233   -0.251506   -0.52922
  0.325616    0.756927    1.08231        0.358873    0.513236    3.19574
 -0.617947   -0.446685   -0.523259      -0.640891   -0.574454   -0.869935
 -0.212122    0.660088    1.35033       -0.01321    -0.0306596   2.37239

10×10000 Matrix{Float32}:
 -0.43869    -1.32467   -0.48793    …  -1.86499   -1.39773    -1.42661
 -0.0573533  -0.349788   0.226358      -0.997718   0.158519   -0.697928
 -0.661876   -1.6127    -0.304049      -1.88248   -0.635982   -0.787849
  0.163891   -1.11236   -0.409118       1.19799    0.292961    0.566025
 -0.641763   -0.248268   0.13731       -2.17349    0.525081   -0.562812
  0.388677    0.208352   0.0405928  …   0.18522   -0.0356763  -0.0317152
 -1.15554    -0.775594  -0.31554       -1.34807    0.195189   -0.202108
  1.20535     1.55674    0.991164       4.68596    0.860269    3.19678
 -0.331855   -0.71974   -0.441225      -0.582652  -0.77032    -0.919591
  1.23635     2.23694    0.178571       2.87533    1.62226     2.23182

  0.715945 seconds (541.34 k allocations: 2.630 GiB, 31.78% gc time)


┌ Info: 5
│   acc = 94.49
│   test_acc = 94.55
└ @ Main /home/maciek/Templates/AutomaticDifferention/addons/rnn.ipynb:18


In [4]:
Flux.reset!(net)
x1, y1 = first(loader(train_data)); # (28×28×1×1 Array{Float32, 3}, 10×1 OneHotMatrix(::Vector{UInt32}))
y1hat = net(x1[  1:196,:])
y1hat = net(x1[197:392,:])
y1hat = net(x1[393:588,:])
y1hat = net(x1[589:end,:])
@show hcat(Flux.onecold(y1hat, 0:9), Flux.onecold(y1, 0:9))

@show loss_and_accuracy(net, train_data);

hcat(Flux.onecold(y1hat, 0:9), Flux.onecold(y1, 0:9)) = [5 5]
loss_and_accuracy(net, train_data) = (loss = 0.19652821f0, acc = 94.38, split = :train)
