In [1]:
mutable struct RNN{T}
    Wx::AbstractMatrix{T}
    Wh::AbstractMatrix{T}
    b::AbstractMatrix{T}
    dWx::AbstractMatrix{T}
    dWh::AbstractMatrix{T}
    db::AbstractMatrix{T}
    
    # self.cache @ python
    x::AbstractMatrix{T}
    hp::AbstractMatrix{T}
    hn::AbstractMatrix{T}

    function (::Type{RNN})(Wx::AbstractMatrix{T}, Wh::AbstractMatrix{T}, b::AbstractMatrix{T}) where {T}
        layer = new{T}()
        layer.Wx = Wx
        layer.Wh = Wh
        layer.b = b
        layer.dWx = zeros(size(Wx))
        layer.dWh = zeros(size(Wh))
        layer.db = zeros(size(b))
        layer
    end
end

function forward(rnn::RNN{T}, x::AbstractMatrix{T}, h_prev::AbstractMatrix{T}) where {T}
    # Eq.(5.9)
    t = h_prev * rnn.Wh + x * rnn.Wx .+ b  
    h_next = tanh.(t)
    
    rnn.x = x
    rnn.hn = h_next
    rnn.hp = h_prev
    return h_next
end

function backward(rnn::RNN{T}, dh_next::AbstractMatrix{T}) where{T}
    x, h_prev, h_next = rnn.x, rnn.hp, rnn.hn
    dt = dh_next .* (1 .- h_next .^ 2)
    dh_prev = dt * Wh'
    rnn.dWh = h_prev' * dt
    rnn.dWx = x' * dt
    rnn.db .= sum(dt, dims=1)
    dx = dt * Wx'
    return dx, dh_prev
end

backward (generic function with 1 method)

### Example

In [2]:
# N: バッチサイズ
# H: 隠れ層次元
# D: データ次元
N, H, D = 5, 3, 2
Wh = ones(H, H);
Wx = ones(D, H);
b = ones(N, H);
hp = zeros(N, H);
x = Float64[0 1; 2 3; 4 5; 6 7; 8 9];
display(x)

5×2 Array{Float64,2}:
 0.0  1.0
 2.0  3.0
 4.0  5.0
 6.0  7.0
 8.0  9.0

In [3]:
rnn = RNN(Wx, Wh, b)
h_next = forward(rnn, x, hp)
display(h_next)

5×3 Array{Float64,2}:
 0.964028  0.964028  0.964028
 0.999988  0.999988  0.999988
 1.0       1.0       1.0     
 1.0       1.0       1.0     
 1.0       1.0       1.0     

In [4]:
dx, dh_prev = backward(rnn, h_next);
display(dx)
display(dh_prev)
display(rnn.dWx)
display(rnn.dWh)
display(rnn.db)

5×2 Array{Float64,2}:
 0.204328     0.204328   
 7.37287e-5   7.37287e-5 
 2.47338e-8   2.47338e-8 
 8.29736e-12  8.29736e-12
 2.66454e-15  2.66454e-15

5×3 Array{Float64,2}:
 0.204328     0.204328     0.204328   
 7.37287e-5   7.37287e-5   7.37287e-5 
 2.47338e-8   2.47338e-8   2.47338e-8 
 8.29736e-12  8.29736e-12  8.29736e-12
 2.66454e-15  2.66454e-15  2.66454e-15

2×3 Array{Float64,2}:
 4.91855e-5  4.91855e-5  4.91855e-5
 0.0681831   0.0681831   0.0681831 

3×3 Array{Float64,2}:
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0

5×3 Array{Float64,2}:
 0.0681339  0.0681339  0.0681339
 0.0681339  0.0681339  0.0681339
 0.0681339  0.0681339  0.0681339
 0.0681339  0.0681339  0.0681339
 0.0681339  0.0681339  0.0681339