# Setup

In [12]:
import torch

# Basic attention

Embeddings tensor

In [13]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)


Compute attention by computing the dot products of each input embedding tensor with the query

In [14]:
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
attn_scores_2

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

Real dot product function vs hand-coded version

In [15]:
res = 0.
for idx, element in enumerate(inputs[0]):
    res += element * query[idx]
print(res)
print(torch.dot(inputs[0], query))

tensor(0.9544)
tensor(0.9544)


Now we normalize the attention dot producs (attention scores) we have computed previously so that they sum up to 1. Once normalized, they are called attention weights instead of scores but they're mostly the same thing.

In [16]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


Normally, we use the softmax function to normalize attention scores.
This uses the exponential function before normalizing: larger values become even bigger relative to smaller values, and all values become positive. 

Softmax is more "peaked" and is useful for attention mechanisms but forcing to focus on more relevant parts of the input, greater division important / unimportant, ...

In [17]:
def softmax_naive(x: torch.Tensor) -> torch.Tensor:
    """Compute softmax."""
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


But it's generally better to use the pytorch version of softmax instead of rolling out your own

In [18]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)
