## Implement LSTM to verify PyTorch  APIs of LSTM, LSTM projection, and GRU
Reference: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html

`CLASStorch.nn.LSTM(*args, **kwargs)`

Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence. For each element in the input sequence, each layer computes the following function:
\begin{align*}
i_t&=\sigma(W_{ii}x_t+b_{ii}+W_{hi}h_{t-1}+b_{hi})\\
f_t&=\sigma(W_{if}x_t+b_{if}+W_{hf}h_{t-1}+b_{hf})\\
g_t&=\tanh(W_{ig}x_t+b_{ig}+W_{hg}h_{t-1}+b_{hg})\\
o_t&=\sigma(W_{io}x_t+b_{io}+W_{ho}h_{t-1}+b_{ho})\\
c_t&=f_t\odot c_{t-1}+i_t\odot g_t\\
h_t&=o_t\odot\tanh(c_t)
\end{align*}
where $h_t$ is the hidden state at time $t$, $c_t$ is the cell state at time $t$, $x_t$ is the input at time $t$, $h_{t-1}$ is the hidden state of the layer at time $t-1$ or the initial hidden state at time $0$, and $i_t$, $f_t$, $g_t$, $o_t$ are the input, forget, cell, and output gates, respectively. $\sigma$ is the sigmoid function, and $\odot$ is the Hadamard product.

In [72]:
import torch
import torch.nn as nn

bs, T, i_size, h_size = 2, 3, 4, 5
D = 1
num_layers = 1
data = torch.randn(bs, T, i_size)
c0 = torch.randn(D*num_layers, bs, h_size) # initial states, see API doc for the dimension requirement
h0 = torch.randn(D*num_layers, bs, h_size)

# call PyTorch APIs
lstm_layer = nn.LSTM(i_size, h_size, batch_first=True) # instantiate an LSTM class
output, (h_final, c_final) = lstm_layer(data, (h0,c0))
print(output.shape,h_final.shape,c_final.shape)

torch.Size([2, 3, 5]) torch.Size([1, 2, 5]) torch.Size([1, 2, 5])


In [73]:
for n,p in lstm_layer.named_parameters():
    print(n,p.shape)

weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 5])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])


In [74]:
# write a LSTM function
def lstm_forward(data, initial_states, w_ih, w_hh, b_ih, b_hh):
    h0, c0 = initial_states
    bs, T, i_size = data.shape
    h_size = w_ih.shape[0]//4 # there are 4 groups of parameters
    prev_h = h0
    prev_c = c0
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1) # [bs,(4*h_size),i_size]
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1) # [bs,(4*h_size),i_size]
    
    output_size = h_size
    output_h = torch.zeros(bs,T,output_size)
    for t in range(T):
        x = data[:, t, :] # get the data of the current time step
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1)) # unsqueeze x because batch 2 must be 3D tensor
        w_times_prev_h = torch.bmm(batch_w_hh, prev_h.reshape(bs,h_size).unsqueeze(-1))# reshape into [bs,h_size] and then unsqueeze  because batch 2 must be 3D tensor
        all_before_activating = w_times_x.squeeze() + b_ih + w_times_prev_h.squeeze() + b_hh

        # calculate the outputs of input, forget, cell, and output gates
        i_t = torch.sigmoid(all_before_activating[:,:h_size])
        f_t = torch.sigmoid(all_before_activating[:,h_size:2*h_size])
        g_t = torch.tanh(all_before_activating[:,2*h_size:3*h_size])
        o_t = torch.sigmoid(all_before_activating[:,3*h_size:])        
        prev_c = f_t*prev_c + i_t*g_t
        prev_h = o_t*torch.tanh(prev_c)
        output_h[:,t,:] = prev_h
        
    return output_h, (prev_h,prev_c)        

custom_lstm_output, (custom_h_final,custom_c_final) = lstm_forward(data,(h0,c0),\
                                                     lstm_layer.weight_ih_l0,\
                                                     lstm_layer.weight_hh_l0,\
                                                     lstm_layer.bias_ih_l0,\
                                                     lstm_layer.bias_hh_l0)
# verify if our outputs are consistent with PyTorch APIs' outputs
print(torch.allclose(custom_lstm_output, output),\
      torch.allclose(custom_h_final,h_final),\
      torch.allclose(custom_c_final,c_final))

True True True


### LSTM projection
This is used to compress `h`. We need to initialize `h` in `proj_size`.

In [75]:
proj_size = 3
proj_h_0 = torch.randn(D*num_layers, bs, proj_size) # Note that only h_0 is in proj_size. c_0 is still in h_size.

lstm_proj_layer = nn.LSTM(i_size, h_size, batch_first=True, proj_size=proj_size) # instantiate an LSTM class
lstm_proj_output, (lstm_proj_final_h,lstm_proj_final_c) = lstm_proj_layer(data, (proj_h_0,c0))
print(lstm_proj_output.shape,lstm_proj_final_h.shape,lstm_proj_final_c.shape) # c_n is still in h_size.

torch.Size([2, 3, 3]) torch.Size([1, 2, 3]) torch.Size([1, 2, 5])


In [76]:
for n,p in lstm_proj_layer.named_parameters():
    print(n,p.shape) # Besides the regular params, we will get weight_hr_l0.

weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 3])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])
weight_hr_l0 torch.Size([3, 5])


In [77]:
# write a LSTM function with an interface for projecting h to a lower dimension
def lstm_forward(data, initial_states, w_ih, w_hh, b_ih, b_hh, w_hr=None):
    h0, c0 = initial_states
    bs, T, i_size = data.shape
    h_size = w_ih.shape[0]//4 # there are 4 groups of parameters
    prev_h = h0
    prev_c = c0
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1) # [bs,(4*h_size),i_size]
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1) # [bs,(4*h_size),i_size]
    
    if w_hr is not None:
        p_size = w_hr.shape[0]
        output_size = p_size
        batch_w_hr = w_hr.unsqueeze(0).tile(bs, 1, 1) # [bs,(4*h_size),i_size]
    else:
        output_size = h_size
    
    output_h = torch.zeros(bs,T,output_size)
    for t in range(T):
        x = data[:, t, :] # get the data of the current time step
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1)) # unsqueeze x because batch 2 must be 3D tensor
        w_times_prev_h = torch.bmm(batch_w_hh, prev_h.reshape(bs,p_size).unsqueeze(-1))# reshape into [bs,h_size] and then unsqueeze  because batch 2 must be 3D tensor
        all_before_activating = w_times_x.squeeze() + b_ih + w_times_prev_h.squeeze() + b_hh

        # calculate the outputs of input, forget, cell, and output gates
        i_t = torch.sigmoid(all_before_activating[:,:h_size])
        f_t = torch.sigmoid(all_before_activating[:,h_size:2*h_size])
        g_t = torch.tanh(all_before_activating[:,2*h_size:3*h_size])
        o_t = torch.sigmoid(all_before_activating[:,3*h_size:])        
        prev_c = f_t*prev_c + i_t*g_t
        prev_h = o_t*torch.tanh(prev_c) 
        
        if w_hr is not None:
            prev_h = torch.bmm(batch_w_hr, prev_h.squeeze().unsqueeze(-1))
            prev_h = prev_h.squeeze()
            
        output_h[:,t,:] = prev_h # output_h[:,t,:] is 2D
        
        
    return output_h, (prev_h.reshape(1,bs,output_size),prev_c)        

custom_lstm_p_output, (custom_lstm_p_h_final,custom_lstm_p_c_final) = lstm_forward(data,(proj_h_0,c0),\
                                                     lstm_proj_layer.weight_ih_l0,\
                                                     lstm_proj_layer.weight_hh_l0,\
                                                     lstm_proj_layer.bias_ih_l0,\
                                                     lstm_proj_layer.bias_hh_l0,\
                                                     lstm_proj_layer.weight_hr_l0)
print(custom_lstm_p_output.shape)
print(custom_lstm_p_h_final.shape)
print(custom_lstm_p_c_final.shape)

torch.Size([2, 3, 3])
torch.Size([1, 2, 3])
torch.Size([1, 2, 5])


In [78]:
torch.allclose(custom_lstm_p_output,lstm_proj_output), torch.allclose(custom_lstm_p_h_final,lstm_proj_final_h),\
torch.allclose(custom_lstm_p_c_final,lstm_proj_final_c)

(True, True, True)

## GRU API
Reference: https://pytorch.org/docs/stable/generated/torch.nn.GRU.html
        
`CLASS torch.nn.GRU(*args, **kwargs)`

Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. For each element in the input sequence, each layer computes the following function:
\begin{align*}
r_t&=\sigma(W_{ir}x_t+b_{ir}+W_{hr}h_{t-1}+b_{hr})\\
z_t&=\sigma(W_{iz}x_t+b_{iz}+W_{hz}h_{t-1}+b_{hz})\\
n_t&=\tanh(W_{in}x_t+b_{in}+W_{hn}h_{t-1}+b_{hn})\\
h_t&=(1-z_t)*n_t+z_t*h_{t-1}
\end{align*}
where $h_t$ is the hidden state at time $t$, $x_t$ is the input at time $t$, $h_{t-1}$ is the hidden state of the layer at time $t-1$ or the initial hidden state at time $0$, and $r_t$, $z_t$, $n_t$, $o_t$ are the reset, update, and new gates, respectively. $\sigma$ is the sigmoid function, and $\odot$ is the Hadamard product.

According to the formulae, it is clear that the number of parameters of GRU is 3/4 of the number of parameters of LSTM.

In [79]:
bs, T, i_size, h_size = 2, 3, 4, 5
D = 1
num_layers = 1
data = torch.randn(bs, T, i_size)
h0 = torch.randn(D*num_layers, bs, h_size) # initial states, see API doc for the dimension requirement

# call PyTorch APIs
gru_layer = nn.GRU(i_size, h_size, batch_first=True) # instantiate an LSTM class
gru_output, gru_h_final = gru_layer(data, h0)
print(gru_output.shape,gru_h_final.shape)

torch.Size([2, 3, 5]) torch.Size([1, 2, 5])


In [80]:
for n,p in gru_layer.named_parameters():
    print(n,p.shape)

weight_ih_l0 torch.Size([15, 4])
weight_hh_l0 torch.Size([15, 5])
bias_ih_l0 torch.Size([15])
bias_hh_l0 torch.Size([15])


In [81]:
# write a GRU forward function
def gru_forward(data, initial_states, w_ih, w_hh, b_ih, b_hh):
    h0 = initial_states
    bs, T, i_size = data.shape
    h_size = w_ih.shape[0]//3 # there are 3 groups of parameters
    prev_h = h0
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1) # [bs,(3*h_size),i_size]
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1) # [bs,(3*h_size),i_size]
    
    output_size = h_size
    output_h = torch.zeros(bs,T,output_size)
    for t in range(T):
        x = data[:, t, :] # get the data of the current time step
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1)) # unsqueeze x because batch 2 must be 3D tensor
        w_times_prev_h = torch.bmm(batch_w_hh, prev_h.reshape(bs,h_size).unsqueeze(-1))# reshape into [bs,h_size] and then unsqueeze  because batch 2 must be 3D tensor
        r_z_before_activating = w_times_x.squeeze()[:,:2*h_size] + b_ih[:2*h_size] + \
                                w_times_prev_h.squeeze()[:,:2*h_size] + b_hh[:2*h_size]

        # calculate the outputs of reset, update, and new gates
        r_t = torch.sigmoid(r_z_before_activating[:,:h_size])
        z_t = torch.sigmoid(r_z_before_activating[:,h_size:2*h_size])
        n_t = torch.tanh(w_times_x.squeeze()[:,2*h_size:] + b_ih[2*h_size:]+\
                         r_t*(w_times_prev_h.squeeze()[:,2*h_size:] + b_hh[2*h_size:]))       
        prev_h = (1-z_t)*n_t + z_t*prev_h # h_t
        output_h[:,t,:] = prev_h
        
    return output_h, prev_h      

custom_gru_output, custom_gru_h_final = gru_forward(data,h0,\
                                                     gru_layer.weight_ih_l0,\
                                                     gru_layer.weight_hh_l0,\
                                                     gru_layer.bias_ih_l0,\
                                                     gru_layer.bias_hh_l0)
# verify if our outputs are consistent with PyTorch APIs' outputs
print(torch.allclose(custom_gru_output, gru_output),\
      torch.allclose(custom_gru_h_final, gru_h_final))

True True
