## Attention

![attention](../00–assets/attention.png)

---

## Scaled Dot-Product Attention

<img src="../00–assets/Scaled-Dot-Product.png" width="300px"  height="300px" />

In [5]:
import torch, torchtext

SDP = torchtext.nn.ScaledDotProduct(dropout=0.1)

# let's say we have batch size of 21, 256 tokens, and 3 heads
q = torch.randn(21, 256, 3)
k = torch.randn(21, 256, 3)
v = torch.randn(21, 256, 3)

attn_output, attn_weights = SDP(q, k, v)

In [6]:
print(attn_output.shape, attn_weights.shape)

torch.Size([21, 256, 3]) torch.Size([256, 21, 21])


In [15]:
import torch.nn.functional as F

# Optionally use the context manager to ensure one of the fused kernels is run
query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cpu")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cpu")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cpu")
# with torch.backends.cuda.sdp_kernel(enable_math=False):
contextual_embedding = F.scaled_dot_product_attention(query, key, value)

In [23]:
contextual_embedding.shape, contextual_embedding.dtype, contextual_embedding.device

(torch.Size([32, 8, 128, 64]), torch.float16, device(type='cpu'))

In [17]:
logit_attention = query @ key.transpose(-2, -1) / (64 ** 0.5)

attention = F.softmax(logit_attention, dim=-1)

attention.shape

torch.Size([32, 8, 128, 128])

In [18]:
contextual_embedding_custom = attention @ value

In [22]:
contextual_embedding_custom.shape, contextual_embedding_custom.dtype

(torch.Size([32, 8, 128, 64]), torch.float16)

In [25]:
torch.equal(contextual_embedding_custom,contextual_embedding)

False

In [26]:
contextual_embedding_custom == contextual_embedding

tensor([[[[ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          ...,
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True]],

         [[False, False,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          ...,
          [ True,  True,  True,  ...,  True,  True,  True],
          [False,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True]],

         [[ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          ...,
          [ True,  True,  True,  ...,  True,  True,