In [1]:
import torch

In [2]:
# let make a very useless batch here
vec_1 = torch.FloatTensor([[1]])
vec_2 = torch.FloatTensor([[1], [2], [3]])
vec_3 = torch.FloatTensor([[1], [2]])
unsorted_batch = [vec_1, vec_2, vec_3]

In [3]:
# retrieve perm_index by sorting lengths
lengths = torch.Tensor([seq.shape[0] for seq in unsorted_batch])
_, perm_indices = lengths.sort(dim=0, descending=True)

# actually sort the sequences
sorted_batch = [unsorted_batch[i] for i in perm_indices]
for d in sorted_batch:
    print(d)

tensor([[1.],
        [2.],
        [3.]])
tensor([[1.],
        [2.]])
tensor([[1.]])


In [4]:
from dataset import invert_permutation
invert_perm_index = invert_permutation(perm_indices)

# let's make sure that inverting works
for i in invert_perm_index:
    print(sorted_batch[i])

tensor([[1.]])
tensor([[1.],
        [2.],
        [3.]])
tensor([[1.],
        [2.]])


In [5]:
# The time has come to make a packed batch!
# The returned PackedSequence may not make sense to you: that's totally okay
from torch.nn.utils.rnn import pack_sequence
packed_batch = pack_sequence(sorted_batch)
packed_batch


PackedSequence(data=tensor([[1.],
        [1.],
        [1.],
        [2.],
        [2.],
        [3.]]), batch_sizes=tensor([3, 2, 1]))

In [6]:
from torch import nn
rnn = nn.RNN(input_size=1, hidden_size=1, batch_first=True)
# we are feeding in the PackedSequence
packed_output, h_n = rnn.forward(packed_batch)

# and the output sequence of rnn is also a PackedSequence
print("type(packed_output): ", type(packed_output))
# but the final hidden output is, naturally, just a Tensor 
print("type(h_n): ", type(h_n))

type(packed_output):  <class 'torch.nn.utils.rnn.PackedSequence'>
type(h_n):  <class 'torch.Tensor'>


## If you are interested in the output sequence

In [7]:
# that's where unpack happens
from torch.nn.utils.rnn import pad_packed_sequence

unpacked_output, lengths = pad_packed_sequence(packed_output, batch_first=True)
unpacked_output

tensor([[[-0.3993],
         [-0.1942],
         [ 0.4553]],

        [[-0.3993],
         [-0.1942],
         [ 0.0000]],

        [[-0.3993],
         [ 0.0000],
         [ 0.0000]]], grad_fn=<TransposeBackward0>)

In [8]:
# However the unpacked sequences are still both sorted and padded
# Remember our invert perm index? Let's "unsort" the output
# CAVEAT: Always use index_select; Don't use in-place assignment
output_seq = unpacked_output.index_select(dim=0, index=invert_perm_index)

# lastly, we don't want padding in the output
# we need to "unsort" the lengths as well
lengths = lengths.index_select(dim=0, index=invert_perm_index)
# slice away the padding
output_seq = [o[:l] for o, l in zip(output_seq, lengths)]

# And we're done!
output_seq


[tensor([[-0.3993]], grad_fn=<SliceBackward>), tensor([[-0.3993],
         [-0.1942],
         [ 0.4553]], grad_fn=<SliceBackward>), tensor([[-0.3993],
         [-0.1942]], grad_fn=<SliceBackward>)]

## If you are interested in the final hidden output (h_n)

In [12]:
# this is much easier
# however, remember that h_n is of shape (num_layer * num_direction, batch_size, hidden_dim)
# in our minimal case, it means (1, 3, 1)
# as a result, we should do index_select on dim=1
print("Shape: ", h_n.shape)
h_n = h_n.index_select(dim=1, index=invert_perm_index)
h_n

Shape:  torch.Size([1, 3, 1])


tensor([[[-0.1942],
         [-0.3993],
         [ 0.4553]]], grad_fn=<IndexSelectBackward>)