In [131]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

In [132]:
def attention(Q, K, V):
    downscaler = torch.sqrt(torch.tensor(Q.size(-1)))

    # attention scores (b, m, n)
    S = torch.matmul(Q, K.transpose(-2, -1)) / downscaler  # 4,2,10 -> 4,10,2

    attention_weights = nn.functional.softmax(S, dim=1)
    
    out = torch.matmul(
        attention_weights,
        V
    )
    return out, attention_weights

In [None]:
def multihead_attention(Q, K, V, nheads):

    Q = Q.chunk(nheads, dim=-1)
    K = K.chunk(nheads, dim=-1)
    V = V.chunk(nheads, dim=-1)

    outputs = []
    weights = []
    for i in range(nheads):
        Q_i = Q[i]
        K_i = K[i]
        V_i = V[i]

        downscaler = torch.sqrt(torch.tensor(Q_i.size(-1)))
        
        # attention scores (b, m, n)
        S = torch.matmul(Q_i, K_i.transpose(-2, -1)) / downscaler  # 4,2,10 -> 4,10,2
    
        attention_weights = nn.functional.softmax(S, dim=1)
        
        out = torch.matmul(
            attention_weights,
            V_i
        )
        outputs.append(out)
        weights.append(attention_weights)

    out = torch.cat(outputs, dim=-1)
    attention_weights = torch.stack(weights, dim=1)

    return out, attention_weights

In [133]:
def one_hot(word, vocab):
    one_hot_vector = [0] * len(vocab)
    one_hot_vector[vocab.index(word)] = 1
    return torch.tensor(one_hot_vector).unsqueeze(0)

In [134]:
sentence = 'The quick brown fox jumps over the lazy dog'

vocab = sentence.split()

vocab_size = len(tokens)

vocab_indices = torch.tensor([vocab.index(x) for x in vocab])

In [135]:
print(vocab)

['The', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']


In [136]:
embedding_dim = 10

embedding_Q = nn.Embedding(vocab_size, embedding_dim)
embedding_K = nn.Embedding(vocab_size, embedding_dim)
embedding_V = nn.Embedding(vocab_size, embedding_dim)

In [137]:
query = 'fox'

query_encoded = one_hot(query, vocab)
query_embedding = embedding_Q(query_encoded)

keys_encoded = torch.tensor(range(len(vocab))).unsqueeze(0)
keys_embedding = embedding_K(keys_encoded)

values_embedding = embedding_V(vocab_indices.unsqueeze(0))

In [138]:
scores, weights = attention(query_embedding, keys_embedding, values_embedding)

print(scores.shape)
print(weights.shape)

torch.Size([1, 9, 10])
torch.Size([1, 9, 9])


In [139]:
scores_map = {}

query_index = vocab.index(query)

attention_for_query = weights[0][:, vocab.index(query)].tolist()

for i, word in enumerate(vocab):
    scores_map[word] = attention_for_query[i]

from pprint import pprint
pprint(scores_map)

{'The': 0.11266357451677322,
 'brown': 0.11266357451677322,
 'dog': 0.11266357451677322,
 'fox': 0.09869139641523361,
 'jumps': 0.11266357451677322,
 'lazy': 0.11266357451677322,
 'over': 0.11266357451677322,
 'quick': 0.11266357451677322,
 'the': 0.11266357451677322}


In [140]:
# batch, m, d
Q = torch.randn([1, 2, 10])
# batch, n, d
K = torch.randn([1, 2, 10])
V = torch.randn([1, 2, 10])