In [162]:
import torch
from torch.nn.utils.rnn import pad_sequence

In [163]:
a = torch.ones(5)  # a.shape=[5,]
b = torch.ones(3)  # b.shape=[3,]
c = torch.ones(2)  # c.shape=[2,]

'''
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise.

`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
'''
pad_false = pad_sequence([a, b, c])  # 默认batch_first=False
pad_false.shape

torch.Size([5, 3])

In [164]:
pad_false

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.]])

In [165]:
'''
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
    ``T x B x *`` otherwise. Default: False.
'''
pad_true = pad_sequence([a, b, c],
                        batch_first=True,  # batch为第一个维度
                        padding_value=-999)  # padding_value (float, optional) – value for padded elements. Default: 0.
pad_true.shape

torch.Size([3, 5])

In [166]:
pad_true

tensor([[   1.,    1.,    1.,    1.,    1.],
        [   1.,    1.,    1., -999., -999.],
        [   1.,    1., -999., -999., -999.]])

In [167]:
a1 = torch.tensor([[0, 1, 2, 3],
                   [4, 5, 6, 7],
                   [8, 9, 10, 11]])  # a1.shape=[3, 4]
b1 = torch.tensor([[-1, -2],
                   [-3, -4],
                   [-5, -6]])  # b1.shape=[3, 2]
c1 = torch.tensor([[-11,
                    -12,
                    -13]])  # c1.shape=[1, 3]

In [168]:
# torch.transpose(a1, 0, 1).shape=[4, 3]
# torch.transpose(b1, 0, 1).shape=[2, 3]
# torch.transpose(c1, 0, 1).shape=[3, 1]
d1 = pad_sequence([torch.transpose(a1, 0, 1),
                   torch.transpose(b1, 0, 1),
                   c1,
                   torch.transpose(c1, 0, 1)],
                  batch_first=True)
d1.shape  # 填充第0个维度,其他维度必须相等或可广播

torch.Size([4, 4, 3])

In [169]:
torch.transpose(d1, 1, 2)

tensor([[[  0,   1,   2,   3],
         [  4,   5,   6,   7],
         [  8,   9,  10,  11]],

        [[ -1,  -2,   0,   0],
         [ -3,  -4,   0,   0],
         [ -5,  -6,   0,   0]],

        [[-11,   0,   0,   0],
         [-12,   0,   0,   0],
         [-13,   0,   0,   0]],

        [[-11, -12, -13,   0],
         [-11, -12, -13,   0],
         [-11, -12, -13,   0]]])

In [170]:
torch.transpose(pad_sequence([torch.transpose(a1, 0, 1),
                              torch.transpose(b1, 0, 1),
                              c1,
                              torch.broadcast_to(torch.transpose(c1, 0, 1), (3, 3))],
                             batch_first=True), 1, 2)  # 与上等价

tensor([[[  0,   1,   2,   3],
         [  4,   5,   6,   7],
         [  8,   9,  10,  11]],

        [[ -1,  -2,   0,   0],
         [ -3,  -4,   0,   0],
         [ -5,  -6,   0,   0]],

        [[-11,   0,   0,   0],
         [-12,   0,   0,   0],
         [-13,   0,   0,   0]],

        [[-11, -12, -13,   0],
         [-11, -12, -13,   0],
         [-11, -12, -13,   0]]])