## Construct a single layer RNN

In [2]:
import torch
import torch.nn as nn

single_rnn = nn.RNN(4, 3, 1, batch_first=True) #input_size * hidden_size * num_layers
data = torch.randn(1, 2, 4) # batch_size * sequence_length * input_features
output, h_n = single_rnn(data)
print(output, output.shape) # bs * sl * (D*out_features) if bidirectional is True, D=2, otherwise D=1
print(h_n, h_n.shape) # bs * (D*num_layers) * h_out, the final hidden state for each element in the batch

tensor([[[ 0.2249,  0.1533, -0.0034],
         [ 0.3328, -0.1285, -0.3007]]], grad_fn=<TransposeBackward1>) torch.Size([1, 2, 3])
tensor([[[ 0.3328, -0.1285, -0.3007]]], grad_fn=<StackBackward0>) torch.Size([1, 1, 3])


## Construct a bidirectional and single layer RNN

In [4]:
bidirectional_rnn = nn.RNN(4, 3, 1, batch_first=True, bidirectional=True) #input_size * hidden_size * num_layers
bi_output, bi_h_n = bidirectional_rnn(data)
print(bi_output, bi_output.shape)
print(bi_h_n, bi_h_n.shape)

tensor([[[-0.0799, -0.2464, -0.3475, -0.4973, -0.0137,  0.0037],
         [-0.4458, -0.0231, -0.1119, -0.3422,  0.0224, -0.0042]]],
       grad_fn=<TransposeBackward1>) torch.Size([1, 2, 6])
tensor([[[-0.4458, -0.0231, -0.1119]],

        [[-0.4973, -0.0137,  0.0037]]], grad_fn=<StackBackward0>) torch.Size([2, 1, 3])


## Verify PyTorch RNN API by writing code according to math 
Reference: https://pytorch.org/docs/stable/generated/torch.nn.RNN.html

Applies a multi-layer Elman RNN with $\tanh$ or $\text{ReLU}$ non-linearity to an input sequence. For each element in the input sequence, each layer computes the following function:
$$
h_t = \tanh(W_{ih}^Tx_t+b_{ih}+W_{hh}^Th_{t-1}+b_{hh})
$$
where $h_t$ is the hidden state at time $t$, $x_t$ is the input at time $t$, and $h_{(t-1)}$ is the hidden state of the previous layer at time $t-1$ or the initial hidden state at time $0$.

`batch_first` – If `True`, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). **Note that this does not apply to hidden or cell states.** See the Inputs/Outputs sections below for details. Default: `False`

h_0: tensor of shape ($D$ * $\text{num_layers}$, $H_{out}$) for unbatched input or ($D * \text{num_layers}, N, H_{out}$) containing the initial hidden state for the input sequence batch. Defaults to zeros if not provided.


In [22]:
batch_size = 2
T = 4 # sequence length
input_size, hidden_size = 2, 3
num_layers = 1
D = 1
data = torch.randn(batch_size, T, input_size) # bs * sl * num_in_feat
h_prev = torch.zeros(D*num_layers, batch_size, hidden_size) # the initial hidden features, (D*num_layers) * bs * hidden_size

rnn = nn.RNN(input_size, hidden_size, batch_first=True)
rnn_output, final_state = rnn(data, h_prev)
print(rnn_output, rnn_output.shape) # bs * sl * h_out
print(final_state, final_state.shape) # (D*num_layers) * N * h_out

tensor([[[-0.7509,  0.0179,  0.3901],
         [-0.3166, -0.2233, -0.1727],
         [-0.2619, -0.0492, -0.1413],
         [ 0.1561,  0.3625, -0.1462]],

        [[-0.7074,  0.4189,  0.6072],
         [-0.5832, -0.2940,  0.1098],
         [-0.7248,  0.0336,  0.4950],
         [-0.4684, -0.1603, -0.0052]]], grad_fn=<TransposeBackward1>) torch.Size([2, 4, 3])
tensor([[[ 0.1561,  0.3625, -0.1462],
         [-0.4684, -0.1603, -0.0052]]], grad_fn=<StackBackward0>) torch.Size([1, 2, 3])


### Take a look at the weights of a RNN layer

In [18]:
for n,p in rnn.named_parameters():
    print(n,p.shape)

weight_ih_l0 torch.Size([3, 2])
weight_hh_l0 torch.Size([3, 3])
bias_ih_l0 torch.Size([3])
bias_hh_l0 torch.Size([3])


In [194]:
def rnn_forward(data, weight_ih, weight_hh, bias_ih, bias_hh, h_prev):
    bs, T, input_size = data.shape
    h_dim = weight_ih.shape[0]
#     h_out = torch.empty(bs, T, h_dim) # initialize an output matrix
    h_out = torch.empty(0)
    h_prev = h_prev.permute(1,-1,0) # output: bs * h_size * (D*num_layers)
    for t in range(T):
        x = data[:,t,:].unsqueeze(2) # bs * num_feat * 1
        batch_w_ih = weight_ih.tile(bs, 1, 1) # bs * h_size * in_size
        batch_w_hh = weight_hh.tile(bs, 1, 1) # bs * h_size * h_size
        part1 = torch.bmm(batch_w_ih, x) + bias_ih.unsqueeze(-1) # the size of the first term: bs*h_size*1, second term: h_size*1
        part2 = torch.bmm(batch_w_hh, h_prev) + bias_hh.unsqueeze(-1)# the size of the first term: bs*h_size*1, second term: h_size*1
        h_prev = torch.tanh(part1 + part2) # bs * h_size * 1
        
        h_out = torch.cat([h_out,h_prev],dim=2) # concat h_out and h_prev on the dimension of time
        
    return h_out.permute(0,2,1), h_prev.permute(2,0,1) # permute to align with the shapes of the PyTorch APIs' output 

In [195]:
custom_rnn_output, custom_final_state = rnn_forward(data, rnn.weight_ih_l0, rnn.weight_hh_l0, rnn.bias_ih_l0, rnn.bias_hh_l0, h_prev)
print(custom_rnn_output, custon_rnn_output.shape)
print(custom_final_state, custom_final_state.shape)

tensor([[[-0.7509,  0.0179,  0.3901],
         [-0.3166, -0.2233, -0.1727],
         [-0.2619, -0.0492, -0.1413],
         [ 0.1561,  0.3625, -0.1462]],

        [[-0.7074,  0.4189,  0.6072],
         [-0.5832, -0.2940,  0.1098],
         [-0.7248,  0.0336,  0.4950],
         [-0.4684, -0.1603, -0.0052]]], grad_fn=<PermuteBackward0>) torch.Size([2, 4, 3])
tensor([[[ 0.1561,  0.3625, -0.1462],
         [-0.4684, -0.1603, -0.0052]]], grad_fn=<PermuteBackward0>) torch.Size([1, 2, 3])


In [196]:
torch.allclose(custom_rnn_output,rnn_output), torch.allclose(custom_final_state,final_state)

(True, True)

In [210]:
def bidirectional_rnn_forward(data, weight_ih, weight_hh, bias_ih, bias_hh, h_prev,\
                             weight_ih_reverse, weight_hh_reverse, bias_ih_reverse, bias_hh_reverse, h_prev_reverse):
    bs, T, input_size = data.shape
    h_dim = weight_ih.shape[0]
    forward_output, f_final_state = rnn_forward(data, weight_ih, weight_hh, bias_ih, bias_hh, h_prev)
    backward_output, b_final_state = rnn_forward(torch.flip(data,dims=(1,)), weight_ih_reverse, weight_hh_reverse, \
                                  bias_ih_reverse, bias_hh_reverse, h_prev_reverse) # reverse the data on time dimension
    
    output = torch.concat([forward_output,backward_output.flip(dims=(1,))], dim=-1) # reverse again
    final_output = torch.concat([f_final_state, b_final_state], dim=0)
    return output, final_output    

In [215]:
bi_rnn = nn.RNN(input_size, hidden_size, batch_first=True, bidirectional=True)
bi_h_prev = torch.randn(2,batch_size,hidden_size)
bi_rnn_pytorch_api_output, bi_rnn_pytorch_api_final_state = bi_rnn(data,bi_h_prev)
print(bi_rnn_pytorch_api_output, bi_rnn_pytorch_api_output.shape)
print(bi_rnn_pytorch_api_final_state, bi_rnn_pytorch_api_final_state.shape)

tensor([[[ 0.1752,  0.0359, -0.9121, -0.7831, -0.9604, -0.4477],
         [-0.0909,  0.6418, -0.2060,  0.3008, -0.7607, -0.0813],
         [ 0.0593,  0.5991, -0.3838, -0.2978, -0.6920, -0.6060],
         [ 0.1246,  0.7805, -0.6746, -0.8776, -0.9304,  0.3111]],

        [[-0.2166, -0.5227, -0.6141, -0.9348, -0.9833, -0.7267],
         [-0.0757,  0.2857, -0.2566, -0.0837, -0.7599, -0.3712],
         [-0.5078, -0.1643, -0.4274, -0.7758, -0.7244, -0.4726],
         [ 0.2099,  0.5098, -0.3588, -0.6636, -0.2914,  0.6622]]],
       grad_fn=<TransposeBackward1>) torch.Size([2, 4, 6])
tensor([[[ 0.1246,  0.7805, -0.6746],
         [ 0.2099,  0.5098, -0.3588]],

        [[-0.7831, -0.9604, -0.4477],
         [-0.9348, -0.9833, -0.7267]]], grad_fn=<StackBackward0>) torch.Size([2, 2, 3])


In [216]:
for n,p in bi_rnn.named_parameters():
    print(n,p.shape)

weight_ih_l0 torch.Size([3, 2])
weight_hh_l0 torch.Size([3, 3])
bias_ih_l0 torch.Size([3])
bias_hh_l0 torch.Size([3])
weight_ih_l0_reverse torch.Size([3, 2])
weight_hh_l0_reverse torch.Size([3, 3])
bias_ih_l0_reverse torch.Size([3])
bias_hh_l0_reverse torch.Size([3])


In [217]:
custom_bi_rnn_output, custom_bi_rnn_final_state = bidirectional_rnn_forward(data,\
                                                                            bi_rnn.weight_ih_l0, \
                                                                            bi_rnn.weight_hh_l0, \
                                                                            bi_rnn.bias_ih_l0, \
                                                                            bi_rnn.bias_hh_l0, \
                                                                            bi_h_prev[0].unsqueeze(0),\
                                                                            bi_rnn.weight_ih_l0_reverse,\
                                                                            bi_rnn.weight_hh_l0_reverse,\
                                                                            bi_rnn.bias_ih_l0_reverse,\
                                                                            bi_rnn.bias_hh_l0_reverse,\
                                                                            bi_h_prev[1].unsqueeze(0))

torch.Size([2, 4, 3])


In [218]:
print(custom_bi_rnn_output,custom_bi_rnn_output.shape)
print(custom_bi_rnn_final_state,custom_bi_rnn_final_state.shape)

tensor([[[ 0.1752,  0.0359, -0.9121, -0.7831, -0.9604, -0.4477],
         [-0.0909,  0.6418, -0.2060,  0.3008, -0.7607, -0.0813],
         [ 0.0593,  0.5991, -0.3838, -0.2978, -0.6920, -0.6060],
         [ 0.1246,  0.7805, -0.6746, -0.8776, -0.9304,  0.3111]],

        [[-0.2166, -0.5227, -0.6141, -0.9348, -0.9833, -0.7267],
         [-0.0757,  0.2857, -0.2566, -0.0837, -0.7599, -0.3712],
         [-0.5078, -0.1643, -0.4274, -0.7758, -0.7244, -0.4726],
         [ 0.2099,  0.5098, -0.3588, -0.6636, -0.2914,  0.6622]]],
       grad_fn=<CatBackward0>) torch.Size([2, 4, 6])
tensor([[[ 0.1246,  0.7805, -0.6746],
         [ 0.2099,  0.5098, -0.3588]],

        [[-0.7831, -0.9604, -0.4477],
         [-0.9348, -0.9833, -0.7267]]], grad_fn=<CatBackward0>) torch.Size([2, 2, 3])


In [220]:
torch.allclose(custom_bi_rnn_output, bi_rnn_pytorch_api_output), torch.allclose(custom_bi_rnn_final_state, bi_rnn_pytorch_api_final_state)

(True, True)