In [1]:
import torch
import torch.nn as nn

In [2]:
inputs = torch.tensor([
    [0.72, 0.45, 0.31],  # Dream
    [0.75, 0.20, 0.55],  # big
    [0.30, 0.80, 0.40],  # and
    [0.85, 0.35, 0.60],  # work
    [0.55, 0.15, 0.75],  # for
    [0.25, 0.20, 0.85]   # it
])

words = ['Dream', 'big', 'and', 'work', 'for', 'it']


In [3]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [4]:
torch.manual_seed(123)

W_query = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [5]:
query_2 = x_2 @ W_query
key_2   = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2)

tensor([0.3131, 1.0017])


In [6]:
keys    = inputs @ W_key
values  = inputs @ W_value
queries = inputs @ W_query

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
print("queries.shape:", queries.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])
queries.shape: torch.Size([6, 2])


In [7]:
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(0.6990)


In [8]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([0.7021, 0.6990, 0.9867, 0.8707, 0.7880, 0.8624])


In [9]:
attn_scores = queries @ keys.T
print(attn_scores)

tensor([[0.6807, 0.6795, 0.9526, 0.8454, 0.7654, 0.8359],
        [0.7021, 0.6990, 0.9867, 0.8707, 0.7880, 0.8624],
        [0.7350, 0.7315, 1.0337, 0.9113, 0.8248, 0.9029],
        [0.8436, 0.8402, 1.1848, 1.0464, 0.9471, 1.0361],
        [0.7080, 0.7025, 1.0003, 0.8764, 0.7929, 0.8699],
        [0.6680, 0.6606, 0.9486, 0.8254, 0.7465, 0.8210]])


In [10]:
d_k = keys.shape[-1]

attn_weights_2 = torch.softmax(
    attn_scores_2 / d_k**0.5,
    dim=-1
)

print(attn_weights_2)
print(d_k)

tensor([0.1531, 0.1528, 0.1873, 0.1725, 0.1627, 0.1715])
2


In [11]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.2274, 0.7362])


In [12]:
attn_weights = torch.softmax(
    attn_scores / d_k**0.5,
    dim=-1
)

context_vec = attn_weights @ values

print(context_vec)

tensor([[0.2273, 0.7361],
        [0.2274, 0.7362],
        [0.2276, 0.7363],
        [0.2280, 0.7368],
        [0.2275, 0.7362],
        [0.2275, 0.7360]])


In [13]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys    = x @ self.W_key
        queries = x @ self.W_query
        values  = x @ self.W_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5,
            dim=-1
        )

        context_vec = attn_weights @ values
        return context_vec


In [14]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys    = self.W_key(x)
        queries = self.W_query(x)
        values  = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5,
            dim=-1
        )

        context_vec = attn_weights @ values
        return context_vec


In [15]:
sa = SelfAttention_v2(d_in=3, d_out=2)
output = sa(inputs)
print(output)

tensor([[0.5269, 0.2695],
        [0.5274, 0.2714],
        [0.5269, 0.2714],
        [0.5278, 0.2726],
        [0.5277, 0.2733],
        [0.5277, 0.2743]], grad_fn=<MmBackward0>)
