In [41]:
import numpy as np
import ipdb;

In [42]:
class RNN:
    def __init__(self, Wx, Wh, b):
        self.params = [Wx, Wh, b]
        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
        self.cache = None

    def forward(self, x, h_prev):
        Wx, Wh, b = self.params
        t = np.dot(h_prev, Wh) + np.dot(x, Wx) + b
        h_next = np.tanh(t)

        self.cache = (x, h_prev, h_next)
        return h_next

    def backward(self, dh_next):
        Wx, Wh, b = self.params
        x, h_prev, h_next = self.cache

        dt = dh_next * (1 - h_next ** 2)
        db = np.sum(dt, axis=0)
        dWh = np.dot(h_prev.T, dt)
        dh_prev = np.dot(dt, Wh.T)
        dWx = np.dot(x.T, dt)
        dx = np.dot(dt, Wx.T)

        self.grads[0][...] = dWx
        self.grads[1][...] = dWh
        self.grads[2][...] = db

        return dx, dh_prev

In [43]:
class TimeRNN:
    def __init__(self, Wx, Wh, b, stateful=False):
        self.params = [Wx, Wh, b]
        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
        self.layers = None

        self.h, self.dh = None, None
        self.stateful = stateful

    def forward(self, xs):
        Wx, Wh, b = self.params
        N, T, D = xs.shape
        D, H = Wx.shape

        self.layers = []
        hs = np.empty((N, T, H), dtype='f')

        if not self.stateful or self.h is None:
            self.h = np.zeros((N, H), dtype='f')

        for t in range(T):
            layer = RNN(*self.params)
#             print('xs[:, t, :]==',xs[:, t, :])
            self.h = layer.forward(xs[:, t, :], self.h)
#             print('self.h==',self.h)
            hs[:, t, :] = self.h
            self.layers.append(layer)

        return hs

    def backward(self, dhs):
        Wx, Wh, b = self.params
        N, T, H = dhs.shape
        D, H = Wx.shape

        dxs = np.empty((N, T, D), dtype='f')
#         ipdb.set_trace()
        print('dxs==',dxs)
        dh = 0
        grads = [0, 0, 0]
        for t in reversed(range(T)):
            
            layer = self.layers[t]
            dx, dh = layer.backward(dhs[:, t, :] + dh)
            dxs[:, t, :] = dx
#             ipdb.set_trace()
            for i, grad in enumerate(layer.grads):
                grads[i] += grad
#                 ipdb.set_trace()

        for i, grad in enumerate(grads):
            self.grads[i][...] = grad
        self.dh = dh

        return dxs

    def set_state(self, h):
        self.h = h

    def reset_state(self):
        self.h = None

In [44]:
# 初始化参数
np.random.seed(0)
N, T, D, H = 1, 3, 2, 2
xs = np.random.rand(N, T, D)
Wx = np.random.rand(D, H)
Wh = np.random.rand(H, H)
b = np.random.rand(H)
print('xs===',xs)
print('Wx===',Wx)
print('Wh===',Wh)
print('b===',b)
# 创建 TimeRNN 实例
model = TimeRNN(Wx, Wh, b, stateful=False)

# 前向传播
hs = model.forward(xs)
print("Hidden states (hs):")
# print(hs)

# 反向传播
dhs = np.random.rand(N, T, H)
dxs = model.backward(dhs)
print("Input gradients (dxs):")
print(dxs)

# 更新权重
learning_rate = 0.01
for i, param in enumerate(model.params):
    param -= learning_rate * model.grads[i]


xs=== [[[0.5488135  0.71518937]
  [0.60276338 0.54488318]
  [0.4236548  0.64589411]]]
Wx=== [[0.43758721 0.891773  ]
 [0.96366276 0.38344152]]
Wh=== [[0.79172504 0.52889492]
 [0.56804456 0.92559664]]
b=== [0.07103606 0.0871293 ]
Hidden states (hs):
dxs== [[[0.7617584  0.6914766 ]
  [0.9522868  0.95418036]
  [0.97450584 0.9704194 ]]]
Input gradients (dxs):
[[[0.46781406 0.24244477]
  [0.10945939 0.10787515]
  [0.06309512 0.06533264]]]
