In [66]:
import torch

In [67]:
torch.manual_seed(123)

<torch._C.Generator at 0x7c7630519790>

In [68]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [69]:
d_in = inputs.shape[1]
d_out = 2

In [70]:
W_q = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad = False)
print(W_q.shape)
print(W_q)

torch.Size([3, 2])
Parameter containing:
tensor([[-0.1115,  0.1204],
        [-0.3696, -0.2404],
        [-1.1969,  0.2093]])


In [71]:
W_k = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad = False)
print(W_k.shape)
print(W_k)

torch.Size([3, 2])
Parameter containing:
tensor([[-0.9724, -0.7550],
        [ 0.3239, -0.1085],
        [ 0.2103, -0.3908]])


In [72]:
W_v = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad = False)
print(W_v.shape)
print(W_v)

torch.Size([3, 2])
Parameter containing:
tensor([[ 0.2350,  0.6653],
        [ 0.3528,  0.9728],
        [-0.0386, -0.8861]])


In [73]:
queries = torch.matmul(inputs, W_q)
print(queries.shape)
print(queries)

torch.Size([6, 2])
tensor([[-1.1686,  0.2019],
        [-1.1729, -0.0048],
        [-1.1438, -0.0018],
        [-0.6339, -0.0439],
        [-0.2979,  0.0535],
        [-0.9596, -0.0712]])


In [74]:
keys = torch.matmul(inputs, W_k)
print(keys.shape)
print(keys)

torch.Size([6, 2])
tensor([[-0.1823, -0.6888],
        [-0.1142, -0.7676],
        [-0.1443, -0.7728],
        [ 0.0434, -0.3580],
        [-0.6467, -0.6476],
        [ 0.3262, -0.3395]])


In [75]:
values = torch.matmul(inputs, W_v)
print(values.shape)
print(values)

torch.Size([6, 2])
tensor([[ 0.1196, -0.3566],
        [ 0.4107,  0.6274],
        [ 0.4091,  0.6390],
        [ 0.2436,  0.4182],
        [ 0.2653,  0.6668],
        [ 0.2728,  0.3242]])


In [76]:
attention_scores = torch.matmul(queries, keys.T)
print(attention_scores.shape)
print(attention_scores)

torch.Size([6, 6])
tensor([[ 0.0740, -0.0216,  0.0126, -0.1230,  0.6250, -0.4498],
        [ 0.2172,  0.1376,  0.1730, -0.0491,  0.7616, -0.3809],
        [ 0.2098,  0.1320,  0.1665, -0.0489,  0.7408, -0.3725],
        [ 0.1458,  0.1061,  0.1254, -0.0118,  0.4384, -0.1919],
        [ 0.0175, -0.0071,  0.0017, -0.0321,  0.1580, -0.1153],
        [ 0.2240,  0.1642,  0.1935, -0.0161,  0.6667, -0.2888]])


In [77]:
context_length = attention_scores.shape[0]

In [78]:
_attention_scores_mask = torch.triu(input=torch.ones(context_length, context_length), diagonal=1)
print(_attention_scores_mask)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])


In [79]:
masked_attention_scores = attention_scores.masked_fill_(_attention_scores_mask.bool(), -torch.inf)
print(masked_attention_scores)

tensor([[ 0.0740,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.2172,  0.1376,    -inf,    -inf,    -inf,    -inf],
        [ 0.2098,  0.1320,  0.1665,    -inf,    -inf,    -inf],
        [ 0.1458,  0.1061,  0.1254, -0.0118,    -inf,    -inf],
        [ 0.0175, -0.0071,  0.0017, -0.0321,  0.1580,    -inf],
        [ 0.2240,  0.1642,  0.1935, -0.0161,  0.6667, -0.2888]])


In [80]:
attention_weights = torch.softmax(masked_attention_scores/keys.shape[-1] ** 0.5, dim=-1)
print(attention_weights.shape)
print(attention_weights)

torch.Size([6, 6])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5141, 0.4859, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3429, 0.3245, 0.3326, 0.0000, 0.0000, 0.0000],
        [0.2596, 0.2524, 0.2559, 0.2322, 0.0000, 0.0000],
        [0.1983, 0.1949, 0.1961, 0.1915, 0.2191, 0.0000],
        [0.1711, 0.1640, 0.1675, 0.1444, 0.2340, 0.1191]])


In [84]:
attention_weights = torch.nn.functional.dropout(attention_weights, p=0.6, training=False)
print(attention_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5141, 0.4859, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3429, 0.3245, 0.3326, 0.0000, 0.0000, 0.0000],
        [0.2596, 0.2524, 0.2559, 0.2322, 0.0000, 0.0000],
        [0.1983, 0.1949, 0.1961, 0.1915, 0.2191, 0.0000],
        [0.1711, 0.1640, 0.1675, 0.1444, 0.2340, 0.1191]])


In [81]:
context_vectors = torch.matmul(attention_weights, values)
print(context_vectors.shape)
print(context_vectors)

torch.Size([6, 2])
tensor([[ 0.1196, -0.3566],
        [ 0.2611,  0.1216],
        [ 0.3104,  0.2938],
        [ 0.2959,  0.3264],
        [ 0.2888,  0.4031],
        [ 0.2860,  0.4039]])
