# Masking

In [35]:
import torch

def get_attention_mask(
    q_tokens: torch.Tensor,
    k_tokens: torch.Tensor,
    q_pad_idx: int | None,
    k_pad_idx: int | None,
    mask_future: bool,
) -> torch.Tensor:
    """Returns an addative attention mask."""
    B, T = q_tokens.shape
    mask = torch.zeros((B, T, T), dtype=torch.float32)

    # Pad mask.
    if None not in (q_pad_idx, k_pad_idx):
        q_mask = (q_tokens == q_pad_idx).unsqueeze(-1)
        k_mask = (k_tokens == k_pad_idx).unsqueeze(-2)
        print(q_mask.shape, k_mask.shape)
        mask[q_mask | k_mask] = torch.finfo(torch.float32).min

    # Future mask.
    if mask_future:
        future_mask = torch.triu(
            torch.full((T, T), fill_value=True, dtype=torch.bool), diagonal=1
        )
        mask.masked_fill_(future_mask, float("-inf"))

    print(mask.shape)

    return mask.to(device="cpu")

In [36]:
import torch.nn.functional as F

q_tokens = torch.tensor([[1, 2, 3, 5, 5]])
k_tokens = torch.tensor([[1, 2, 3, 4, 7]])

attention_mask = get_attention_mask(
    q_tokens, k_tokens, q_pad_idx=5, k_pad_idx=7, mask_future=True
)
print(attention_mask)

wei = torch.rand((2, 2, 5, 5))

wei = wei + attention_mask
wei = F.softmax(wei, dim=-1)

print(wei)

torch.Size([1, 5, 1]) torch.Size([1, 1, 5])
torch.Size([1, 5, 5])
tensor([[[ 0.0000e+00,        -inf,        -inf,        -inf,        -inf],
         [ 0.0000e+00,  0.0000e+00,        -inf,        -inf,        -inf],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,        -inf,        -inf],
         [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,        -inf],
         [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38]]])
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.6311, 0.3689, 0.0000, 0.0000, 0.0000],
          [0.3140, 0.4101, 0.2759, 0.0000, 0.0000],
          [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
          [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5800, 0.4200, 0.0000, 0.0000, 0.0000],
          [0.3043, 0.3177, 0.3780, 0.0000, 0.0000],
          [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
          [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


        [[[1.0000, 0.0000,