In [44]:
import torch
import math

In [45]:
N, L, D = 2, 4, 6

In [190]:
def scaled_dot_product_attention(query, key, value, attn_mask=None, is_causal=False):
    """
    query: N x L x D
    key: N x L x D
    value: N x L x D
    """
    assert (bool(attn_mask is not None) and is_causal) == 0, "Both cant be True"

    q_k = query @ key.transpose(-1, -2)
    scaled_q_k = q_k * (1 / math.sqrt(query.size(-1)))

    attn_bias = torch.zeros(L, L)
    if attn_mask is not None:
        attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf"))

    if is_causal:
        tmp_mask = torch.ones(L, L).tril()
        attn_bias = attn_bias.masked_fill(tmp_mask.logical_not(), float("-inf"))

    scaled_q_k += attn_bias
    attn = torch.nn.functional.softmax(scaled_q_k, dim=-1)
    weight_v = attn @ value
    return weight_v

In [203]:
query = torch.rand(N, L, D)
key = torch.rand(N, L, D)
value = torch.rand(N, L, D)

query.shape, key.shape, value.shape

(torch.Size([2, 4, 6]), torch.Size([2, 4, 6]), torch.Size([2, 4, 6]))

In [192]:
res = scaled_dot_product_attention(query, key, value)

# Verification

## Self Attention

In [193]:
orig = torch.nn.functional.scaled_dot_product_attention(query, key, value)

In [194]:
res = scaled_dot_product_attention(query, key, value)

In [195]:
torch.all((orig - res) < 1e-5)

tensor(True)

## Causal Attention

In [196]:
orig = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=True)

In [197]:
res = scaled_dot_product_attention(query, key, value, is_causal=True)

In [198]:
torch.all((orig - res) < 1e-5)

tensor(True)

## Masked Attention

In [199]:
mask = torch.rand(L, L) > 0.5

In [200]:
orig = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=mask)

In [201]:
res = scaled_dot_product_attention(query, key, value, attn_mask=mask)

In [202]:
torch.all((orig - res) < 1e-5)

tensor(True)