In [1]:
import torch

inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your     (x^1)
        [0.55, 0.87, 0.66],  # journey  (x^2)
        [0.57, 0.85, 0.64],  # starts   (x^3)
        [0.22, 0.58, 0.33],  # with     (x^4)
        [0.77, 0.25, 0.10],  # one      (x^5)
        [0.05, 0.80, 0.55],
    ]  # step     (x^6)
)

In [15]:
# Use query ,key and value vectors instead of the input vectors directly to compute the attention scores.
# Everything else looks the same as the previous example with simple attention

import torch.nn as nn


class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        # initialize the weight matrices
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        # compute query, key and value vectors
        keys = x @ self.W_key  # multiply 6X3 with 3x2 matrices
        queries = x @ self.W_query
        values = x @ self.W_value

        # compute attention scores
        attention_scores = queries @ keys.T

        # compute scaled dot product attention
        attention_weights = torch.softmax(
            attention_scores / keys.shape[-1] ** 0.5, dim=-1
        )

        # compute context vectors
        context_vectors = attention_weights @ values
        return context_vectors


# create an instance of the SelfAttention_v1 class
d_in = 3
d_out = 2
attention = SelfAttention_v1(d_in, d_out)
context_vectors = attention(inputs)
print(context_vectors)

tensor([[0.6472, 1.0750],
        [0.6661, 1.1112],
        [0.6652, 1.1094],
        [0.6479, 1.0759],
        [0.6359, 1.0530],
        [0.6585, 1.0963]], grad_fn=<MmBackward0>)


In [16]:
import torch.nn as nn


# using nn.Linear() as way to define the weight matrices
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        # initialize the weight matrices
        self.W_query = nn.Linear(d_in, d_out, qkv_bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, qkv_bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, qkv_bias=qkv_bias)

    def forward(self, x):
        # compute query, key and value vectors
        keys = x @ self.W_key  # multiply 6X3 with 3x2 matrices
        queries = x @ self.W_query
        values = x @ self.W_value

        # compute attention scores
        attention_scores = queries @ keys.T

        # compute scaled dot product attention
        attention_weights = torch.softmax(
            attention_scores / keys.shape[-1] ** 0.5, dim=-1
        )

        # compute context vectors
        context_vectors = attention_weights @ values
        return context_vectors


# create an instance of the SelfAttention_v1 class
d_in = 3
d_out = 2
attention = SelfAttention_v1(d_in, d_out)
context_vectors = attention(inputs)
print(context_vectors)

tensor([[0.7829, 0.8520],
        [0.8152, 0.8843],
        [0.8142, 0.8831],
        [0.7841, 0.8552],
        [0.7744, 0.8429],
        [0.7970, 0.8686]], grad_fn=<MmBackward0>)
