In [1]:
import torch
import torch.nn.functional as F

In [4]:
torch.manual_seed(42)
x = torch.rand(1, 2, 3) # batch sample embedding dim


query = F.linear(x, weight=torch.rand(3, 3), bias=torch.rand(3))
key = F.linear(x, weight=torch.rand(3, 3), bias=torch.rand(3))
value = F.linear(x, weight=torch.rand(3, 3), bias=torch.rand(3))


def scaled_dot_product_causal_attention(q, k, v):
  attn_filter = q @ k.transpose(1, 2) # batch dim untouched
  mask = torch.tril(torch.ones(attn_filter.shape[1:]), diagonal=0) # lower triangle for causal mask
  attn_filter = attn_filter.masked_fill(mask==0, value=float('-inf')) # fill -inf for upper parts of the triangle
  attn_filter = attn_filter / torch.sqrt(torch.tensor(k.shape[-1]).float()) # scaling to prevent diminishing/exploding gradients
  attn_filter = F.softmax(attn_filter, dim=-1)
  output = attn_filter @ v #
  return output, attn_filter

output, attn_filter = scaled_dot_product_causal_attention(query, key, value)
output




tensor([[[1.5899, 1.7676, 1.9138],
         [1.6904, 1.7871, 1.7771]]])

In [5]:
expected_output = F.scaled_dot_product_attention(query, key, value, is_causal=True)
print(torch.allclose(output, expected_output))

True
