In [1]:
import numpy as np
import matplotlib.pyplot as plt

# RNN

In [27]:
class RNN:
    def __init__(self, Wx, Wh, b):
        self.params = [Wx, Wh, b]
        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
        self.cache = None

    def forward(self, x, h_prev):
        Wx, Wh, b = self.params
        t = np.dot(h_prev, Wh) + np.dot(x, Wx) + b
        h_next = np.tanh(t)

        self.cache = (x, h_prev, h_next)
        return h_next

    def backward(self, dh_next):
        Wx, Wh, b = self.params
        x, h_prev, h_next = self.cache

        dt = dh_next * (1 - h_next**2)
        db = np.sum(dt, axis=0)
        dWh = np.dot(h_prev.T, dt)
        dh_prev = np.dot(dt, Wh.T)
        dWx = np.dot(x.T, dt)
        dx = np.dot(dt, Wx.T)

        self.grads[0][...] = dWx
        self.grads[1][...] = dWh
        self.grads[2][...] = db

        return dx, dh_prev

In [28]:
Wx = np.random.rand(5, 6)
Wh = np.random.rand(6, 6)
b = 1
h_prev = np.arange(18).reshape(3, 6)
x = np.arange(15).reshape(3, 5)

rnn = RNN(Wx, Wh, b)
rnn_out = rnn.forward(x, h_prev)
print(rnn_out.shape)

(3, 6)


In [29]:
rnn.cache

(array([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]]),
 array([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17]]),
 array([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]]))

# TimeRNN

In [26]:
class TimeRNN:
    def __init__(self, Wx, Wh, b, stateful=False):
        self.params = [Wx, Wh, b]
        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
        # RNNレイヤ保持のためself.layers
        self.layers = None
        # 前時刻の隠れ状態とその勾配を保持するためのself.h, self.dh
        self.h, self.dh = None, None
        # statefulは隠れ状態を保持するかしないか（Trueで保持，Falseで保持しない）
        self.stateful = stateful

    def set_state(self, h):
        self.h = h

    def reset_state(self):
        self.h = None

    def forward(self, xs):
        Wx, Wh, b = self.params
        # xsはバッチ数N，系列長T，D次元を想定
        N, T, D = xs.shape
        D, H = Wx.shape

        self.layers = []
        hs = np.empty((N, T, H), dtype=float)

        if not self.stateful or self.h is None:
            self.h = np.zeros((N, H), dtype=float)

        for t in range(T):
            layer = RNN(*self.params)
            self.h = layer.forward(xs[:, t, :], self.h)
            hs[:, t, :] = self.h
            self.layers.append(layer)

            return hs

    def backward(self, dhs):
        Wx, Wh, b = self.params
        N, T, H = dhs.shape
        D, H = Wx.shape

        dxs = np.empty((N, T, D), dtype=float)
        dh = 0
        grads = [0, 0, 0]
        for t in reversed(range(T)):
            layer = self.layers[t]
            dx, dh = layer.backward(dhs[:, t, :] + dh)
            dxs[:, t, :] = dx

            for i, grad in enumerate(layer.grads):
                grads[i] += grad

        for i, grad in enumerate(grads):
            self.grads[i][...] = grad
        self.dh = dh

        return dxs

In [22]:
np.empty((3,4), dtype=float)

array([[1.28822975e-231, 1.28822975e-231, 2.96439388e-323,
        0.00000000e+000],
       [6.79038653e-313, 5.64233733e-067, 4.97285413e-091,
        1.34994249e+161],
       [1.55054610e+184, 2.65815925e+179, 3.99910963e+252,
        8.34402833e-309]])

In [25]:
Wx = np.random.rand(5, 6)
Wh = np.random.rand(6, 6)
b = 1
h_prev = np.arange(18).reshape(3, 6)
x = np.arange(15).reshape(3, 5)

params = [Wx, Wh, b]
