# Minimal tutorial on packing and unpacking sequences in PyTorch, aka how to use `pack_padded_sequence` and  `pad_packed_sequence`

This is a jupyter version of [@Tushar-N 's gist](https://gist.github.com/Tushar-N/dfca335e370a2bc3bc79876e6270099e) with comments from [@Harsh Trivedi repo](https://github.com/HarshTrivedi/packing-unpacking-pytorch-minimal-tutorial)


In [1]:
# from https://github.com/HarshTrivedi/packing-unpacking-pytorch-minimal-tutorial
import torch
from torch import LongTensor
from torch.nn import Embedding, LSTM
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

## We want to run LSTM on a batch of 3 character sequences ['long_str', 'tiny', 'medium']
#
#     Step 1: Construct Vocabulary
#     Step 2: Load indexed data (list of instances, where each instance is list of character indices)
#     Step 3: Make Model
#  *  Step 4: Pad instances with 0s till max length sequence
#  *  Step 5: Sort instances by sequence length in descending order
#  *  Step 6: Embed the instances
#  *  Step 7: Call pack_padded_sequence with embeded instances and sequence lengths
#  *  Step 8: Forward with LSTM
#  *  Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector
#  *  Summary of Shape Transformations

In [2]:
# We want to run LSTM on a batch following 3 character sequences
seqs = ['long_str',  # len = 8
        'tiny',      # len = 4
        'medium']    # len = 6

In [3]:
## Step 1: Construct Vocabulary ##
##------------------------------##
# make sure <pad> idx is 0
vocab = ['<pad>'] + sorted(set([char for seq in seqs for char in seq]))
# => ['<pad>', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y']

In [4]:
vocab

['<pad>', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y']

In [5]:
## Step 2: Load indexed data (list of instances, where each instance is list of character indices) ##
##-------------------------------------------------------------------------------------------------##
vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs]
# vectorized_seqs => [[6, 9, 8, 4, 1, 11, 12, 10],
#                     [12, 5, 8, 14],
#                     [7, 3, 2, 5, 13, 7]]

In [6]:
vectorized_seqs

[[6, 9, 8, 4, 1, 11, 12, 10], [12, 5, 8, 14], [7, 3, 2, 5, 13, 7]]

In [7]:
## Step 3: Make Model ##
##--------------------##
embed = Embedding(len(vocab), 4) # embedding_dim = 4
lstm = LSTM(input_size=4, hidden_size=5, batch_first=True) # input_dim = 4, hidden_dim = 5

In [9]:
## Step 4: Pad instances with 0s till max length sequence ##
##--------------------------------------------------------##

# get the length of each seq in your batch
seq_lengths = LongTensor(list(map(len, vectorized_seqs)))
# seq_lengths => [ 8, 4,  6]
# batch_sum_seq_len: 8 + 4 + 6 = 18
# max_seq_len: 8

seq_tensor = (torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long()
# seq_tensor => [[0 0 0 0 0 0 0 0]
#                [0 0 0 0 0 0 0 0]
#                [0 0 0 0 0 0 0 0]]

for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
    seq_tensor[idx, :seqlen] = LongTensor(seq)
# seq_tensor => [[ 6  9  8  4  1 11 12 10]          # long_str
#                [12  5  8 14  0  0  0  0]          # tiny
#                [ 7  3  2  5 13  7  0  0]]         # medium
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)

In [10]:
seq_lengths

tensor([8, 4, 6])

In [11]:
seq_tensor

tensor([[ 6,  9,  8,  4,  1, 11, 12, 10],
        [12,  5,  8, 14,  0,  0,  0,  0],
        [ 7,  3,  2,  5, 13,  7,  0,  0]])

In [12]:
## Step 5: Sort instances by sequence length in descending order ##
##---------------------------------------------------------------##

seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
# seq_tensor => [[ 6  9  8  4  1 11 12 10]           # long_str
#                [ 7  3  2  5 13  7  0  0]           # medium
#                [12  5  8 14  0  0  0  0]]          # tiny
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)

In [13]:
perm_idx

tensor([0, 2, 1])

In [14]:
seq_tensor

tensor([[ 6,  9,  8,  4,  1, 11, 12, 10],
        [ 7,  3,  2,  5, 13,  7,  0,  0],
        [12,  5,  8, 14,  0,  0,  0,  0]])

In [15]:
## Step 6: Embed the instances ##
##-----------------------------##

embedded_seq_tensor = embed(seq_tensor)
# embedded_seq_tensor =>
#                       [[[-0.77578706 -1.8080667  -1.1168439   1.1059115 ]     l
#                         [-0.23622951  2.0361056   0.15435742 -0.04513785]     o
#                         [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     n
#                         [ 0.40524676  0.98665565 -0.08621677 -1.1728264 ]     g
#                         [-1.6334635  -0.6100042   1.7509955  -1.931793  ]     _
#                         [-0.6470658  -0.6266589  -1.7463604   1.2675372 ]     s
#                         [ 0.64004815  0.45813003  0.3476034  -0.03451729]     t
#                         [-0.22739866 -0.45782727 -0.6643252   0.25129375]]    r

#                        [[ 0.16031227 -0.08209462 -0.16297023  0.48121014]     m
#                         [-0.7303265  -0.857339    0.58913064 -1.1068314 ]     e
#                         [ 0.48159844 -1.4886451   0.92639893  0.76906884]     d
#                         [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     i
#                         [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ]     u
#                         [ 0.16031227 -0.08209462 -0.16297023  0.48121014]     m
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]]    <pad>

#                        [[ 0.64004815  0.45813003  0.3476034  -0.03451729]     t
#                         [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     i
#                         [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     n
#                         [-1.284392    0.68294704  1.4064184  -0.42879772]     y
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]]]   <pad>
# embedded_seq_tensor.shape : (batch_size X max_seq_len X embedding_dim) = (3 X 8 X 4)

In [16]:
embedded_seq_tensor

tensor([[[-1.1597,  0.8898, -0.7697, -0.1434],
         [ 0.3070,  0.3274, -0.6213,  1.3172],
         [ 1.0697, -1.9666,  0.6205,  0.9001],
         [-0.9073,  0.5386,  0.2677, -0.5309],
         [-0.2918,  0.6517, -1.7233,  0.7375],
         [-0.8174, -1.0337, -0.6716,  2.1332],
         [ 0.4409,  0.9304, -0.1408,  1.3347],
         [ 0.1776,  0.3898,  0.9927,  0.7810]],

        [[ 0.3164, -0.4242, -0.1908,  0.3838],
         [-0.4709,  1.0608, -1.3358, -0.0262],
         [ 0.8089, -0.0299,  0.3917, -1.2011],
         [ 1.2057, -1.7734, -0.3940,  1.7063],
         [ 0.7673,  0.5728,  1.6220,  0.9624],
         [ 0.3164, -0.4242, -0.1908,  0.3838],
         [-1.4086, -0.0051,  2.1867,  0.3257],
         [-1.4086, -0.0051,  2.1867,  0.3257]],

        [[ 0.4409,  0.9304, -0.1408,  1.3347],
         [ 1.2057, -1.7734, -0.3940,  1.7063],
         [ 1.0697, -1.9666,  0.6205,  0.9001],
         [-1.2739,  0.3627, -1.3498,  0.8632],
         [-1.4086, -0.0051,  2.1867,  0.3257],
         

In [17]:
## Step 7: Call pack_padded_sequence with embeded instances and sequence lengths ##
##-------------------------------------------------------------------------------##

packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)
# packed_input (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
#
# packed_input.data =>
#                         [[-0.77578706 -1.8080667  -1.1168439   1.1059115 ]     l
#                          [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ]     m
#                          [-0.6470658  -0.6266589  -1.7463604   1.2675372 ]     t
#                          [ 0.16031227 -0.08209462 -0.16297023  0.48121014]     o
#                          [ 0.40524676  0.98665565 -0.08621677 -1.1728264 ]     e
#                          [-1.284392    0.68294704  1.4064184  -0.42879772]     i
#                          [ 0.64004815  0.45813003  0.3476034  -0.03451729]     n
#                          [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     d
#                          [ 0.64004815  0.45813003  0.3476034  -0.03451729]     n
#                          [-0.23622951  2.0361056   0.15435742 -0.04513785]     g
#                          [ 0.16031227 -0.08209462 -0.16297023  0.48121014]     i
#                          [-0.22739866 -0.45782727 -0.6643252   0.25129375]]    y
#                          [-0.7303265  -0.857339    0.58913064 -1.1068314 ]     _
#                          [-1.6334635  -0.6100042   1.7509955  -1.931793  ]     u
#                          [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     s
#                          [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     m
#                          [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     t
#                          [ 0.48159844 -1.4886451   0.92639893  0.76906884]     r
# packed_input.data.shape : (batch_sum_seq_len X embedding_dim) = (18 X 4)
#
# packed_input.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1]
# visualization :
# l  o  n  g  _  s  t  r   #(long_str)
# m  e  d  i  u  m         #(medium)
# t  i  n  y               #(tiny)
# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])

In [18]:
packed_input.data.shape

torch.Size([18, 4])

In [19]:
## Step 8: Forward with LSTM ##
##---------------------------##

packed_output, (ht, ct) = lstm(packed_input)
# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
#
# packed_output.data :
#                          [[-0.00947162  0.07743231  0.20343193  0.29611713  0.07992904]   l
#                           [ 0.08596145  0.09205993  0.20892891  0.21788561  0.00624391]   o
#                           [ 0.16861682  0.07807446  0.18812777 -0.01148055 -0.01091915]   n
#                           [ 0.20994528  0.17932937  0.17748171  0.05025435  0.15717036]   g
#                           [ 0.01364102  0.11060348  0.14704391  0.24145307  0.12879576]   _
#                           [ 0.02610307  0.00965587  0.31438383  0.246354    0.08276576]   s
#                           [ 0.09527554  0.14521319  0.1923058  -0.05925677  0.18633027]   t
#                           [ 0.09872741  0.13324396  0.19446367  0.4307988  -0.05149471]   r
#                           [ 0.03895474  0.08449443  0.18839942  0.02205326  0.23149511]   m
#                           [ 0.14620507  0.07822411  0.2849248  -0.22616537  0.15480657]   e
#                           [ 0.00884941  0.05762182  0.30557525  0.373712    0.08834908]   d
#                           [ 0.12460691  0.21189159  0.04823487  0.06384943  0.28563985]   i
#                           [ 0.01368293  0.15872964  0.03759198 -0.13403234  0.23890573]   u
#                           [ 0.00377969  0.05943518  0.2961751   0.35107893  0.15148178]   m
#                           [ 0.00737647  0.17101538  0.28344846  0.18878219  0.20339936]   t
#                           [ 0.0864429   0.11173367  0.3158251   0.37537992  0.11876849]   i
#                           [ 0.17885767  0.12713005  0.28287745  0.05562563  0.10871304]   n
#                           [ 0.09486895  0.12772645  0.34048414  0.25930756  0.12044918]]  y
# packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5)

# packed_output.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1] (same as packed_input.batch_sizes)
# visualization :
# l  o  n  g  _  s  t  r   #(long_str)
# m  e  d  i  u  m         #(medium)
# t  i  n  y               #(tiny)
# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])

In [20]:
packed_output.data.shape

torch.Size([18, 5])

In [21]:
ht

tensor([[[-0.0833,  0.0381, -0.1947, -0.0692,  0.2844],
         [-0.0571, -0.0250,  0.0093, -0.2057,  0.3676],
         [-0.2217,  0.0282, -0.2632, -0.1919,  0.4664]]],
       grad_fn=<StackBackward>)

In [22]:
ct

tensor([[[-0.1472,  0.0989, -0.4971, -0.1882,  0.6185],
         [-0.0932, -0.0559,  0.0166, -0.5489,  0.6335],
         [-0.2833,  0.1876, -0.3830, -0.2727,  0.8058]]],
       grad_fn=<StackBackward>)

In [23]:
## Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector ##
##------------------------------------------------------------------------------------##

# unpack your output if required
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
# output:
# output =>
#                          [[[-0.00947162  0.07743231  0.20343193  0.29611713  0.07992904]   l
#                            [ 0.20994528  0.17932937  0.17748171  0.05025435  0.15717036]   o
#                            [ 0.09527554  0.14521319  0.1923058  -0.05925677  0.18633027]   n
#                            [ 0.14620507  0.07822411  0.2849248  -0.22616537  0.15480657]   g
#                            [ 0.01368293  0.15872964  0.03759198 -0.13403234  0.23890573]   _
#                            [ 0.00737647  0.17101538  0.28344846  0.18878219  0.20339936]   s
#                            [ 0.17885767  0.12713005  0.28287745  0.05562563  0.10871304]   t
#                            [ 0.09486895  0.12772645  0.34048414  0.25930756  0.12044918]]  r

#                           [[ 0.08596145  0.09205993  0.20892891  0.21788561  0.00624391]   m
#                            [ 0.01364102  0.11060348  0.14704391  0.24145307  0.12879576]   e
#                            [ 0.09872741  0.13324396  0.19446367  0.4307988  -0.05149471]   d
#                            [ 0.00884941  0.05762182  0.30557525  0.373712    0.08834908]   i
#                            [ 0.00377969  0.05943518  0.2961751   0.35107893  0.15148178]   u
#                            [ 0.0864429   0.11173367  0.3158251   0.37537992  0.11876849]   m
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]]  <pad>

#                           [[ 0.16861682  0.07807446  0.18812777 -0.01148055 -0.01091915]   t
#                            [ 0.02610307  0.00965587  0.31438383  0.246354    0.08276576]   i
#                            [ 0.03895474  0.08449443  0.18839942  0.02205326  0.23149511]   n
#                            [ 0.12460691  0.21189159  0.04823487  0.06384943  0.28563985]   y
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]]] <pad>
# output.shape : ( batch_size X max_seq_len X hidden_dim) = (3 X 8 X 5)

In [24]:
output

tensor([[[-0.0500,  0.0737, -0.2306, -0.0916,  0.1281],
         [-0.0580,  0.0823, -0.1971, -0.1181,  0.2464],
         [-0.1110, -0.2148, -0.1024, -0.1155,  0.4252],
         [-0.0759,  0.1037, -0.1557, -0.1975,  0.2558],
         [-0.0470,  0.0539, -0.2519, -0.1898,  0.4495],
         [-0.3216,  0.0305, -0.3365, -0.0784,  0.5821],
         [-0.1083,  0.0540, -0.2394, -0.0532,  0.3978],
         [-0.0833,  0.0381, -0.1947, -0.0692,  0.2844]],

        [[-0.0470,  0.0499, -0.0118, -0.1345,  0.1937],
         [ 0.0115,  0.0615, -0.1478, -0.1867,  0.2706],
         [ 0.0273,  0.1251,  0.1285, -0.1488,  0.2152],
         [-0.0741, -0.1454,  0.0729, -0.1980,  0.5337],
         [-0.0280, -0.1729, -0.0034, -0.1446,  0.2195],
         [-0.0571, -0.0250,  0.0093, -0.2057,  0.3676],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0083,  0.0506, -0.0877, -0.0824,  0.0445],
         [-0.0949, -0.1112, -0.0777, -0.1488

In [25]:
# Or if you just want the final hidden state?
print(ht[-1])

## Summary of Shape Transformations ##
##----------------------------------##

# (batch_size X max_seq_len X embedding_dim) --> Sort by seqlen ---> (batch_size X max_seq_len X embedding_dim)
# (batch_size X max_seq_len X embedding_dim) --->      Pack     ---> (batch_sum_seq_len X embedding_dim)
# (batch_sum_seq_len X embedding_dim)        --->      LSTM     ---> (batch_sum_seq_len X hidden_dim)
# (batch_sum_seq_len X hidden_dim)           --->    UnPack     ---> (batch_size X max_seq_len X hidden_dim)

tensor([[-0.0833,  0.0381, -0.1947, -0.0692,  0.2844],
        [-0.0571, -0.0250,  0.0093, -0.2057,  0.3676],
        [-0.2217,  0.0282, -0.2632, -0.1919,  0.4664]],
       grad_fn=<SelectBackward>)
