In [1]:
import torch
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)

# Implement attention class version 1:

In [2]:
from torch import nn

In [3]:
class SelfAttentionV1:
  def __init__(self, ini_dim, reduced_dim):
    super().__init__()
    self.W_Q = nn.Parameter(torch.rand(ini_dim, reduced_dim))
    self.W_K = nn.Parameter(torch.rand(ini_dim, reduced_dim))
    self.W_V = nn.Parameter(torch.rand(ini_dim, reduced_dim))

  def forward(self, x):
    query_vectors_matrix = x @ self.W_Q
    key_vectors_matrix = x @ self.W_K
    value_vectors_matrix = x @ self.W_V

    attention_score_matrix = query_vectors_matrix @ key_vectors_matrix.T
    attention_weight_matrix = torch.softmax(attention_score_matrix / (key_vectors_matrix.shape[-1] ** 0.5), dim = -1)

    context_vector_matrix = attention_weight_matrix @ value_vectors_matrix

    return context_vector_matrix


In [4]:
torch.manual_seed(123)
self_attention_V1_test = SelfAttentionV1(3, 2)
context_vector_matrix = self_attention_V1_test.forward(inputs)
context_vector_matrix

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

#Implement version 2:

In [5]:
class SelfAttentionV2:
  def __init__(self, ini_dim, reduced_dim, enable_bias = False):
    super().__init__()
    self.W_Q = nn.Linear(ini_dim, reduced_dim, bias = enable_bias)
    self.W_K = nn.Linear(ini_dim, reduced_dim, bias = enable_bias)
    self.W_V = nn.Linear(ini_dim, reduced_dim, bias = enable_bias)

  def forward(self, x):
    keys = self.W_K(x)
    values = self.W_V(x)
    queries = self.W_Q(x)

    a_score = queries @ keys.T
    a_weight = torch.softmax(a_score / keys.shape[-1]**0.5, dim = -1)

    context = a_weight @ values
    return context



In [6]:
torch.manual_seed(789)
attention_v2 = SelfAttentionV2(3, 2)
context = attention_v2.forward(inputs)
context

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)