In [1]:
import torch

In [2]:
input_embeddings = torch.tensor([
    [0.43, 0.15, 0.89], # Your    -> x_0
    [0.55, 0.87, 0.66], # journey -> x_1
    [0.57, 0.85, 0.64], # starts  -> x_2
    [0.22, 0.58, 0.33], # with    -> x_3
    [0.77, 0.25, 0.10], # one     -> x_4
    [0.05, 0.80, 0.55], # step    -> x_5
])

## define the parameters: W_query, W_key, W_value

In [4]:
d_in = input_embeddings.shape[1]
d_out = 2
W_query = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)

## project the input embeddings with W_query, W_key and W_value

In [13]:
queries = input_embeddings @ W_query
display(queries.shape)
display(queries)

torch.Size([6, 2])

tensor([[-1.3912,  0.7242],
        [-1.0950, -0.1293],
        [-1.0720, -0.1688],
        [-0.5321, -0.0841],
        [-0.3561, -0.8315],
        [-0.7993,  0.2957]])

In [12]:
keys = input_embeddings @ W_key
display(keys.shape)
display(keys)

torch.Size([6, 2])

tensor([[ 1.7963, -0.3983],
        [ 0.4943,  1.1740],
        [ 0.4606,  1.1562],
        [ 0.1204,  0.8522],
        [-0.2745,  0.5091],
        [ 0.4742,  1.0558]])

In [14]:
values = input_embeddings @ W_value
display(values.shape)
display(values)

torch.Size([6, 2])

tensor([[-2.6303,  0.0225],
        [-2.5330,  0.5452],
        [-2.5124,  0.5169],
        [-1.2311,  0.4201],
        [-1.4368, -0.1399],
        [-1.5307,  0.7086]])

## calculate single attention score, attention weight and context vector

In [15]:
# how much the third word should attend to itself
query_2 = queries[2]
key_2 = keys[2]
attn_score_22 = query_2 @ key_2
attn_score_22

tensor(-0.6889)

In [16]:
# how much the third word should attend to all other words
attn_scores_2 = query_2 @ keys.T
attn_scores_2

tensor([-1.8584, -0.7280, -0.6889, -0.2729,  0.2083, -0.6865])

In [20]:
# "scaled" from "scaled dot-product attention" comes from the square root of the dimension of the key vectors
# this is to prevent the dot product from getting too large, which would result in very small gradients
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
attn_weights_2

tensor([0.0659, 0.1465, 0.1506, 0.2021, 0.2840, 0.1509])

In [22]:
context_vec_2 = attn_weights_2 @ values
context_vec_2

tensor([-1.8106,  0.3113])

## calculate whole context matrix at once 

In [28]:
class SelfAttn(torch.nn.Module):
    def __init__(self, d_in: int, d_out: int):
        super().__init__()
        # VERSION 1) BASIC IMPLEMENTATION
        # self.W_query = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
        # self.W_key = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
        # self.W_value = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
        
        # VERSION 2) OPTIMISED IMPLEMENTATION (SMARTER WEIGHTS INITIALIZATION)
        self.W_query = torch.nn.Linear(d_in, d_out, bias=False)
        self.W_key = torch.nn.Linear(d_in, d_out, bias=False)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=False)
        
    def forward(self, x):
        # VERSION 1)
        # queries = x @ self.W_query
        # keys = x @ self.W_key
        # values = x @ self.W_value

        # VERSION 2)
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vecs = attn_weights @ values
        return context_vecs

In [29]:
torch.manual_seed(789)
attn_layer = SelfAttn(d_in, d_out)
attn_layer(input_embeddings)

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)