In [90]:
using Flux
using Base.Iterators: repeated, partition
using DelimitedFiles

## manual RNN

In [30]:
Wxh = randn(5, 10)
Whh = randn(5, 5)
b   = randn(5)


function rnn(h, x)
  h = tanh.(Wxh * x .+ Whh * h .+ b)
  return h, h
end

x = rand(10) # dummy data
h = rand(5)  # initial hidden state

h, y = rnn(h, x)

([-0.99856, 0.968526, -0.45685, 0.739666, 0.999984], [-0.99856, 0.968526, -0.45685, 0.739666, 0.999984])

## Ready RNN

In [238]:
m = Chain(
  Flux.RNN(5, 2),
  Dense(2, 5)
  )

function loss(xs, ys)
  l = Flux.mse(m(xs), ys)
  println(l)
  Flux.truncate!(m)
  return l
end

opt = ADAM(params(m), 0.01)

#43 (generic function with 1 method)

In [239]:
@show size(params(m)[1])
@show size(params(m)[2])
@show size(params(m)[3])
@show size(params(m)[4])
@show size(params(m)[5])
@show size(params(m)[6])

size((params(m))[1]) = (2, 5)
size((params(m))[2]) = (2, 2)
size((params(m))[3]) = (2,)
size((params(m))[4]) = (2,)
size((params(m))[5]) = (5, 2)
size((params(m))[6]) = (5,)


(5,)

In [240]:
m_lstm = Chain(
  Flux.LSTM(5, 2),
  Dense(2, 5)
  )

@show size(params(m_lstm)[1])
@show size(params(m_lstm)[2])
@show size(params(m_lstm)[3])
@show size(params(m_lstm)[4])
@show size(params(m_lstm)[5])
@show size(params(m_lstm)[6])

size((params(m_lstm))[1]) = (8, 5)
size((params(m_lstm))[2]) = (8, 2)
size((params(m_lstm))[3]) = (8,)
size((params(m_lstm))[4]) = (2,)
size((params(m_lstm))[5]) = (2,)
size((params(m_lstm))[6]) = (5, 2)


(5, 2)

In [241]:
d,h = readdlm("rnn_data.csv",',',header=true)

([178.0 109.0; 1457.0 1361.0; … ; 356.0 333.0; 1954.0 1016.0], AbstractString["x" "y"])

In [242]:
x=d[:,1];
y=d[:,2];

In [243]:
Xs=collect(partition(x,5)); #partiion each 10 elements as batch
Ys=collect(partition(y,5));

In [244]:
Xs

40-element Array{Array{Float64,1},1}:
 [178.0, 1457.0, 177.0, 2211.0, 671.0]  
 [1191.0, 481.0, 537.0, 1530.0, 280.0]  
 [4258.0, 208.0, 193.0, 249.0, 887.0]   
 [1264.0, 163.0, 1706.0, 130.0, 121.0]  
 [200.0, 15.0, 118.0, 176.0, 1341.0]    
 [1521.0, 2806.0, 212.0, 973.0, 1141.0] 
 [385.0, 192.0, 368.0, 268.0, 298.0]    
 [423.0, 252.0, 479.0, 5352.0, 306.0]   
 [2055.0, 517.0, 2087.0, 134.0, 355.0]  
 [206.0, 3572.0, 97.0, 338.0, 84.0]     
 [108.0, 726.0, 355.0, 237.0, 260.0]    
 [459.0, 343.0, 165.0, 230.0, 421.0]    
 [417.0, 630.0, 802.0, 541.0, 232.0]    
 ⋮                                      
 [1032.0, 2543.0, 6164.0, 714.0, 1462.0]
 [528.0, 421.0, 303.0, 1967.0, 111.0]   
 [2105.0, 477.0, 806.0, 813.0, 746.0]   
 [322.0, 469.0, 390.0, 622.0, 346.0]    
 [757.0, 598.0, 769.0, 1843.0, 266.0]   
 [435.0, 519.0, 892.0, 3370.0, 455.0]   
 [572.0, 360.0, 4840.0, 545.0, 584.0]   
 [2675.0, 1805.0, 644.0, 626.0, 537.0]  
 [1476.0, 631.0, 588.0, 2080.0, 393.0]  
 [374.0, 544.0, 330

In [245]:
dataset=collect(zip(Xs,Ys) );

In [246]:
Flux.train!(loss, dataset, opt)

533603.4893578403 (tracked)
657872.1255309687 (tracked)
3.3951454120171554e6 (tracked)
542049.8238607257 (tracked)
330842.55057825125 (tracked)
2.0020564210947428e6 (tracked)
56210.75553586813 (tracked)
5.4853134147440735e6 (tracked)
1.645133324784847e6 (tracked)
2.423962767930881e6 (tracked)
66512.82006851111 (tracked)
69698.01971671465 (tracked)
189181.15724903328 (tracked)
242833.2515386131 (tracked)
864144.8449620723 (tracked)
246061.5820698685 (tracked)
1.0582942415753312e6 (tracked)
69172.57743917863 (tracked)
116488.2225837993 (tracked)
868797.1661508481 (tracked)
393134.2700466471 (tracked)
2.947929043654736e6 (tracked)
77536.2407785309 (tracked)
525706.4876866997 (tracked)
1.4582918076257245e6 (tracked)
59315.31796561587 (tracked)
180150.40089882832 (tracked)
478159.10003341193 (tracked)
8.876239084243275e6 (tracked)
780194.1613656018 (tracked)
1.002797405284827e6 (tracked)
137511.34900527768 (tracked)
762743.7608346053 (tracked)
1.3169785809434145e6 (tracked)
4.65425121383290

In [229]:
m(Xs[1])

Tracked 5-element Array{Float64,1}:
  0.3362286152424313
  1.7986437340884203
 -0.553178836489123 
 -0.6213351530805327
  0.7704204754144215