In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
  def __init__(self, d_in, d_out):
    super().__init__()
    self.d_in = d_in
    self.d_out = d_out
    self.Q = nn.Linear(d_in, d_out)
    self.K = nn.Linear(d_in, d_out)
    self.V = nn.Linear(d_in, d_out)

  def forward(self, x):
    queries = self.Q(x)
    keys = self.K(x)
    values = self.V(x)
    scores = torch.bmm(queries, keys.transpose(1,2))
    scores = scores/ (self.d_out ** 0.5)
    attention = F.softmax(scores, dim=2)
    hidden_states = torch.bmm(attention, values)
    return hidden_states

In [7]:
SOS_token = 0
EOS_token = 1

index2words = {
    SOS_token: "SOS",
   EOS_token: "EOS"
}
words = "How are you doing ? I am doing good and you ?"
words_list = set(words.lower().split(' '))
for word in words_list:
  index2words[len(index2words)] = word
index2words

{0: 'SOS',
 1: 'EOS',
 2: 'how',
 3: 'doing',
 4: 'i',
 5: 'are',
 6: 'am',
 7: 'and',
 8: 'you',
 9: 'good',
 10: '?'}

In [8]:
words2index = {w: i for i, w in index2words.items()}
words2index

{'SOS': 0,
 'EOS': 1,
 'how': 2,
 'doing': 3,
 'i': 4,
 'are': 5,
 'am': 6,
 'and': 7,
 'you': 8,
 'good': 9,
 '?': 10}

In [12]:
def convert2tensors(sentence):
  words_list = sentence.lower().split(' ')
  indexes = [words2index[word] for word in words_list]

  return torch.tensor(indexes, dtype=torch.long).view(1, -1)

indexes = convert2tensors("How are you doing ?")
indexes.size()

torch.Size([1, 5])

In [21]:
HIDDEN_SIZE = 10
VOCAB_SIZE = len(words2index)

embedding = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
attention = Attention(HIDDEN_SIZE, HIDDEN_SIZE)

input_tensor = convert2tensors("How are you doing ?")
print(input_tensor)
embedded = embedding(input_tensor)
print(embedded)
embedded.size()

tensor([[ 2,  5,  8,  3, 10]])
tensor([[[ 9.3225e-01, -5.8053e-01,  5.6782e-04, -6.3848e-01, -5.2476e-01,
          -1.8279e-01, -1.3435e+00,  6.7209e-01,  1.5825e-01, -1.4818e-01],
         [ 3.7587e-01, -3.8707e-01,  1.3779e+00,  1.0981e+00, -1.0002e+00,
           2.8556e-01, -1.2906e+00, -1.3199e-01, -6.7352e-01,  1.1170e+00],
         [-2.4820e-01,  2.3479e-01,  6.3158e-01,  3.1945e-01,  2.3544e-01,
          -1.0188e+00, -1.8333e-01, -4.1408e-01,  4.5394e-01, -1.6592e+00],
         [ 3.9137e-01, -6.5312e-01,  1.2836e+00,  9.6799e-01,  1.0509e+00,
           9.0970e-01,  2.1353e-01,  1.3721e+00,  1.8131e-01, -4.4169e-01],
         [-1.1123e-01,  5.2791e-01, -7.1280e-01,  1.1345e+00,  4.1034e-01,
           2.1899e+00, -1.0925e+00,  7.4364e-01, -9.7716e-01, -1.3324e+00]]],
       grad_fn=<EmbeddingBackward0>)


torch.Size([1, 5, 10])

In [22]:
hidden_states = attention(embedded)


In [23]:
hidden_states

tensor([[[-0.2396,  0.3525,  0.2216,  0.1476,  0.0265, -0.3573, -0.4059,
          -0.0616, -0.0868, -0.1471],
         [-0.2114,  0.3979,  0.2715,  0.2000,  0.0288, -0.3918, -0.4333,
          -0.0440, -0.1701, -0.2479],
         [-0.2412,  0.3568,  0.2305,  0.1614,  0.0313, -0.3693, -0.4202,
          -0.0552, -0.1124, -0.1677],
         [-0.2070,  0.4210,  0.2743,  0.2330,  0.0286, -0.4104, -0.4461,
          -0.0282, -0.1941, -0.2775],
         [-0.2062,  0.4410,  0.3058,  0.2665,  0.0399, -0.4149, -0.4192,
           0.0038, -0.2088, -0.2959]]], grad_fn=<BmmBackward0>)