## RNN中常涉及到的操作

In [1]:
import torch
from torch import Tensor
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence, pack_padded_sequence, pad_sequence, pack_sequence

## pad_sequence

In [2]:
caption = pad_sequence(
    [Tensor([1]),
     Tensor([1,2]),
     Tensor([1,2,3]),
     Tensor([1,2,3,4])
    ]
    ,batch_first=True
)
caption

tensor([[1., 0., 0., 0.],
        [1., 2., 0., 0.],
        [1., 2., 3., 0.],
        [1., 2., 3., 4.]])

## pack_sequence

In [3]:
pack_sequence(    
    [Tensor([1]),
     Tensor([1,2]),
     Tensor([1,2,3]),
     Tensor([1,2,3,4])
    ], enforce_sorted=False)

PackedSequence(data=tensor([1., 1., 1., 1., 2., 2., 2., 3., 3., 4.]), batch_sizes=tensor([4, 3, 2, 1]), sorted_indices=tensor([3, 2, 1, 0]), unsorted_indices=tensor([3, 2, 1, 0]))

In [4]:
pack_sequence(    
    [Tensor([1,1,1]),
     Tensor([0,2,2]),
     Tensor([0,0,3]),
    ]
)

PackedSequence(data=tensor([1., 0., 0., 1., 2., 0., 1., 2., 3.]), batch_sizes=tensor([3, 3, 3]), sorted_indices=None, unsorted_indices=None)

## sort & reIndex

In [5]:
lengths = torch.Tensor([1,2,3,4])
caption, lengths

(tensor([[1., 0., 0., 0.],
         [1., 2., 0., 0.],
         [1., 2., 3., 0.],
         [1., 2., 3., 4.]]),
 tensor([1., 2., 3., 4.]))

In [6]:
sort_lengths, sort_index = torch.sort(lengths, dim=0, descending=True)
sort_lengths, sort_index

(tensor([4., 3., 2., 1.]), tensor([3, 2, 1, 0]))

In [7]:
sort_caption = torch.index_select(input=caption, dim=0, index=sort_index)
sort_caption

tensor([[1., 2., 3., 4.],
        [1., 2., 3., 0.],
        [1., 2., 0., 0.],
        [1., 0., 0., 0.]])

## pack_padded_sequence

In [8]:
temp = pack_padded_sequence(input=sort_caption, lengths=sort_lengths, batch_first=True)
temp

PackedSequence(data=tensor([1., 1., 1., 1., 2., 2., 2., 3., 3., 4.]), batch_sizes=tensor([4, 3, 2, 1]), sorted_indices=None, unsorted_indices=None)

## pad_packed_sequence

In [9]:
pad_packed_sequence(temp, batch_first=True, padding_value=0)

(tensor([[1., 2., 3., 4.],
         [1., 2., 3., 0.],
         [1., 2., 0., 0.],
         [1., 0., 0., 0.]]),
 tensor([4, 3, 2, 1]))