# Pack padded / pad packed
Source: https://gist.github.com/HarshTrivedi/f4e7293e941b17d19058f6fb90ab0fec

In [1]:
import torch
from torch.nn import LSTM, Embedding
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

We want to run LSTM on a batch of 3 character sequences `['long_str', 'tiny', 'medium']`

 *  Step 1: Construct Vocabulary
 *  Step 2: Load indexed data (list of instances, where each instance is list of character indices)
 *  Step 3: Make Model
 *  Step 4: Pad instances with 0s till max length sequence
 *  Step 5: Sort instances by sequence length in descending order
 *  Step 6: Embed the instances
 *  Step 7: Call pack_padded_sequence with embeded instances and sequence lengths
 *  Step 8: Forward with LSTM
 *  Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector
 *  Summary of Shape Transformations

In [4]:
seqs = ['long_str',  # len = 8
        'tiny',      # len = 4
        'medium']    # len = 6

In [18]:
# Step 1: Construct Vocabulary
vocab = ['<pad>'] + sorted(list(set([x for word in seqs for x in word])))
vocab

['<pad>', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y']

In [21]:
# Step 2: Load indexed data (list of instances, where each instance is list of character indices)
vectorized_seqs = [[vocab.index(x) for x in word] for word in seqs]
vectorized_seqs

[[6, 9, 8, 4, 1, 11, 12, 10], [12, 5, 8, 14], [7, 3, 2, 5, 13, 7]]

In [22]:
# Step 3: Make Model
embed = Embedding(len(vocab), 4) # embedding size = 4
lstm = LSTM(input_size=4, hidden_size=5, batch_first=True)

In [25]:
# Step 4: Pad instances with 0s till max length sequence
seq_lengths = torch.LongTensor([len(x) for x in vectorized_seqs])

# seq_lengths => [ 8, 4,  6]
# batch_sum_seq_len: 8 + 4 + 6 = 18
# max_seq_len: 8

In [31]:
seq_tensor = torch.zeros(len(vectorized_seqs), max(seq_lengths)).long()
seq_tensor

tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]])

In [34]:
for i, x in enumerate(vectorized_seqs):
    seq_tensor[i, :len(x)] = torch.LongTensor(x)

In [35]:
seq_tensor

tensor([[ 6,  9,  8,  4,  1, 11, 12, 10],
        [12,  5,  8, 14,  0,  0,  0,  0],
        [ 7,  3,  2,  5, 13,  7,  0,  0]])

In [36]:
#  Step 5: Sort instances by sequence length in descending order

In [39]:
seq_lengths, idxs = seq_lengths.sort(0, descending=True)

In [40]:
seq_tensor = seq_tensor[idxs]
seq_tensor

tensor([[ 6,  9,  8,  4,  1, 11, 12, 10],
        [ 7,  3,  2,  5, 13,  7,  0,  0],
        [12,  5,  8, 14,  0,  0,  0,  0]])

In [43]:
# Step 6: Embed the instances
embedded_seq_tensor = embed(seq_tensor)
embedded_seq_tensor

tensor([[[-7.9409e-01, -2.8896e+00,  6.9483e-01, -1.2146e+00],
         [ 1.9987e+00,  6.2710e-02, -2.8276e-01,  1.5810e+00],
         [-2.5376e-01, -1.0364e+00, -4.3581e-01, -8.3744e-01],
         [ 9.3385e-01,  2.1982e-01,  1.0951e-01, -6.4361e-01],
         [-1.1321e+00,  1.6221e+00, -4.6939e-01,  8.1539e-01],
         [ 7.4638e-01, -6.3624e-01,  1.0528e-02, -7.6883e-02],
         [-5.5516e-01, -1.2022e+00,  1.1431e+00, -7.5028e-01],
         [-5.3181e-01,  1.9000e+00,  5.9851e-01,  6.9586e-01]],

        [[-1.4821e+00, -7.2892e-01,  3.9801e-01, -3.1882e-01],
         [ 5.2833e-01, -2.7108e+00,  6.2506e-01,  7.2514e-01],
         [ 5.2835e-01,  1.5062e+00,  9.0398e-01,  2.5230e-03],
         [-5.9331e-01, -1.1998e+00, -1.2029e+00, -8.4205e-01],
         [-3.8441e-01, -8.1292e-01, -3.6298e-01, -1.1046e+00],
         [-1.4821e+00, -7.2892e-01,  3.9801e-01, -3.1882e-01],
         [-1.6660e-01,  1.8722e+00, -2.9307e+00, -4.1734e-01],
         [-1.6660e-01,  1.8722e+00, -2.9307e+00, -4.1

In [44]:
# embedded_seq_tensor =>
#                       [[[-0.77578706 -1.8080667  -1.1168439   1.1059115 ]     l
#                         [-0.23622951  2.0361056   0.15435742 -0.04513785]     o
#                         [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     n
#                         [ 0.40524676  0.98665565 -0.08621677 -1.1728264 ]     g
#                         [-1.6334635  -0.6100042   1.7509955  -1.931793  ]     _
#                         [-0.6470658  -0.6266589  -1.7463604   1.2675372 ]     s
#                         [ 0.64004815  0.45813003  0.3476034  -0.03451729]     t
#                         [-0.22739866 -0.45782727 -0.6643252   0.25129375]]    r

#                        [[ 0.16031227 -0.08209462 -0.16297023  0.48121014]     m
#                         [-0.7303265  -0.857339    0.58913064 -1.1068314 ]     e
#                         [ 0.48159844 -1.4886451   0.92639893  0.76906884]     d
#                         [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     i
#                         [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ]     u
#                         [ 0.16031227 -0.08209462 -0.16297023  0.48121014]     m
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]]    <pad>

#                        [[ 0.64004815  0.45813003  0.3476034  -0.03451729]     t
#                         [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     i
#                         [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     n
#                         [-1.284392    0.68294704  1.4064184  -0.42879772]     y
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]]]   <pad>
# embedded_seq_tensor.shape : (batch_size X max_seq_len X embedding_dim) = (3 X 8 X 4)


In [47]:
packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)
# packed_input (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes

In [49]:
# packed_input.data =>
#                         [[-0.77578706 -1.8080667  -1.1168439   1.1059115 ]     l
#                          [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ]     m
#                          [-0.6470658  -0.6266589  -1.7463604   1.2675372 ]     t
#                          [ 0.16031227 -0.08209462 -0.16297023  0.48121014]     o
#                          [ 0.40524676  0.98665565 -0.08621677 -1.1728264 ]     e
#                          [-1.284392    0.68294704  1.4064184  -0.42879772]     i
#                          [ 0.64004815  0.45813003  0.3476034  -0.03451729]     n
#                          [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     d
#                          [ 0.64004815  0.45813003  0.3476034  -0.03451729]     n
#                          [-0.23622951  2.0361056   0.15435742 -0.04513785]     g
#                          [ 0.16031227 -0.08209462 -0.16297023  0.48121014]     i
#                          [-0.22739866 -0.45782727 -0.6643252   0.25129375]]    y
#                          [-0.7303265  -0.857339    0.58913064 -1.1068314 ]     _
#                          [-1.6334635  -0.6100042   1.7509955  -1.931793  ]     u
#                          [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     s
#                          [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     m
#                          [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     t
#                          [ 0.48159844 -1.4886451   0.92639893  0.76906884]     r

In [50]:
packed_input.data.shape

torch.Size([18, 4])

In [51]:
packed_input.batch_sizes

tensor([3, 3, 3, 3, 2, 2, 1, 1])

In [52]:
# l  o  n  g  _  s  t  r   #(long_str)
# m  e  d  i  u  m         #(medium)
# t  i  n  y               #(tiny)
# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])

In [54]:
# Step 8: Forward with LSTM
packed_output, (ht,ct) = lstm(packed_input)

In [56]:
# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes

In [58]:
# packed_output.data :
#                          [[-0.00947162  0.07743231  0.20343193  0.29611713  0.07992904]   l
#                           [ 0.08596145  0.09205993  0.20892891  0.21788561  0.00624391]   o
#                           [ 0.16861682  0.07807446  0.18812777 -0.01148055 -0.01091915]   n
#                           [ 0.20994528  0.17932937  0.17748171  0.05025435  0.15717036]   g
#                           [ 0.01364102  0.11060348  0.14704391  0.24145307  0.12879576]   _
#                           [ 0.02610307  0.00965587  0.31438383  0.246354    0.08276576]   s
#                           [ 0.09527554  0.14521319  0.1923058  -0.05925677  0.18633027]   t
#                           [ 0.09872741  0.13324396  0.19446367  0.4307988  -0.05149471]   r
#                           [ 0.03895474  0.08449443  0.18839942  0.02205326  0.23149511]   m
#                           [ 0.14620507  0.07822411  0.2849248  -0.22616537  0.15480657]   e
#                           [ 0.00884941  0.05762182  0.30557525  0.373712    0.08834908]   d
#                           [ 0.12460691  0.21189159  0.04823487  0.06384943  0.28563985]   i
#                           [ 0.01368293  0.15872964  0.03759198 -0.13403234  0.23890573]   u
#                           [ 0.00377969  0.05943518  0.2961751   0.35107893  0.15148178]   m
#                           [ 0.00737647  0.17101538  0.28344846  0.18878219  0.20339936]   t
#                           [ 0.0864429   0.11173367  0.3158251   0.37537992  0.11876849]   i
#                           [ 0.17885767  0.12713005  0.28287745  0.05562563  0.10871304]   n
#                           [ 0.09486895  0.12772645  0.34048414  0.25930756  0.12044918]]  y

In [59]:
packed_output.data.shape

torch.Size([18, 5])

In [63]:
packed_output.batch_sizes # same as packed_input.batch_sizes

tensor([3, 3, 3, 3, 2, 2, 1, 1])

In [64]:
# visualization :

# l  o  n  g  _  s  t  r   #(long_str)
# m  e  d  i  u  m         #(medium)
# t  i  n  y               #(tiny)
# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])

In [66]:
# Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)

In [67]:
# output =>
#                          [[[-0.00947162  0.07743231  0.20343193  0.29611713  0.07992904]   l
#                            [ 0.20994528  0.17932937  0.17748171  0.05025435  0.15717036]   o
#                            [ 0.09527554  0.14521319  0.1923058  -0.05925677  0.18633027]   n
#                            [ 0.14620507  0.07822411  0.2849248  -0.22616537  0.15480657]   g
#                            [ 0.01368293  0.15872964  0.03759198 -0.13403234  0.23890573]   _
#                            [ 0.00737647  0.17101538  0.28344846  0.18878219  0.20339936]   s
#                            [ 0.17885767  0.12713005  0.28287745  0.05562563  0.10871304]   t
#                            [ 0.09486895  0.12772645  0.34048414  0.25930756  0.12044918]]  r

#                           [[ 0.08596145  0.09205993  0.20892891  0.21788561  0.00624391]   m
#                            [ 0.01364102  0.11060348  0.14704391  0.24145307  0.12879576]   e
#                            [ 0.09872741  0.13324396  0.19446367  0.4307988  -0.05149471]   d
#                            [ 0.00884941  0.05762182  0.30557525  0.373712    0.08834908]   i
#                            [ 0.00377969  0.05943518  0.2961751   0.35107893  0.15148178]   u
#                            [ 0.0864429   0.11173367  0.3158251   0.37537992  0.11876849]   m
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]]  <pad>

#                           [[ 0.16861682  0.07807446  0.18812777 -0.01148055 -0.01091915]   t
#                            [ 0.02610307  0.00965587  0.31438383  0.246354    0.08276576]   i
#                            [ 0.03895474  0.08449443  0.18839942  0.02205326  0.23149511]   n
#                            [ 0.12460691  0.21189159  0.04823487  0.06384943  0.28563985]   y
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]]] <pad>

In [69]:
output.shape # ( batch_size X max_seq_len X hidden_dim) = (3 X 8 X 5)

torch.Size([3, 8, 5])

In [70]:
# Or if you just want the final hidden state?
print(ht[-1])

tensor([[ 0.0044, -0.2581,  0.1075, -0.0793,  0.2602],
        [ 0.1094,  0.0606,  0.0343, -0.0094,  0.3211],
        [ 0.0820,  0.1148, -0.0205,  0.1659,  0.3319]],
       grad_fn=<SelectBackward>)


In [72]:
## Summary of Shape Transformations ##
##----------------------------------##

# (batch_size X max_seq_len X embedding_dim) --> Sort by seqlen ---> (batch_size X max_seq_len X embedding_dim)
# (batch_size X max_seq_len X embedding_dim) --->      Pack     ---> (batch_sum_seq_len X embedding_dim)
# (batch_sum_seq_len X embedding_dim)        --->      LSTM     ---> (batch_sum_seq_len X hidden_dim)
# (batch_sum_seq_len X hidden_dim)           --->    UnPack     ---> (batch_size X max_seq_len X hidden_dim)