## Visualising dynamic padding with causal self attention

Typically, the causal self attention behaves as having the upper part of the triangle, disregarding the diagonal line, as being cast as float('-inf'). This ensures that no token at an instance of time can look to the future tokens. However, this assumes that the inputs are all filled to the max block size (T) of the model. 

What happens if the inputs are not in a similar length, and we have to apply padding to them just so that we can feed them in a batch? 

This document attempts to visualise what happens in the hood.

In [1]:
## create a fake scenario of the attention. Goal
## here is to incorporate padding and mask the padding

## first, create the input matrix
import torch
import torch.nn as nn
from torch.functional import F 

input_ids = torch.tensor([[1, 2, 3, 0, 0], # padding_token_id = 0
                  [6, 7, 8, 9, 10], 
                  [11, 12, 0, 0, 0]]) # batch_size x seq_len

B = input_ids.size(0) # 3
T = input_ids.size(1) # 5
C = 4 # n_embed
h_s = 3 # head_size
dropout = 0.1


In [2]:
position_embedding = nn.Embedding(100, C)

x = position_embedding(input_ids)

print(x.size())
print(x)

torch.Size([3, 5, 4])
tensor([[[-0.5822, -0.7915,  1.6073,  0.5012],
         [-1.2199,  1.1544, -1.4182, -0.3525],
         [-1.4957, -0.2008, -0.7390,  0.3103],
         [ 1.0110, -1.1545, -0.6795, -1.0732],
         [ 1.0110, -1.1545, -0.6795, -1.0732]],

        [[ 0.4927, -1.0773, -0.6775, -1.2340],
         [ 0.2375, -0.0274, -0.2853, -0.8249],
         [ 0.0331, -0.8663,  0.3287,  1.3669],
         [ 0.7534,  0.2125,  1.2789,  0.6051],
         [-1.7066,  0.9056, -0.2907, -0.1510]],

        [[-1.3441,  0.6079, -0.7126, -0.5544],
         [-0.3601, -0.1667, -0.5483, -1.1857],
         [ 1.0110, -1.1545, -0.6795, -1.0732],
         [ 1.0110, -1.1545, -0.6795, -1.0732],
         [ 1.0110, -1.1545, -0.6795, -1.0732]]], grad_fn=<EmbeddingBackward0>)


In [3]:
## Creating the scenario of inside the causal self attention

## create the weights of q, k, v
q = nn.Linear(C, h_s)
k = nn.Linear(C, h_s)
v = nn.Linear(C, h_s)

## we will not create the attention mask here. That will be created layer

## create dropout
dropout = nn.Dropout(dropout)

In [4]:
## Normally, this is how the attention mask will look
torch.tril(torch.ones(T, T)) # used to mask the attention matrix


tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])

In [5]:
## perform matrix multiplication of the weights with the input
Q = q(x)
K = k(x)
V = v(x)

In [6]:
## Calculate the numerator of the attention
attention = Q @ K.transpose(-2,-1) * (C ** -0.5) # B,T,T
print(attention.shape)
print(attention)

torch.Size([3, 5, 5])
tensor([[[-0.7093,  0.2768, -0.0566,  0.3933,  0.3933],
         [ 0.1542, -1.2993, -0.8118, -0.0039, -0.0039],
         [-0.1514, -1.3935, -0.9908,  0.1829,  0.1829],
         [ 0.2091, -0.9234, -0.4266, -0.0906, -0.0906],
         [ 0.2091, -0.9234, -0.4266, -0.0906, -0.0906]],

        [[-0.2317, -0.4375,  0.3497,  0.1786, -1.0139],
         [-0.0871, -0.2021,  0.1585,  0.0427, -0.4850],
         [ 0.1949,  0.0660, -0.2792, -0.2333, -0.2043],
         [ 0.3182,  0.3802, -0.3170, -0.3909,  0.7895],
         [ 0.0429, -0.2189, -0.1106, -0.0816, -0.9293]],

        [[-1.0239, -0.5582,  0.1024,  0.1024,  0.1024],
         [-0.9580, -0.5921,  0.0225,  0.0225,  0.0225],
         [-0.7755, -0.5549, -0.0906, -0.0906, -0.0906],
         [-0.7755, -0.5549, -0.0906, -0.0906, -0.0906],
         [-0.7755, -0.5549, -0.0906, -0.0906, -0.0906]]],
       grad_fn=<MulBackward0>)


In [7]:
## here, create the causal mask and the padding mask. The causal mask is similar 
## to the one we usually create. The padding mask is created by checking if the input has
## padding tokens. If it has padding tokens, then the mask is 1, else 0.

# Create causal mask tensor (batch_size, seq_len, seq_len)
causal_mask = torch.triu(torch.ones(T, T), diagonal=1).unsqueeze(0)  # Upper triangular matrix
causal_mask = causal_mask == 1 ## converts into BoolTensor

# Create padding mask tensor (batch_size, seq_len)
padding_mask = (input_ids == 0).unsqueeze(1)  # Padding tokens are zeros

# Combine the masks (padding_mask will have zeros for padded tokens)
mask = causal_mask + padding_mask.unsqueeze(-1) ## True + False = True
mask = mask.squeeze(1)
print(mask.shape)
print(mask)

torch.Size([3, 5, 5])
tensor([[[False,  True,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False, False,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True]],

        [[False,  True,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False, False,  True],
         [False, False, False, False, False]],

        [[False,  True,  True,  True,  True],
         [False, False,  True,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True]]])


In [8]:
## masked fill the attention matrix with the mask
attention = attention.masked_fill(mask, float('-inf')) # B,T,T
print(attention.size())
print(attention)

torch.Size([3, 5, 5])
tensor([[[-0.7093,    -inf,    -inf,    -inf,    -inf],
         [ 0.1542, -1.2993,    -inf,    -inf,    -inf],
         [-0.1514, -1.3935, -0.9908,    -inf,    -inf],
         [   -inf,    -inf,    -inf,    -inf,    -inf],
         [   -inf,    -inf,    -inf,    -inf,    -inf]],

        [[-0.2317,    -inf,    -inf,    -inf,    -inf],
         [-0.0871, -0.2021,    -inf,    -inf,    -inf],
         [ 0.1949,  0.0660, -0.2792,    -inf,    -inf],
         [ 0.3182,  0.3802, -0.3170, -0.3909,    -inf],
         [ 0.0429, -0.2189, -0.1106, -0.0816, -0.9293]],

        [[-1.0239,    -inf,    -inf,    -inf,    -inf],
         [-0.9580, -0.5921,    -inf,    -inf,    -inf],
         [   -inf,    -inf,    -inf,    -inf,    -inf],
         [   -inf,    -inf,    -inf,    -inf,    -inf],
         [   -inf,    -inf,    -inf,    -inf,    -inf]]],
       grad_fn=<MaskedFillBackward0>)


In [9]:
## now, apply the softmax to the attention matrix
attention = F.softmax(attention, dim=-1)
print(attention.size())
print(attention)

torch.Size([3, 5, 5])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.8105, 0.1895, 0.0000, 0.0000, 0.0000],
         [0.5811, 0.1678, 0.2510, 0.0000, 0.0000],
         [   nan,    nan,    nan,    nan,    nan],
         [   nan,    nan,    nan,    nan,    nan]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5287, 0.4713, 0.0000, 0.0000, 0.0000],
         [0.3998, 0.3514, 0.2488, 0.0000, 0.0000],
         [0.3241, 0.3448, 0.1717, 0.1595, 0.0000],
         [0.2572, 0.1979, 0.2206, 0.2271, 0.0973]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4095, 0.5905, 0.0000, 0.0000, 0.0000],
         [   nan,    nan,    nan,    nan,    nan],
         [   nan,    nan,    nan,    nan,    nan],
         [   nan,    nan,    nan,    nan,    nan]]],
       grad_fn=<SoftmaxBackward0>)


In [10]:
## apply dropout
attention = dropout(attention)
print(attention.size())
print(attention)

torch.Size([3, 5, 5])
tensor([[[1.1111, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.9006, 0.2105, 0.0000, 0.0000, 0.0000],
         [0.6457, 0.1865, 0.2789, 0.0000, 0.0000],
         [   nan,    nan,    nan,    nan,    nan],
         [   nan,    nan,    nan,    nan,    nan]],

        [[1.1111, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5875, 0.5236, 0.0000, 0.0000, 0.0000],
         [0.4442, 0.3904, 0.0000, 0.0000, 0.0000],
         [0.3601, 0.3831, 0.1908, 0.1772, 0.0000],
         [0.2857, 0.2199, 0.2451, 0.2523, 0.1081]],

        [[1.1111, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4551, 0.6561, 0.0000, 0.0000, 0.0000],
         [   nan,    nan,    nan,    nan,    nan],
         [   nan,    nan,    nan,    nan,    nan],
         [   nan,    nan,    nan,    nan,    nan]]], grad_fn=<MulBackward0>)


In [11]:
## lastly, multiply the attention with the value matrix
out = attention @ V
print(out.size())
print(out)

torch.Size([3, 5, 3])
tensor([[[ 0.5873, -0.6878, -0.1615],
         [ 0.3643, -0.3133, -0.1221],
         [ 0.1727, -0.0221, -0.1713],
         [    nan,     nan,     nan],
         [    nan,     nan,     nan]],

        [[ 0.2732, -0.1841,  0.2674],
         [ 0.2040, -0.0867,  0.3217],
         [ 0.1536, -0.0657,  0.2413],
         [ 0.1762, -0.4011,  0.4033],
         [ 0.1372, -0.4055,  0.3681]],

        [[-0.2119,  1.0908, -0.2169],
         [-0.0069,  0.7432, -0.0770],
         [    nan,     nan,     nan],
         [    nan,     nan,     nan],
         [    nan,     nan,     nan]]], grad_fn=<UnsafeViewBackward0>)


In [12]:
linear = nn.Linear(h_s, C)

In [13]:
linear(out)

tensor([[[-0.5453,  0.6221,  0.1022, -0.2521],
         [-0.4833,  0.5375,  0.2063, -0.3706],
         [-0.4284,  0.4553,  0.3124, -0.4826],
         [    nan,     nan,     nan,     nan],
         [    nan,     nan,     nan,     nan]],

        [[-0.5044,  0.6096,  0.1712, -0.3612],
         [-0.4943,  0.6017,  0.1944, -0.3903],
         [-0.4859,  0.5841,  0.2351, -0.4258],
         [-0.5817,  0.7515,  0.1841, -0.3674],
         [-0.5831,  0.7530,  0.2096, -0.3893]],

        [[-0.1971,  0.0984,  0.5250, -0.7377],
         [-0.2740,  0.2232,  0.3928, -0.6022],
         [    nan,     nan,     nan,     nan],
         [    nan,     nan,     nan,     nan],
         [    nan,     nan,     nan,     nan]]], grad_fn=<ViewBackward0>)

In [14]:
linear(out)

tensor([[[-0.5453,  0.6221,  0.1022, -0.2521],
         [-0.4833,  0.5375,  0.2063, -0.3706],
         [-0.4284,  0.4553,  0.3124, -0.4826],
         [    nan,     nan,     nan,     nan],
         [    nan,     nan,     nan,     nan]],

        [[-0.5044,  0.6096,  0.1712, -0.3612],
         [-0.4943,  0.6017,  0.1944, -0.3903],
         [-0.4859,  0.5841,  0.2351, -0.4258],
         [-0.5817,  0.7515,  0.1841, -0.3674],
         [-0.5831,  0.7530,  0.2096, -0.3893]],

        [[-0.1971,  0.0984,  0.5250, -0.7377],
         [-0.2740,  0.2232,  0.3928, -0.6022],
         [    nan,     nan,     nan,     nan],
         [    nan,     nan,     nan,     nan],
         [    nan,     nan,     nan,     nan]]], grad_fn=<ViewBackward0>)

In [15]:
linear(out).size()

torch.Size([3, 5, 4])

In [16]:
F.scaled_dot_product_attention(Q, K, V, mask, is_causal=False)

tensor([[[-0.0433,  0.1076,  0.2371],
         [ 0.1520, -0.3164,  0.3542],
         [ 0.2311, -0.4924,  0.4840],
         [ 0.1870, -0.2289,  0.1277],
         [ 0.1870, -0.2289,  0.1277]],

        [[ 0.0853, -0.6197,  0.4003],
         [ 0.0524, -0.5504,  0.3384],
         [ 0.0316,  0.0387,  0.1955],
         [-0.1959,  1.0176, -0.2850],
         [    nan,     nan,     nan]],

        [[ 0.2164, -0.3653,  0.4213],
         [ 0.2311, -0.4924,  0.4840],
         [ 0.1679, -0.1901,  0.3402],
         [ 0.1679, -0.1901,  0.3402],
         [ 0.1679, -0.1901,  0.3402]]], grad_fn=<UnsafeViewBackward0>)

In [21]:
attn_bias = torch.zeros(Q.size(-2), K.size(-2), dtype = Q.dtype)

In [23]:
attn_bias.unsqueeze(0).expand(3, -1, -1)

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])

In [33]:
~mask.logical_not()

tensor([[[False,  True,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False, False,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True]],

        [[False,  True,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False, False,  True],
         [False, False, False, False, False]],

        [[False,  True,  True,  True,  True],
         [False, False,  True,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True]]])

In [32]:
attn_bias.unsqueeze(0).expand(3, -1, -1).masked_fill_(~mask.logical_not(), float('-inf'))

  attn_bias.unsqueeze(0).expand(3, -1, -1).masked_fill_(~mask.logical_not(), float('-inf'))


tensor([[[-inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf]],

        [[-inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf]],

        [[-inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf]]])

In [31]:
import math

def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

In [32]:
scaled_dot_product_attention(Q, K, V, mask, is_causal=False)

RuntimeError: output with shape [5, 5] doesn't match the broadcast shape [3, 5, 5]

In [33]:
F.scaled_dot_product_attention(Q, K, V, mask, is_causal=False)

tensor([[[-0.0410, -1.0452, -0.2190],
         [ 0.0987, -1.1155, -0.0409],
         [ 0.2031, -1.0693,  0.1037],
         [-0.2137, -0.8735, -0.5728],
         [-0.2137, -0.8735, -0.5728]],

        [[-0.3442,  0.0989, -0.0902],
         [-0.0428, -0.2093,  0.2775],
         [-0.0196, -0.2133,  0.3251],
         [ 0.1189, -0.5733,  0.3315],
         [    nan,     nan,     nan]],

        [[ 0.0716, -0.9821, -0.0615],
         [ 0.2031, -1.0693,  0.1037],
         [-0.3784, -0.5380, -0.5431],
         [-0.3784, -0.5380, -0.5431],
         [-0.3784, -0.5380, -0.5431]]], grad_fn=<UnsafeViewBackward0>)