In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Attention

In [97]:
class Attention(nn.Module):
    def __init__(self, d_in, d_out):
        """
        Parameters:
        - d_in (int): size of input - the hidden size aka d_model
        - d_out (int): size of output - the length of the k and q vectors
        """
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.Q = nn.Linear(d_in, d_out)
        self.K = nn.Linear(d_in, d_out)
        self.V = nn.Linear(d_in, d_out)

    def forward(self, x):
        queries = self.Q(x)
        keys = self.K(x)
        values = self.V(x)
        scores = torch.bmm(queries, keys.transpose(1, 2))
        scores = scores / (self.d_out ** 0.5)
        attention = F.softmax(scores, dim=2)
        hidden_states = torch.bmm(attention, values)
        return hidden_states        

In [106]:
SOS_token = 0
EOS_token = 0

index2words = {
    SOS_token: 'SOS',
    EOS_token: 'EOS',
}

words = "How are you doing ? I am good and you ?"
words_list = set(words.lower().split())
for word in words_list:
    index2words[len(index2words)] = word

words2index = {w: i for i, w in index2words.items()}

In [107]:
def convert2tensors(sentence):
    words_list = sentence.lower().split(' ')
    indexes = [words2index[word] for word in words_list]
    return torch.tensor(indexes, dtype=torch.long).view(1, -1)

In [110]:
sentence = "How are you doing ? I am good and you ?"
convert2tensors(sentence)

tensor([[3, 6, 1, 7, 2, 8, 4, 9, 5, 1, 2]])

In [111]:
HIDDEN_SIZE = 10
VOCAB_SIZE = len(words2index)

embedding = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
attention = Attention(HIDDEN_SIZE, HIDDEN_SIZE)

sentence = "How are you doing ?"
input_tensor = convert2tensors(sentence)
embedded = embedding(input_tensor)

In [120]:
hidden_states = attention(embedded)

In [122]:
hidden_states.shape

torch.Size([1, 5, 10])

In [123]:
hidden_states

tensor([[[-0.1097,  0.1819, -0.2460,  0.1663,  0.2060, -0.2877, -0.7588,
          -0.1137, -0.0173,  0.3440],
         [-0.2254, -0.0467,  0.2279,  0.0389,  0.0265, -0.2272, -0.5867,
           0.0130, -0.1764,  0.1194],
         [-0.2163, -0.1267,  0.3212,  0.1042, -0.0701, -0.3189, -0.6346,
           0.0520, -0.1989,  0.1196],
         [-0.2384, -0.0666,  0.2611,  0.0185,  0.0106, -0.1989, -0.5209,
           0.0261, -0.1854,  0.0587],
         [-0.1606,  0.1044, -0.1331,  0.1148,  0.1526, -0.2480, -0.6132,
          -0.0816, -0.0318,  0.1648]]], grad_fn=<BmmBackward0>)