<a href="https://colab.research.google.com/github/mrdbourke/learn-transformers/blob/main/attention_mechanism.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# [WIP] Attention mechanism 

**Focus:** Build intuition to build up to replicating the original Transformer paper.

### This notebook

* Recreate self-attention as per Transformer paper
* Recreate multi-head attention as per Transformer paper

### Later
* Recreate Transformer model architecture
* Train on a simple example 

Sources:

* Transformer paper: https://arxiv.org/abs/1706.03762
* The annotated transformer: http://nlp.seas.harvard.edu/2018/04/01/attention.html 
* https://lilianweng.github.io/posts/2018-06-24-attention/#self-attention
* https://jaykmody.com/blog/attention-intuition/
* Compact transformers - https://medium.com/pytorch/training-compact-transformers-from-scratch-in-30-minutes-with-pytorch-ff5c21668ed5 

In [44]:
import torch
from torch import nn

import torch.nn.functional as F

## Simple scaled-dot-product-attention (no mask) 

TK:
- Can I replicate this in Google Sheets?
- Turn this function into the same format as the transformer paper (e.g. figure 2)

In [None]:
def attention(query, key, value):  
  d_k = torch.tensor(query.shape[-1]) # torch.sqrt needs a tensor
  q_k = F.softmax(torch.matmul(query, key.T)/torch.sqrt(d_k), dim=-1)
  return torch.matmul(q_k, value.T)

In [61]:
torch.manual_seed(42)
x = torch.randn(10, 10)

attention(query=x, key=x, value=x)

tensor([[ 1.8933, -0.3887, -0.7377, -0.2461, -1.5391,  0.0764, -0.0861, -1.0471,
          0.0089,  0.1876],
        [ 0.4219, -0.7769,  0.8384,  0.3348,  0.0515,  0.2086,  1.2753,  0.2287,
         -0.2691, -0.0226],
        [ 0.4126, -0.0101,  0.7777, -0.7219, -0.8859, -0.6319,  0.0446, -0.0064,
         -0.4851, -0.0692],
        [-0.5196, -0.3668,  1.1495, -0.2994, -0.7661,  0.0208,  1.2501,  0.0915,
         -0.0215,  0.3415],
        [ 0.2543, -0.6703,  1.2743, -0.2598, -1.0734, -0.5371,  1.3273,  0.0430,
          0.4069,  0.1190],
        [-0.1206,  0.2336,  0.8385,  0.4162,  0.3648, -0.2610,  0.5973, -0.0789,
          0.2875,  0.0406],
        [-0.0315,  1.2026,  0.6622,  0.0972, -1.1233, -1.1447,  1.7388,  0.3997,
          0.6841, -0.7225],
        [-0.6842, -0.1752,  1.0953, -0.3012, -0.6531, -0.2085,  0.7526, -0.3217,
         -0.1753,  0.0630],
        [-0.0737, -0.3024,  0.6222,  0.0706, -0.5312, -0.0393,  0.9918, -0.3184,
         -0.1715,  0.0771],
        [ 0.2698,  

In [42]:
x.size(-1)

10

## TODO: Simple scaled-dot-product-attention (with mask) 

UPTOHERE:
* see: https://jaykmody.com/blog/gpt-from-scratch/#causal 
* And see: https://github.com/facebookresearch/xformers/blob/main/xformers/components/attention/attention_mask.py 
  * Default to causal mask: https://github.com/facebookresearch/xformers/blob/97daac83cece6d3d77bb09479777ad6e8ef7dfed/xformers/components/attention/attention_mask.py#LL74C16-L74C16 (`make_causal()`) 

In [109]:
# Make causal mask, see: https://jaykmody.com/blog/gpt-from-scratch/#causal
additive_mask = torch.triu(
    # torch.ones(x.shape[0], x.shape[0]) * float("inf"),
    torch.ones(x.shape[0], x.shape[0]) * -1e10, # can use -1e10 to prevent nans
    diagonal=1
)

additive_mask

tensor([[ 0.0000e+00, -1.0000e+10, -1.0000e+10, -1.0000e+10, -1.0000e+10,
         -1.0000e+10, -1.0000e+10, -1.0000e+10, -1.0000e+10, -1.0000e+10],
        [ 0.0000e+00,  0.0000e+00, -1.0000e+10, -1.0000e+10, -1.0000e+10,
         -1.0000e+10, -1.0000e+10, -1.0000e+10, -1.0000e+10, -1.0000e+10],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+10, -1.0000e+10,
         -1.0000e+10, -1.0000e+10, -1.0000e+10, -1.0000e+10, -1.0000e+10],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+10,
         -1.0000e+10, -1.0000e+10, -1.0000e+10, -1.0000e+10, -1.0000e+10],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -1.0000e+10, -1.0000e+10, -1.0000e+10, -1.0000e+10, -1.0000e+10],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -1.0000e+10, -1.0000e+10, -1.0000e+10, -1.0000e+10],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+0

In [110]:
def attention_with_mask(query, key, value, mask=None):  
  d_k = torch.tensor(query.shape[-1]) # torch.sqrt needs a tensor
  q_k = torch.matmul(query, key.T) / torch.sqrt(d_k)
  print(q_k.shape)

  
  print(f"q_k: {q_k}")

  # Apply attention mask
  if mask is not None:
    q_k = q_k + mask

  print(f"q_k with mask: {q_k}")

  # Softmax
  attn = F.softmax(q_k, dim=-1)

  return torch.matmul(attn, value), attn

attention_with_mask(query=x, key=x, value=x, mask=additive_mask)

torch.Size([10, 10])
torch.Size([10, 10])
q_k: tensor([[ 6.0130e+00, -7.9063e-01, -1.6862e+00, -5.4706e-01, -1.2946e+00,
         -1.5651e-01, -9.5597e-01, -1.4247e-01,  4.3972e-01, -5.1831e-01],
        [-7.9063e-01,  2.3093e+00, -5.7266e-01,  6.2206e-01,  3.1750e-01,
         -9.4774e-01,  1.2001e-01,  1.7332e-01,  1.3094e+00, -7.1141e-01],
        [-1.6862e+00, -5.7266e-01,  3.3930e+00,  3.3083e-03,  1.3948e-01,
         -2.0294e-01,  2.6444e+00,  1.1129e+00, -1.9716e-01,  2.9590e-01],
        [-5.4706e-01,  6.2206e-01,  3.3083e-03,  2.3570e+00,  2.2022e+00,
         -6.3796e-01,  1.2451e+00,  7.0473e-01,  8.8856e-01, -9.7713e-01],
        [-1.2946e+00,  3.1750e-01,  1.3948e-01,  2.2022e+00,  4.3364e+00,
          2.9038e-01, -3.1773e-01,  1.8019e+00,  2.3513e-01, -5.8754e-01],
        [-1.5651e-01, -9.4774e-01, -2.0294e-01, -6.3796e-01,  2.9038e-01,
          1.7630e+00, -5.9041e-01, -8.1987e-02, -7.1263e-01,  1.0189e+00],
        [-9.5597e-01,  1.2001e-01,  2.6444e+00,  1.2451e+00

(tensor([[ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784, -1.2345, -0.0431, -1.6047,
          -0.7521,  1.6487],
         [-0.2925, -1.2790, -0.6577, -0.6261, -0.7064,  0.6764,  1.5697, -0.2219,
          -0.5084,  0.4917],
         [-0.7351,  1.0349,  0.7731,  1.6162,  1.2376,  1.2712,  0.6256,  1.2893,
          -0.2397,  0.0589],
         [-0.2166,  0.6004, -1.0463, -0.6979, -0.1510,  1.4382,  0.5008, -0.3120,
           0.1167, -0.4545],
         [-1.3844,  0.9470, -0.9017, -0.6031, -1.1193,  2.0389, -1.0030, -0.4560,
          -0.7730, -0.6367],
         [-0.0906,  0.6622, -0.3702,  0.5163, -0.5373, -0.0254, -0.8781, -0.1037,
          -0.2517,  0.4371],
         [-0.1764,  1.6977, -0.9631,  1.3173,  1.3377,  0.9195,  1.9504,  0.5652,
           0.2647, -0.1750],
         [-0.8553,  1.1510, -0.3746,  0.3669,  0.0303,  0.8280,  0.4054, -0.3187,
          -1.3058, -0.4890],
         [-0.1844,  0.3723, -0.7960, -0.2162,  0.2121,  0.7338,  0.8952, -0.2755,
          -0.4961,  0.1444],
 

torch.Size([10, 10])
torch.Size([10, 10])
tensor([[ 6.0130e+00, -7.9063e-01, -1.6862e+00, -5.4706e-01, -1.2946e+00,
         -1.5651e-01, -9.5597e-01, -1.4247e-01,  4.3972e-01, -5.1831e-01],
        [-7.9063e-01,  2.3093e+00, -5.7266e-01,  6.2206e-01,  3.1750e-01,
         -9.4774e-01,  1.2001e-01,  1.7332e-01,  1.3094e+00, -7.1141e-01],
        [-1.6862e+00, -5.7266e-01,  3.3930e+00,  3.3083e-03,  1.3948e-01,
         -2.0294e-01,  2.6444e+00,  1.1129e+00, -1.9716e-01,  2.9590e-01],
        [-5.4706e-01,  6.2206e-01,  3.3083e-03,  2.3570e+00,  2.2022e+00,
         -6.3796e-01,  1.2451e+00,  7.0473e-01,  8.8856e-01, -9.7713e-01],
        [-1.2946e+00,  3.1750e-01,  1.3948e-01,  2.2022e+00,  4.3364e+00,
          2.9038e-01, -3.1773e-01,  1.8019e+00,  2.3513e-01, -5.8754e-01],
        [-1.5651e-01, -9.4774e-01, -2.0294e-01, -6.3796e-01,  2.9038e-01,
          1.7630e+00, -5.9041e-01, -8.1987e-02, -7.1263e-01,  1.0189e+00],
        [-9.5597e-01,  1.2001e-01,  2.6444e+00,  1.2451e+00, -3.

RuntimeError: ignored

## TODO: Why scaled?

TL;DR softmax can get out of hand with large values

In [52]:
small_values = torch.tensor([1, 2, 3], dtype=torch.float32) # need dtype otherwise error
big_values = small_values * 10 
huge_values = big_values * 10

small_softmax = F.softmax(small_values, dim=0)
big_softmax = F.softmax(big_values, dim=0)
huge_softmax = F.softmax(huge_values, dim=0)

print(f"Small values: {small_values}\nSmall softmax: {small_softmax}\n")
print(f"Big values: {big_values}\nBig softmax: {big_softmax}\n")
print(f"Huge values: {huge_values}\nHuge softmax: {huge_softmax}\n")

Small values: tensor([1., 2., 3.])
Small softmax: tensor([0.0900, 0.2447, 0.6652])

Big values: tensor([10., 20., 30.])
Big softmax: tensor([2.0611e-09, 4.5398e-05, 9.9995e-01])

Huge values: tensor([100., 200., 300.])
Huge softmax: tensor([0.0000e+00, 3.7835e-44, 1.0000e+00])



## TODO: Why dot-product?

TL;DR dot product measures how closely two vectors are related 

* big values = close
* negative values = far away
* zero value = same direction? (TK - fix this)

See:
* 3blue1brown on dot product - https://www.youtube.com/watch?v=LyGKycYT2v0

## TODO: Replicate PyTorch's `scaled_dot_product_attention` 

(minus all the fancy optimizations, the library can do those for us)

See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

Also see: https://github.com/facebookresearch/xformers/blob/main/xformers/components/attention/core.py#L297 

In [57]:
# Optionally use the context manager to ensure one of the fused kerenels is run
query = torch.rand(32, 8, 128, 64)
key = torch.rand(32, 8, 128, 64)
value = torch.rand(32, 8, 128, 64)

F.scaled_dot_product_attention(query,key,value)

tensor([[[[0.5064, 0.5570, 0.4984,  ..., 0.4607, 0.5188, 0.5203],
          [0.5054, 0.5610, 0.4979,  ..., 0.4616, 0.5145, 0.5213],
          [0.5082, 0.5588, 0.4970,  ..., 0.4598, 0.5156, 0.5171],
          ...,
          [0.5053, 0.5617, 0.4959,  ..., 0.4591, 0.5203, 0.5212],
          [0.5072, 0.5619, 0.4990,  ..., 0.4625, 0.5145, 0.5201],
          [0.5055, 0.5632, 0.4981,  ..., 0.4622, 0.5161, 0.5197]],

         [[0.4851, 0.5138, 0.5436,  ..., 0.4788, 0.5471, 0.4800],
          [0.4856, 0.5199, 0.5413,  ..., 0.4726, 0.5412, 0.4804],
          [0.4879, 0.5136, 0.5421,  ..., 0.4744, 0.5447, 0.4798],
          ...,
          [0.4885, 0.5149, 0.5462,  ..., 0.4739, 0.5466, 0.4794],
          [0.4873, 0.5157, 0.5460,  ..., 0.4710, 0.5407, 0.4825],
          [0.4868, 0.5140, 0.5460,  ..., 0.4763, 0.5440, 0.4813]],

         [[0.5155, 0.4867, 0.4495,  ..., 0.5153, 0.4683, 0.4834],
          [0.5125, 0.4893, 0.4498,  ..., 0.5186, 0.4674, 0.4860],
          [0.5111, 0.4905, 0.4516,  ..., 0

## TODO: Why multi-head attention?

TL;DR more opportunities to learn (e.g. 8x64 scaled dot-product attention = better than 1*512)

TK:
- One big matrix multiplication better than lots of small ones
- Just perform a `nn.Linear()` then break it up