<a href="https://colab.research.google.com/github/hao1zhao/Model/blob/main/lstm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LSTM

$\mathbf{I}_t$:Input $\mathbf{F}_t$:forget $\mathbf{O}_t$:output $\tilde{\mathbf{C}}_t$:candidate memory cell $\mathbf{H}_t$:memory cell
$$
\begin{aligned}
\mathbf{I}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\
\mathbf{F}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\
\mathbf{O}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o),
\end{aligned}
$$
$$\tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c),$$
$$\mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t.$$





In [1]:
import torch
from torch import nn

In [10]:
#initial parameter
bs,T,i_size,h_size =2,3,4,5 #batch_size,time,input_size,hidden_size
input = torch.randn(bs,T,i_size)
c0 = torch.randn(bs,h_size)
h0 = torch.randn(bs,h_size)
#API
lstm_layer = nn.LSTM(i_size,h_size,batch_first=True)
output,(h_final,c_final) = lstm_layer(input,(h0.unsqueeze(0),c0.unsqueeze(0)))
for k,v in lstm_layer.named_parameters():
  print(k,v.shape)

weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 5])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])


In [13]:
#unidirection LSTM
def listm_forward(input,initial_states,w_ih,w_hh,b_ih,b_hh):
  h0,c0 = initial_states
  bs,T,i_size = input.shape #break down input
  h_size = w_ih.shape[0]//4

  prev_h = h0
  prev_c = c0
  #w_ih {4*h_size,i_size}
  #w_hh {4*h_size,h_size}
  batch_w_ih = w_ih.unsqueeze(0).tile(bs,1,1) #{bs,4*h_size,i_size}
  batch_w_hh = w_hh.unsqueeze(0).tile(bs,1,1) #{bs,4*h_size,h_size}

  output_size = h_size
  output = torch.zeros(bs,T,output_size)
  for t in range(T):
    x = input[:,t,:] #iterate T, {bs,i_size}
    w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1)) #{bs,4*h_size,1}
    w_times_x = w_times_x.squeeze(-1) #{bs,4*h_size}

    w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1)) #{bs,4*h_size,1}
    w_times_h_prev = w_times_h_prev.squeeze(-1) #{bs,4*h_size}

    #gates
    i_t = torch.sigmoid(w_times_x[:,:h_size] + w_times_h_prev[:,:h_size] +b_ih[:h_size] + b_hh[:h_size])
    f_t = torch.sigmoid(w_times_x[:,h_size:2*h_size] + w_times_h_prev[:,h_size:2*h_size] +b_ih[h_size:2*h_size] + b_hh[h_size:2*h_size])
    cc_t = torch.tanh(w_times_x[:,2*h_size:3*h_size] + w_times_h_prev[:,2*h_size:3*h_size] +b_ih[2*h_size:3*h_size] + b_hh[2*h_size:3*h_size])
    o_t = torch.sigmoid(w_times_x[:,3*h_size:4*h_size] + w_times_h_prev[:,3*h_size:4*h_size] +b_ih[3*h_size:4*h_size] + b_hh[3*h_size:4*h_size])
    #update c and h
    prev_c = f_t*prev_c + i_t*cc_t
    prev_h = o_t*torch.tanh(prev_c)
    
    output[:,t,:] = prev_h

  return output, (prev_h,prev_c)
#test
output_custom,(h_final_custom,c_final_custom) = listm_forward(input,(h0,c0),lstm_layer.weight_ih_l0,lstm_layer.weight_hh_l0,lstm_layer.bias_ih_l0,lstm_layer.bias_hh_l0)
print(f'output:{output}',f'output_custom:{output_custom}')



output:tensor([[[-0.1076, -0.4109,  0.0216, -0.4851,  0.3007],
         [ 0.0926, -0.3760,  0.0407, -0.2144,  0.3392],
         [ 0.2056, -0.2599, -0.0100, -0.1721,  0.2170]],

        [[-0.1218, -0.1491, -0.0662, -0.0635,  0.1405],
         [-0.0518, -0.0886,  0.0478, -0.1821,  0.0289],
         [ 0.1393, -0.0395, -0.0270, -0.1378,  0.1511]]],
       grad_fn=<TransposeBackward0>) output_custom:tensor([[[-0.1076, -0.4109,  0.0216, -0.4851,  0.3007],
         [ 0.0926, -0.3760,  0.0407, -0.2144,  0.3392],
         [ 0.2056, -0.2599, -0.0100, -0.1721,  0.2170]],

        [[-0.1218, -0.1491, -0.0662, -0.0635,  0.1405],
         [-0.0518, -0.0886,  0.0478, -0.1821,  0.0289],
         [ 0.1393, -0.0395, -0.0270, -0.1378,  0.1511]]], grad_fn=<CopySlices>)
