In [1]:
import math
import torch
from torch import nn
from d2l import torch as d2l

In [2]:
# Masked Softmax Operation

# Dive  into  Deep    Learning
# Learn to    code    <blank>
# Hello world <blank> <blank>

In [7]:
X_example = torch.rand(2, 2, 4)
X_example

tensor([[[0.7344, 0.4231, 0.5460, 0.2873],
         [0.5177, 0.1597, 0.8488, 0.4319]],

        [[0.3024, 0.9401, 0.4310, 0.6325],
         [0.6354, 0.9085, 0.6409, 0.6713]]])

In [8]:
X_example.size(1)

2

In [9]:
X_example.device

device(type='cpu')

In [12]:
vali_len_example = torch.tensor([2, 3])
vali_len_example

tensor([2, 3])

In [15]:
maxlen = X_example.size(1)
mask = torch.arange((maxlen), dtype=torch.float32,
                    device=X_example.device)[None, :] < vali_len_example[:, None]

mask

tensor([[True, True],
        [True, True]])

In [17]:
def masked_softmax(X, valid_lens):  #@save
    """Perform softmax operation by masking elements on the last axis."""
    # X: 3D tensor, valid_lens: 1D or 2D tensor
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.size(1)
        mask = torch.arange((maxlen), dtype=torch.float32,
                            device=X.device)[None, :] < valid_len[:, None]
        X[~mask] = value
        return X

    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

In [None]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))

tensor([[[0.4721, 0.5279, 0.0000, 0.0000],
         [0.5189, 0.4811, 0.0000, 0.0000]],

        [[0.3977, 0.3000, 0.3022, 0.0000],
         [0.2685, 0.3930, 0.3385, 0.0000]]])

In [19]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.2729, 0.3944, 0.3327, 0.0000]],

        [[0.4275, 0.5725, 0.0000, 0.0000],
         [0.2274, 0.1753, 0.2971, 0.3001]]])

In [21]:
# Batch Matrix Multiplication
Q = torch.ones((2, 3, 4))
K = torch.ones((2, 4, 6))

torch.bmm(Q, K)

tensor([[[4., 4., 4., 4., 4., 4.],
         [4., 4., 4., 4., 4., 4.],
         [4., 4., 4., 4., 4., 4.]],

        [[4., 4., 4., 4., 4., 4.],
         [4., 4., 4., 4., 4., 4.],
         [4., 4., 4., 4., 4., 4.]]])

In [22]:
d2l.check_shape(torch.bmm(Q, K), (2, 3, 6))