In [1]:
import math
import torch
from torch import nn


In [2]:
def _sequence_mask(X, valid_len, value=0):
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X


def masked_softmax(X, valid_lens):  #@save
    """Perform softmax operation by masking elements on the last axis."""
    # X: 3D tensor, valid_lens: 1D or 2D tensor


    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

In [16]:
x = torch.arange(0, 16).reshape(2, 2, 4)
len = torch.tensor([2, 3])
print("x:", x)
print("len:", len)
maxlen = x.size(1)
mask = torch.arange((maxlen))[None]
print(mask, mask.shape)
len_ = len[:, None]
print(len_, len_.shape)
x[~mask] = 0
print(x)

x: tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
len: tensor([2, 3])
tensor([[0, 1]]) torch.Size([1, 2])
tensor([[2],
        [3]]) torch.Size([2, 1])
tensor([[[0, 0, 0, 0],
         [0, 0, 0, 0]],

        [[0, 0, 0, 0],
         [0, 0, 0, 0]]])


In [4]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))


tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.3106, 0.2816, 0.4078, 0.0000]],

        [[0.4544, 0.5456, 0.0000, 0.0000],
         [0.3031, 0.2567, 0.1966, 0.2435]]])

In [5]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))


tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.3853, 0.2295, 0.3852, 0.0000]],

        [[0.3378, 0.6622, 0.0000, 0.0000],
         [0.1530, 0.3257, 0.3673, 0.1539]]])