# Minimal tutorial on packing and unpacking sequences in PyTorch
# aka how to use `pack_padded_sequence` and  `pad_packed_sequence`

This is a jupyter version of [@Tushar-N 's gist](https://gist.github.com/Tushar-N/dfca335e370a2bc3bc79876e6270099e) with comments from [@Harsh Trivedi repo](https://github.com/HarshTrivedi/packing-unpacking-pytorch-minimal-tutorial)


In [None]:
# from https://github.com/HarshTrivedi/packing-unpacking-pytorch-minimal-tutorial
import torch
from torch import LongTensor
from torch.nn import Embedding, LSTM
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

## We want to run LSTM on a batch of 3 character sequences ['long_str', 'tiny', 'medium']
#
#     Step 1: Construct Vocabulary
#     Step 2: Load indexed data (list of instances, where each instance is list of character indices)
#     Step 3: Make Model
#  *  Step 4: Pad instances with 0s till max length sequence
#  *  Step 5: Sort instances by sequence length in descending order
#  *  Step 6: Embed the instances
#  *  Step 7: Call pack_padded_sequence with embeded instances and sequence lengths
#  *  Step 8: Forward with LSTM
#  *  Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector
#  *  Summary of Shape Transformations

In [None]:
# We want to run LSTM on a batch following 3 character sequences
seqs = ['long_str',  # len = 8
        'tiny',      # len = 4
        'medium']    # len = 6

In [None]:
## Step 1: Construct Vocabulary ##
##------------------------------##
# make sure <pad> idx is 0
vocab = ['<pad>'] + sorted(set([char for seq in seqs for char in seq]))

In [None]:
vocab

In [None]:
## Step 2: Load indexed data (list of instances, where each instance is list of character indices) ##
##-------------------------------------------------------------------------------------------------##
vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs]

In [None]:
vectorized_seqs

In [None]:
## Step 3: Make Model ##
##--------------------##
embed = Embedding(len(vocab), 4) # embedding_dim = 4
lstm = LSTM(input_size=4, hidden_size=5, num_layers=2, batch_first=True) # input_dim = 4, hidden_dim = 5

In [None]:
## Step 4: Pad instances with 0s till max length sequence ##
##--------------------------------------------------------##

# get the length of each seq in your batch
seq_lengths = LongTensor(list(map(len, vectorized_seqs)))
# seq_lengths => [ 8, 4,  6]
# batch_sum_seq_len: 8 + 4 + 6 = 18
# max_seq_len: 8

seq_tensor = (torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long()
# seq_tensor => [[0 0 0 0 0 0 0 0]
#                [0 0 0 0 0 0 0 0]
#                [0 0 0 0 0 0 0 0]]

for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
    seq_tensor[idx, :seqlen] = LongTensor(seq)
# seq_tensor => [[ 6  9  8  4  1 11 12 10]          # long_str
#                [12  5  8 14  0  0  0  0]          # tiny
#                [ 7  3  2  5 13  7  0  0]]         # medium
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)

In [None]:
seq_lengths

In [None]:
seq_tensor

In [None]:
## Step 5: Sort instances by sequence length in descending order ##
##---------------------------------------------------------------##

seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)

In [None]:
perm_idx

In [None]:
seq_tensor

In [None]:
## Step 6: Embed the instances ##
##-----------------------------##

embedded_seq_tensor = embed(seq_tensor)
# embedded_seq_tensor.shape : (batch_size X max_seq_len X embedding_dim) = (3 X 8 X 4)

In [None]:
embedded_seq_tensor

In [None]:
embedded_seq_tensor.shape

In [None]:
## Step 7: Call pack_padded_sequence with embeded instances and sequence lengths ##
##-------------------------------------------------------------------------------##

packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)
# packed_input (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
# packed_input.data.shape : (batch_sum_seq_len X embedding_dim) = (18 X 4)
#
# packed_input.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1]
# visualization :
# l  o  n  g  _  s  t  r   #(long_str)
# m  e  d  i  u  m         #(medium)
# t  i  n  y               #(tiny)
# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])

In [None]:
packed_input.data.shape

In [None]:
## Step 8: Forward with LSTM ##
##---------------------------##

packed_output, (ht, ct) = lstm(packed_input)
# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
# packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5)

# packed_output.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1] (same as packed_input.batch_sizes)
# visualization :
# l  o  n  g  _  s  t  r   #(long_str)
# m  e  d  i  u  m         #(medium)
# t  i  n  y               #(tiny)
# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])

In [None]:
packed_output.data.shape

In [None]:
ht

In [None]:
ht.shape

In [None]:
ct

In [None]:
ct.shape

In [None]:
## Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector ##
##------------------------------------------------------------------------------------##

# unpack your output if required
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
# output:
# output.shape : ( batch_size X max_seq_len X hidden_dim) = (3 X 8 X 5)

In [None]:
output

In [None]:
output.shape

In [None]:
# Or if you just want the final hidden state?
print(ht[-1])

## Summary of Shape Transformations ##
##----------------------------------##

# (batch_size X max_seq_len X embedding_dim) --> Sort by seqlen ---> (batch_size X max_seq_len X embedding_dim)
# (batch_size X max_seq_len X embedding_dim) --->      Pack     ---> (batch_sum_seq_len X embedding_dim)
# (batch_sum_seq_len X embedding_dim)        --->      LSTM     ---> (batch_sum_seq_len X hidden_dim)
# (batch_sum_seq_len X hidden_dim)           --->    UnPack     ---> (batch_size X max_seq_len X hidden_dim)