https://machinelearningmastery.com/the-attention-mechanism-from-scratch/

In [3]:
from numpy import array
from numpy import random
from numpy import dot
from scipy.special import softmax

random.seed(42)

# encoder representations of four different words
word_1 = array([1, 0, 0])
word_2 = array([0, 1, 0])
word_3 = array([1, 1, 0])
word_4 = array([0, 0, 1])

# stacking the word embeddings into a single array
words = array(
    [
        word_1,
        word_2,
        word_3,
        word_4,
    ]
)

# generating the weight matrices
W_Q = random.randint(3, size=(3, 3))
W_K = random.randint(3, size=(3, 3))
W_V = random.randint(3, size=(3, 3))

# generating the queries, keys and values
Q = words @ W_Q
K = words @ W_K
V = words @ W_V

# scoring the query vectors against all key vectors
scores = Q @ K.transpose()

# computing the weights by a softmax operation
weights = softmax(scores / K.shape[1] ** 0.5, axis=1)

# computing the attention by a weighted sum of the value vectors
attention = weights @ V

print('attention')
print(attention)
print('weights')
print(weights)

attention
[[0.98522025 1.74174051 0.75652026]
 [0.90965265 1.40965265 0.5       ]
 [0.99851226 1.75849334 0.75998108]
 [0.99560386 1.90407309 0.90846923]]
weights
[[2.36089863e-01 7.38987555e-03 7.49130386e-01 7.38987555e-03]
 [4.54826323e-01 4.51736775e-02 4.54826323e-01 4.51736775e-02]
 [2.39275049e-01 7.43870015e-04 7.59237211e-01 7.43870015e-04]
 [8.99501754e-02 2.81554063e-03 9.05653685e-01 1.58059922e-03]]


In [2]:
weights.shape

(4, 4)

In [16]:
scores

array([[ 8,  2, 10,  2],
       [ 4,  0,  4,  0],
       [12,  2, 14,  2],
       [10,  4, 14,  3]])

In [4]:
words @ words.T

array([[1, 0, 1, 0],
       [0, 1, 1, 0],
       [1, 1, 2, 0],
       [0, 0, 0, 1]])

In [15]:
words @ W_Q @ W_K.T @ words.T

array([[ 8,  2, 10,  2],
       [ 4,  0,  4,  0],
       [12,  2, 14,  2],
       [10,  4, 14,  3]])

https://github.com/sooftware/attentions/blob/master/attentions.py

In [5]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from typing import Optional, Tuple

In [15]:
class ScaledDotProductAttention(nn.Module):
    """
    https://paperswithcode.com/method/scaled
    Scaled Dot-Product Attention proposed in "Attention Is All You Need"
    Compute the dot products of the query with all keys, divide each by sqrt(dim),
    and apply a softmax function to obtain the weights on the values

    Args: dim, mask
        dim (int): dimention of attention
        mask (torch.Tensor): tensor containing indices to be masked

    Inputs: query, key, value, mask
        - **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
        - **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
        - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
        - **mask** (-): tensor containing indices to be masked

    Returns: context, attn
        - **context**: tensor containing the context vector from attention mechanism.
        - **attn**: tensor containing the attention (alignment) from the encoder outputs.
    """

    def __init__(self, dim: int):
        super(ScaledDotProductAttention, self).__init__()
        self.sqrt_dim = np.sqrt(dim)

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:
        score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim

        if mask is not None:
            score.masked_fill_(mask.view(score.size()), -float("Inf"))

        attn = F.softmax(score, -1)
        context = torch.bmm(attn, value)
        return context, attn

In [34]:
input = torch.randn(10, 3).unsqueeze(dim=0)
mat2 = torch.randn(3, 4).unsqueeze(dim=0)
# res = torch.bmm(input, mat2)
input.shape, mat2.shape

(torch.Size([1, 10, 3]), torch.Size([1, 3, 4]))

In [10]:
scaled_dot_attn = ScaledDotProductAttention(dim=2)
scaled_dot_attn.forward(
    query=torch.tensor(Q, dtype=torch.float32).unsqueeze(dim=0),
    key=torch.tensor(K, dtype=torch.float32).unsqueeze(dim=0),
    value=torch.tensor(V, dtype=torch.float32).unsqueeze(dim=0),
)

(tensor([[[0.9944, 1.7971, 0.8027],
          [0.9442, 1.4442, 0.5000],
          [0.9997, 1.8040, 0.8043],
          [0.9988, 1.9427, 0.9439]]]),
 tensor([[[1.9448e-01, 2.7946e-03, 7.9993e-01, 2.7946e-03],
          [4.7210e-01, 2.7904e-02, 4.7210e-01, 2.7904e-02],
          [1.9551e-01, 1.6605e-04, 8.0416e-01, 1.6605e-04],
          [5.5740e-02, 8.0097e-04, 9.4306e-01, 3.9493e-04]]]))

In [14]:
class DotProductAttention(nn.Module):
    """
    Compute the dot products of the query with all values and apply a softmax function to obtain the weights on the values
    """

    def __init__(self,):
        super(DotProductAttention, self).__init__()

    def forward(self, query: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
        batch_size, hidden_dim, input_size = query.size(0), query.size(2), value.size(1)
        print(batch_size, hidden_dim, input_size)
        score = torch.bmm(query, value.transpose(1, 2))
        attn = F.softmax(score.view(-1, input_size), dim=1).view(
            batch_size, -1, input_size
        )
        context = torch.bmm(attn, value)

        return context, attn

In [13]:
dot_attn = DotProductAttention()
dot_attn.forward(
    query=torch.tensor(Q, dtype=torch.float32).unsqueeze(dim=0),
    value=torch.tensor(V, dtype=torch.float32).unsqueeze(dim=0),
)

1 3 4


(tensor([[[0.8808, 1.7616, 0.8808],
          [0.8808, 1.3808, 0.5000],
          [0.9820, 1.8628, 0.8808],
          [0.9526, 1.9051, 0.9526]]]),
 tensor([[[0.1050, 0.1050, 0.7758, 0.0142],
          [0.4404, 0.0596, 0.4404, 0.0596],
          [0.1171, 0.0158, 0.8650, 0.0021],
          [0.0452, 0.0452, 0.9074, 0.0022]]]))

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html

### "attention" only on embeddings

In [18]:
import torch

# hidden states -> torch.Size([2, 1024, 1024])
# hidden states -> torch.Size([2, 1024, 1024])

a = torch.randn([2, 1024, 1024])
b = torch.randn([2, 1024, 1024])
scaled_dot_attn = ScaledDotProductAttention(dim=2)
scaled_dot_attn.forward(
    query=a,
    key=a,
    value=a,
)[0].shape

torch.Size([2, 1024, 1024])

In [19]:
torch.nn.functional.scaled_dot_product_attention(
	query=a,
	key=a,
	value=a,
).shape

torch.Size([2, 1024, 1024])