In [20]:
import math
import torch
from torch import nn


In [21]:
def masked_softmax(X, valid_lens):  #@save
    """Perform softmax operation by masking elements on the last axis."""
    # X: 3D tensor, valid_lens: 1D or 2D tensor
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.size(1)
        mask = torch.arange((maxlen), dtype=torch.float32,
                            device=X.device)[None, :] < valid_len[:, None]
        X[~mask] = value
        return X

    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

In [22]:
masked_softmax(torch.rand(2, 2, 4), valid_lens=None)


tensor([[[0.3201, 0.2253, 0.2136, 0.2410],
         [0.1726, 0.2211, 0.2170, 0.3894]],

        [[0.2736, 0.2098, 0.2464, 0.2702],
         [0.2715, 0.3360, 0.1922, 0.2003]]])

In [14]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))


tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.3989, 0.2740, 0.3271, 0.0000]],

        [[0.6673, 0.3327, 0.0000, 0.0000],
         [0.3016, 0.2115, 0.1732, 0.3136]]])

In [8]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))


tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4593, 0.3411, 0.1996, 0.0000]],

        [[0.5672, 0.4328, 0.0000, 0.0000],
         [0.2390, 0.1963, 0.2267, 0.3380]]])