In [6]:
using Flux

num_samples = 1000
num_epochs = 50

function generate_data(num_samples)
	# we just generate data of variable length from 2 to 7 elements with each element being
	# a float between 1 and 10, to keep it simple!
	train_data = [rand(1.0:10.0, rand(2:7)) for i in 1:num_samples]
	train_labels = (v -> sum(v)).(train_data)

	# why bother generating new data when you can just multiply your
	# test data! No really, in real models you never want to do this
	# because that means you're evaluating on your training data which
	# is a big no-no. For learning purposes, it works great!
	test_data = 2 .* train_data
	test_labels = 2 .* train_labels

	train_data, train_labels, test_data, test_labels
end

train_data, train_labels, test_data, test_labels = generate_data(num_samples)
# we use no activation because tanh limits to the range -1,1 which is NOT
# what you want out of a summation function
simple_rnn = Flux.RNN(1, 1, (x -> x))

function eval_model(x)
    out = simple_rnn.(x)[end]
    Flux.reset!(simple_rnn)
    out
end

loss(x, y) = abs(sum((eval_model(x) .- y)))

ps = Flux.params(simple_rnn)

opt = Flux.ADAM()

println("Training loss before = ", sum(loss.(train_data, train_labels)))
println("Test loss before = ", sum(loss.(test_data, test_labels)))

# callback function during training
evalcb() = @show(sum(loss.(test_data, test_labels)))

@epochs num_epochs Flux.train!(loss, ps, zip(train_data, train_labels), opt, cb = Flux.throttle(evalcb, 1))

# after training, evaluate the loss
println("Test loss after = ", sum(loss.(test_data, test_labels)))

Training loss before = 28738.599622982303
Test loss before = 56855.57012622556


┌ Info: Epoch 1
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 56838.658664039096
sum(loss.(test_data, test_labels)) = 48325.28622767743


┌ Info: Epoch 2
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 4539.116326307559


┌ Info: Epoch 3
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 4019.4370400076295


┌ Info: Epoch 4
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 3160.8515734057537


┌ Info: Epoch 5
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 2622.056858645405


┌ Info: Epoch 6
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 1987.0751789573087


┌ Info: Epoch 7
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 1032.095960395042


┌ Info: Epoch 8
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 482.4493236747384


┌ Info: Epoch 9
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 17.1694439962181


┌ Info: Epoch 10
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 34.601635672808975


┌ Info: Epoch 11
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 40.11898157809348


┌ Info: Epoch 12
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 40.029688479828735


┌ Info: Epoch 13
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 47.83626396495533


┌ Info: Epoch 14
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 53.64699739702939


┌ Info: Epoch 15
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 18.765808258312177


┌ Info: Epoch 16
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 21.179019729308102


┌ Info: Epoch 17
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 68.30404678891003


┌ Info: Epoch 18
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 27.720244752277303


┌ Info: Epoch 19
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 31.705385287902462


┌ Info: Epoch 20
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 57.94570536804357


┌ Info: Epoch 21
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 13.295772622560666


┌ Info: Epoch 22
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 36.08657506587092


┌ Info: Epoch 23
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 9.924635298620837


┌ Info: Epoch 24
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 33.36238141242671


┌ Info: Epoch 25
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 8.925530578268559


┌ Info: Epoch 26
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 43.54831489538057


┌ Info: Epoch 27
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 24.773282131951945


┌ Info: Epoch 28
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 33.87782247821441


┌ Info: Epoch 29
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 26.441601074066693


┌ Info: Epoch 30
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 73.58762165313166


┌ Info: Epoch 31
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 15.478175814727493


┌ Info: Epoch 32
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 31.700111218454467


┌ Info: Epoch 33
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 32.644846329031246


┌ Info: Epoch 34
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 7.339873224315917


┌ Info: Epoch 35
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 37.25650422510425


┌ Info: Epoch 36
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 9.066719025384039


┌ Info: Epoch 37
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 20.31696387435766


┌ Info: Epoch 38
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 23.67403261655109


┌ Info: Epoch 39
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 16.4697408844107


┌ Info: Epoch 40
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 46.37106998767376


┌ Info: Epoch 41
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 47.724489136499926


┌ Info: Epoch 42
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 15.823037877852178


┌ Info: Epoch 43
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 43.91808554940699


┌ Info: Epoch 44
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 16.251334042385423


┌ Info: Epoch 45
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 36.13019392217473


┌ Info: Epoch 46
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 39.92448349928947


┌ Info: Epoch 47
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 38.06263296285502


┌ Info: Epoch 48
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 42.182911179632406


┌ Info: Epoch 49
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


sum(loss.(test_data, test_labels)) = 22.698289912518508


┌ Info: Epoch 50
└ @ Main /home/emoryfreitas/.julia/packages/Flux/NpkMm/src/optimise/train.jl:119


Test loss after = 63.464639401712056
