### Implementation of this article

https://magazine.sebastianraschka.com/p/understanding-and-coding-self-attention

In [27]:
import torch
from torch import tensor
import torch.nn as nn
import torch.nn.functional as F

In [3]:
a = tensor([1, 2, 3]); a

tensor([1, 2, 3])

In [4]:
sentence = 'Life is short, eat dessert first'

In [8]:
dc = {s:i for i, s in enumerate(sorted(sentence.replace(',', '').split()))}
dc

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}

In [10]:
sentence_int = tensor(
    [dc[s] for s in sentence.replace(',', '').split()]
)
sentence_int

tensor([0, 4, 5, 2, 1, 3])

In [11]:
vocab_size = 50_000

torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()
embedded_sentence.shape

torch.Size([6, 3])

In [14]:
embedded_sentence

tensor([[ 0.3374, -0.1778, -0.3035],
        [ 0.1794,  1.8951,  0.4954],
        [ 0.2692, -0.0770, -1.0205],
        [-0.2196, -0.3792,  0.7671],
        [-0.5880,  0.3486,  0.6603],
        [-1.1925,  0.6984, -1.4097]])

In [15]:
d = embedded_sentence.shape[1]
d_q, d_k, d_v = 2, 2, 4

# this makes it so that tensors have requires_grad True
W_query = nn.Parameter(torch.rand(d,d_q))
W_key = nn.Parameter(torch.rand(d,d_k))
W_value = nn.Parameter(torch.rand(d,d_v))

In [16]:
W_query.requires_grad

True

In [17]:
a = torch.rand(2, 3)
a.requires_grad

False

In [19]:
x_2 = embedded_sentence[1]
# x_2.shape

In [21]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
value_2.shape

torch.Size([4])

In [22]:
keys = embedded_sentence @ W_key
values = embedded_sentence @ W_value
keys.shape, values.shape

(torch.Size([6, 2]), torch.Size([6, 4]))

In [23]:
query_2.shape

torch.Size([2])

In [24]:
# attention weights unnormalized

omega_2 = query_2 @ keys.T
omega_2

tensor([-0.0996,  2.9454, -0.6374, -0.0801,  0.3291, -1.5970],
       grad_fn=<MvBackward0>)

In [25]:
omega_2.shape

torch.Size([6])

In [34]:
attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
attention_weights_2.shape

torch.Size([6])

In [53]:
context_vector_2 = attention_weights_2 @ values
context_vector_2

tensor([0.3511, 0.2289, 0.8308, 0.7071], grad_fn=<MvBackward0>)

In [97]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        # print(d_out_v)
        self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))
    
    def forward(self, x):
        x_q = x @ self.W_query # [B, d_out_kq]
        x_k = x @ self.W_key # [B, d_out_kq]
        x_v = x @ self.W_value # [B, d_out_v]
        u_attention_weights = x_q @ x_k.T
        attention_weights = F.softmax(u_attention_weights / self.d_out_kq ** 0.5, dim=-1) # [B, B]
        context_vector = attention_weights @ x_v
        return context_vector

In [55]:
torch.manual_seed(123)
self_attention = SelfAttention(d_in=3, d_out_kq=2, d_out_v=4)

In [56]:
embedded_sentence.shape

torch.Size([6, 3])

In [57]:
self_attention(embedded_sentence)

tensor([[-0.2921, -0.0851,  0.1581,  0.0157],
        [ 0.3511,  0.2289,  0.8308,  0.7071],
        [-0.3952, -0.1419,  0.0867, -0.0524],
        [-0.2784, -0.0533,  0.1878,  0.0059],
        [-0.1756, -0.0231,  0.2559,  0.1114],
        [-0.5338, -0.2074,  0.0257, -0.1290]], grad_fn=<MmBackward0>)

In [58]:
a = (3, 2, 4)
self_attention = SelfAttention(*a)

In [59]:
self_attention(embedded_sentence)

tensor([[-0.2921, -0.0851,  0.1581,  0.0157],
        [ 0.3511,  0.2289,  0.8308,  0.7071],
        [-0.3952, -0.1419,  0.0867, -0.0524],
        [-0.2784, -0.0533,  0.1878,  0.0059],
        [-0.1756, -0.0231,  0.2559,  0.1114],
        [-0.5338, -0.2074,  0.0257, -0.1290]], grad_fn=<MmBackward0>)

In [100]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
        super().__init__()
        print(d_in, d_out_kq, d_out_v, num_heads)
        self.heads = nn.ModuleList(
            [SelfAttention(d_in, d_out_kq, d_out_v) for _ in range(num_heads)]
        )
    
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [101]:
mh_attention = MultiHeadAttentionWrapper(3, 2, 1, 3)

3 2 1 3


In [103]:
mh_attention(embedded_sentence).shape

torch.Size([6, 3])

In [99]:
sa = SelfAttention(3, 2, 1)
print(sa(embedded_sentence).shape)

torch.Size([6, 1])


In [104]:
class CrossAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        # print(d_out_v)
        self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))
    
    def forward(self, x1, x2):  # x_2 is new
        x_q = x1 @ self.W_query # [B, d_out_kq]
        x_k = x2 @ self.W_key # [B, d_out_kq]
        x_v = x2 @ self.W_value # [B, d_out_v]
        u_attention_weights = x_q @ x_k.T
        attention_weights = F.softmax(u_attention_weights / self.d_out_kq ** 0.5, dim=-1) # [B, B]
        context_vector = attention_weights @ x_v
        return context_vector

In [105]:
block_size = embedded_sentence.shape[0]; block_size

6

In [106]:
mask_simple = torch.tril(torch.ones(block_size, block_size))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])