In [34]:
import torch
import torch.nn as nn

In [35]:
bs,T=2,3  # 批大小，输入序列长度
input_size,hidden_size = 2,3 # 输入特征大小，隐含层特征大小
input = torch.randn(bs,T,input_size)  # 随机初始化一个输入特征序列
h_prev = torch.zeros(bs,hidden_size) # 初始隐含状态

In [36]:
# step1 调用pytorch RNN API
rnn = nn.RNN(input_size,hidden_size,batch_first=True)
rnn_output,state_finall = rnn(input,h_prev.unsqueeze(0))

print(rnn_output)
print(state_finall)

tensor([[[-0.7709,  0.7301, -0.9299],
         [-0.6976, -0.8241, -0.1903],
         [-0.6485, -0.2633, -0.1093]],

        [[-0.2035,  0.7439, -0.1369],
         [-0.4805, -0.5790,  0.1787],
         [-0.6185,  0.4854, -0.4907]]], grad_fn=<TransposeBackward1>)
tensor([[[-0.6485, -0.2633, -0.1093],
         [-0.6185,  0.4854, -0.4907]]], grad_fn=<StackBackward0>)


In [37]:
# step2 手写 rnn_forward函数，实现RNN的计算原理
def rnn_forward(input,weight_ih,weight_hh,bias_ih,bias_hh,h_prev):
    bs,T,input_size = input.shape
    h_dim = weight_ih.shape[0]
    h_out = torch.zeros(bs,T,h_dim) # 初始化一个输出（状态）矩阵

    for t in range(T):
        x = input[:,t,:].unsqueeze(2)  # 获取当前时刻的输入特征，bs*input_size*1
        w_ih_batch = weight_ih.unsqueeze(0).tile(bs,1,1) # bs * h_dim * input_size
        w_hh_batch = weight_hh.unsqueeze(0).tile(bs,1,1)# bs * h_dim * h_dim

        w_times_x = torch.bmm(w_ih_batch,x).squeeze(-1) # bs*h_dim
        w_times_h = torch.bmm(w_hh_batch,h_prev.unsqueeze(2)).squeeze(-1) # bs*h_him
        h_prev = torch.tanh(w_times_x + bias_ih + w_times_h + bias_hh)

        h_out[:,t,:] = h_prev

    return h_out,h_prev.unsqueeze(0)

In [38]:
# 验证结果
custom_rnn_output,custom_state_finall = rnn_forward(input,
                                                    rnn.weight_ih_l0,
                                                    rnn.weight_hh_l0,
                                                    rnn.bias_ih_l0,
                                                    rnn.bias_hh_l0,
                                                    h_prev)
print(custom_rnn_output)
print(custom_state_finall)

tensor([[[-0.7709,  0.7301, -0.9299],
         [-0.6976, -0.8241, -0.1903],
         [-0.6485, -0.2633, -0.1093]],

        [[-0.2035,  0.7439, -0.1369],
         [-0.4805, -0.5790,  0.1787],
         [-0.6185,  0.4854, -0.4907]]], grad_fn=<CopySlices>)
tensor([[[-0.6485, -0.2633, -0.1093],
         [-0.6185,  0.4854, -0.4907]]], grad_fn=<UnsqueezeBackward0>)


In [39]:
print(torch.allclose(rnn_output,custom_rnn_output))
print(torch.allclose(state_finall,custom_state_finall))

True
True


In [52]:
# step3 手写一个 bidirectional_rnn_forward函数，实现双向RNN的计算原理
def bidirectional_rnn_forward(input,weight_ih,weight_hh,bias_ih,bias_hh,h_prev,
                              weight_ih_reverse,weight_hh_reverse,bias_ih_reverse,
                              bias_hh_reverse,h_prev_reverse):
    bs,T,input_size = input.shape
    h_dim = weight_ih.shape[0]
    h_out = torch.zeros(bs,T,h_dim*2) # 初始化一个输出（状态）矩阵，注意双向是两倍的特征大小

    forward_output = rnn_forward(input,weight_ih,weight_hh,bias_ih,bias_hh,h_prev)[0]  # forward layer
    backward_output = rnn_forward(torch.flip(input,[1]),weight_ih_reverse,weight_hh_reverse,bias_ih_reverse, bias_hh_reverse,h_prev_reverse)[0] # backward layer

    # 将input按照时间的顺序翻转
    h_out[:,:,:h_dim] = forward_output
    h_out[:,:,h_dim:] = torch.flip(backward_output,[1]) #需要再翻转一下 才能和forward output拼接

    
    h_n = torch.zeros(bs,2,h_dim)  # 要最后的状态连接

    h_n[:,0,:] = forward_output[:,-1,:]
    h_n[:,1,:] = backward_output[:,-1,:]

    h_n = h_n.transpose(0,1)

    return h_out,h_n
    # return h_out,h_out[:,-1,:].reshape((bs,2,h_dim)).transpose(0,1)

# 验证一下 bidirectional_rnn_forward的正确性
bi_rnn = nn.RNN(input_size,hidden_size,batch_first=True,bidirectional=True)
h_prev = torch.zeros((2,bs,hidden_size))
bi_rnn_output,bi_state_finall = bi_rnn(input,h_prev)

for k,v in bi_rnn.named_parameters():
    print(k,v)

weight_ih_l0 Parameter containing:
tensor([[ 0.5458,  0.5512],
        [-0.5077, -0.0750],
        [ 0.3572,  0.1419]], requires_grad=True)
weight_hh_l0 Parameter containing:
tensor([[-0.4093,  0.2012,  0.0746],
        [-0.5619, -0.3820, -0.4060],
        [-0.4412,  0.2706, -0.2816]], requires_grad=True)
bias_ih_l0 Parameter containing:
tensor([-0.5063, -0.1391, -0.0587], requires_grad=True)
bias_hh_l0 Parameter containing:
tensor([ 0.0343, -0.2352,  0.3234], requires_grad=True)
weight_ih_l0_reverse Parameter containing:
tensor([[ 0.1298,  0.5538],
        [ 0.4151,  0.2533],
        [-0.4401,  0.5322]], requires_grad=True)
weight_hh_l0_reverse Parameter containing:
tensor([[-0.4232,  0.2246,  0.4265],
        [ 0.3016, -0.4142, -0.3064],
        [-0.1960,  0.2845,  0.3770]], requires_grad=True)
bias_ih_l0_reverse Parameter containing:
tensor([-0.4372, -0.2452,  0.4506], requires_grad=True)
bias_hh_l0_reverse Parameter containing:
tensor([ 0.3957, -0.4655, -0.2143], requires_grad=True

In [53]:
custom_bi_rnn_output,custom_bi_state_finall = bidirectional_rnn_forward(input,
                                                                        bi_rnn.weight_ih_l0,
                                                                        bi_rnn.weight_hh_l0,
                                                                        bi_rnn.bias_ih_l0,
                                                                        bi_rnn.bias_hh_l0,
                                                                        h_prev[0],
                                                                        bi_rnn.weight_ih_l0_reverse,
                                                                        bi_rnn.weight_hh_l0_reverse,
                                                                        bi_rnn.bias_ih_l0_reverse,
                                                                        bi_rnn.bias_hh_l0_reverse,
                                                                        h_prev[1])

In [54]:
print("Pytorch API output")
print(bi_rnn_output)
print(bi_state_finall)

print("\n custom bidirectional_rnn_forward function output:")
print(custom_bi_rnn_output)
print(custom_bi_state_finall)
print(torch.allclose(bi_rnn_output,custom_bi_rnn_output))
print(torch.allclose(bi_state_finall,custom_bi_state_finall))

Pytorch API output
tensor([[[-0.8470,  0.5436, -0.3571,  0.0393, -0.8730,  0.9124],
         [ 0.5724, -0.0194,  0.7805,  0.5884, -0.2443,  0.7351],
         [-0.4845, -0.7670, -0.1836,  0.0907, -0.5768,  0.3587]],

        [[-0.9186, -0.4089,  0.0847, -0.9221, -0.7344, -0.9120],
         [ 0.3084, -0.3562,  0.7382, -0.0584, -0.1021, -0.5778],
         [-0.8156, -0.4316, -0.3803, -0.3811, -0.7703,  0.1212]]],
       grad_fn=<TransposeBackward1>)
tensor([[[-0.4845, -0.7670, -0.1836],
         [-0.8156, -0.4316, -0.3803]],

        [[ 0.0393, -0.8730,  0.9124],
         [-0.9221, -0.7344, -0.9120]]], grad_fn=<StackBackward0>)

 custom bidirectional_rnn_forward function output:
tensor([[[-0.8470,  0.5436, -0.3571,  0.0393, -0.8730,  0.9124],
         [ 0.5724, -0.0194,  0.7805,  0.5884, -0.2443,  0.7351],
         [-0.4845, -0.7670, -0.1836,  0.0907, -0.5768,  0.3587]],

        [[-0.9186, -0.4089,  0.0847, -0.9221, -0.7344, -0.9120],
         [ 0.3084, -0.3562,  0.7382, -0.0584, -0.1021,