# 双方向LSTMを計算するためのTimeBiLSTMクラスを実装する

In [1]:
import numpy as np
from common.time_layers import TimeLSTM

### [演習]
* 以下のTimeBiLSTMクラスを完成させましょう

In [2]:
class TimeBiLSTM:
    """
    双方向LSTM
    """
    def __init__(self, Wx1, Wh1, b1, Wx2, Wh2, b2, stateful=False):
        
        # レイヤの定義
        self.forward_lstm = TimeLSTM(Wx1, Wh1, b1, stateful)
        self.backward_lstm = TimeLSTM(Wx2, Wh2, b2, stateful)
        
        # パラメータ、勾配をそれぞれまとめる
        self.params = self.forward_lstm.params + self.backward_lstm.params
        self.grads = self.forward_lstm.grads + self.backward_lstm.grads

    def forward(self, xs):
        """
        順伝播
        xs : 入力データ
        """
        # 順方向のLSTM
        o1 = self.forward_lstm.forward(xs)
        
        # 逆方向のLSTM
        o2 = self.backward_lstm.forward(xs[:, ::-1]) # xsを逆順にして入力する
        o2 = o2[:, ::-1] # 結果を逆順にする
        
        # 順方向LSTMの結果と逆方向LSTMの結果を結合する
        out = np.concatenate((o1, o2), axis=2)
        return out

    def backward(self, dhs):
        """
        逆伝播
        dhs : 勾配
        """
        H = dhs.shape[2] // 2
        do1 = dhs[:, :, :H]
        do2 = dhs[:, :, H:]

        dxs1 = self.forward_lstm.backward(do1)
        do2 = do2[:, ::-1]
        dxs2 = self.backward_lstm.backward(do2)
        dxs2 = dxs2[:, ::-1]
        dxs = dxs1 + dxs2
        return dxs

In [7]:
# 語彙数
V = 3
# 埋め込み後次元数
D = 3
# 中間層ノード数
H = 4
# データ数
N = 3
# 単語数
T = 5

rn = np.random.randn
Wx1 = (rn(D, 4 * H) / np.sqrt(D))
Wh1 = (rn(H, 4 * H) / np.sqrt(H))
b1 = np.zeros(4 * H)
Wx2 = (rn(D, 4 * H) / np.sqrt(D))
Wh2 = (rn(H, 4 * H) / np.sqrt(H))
b2 = np.zeros(4 * H)

# モデル構築
Wx1, Wh1, b1, Wx2, Wh2, b2
tb = TimeBiLSTM(Wx1, Wh1, b1, Wx2, Wh2, b2)


xs = np.random.randint(0, V, N*T*D).reshape(N, T, D)
print("xs=", xs)
print()

# 順伝播計算
out = tb.forward(xs)
print("out=", out)
print()

# 逆伝播計算
dhs = np.random.randn(N*T*H*2).reshape(N, T, H*2)
dxs = tb.backward(dhs)
print("dxs=", dxs)
print()

xs= [[[2 0 2]
  [0 2 0]
  [0 0 2]
  [2 2 1]
  [1 2 2]]

 [[2 1 1]
  [1 0 2]
  [2 2 2]
  [2 0 1]
  [0 1 0]]

 [[1 0 2]
  [0 1 0]
  [2 1 2]
  [1 1 2]
  [2 0 0]]]

out= [[[-0.4442251  -0.24769719 -0.00777811  0.06645191 -0.02615059
   -0.06437904 -0.00236975  0.00141259]
  [-0.06663981  0.3259179   0.08993442  0.26715383  0.00809839
    0.02414597  0.07904651  0.10039306]
  [-0.1940498   0.10111785  0.01061701  0.13221097  0.0040839
   -0.03207225 -0.01238621 -0.22779877]
  [-0.24917865 -0.03778956  0.01361174  0.38776007 -0.04340029
   -0.199775    0.00074464 -0.01183767]
  [-0.30091047  0.0860653   0.01471565  0.3778718   0.0064317
   -0.13469179 -0.0029961  -0.05469006]]

 [[-0.29237068 -0.24848768 -0.03015092  0.20260546 -0.08885594
   -0.1520348  -0.0133807  -0.07769755]
  [-0.5294353  -0.2660507  -0.03474002  0.11668624 -0.08295359
   -0.04915077 -0.0089692  -0.16220717]
  [-0.5363357  -0.09463511 -0.06448951  0.31671545 -0.04969817
   -0.18180534 -0.0017607  -0.03310262]
  [-0.4572