# Flux.jlでRNN

https://www.juliabloggers.com/a-basic-rnn/

https://github.com/FluxML/model-zoo

## Generate Data

In [2]:
using Random

In [1]:
function generate_data(num_samples)
    train_data = [rand(1.0:10.0, rand(2:7)) for i in 1:num_samples]
    train_labels = (v -> sum(v)).(train_data)

    test_data = 2 .* train_data
    test_labels = 2 .* train_labels

    train_data, train_labels, test_data, test_labels
end

generate_data (generic function with 1 method)

##  Create The Model

In [3]:
using Flux

┌ Info: Precompiling Flux [587475ba-b771-5e3f-ad9e-33799f191a9c]
└ @ Base loading.jl:1278


[32m[1mDownloading[22m[39m artifact: CUDA102
[?25l[1A[2K[?25h[32m[1mDownloading[22m[39m artifact: CUDNN_CUDA102
[?25l[1A[2K[?25h[32m[1mDownloading[22m[39m artifact: CUTENSOR_CUDA102
[?25l[1A[2K[?25h

In [4]:
simple_rnn = Flux.RNN(1, 1, (x -> x))

Recur(RNNCell(1, 1, #5))

In [5]:
using Flux: @epochs

num_samples = 1000
num_epochs = 50

50

In [6]:
# generate our test data with the data generation function from above
train_data, train_labels, test_data, test_labels = generate_data(num_samples)
simple_rnn = Flux.RNN(1, 1, (x -> x))

Recur(RNNCell(1, 1, #7))

In [7]:
function eval_model(x)
    out = simple_rnn.(x)[end]
    Flux.reset!(simple_rnn)
    return out
end

eval_model (generic function with 1 method)

In [8]:
loss(x, y) = abs(sum((eval_model(x) .- y)))

ps = Flux.params(simple_rnn)

# use the ADAM optimizer. It's a pretty good one!
opt = Flux.ADAM()

println("Training loss before = ", sum(loss.(train_data, train_labels)))
println("Test loss before = ", sum(loss.(test_data, test_labels)))

Training loss before = 60893.60434684237
Test loss before = 122420.30913953678


In [9]:
# callback function during training
evalcb() = @show(sum(loss.(test_data, test_labels)))

@epochs num_epochs Flux.train!(loss, ps, zip(train_data, train_labels), opt, cb = Flux.throttle(evalcb, 1))

┌ Info: Epoch 1
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 122185.13194356892
sum(loss.(test_data, test_labels)) = 50602.97853921002


┌ Info: Epoch 2
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 9579.718361095947


┌ Info: Epoch 3
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 8315.15643725877


┌ Info: Epoch 4
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 7546.378205671473


┌ Info: Epoch 5
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 6810.183951943045


┌ Info: Epoch 6
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 5999.728843972561


┌ Info: Epoch 7
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 5385.899285481914


┌ Info: Epoch 8
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 4767.912806135723


┌ Info: Epoch 9
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 3835.6939709043177


┌ Info: Epoch 10
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 2964.399765264607


┌ Info: Epoch 11
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 2175.9858431897237


┌ Info: Epoch 12
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114
┌ Info: Epoch 13
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 1407.6860375724978
sum(loss.(test_data, test_labels)) = 466.0686293698433


┌ Info: Epoch 14
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 219.541676628021


┌ Info: Epoch 15
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 60.676452242215916


┌ Info: Epoch 16
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114
┌ Info: Epoch 17
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 50.21757394452742
sum(loss.(test_data, test_labels)) = 39.93966921727469


┌ Info: Epoch 18
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 117.75765106902307


┌ Info: Epoch 19
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 20.63285573435144


┌ Info: Epoch 20
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 111.60663515657038


┌ Info: Epoch 21
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114
┌ Info: Epoch 22
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 9.88091768983044
sum(loss.(test_data, test_labels)) = 19.489352641001314


┌ Info: Epoch 23
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 14.570188744432366


┌ Info: Epoch 24
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 25.41309618592338


┌ Info: Epoch 25
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 4.430830750730983


┌ Info: Epoch 26
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 85.97691950395873


┌ Info: Epoch 27
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 56.30208064651693


┌ Info: Epoch 28
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114
┌ Info: Epoch 29
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 127.70030155786422
sum(loss.(test_data, test_labels)) = 4.648479264658563


┌ Info: Epoch 30
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 23.31395590225572


┌ Info: Epoch 31
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 22.308025472001773


┌ Info: Epoch 32
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 56.94773638588923


┌ Info: Epoch 33
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 75.94996208865192


┌ Info: Epoch 34
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 156.20153746593473


┌ Info: Epoch 35
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 104.40696220938855


┌ Info: Epoch 36
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 16.810683590949697


┌ Info: Epoch 37
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 71.06113466220454


┌ Info: Epoch 38
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 83.23519764665002


┌ Info: Epoch 39
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 150.88547368324691


┌ Info: Epoch 40
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 101.99075242264308


┌ Info: Epoch 41
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 90.01889395357225


┌ Info: Epoch 42
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 13.842676666718035


┌ Info: Epoch 43
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 6.558015501136127


┌ Info: Epoch 44
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 47.35076637602955


┌ Info: Epoch 45
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114
┌ Info: Epoch 46
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 101.93264343165279
sum(loss.(test_data, test_labels)) = 22.764304933979723


┌ Info: Epoch 47
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 134.96251790609983


┌ Info: Epoch 48
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 7.988061752712836


┌ Info: Epoch 49
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


sum(loss.(test_data, test_labels)) = 37.64791955158063


┌ Info: Epoch 50
└ @ Main C:\Users\yamta\.julia\packages\Flux\05b38\src\optimise\train.jl:114


In [10]:
# after training, evaluate the loss
println("Test loss after = ", sum(loss.(test_data, test_labels)))

Test loss after = 30.506389102767205
