In [1]:
import torch
import torch.nn as nn


In [10]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5, dim=-1
        )
        context_vector = attn_weights @ values
        return context_vector

In [11]:
vocab = {"Your": 0, "journey": 1, "starts": 2, "with": 3, "one": 4, "step": 5}
embedding_dim = 3
tokenized_input = ["Your", "journey", "starts", "with", "one", "step"]

In [12]:
# Random embeddings
torch.manual_seed(789)
embeddings = torch.randn(len(vocab), embedding_dim)

In [13]:
# Convert tokens to embeddings
inputs = torch.stack([embeddings[vocab[word]] for word in tokenized_input])

In [14]:
# Initialize and apply self-attention
sa_v2 = SelfAttention_v2(d_in=3,d_out=2)
print(sa_v2(inputs))

tensor([[-0.0427,  0.0316],
        [-0.0348,  0.0876],
        [-0.0490, -0.0365],
        [-0.0355,  0.1067],
        [-0.0369,  0.1218],
        [-0.0342,  0.0994]], grad_fn=<MmBackward0>)
