In [1]:
import torch

In [2]:
sentence = "Incredible things can be done simply if we are committed to making them happen."

In [4]:
dc = {word: ids for ids, word in enumerate(sorted(sentence.split()))}
dc

{'Incredible': 0,
 'are': 1,
 'be': 2,
 'can': 3,
 'committed': 4,
 'done': 5,
 'happen.': 6,
 'if': 7,
 'making': 8,
 'simply': 9,
 'them': 10,
 'things': 11,
 'to': 12,
 'we': 13}

In [5]:
sentence_tn = torch.tensor([dc[word] for word in sentence.split()])
sentence_tn

tensor([ 0, 11,  3,  2,  5,  9,  7, 13,  1,  4, 12,  8, 10,  6])

In [9]:
torch.manual_seed(1)
embed = torch.nn.Embedding(len(sentence_tn), 32)
embedded_sentence = embed(sentence_tn).detach()
print(embedded_sentence.shape)

torch.Size([14, 32])


In [22]:
# Query, key and value
d = embedded_sentence.shape[1]

dq, dk, dv = 24, 24, 28

Wq = torch.nn.Parameter(torch.rand(dq, d))
Wk = torch.nn.Parameter(torch.rand(dk, d))
Wv = torch.nn.Parameter(torch.rand(dv, d))


In [40]:
Wq.dot(embedded_sentence[1])

RuntimeError: 1D tensors expected, but got 2D and 1D tensors

In [43]:
q5 = Wq.matmul(embedded_sentence[4])
q5, q5.shape

(tensor([ 0.3031,  5.3765,  1.3299,  4.5602,  2.8359,  1.6202,  5.3477,  2.0123,
          4.7996,  5.8271,  4.6035,  1.9264,  3.1095,  4.0134,  4.0661,  3.8443,
          5.0371,  2.1902,  4.0418,  4.5032,  3.6811, -0.6157,  0.2768,  1.8467],
        grad_fn=<MvBackward0>),
 torch.Size([24]))

In [33]:
Wq.matmul(embedded_sentence[1]).shape, Wk.matmul(embedded_sentence[1]).shape, Wv.matmul(embedded_sentence[1]).shape

(torch.Size([24]), torch.Size([24]), torch.Size([28]))

In [30]:
# (24, 32) X (32, 1) = (24, 1)

In [39]:
keys = Wq.matmul(embedded_sentence.T).T
values = Wv.matmul(embedded_sentence.T).T

keys.shape, values.shape

(torch.Size([14, 24]), torch.Size([14, 28]))

In [44]:
# Wij = qi.T * kj --> (14, 32)
w35 = (Wq.matmul(embedded_sentence[2])).dot(keys[4])

In [48]:
w3 = (Wq.matmul(embedded_sentence[2])).matmul(keys.T)
w3.shape, w3

(torch.Size([14]),
 tensor([-335.3750,  -59.7256,  264.1515, -213.3279,  255.0911,  -35.5608,
           82.4628,  -46.8679,   87.6090,  159.8272, -146.9724,   -8.0728,
          -30.5500,  270.8577], grad_fn=<SqueezeBackward4>))

In [52]:
def softmax_transform(w, dk):
    import torch.nn.functional as F
    attention_weights = F.softmax(w/dk**0.5, dim=0)
    return attention_weights

In [69]:
z3 = softmax_transform(w3, dk).matmul(values)
z3

tensor([4.0581, 2.6446, 3.5468, 5.7045, 1.8959, 2.6726, 6.2872, 5.0275, 5.1595,
        5.0305, 6.2471, 4.1249, 3.6407, 5.3727, 5.0096, 4.7481, 2.5611, 1.9688,
        4.3878, 3.3295, 3.4493, 5.9288, 1.9613, 5.0375, 3.5744, 2.4254, 2.5386,
        4.6398], grad_fn=<SqueezeBackward4>)

In [62]:
softmax_transform(w3, dk).shape, values.shape

(torch.Size([14]), torch.Size([14, 28]))

In [67]:
softmax_transform(w3, dk).dot(values[:,0])

tensor(4.0581, grad_fn=<DotBackward0>)

In [68]:
softmax_transform(w3, dk).dot(values[:,-1])

tensor(4.6398, grad_fn=<DotBackward0>)

In [83]:
head = 3
multihead_Wq = torch.nn.Parameter(torch.rand(head, dq, d))
multihead_Wk = torch.nn.Parameter(torch.rand(head, dk, d))
multihead_Wv = torch.nn.Parameter(torch.rand(head, dv, d))


# mulithead_Wq.shape, mulithead_Wk.shape, mulithead_Wv.shape


In [73]:
embedded_sentence[2]

tensor([ 0.6614,  1.1899,  0.8165, -0.9135,  1.3851, -0.8138, -0.9276,  1.1120,
         1.3352,  0.6043, -0.1034, -0.1512, -2.1021, -0.6200, -1.4782, -1.1334,
         0.8738, -0.5603,  1.2858,  0.8168,  0.2053,  0.3051,  0.5357, -0.4312,
         2.5581, -0.2334, -0.0135,  1.8606, -1.9804,  1.7986,  0.1018,  0.3400])

In [77]:
mhk3 = mulithead_Wk.matmul(embedded_sentence[2])
mhv3 = mulithead_Wv.matmul(embedded_sentence[2])
mhk3.shape, mhv3.shape

(torch.Size([3, 24]), torch.Size([3, 28]))

In [79]:
stacked_inputs = embedded_sentence.T.repeat(3, 1, 1)
print(stacked_inputs.shape)

torch.Size([3, 32, 14])


In [85]:
multihead_keys = torch.bmm(multihead_Wk, stacked_inputs)
multihead_values = torch.bmm(multihead_Wv, stacked_inputs)

In [86]:
multihead_keys.shape, multihead_values.shape

(torch.Size([3, 24, 14]), torch.Size([3, 28, 14]))

In [87]:
multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_values = multihead_values.permute(0, 2, 1)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 14, 24])
multihead_values.shape: torch.Size([3, 14, 28])


In [5]:
import torch
text = "learning pytorch is fun, if you understand it."
dc = {s:i for i,s 
      in enumerate(sorted(text.replace(',', '').split()))}
ec = {s:i for i,s in dc.items()}

print(dc, ec)

{'fun': 0, 'if': 1, 'is': 2, 'it.': 3, 'learning': 4, 'pytorch': 5, 'understand': 6, 'you': 7} {0: 'fun', 1: 'if', 2: 'is', 3: 'it.', 4: 'learning', 5: 'pytorch', 6: 'understand', 7: 'you'}


In [8]:
sentence_int = torch.tensor(
    [dc[s] for s in text.replace(',', '').split()]
)
print(sentence_int.shape)

torch.Size([8])


In [9]:
vocab_size = 50_000
embed = torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()
print(embedded_sentence.shape)

torch.Size([8, 3])


In [10]:
d = embedded_sentence.shape[1]

d_q, d_k, d_v = 2, 2, 4

W_query = torch.nn.Parameter(torch.rand(d, d_q))
W_key = torch.nn.Parameter(torch.rand(d, d_k))
W_value = torch.nn.Parameter(torch.rand(d, d_v))

In [19]:
word_1 = embedded_sentence[1]

query_1 = word_1 @ W_query
key_1 = word_1 @ W_key
value_1 = word_1 @ W_value

print(word_1.shape, value_1.shape, W_value.shape)

torch.Size([3]) torch.Size([4]) torch.Size([3, 4])


In [21]:
keys = embedded_sentence @ W_key
values = embedded_sentence @ W_value

In [50]:
class SelfAttention(torch.nn.Module):
    def __init__(self, d_in, d_out_kq, d_v):
        super(SelfAttention, self).__init__()
        
        self.K = torch.nn.Parameter(torch.rand(d_in, d_out_kq))
        self.Q = torch.nn.Parameter(torch.rand(d_in, d_out_kq))
        self.V = torch.nn.Parameter(torch.rand(d_in, d_v))
        self.d_out_kq = d_out_kq

    def forward(self, X):
        keys = X @ self.K
        query = X @ self.Q
        values = X @ self.V

        print(keys.shape, query.shape)
        unnorm = query @ keys.T

        attention = torch.nn.functional.softmax(unnorm/ self.d_out_kq**0.5, dim=-1)

        context_vector = attention @ values

        return context_vector


class Multihead(torch.nn.Module):
    def __init__(self, d_in, d_out_kq, d_v, num_heads=3):
        super().__init__()
        self.heads = torch.nn.ModuleList([SelfAttention(d_in, d_out_kq, d_v) for _ in range(num_heads)])

    def forward(self, X):
        out = torch.concat([head(X) for head in self.heads], dim=-1)
        return out




attn = Multihead(3,2,4)
X = embedded_sentence

out = attn.forward(X)
out.shape


torch.Size([8, 2]) torch.Size([8, 2])
torch.Size([8, 2]) torch.Size([8, 2])
torch.Size([8, 2]) torch.Size([8, 2])


torch.Size([8, 12])

In [52]:
# out