In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Attention

In [97]:
class Attention(nn.Module):
    def __init__(self, d_in, d_out):
        """
        Parameters:
        - d_in (int): size of input - the hidden size aka d_model
        - d_out (int): size of output - the length of the k and q vectors
        """
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.Q = nn.Linear(d_in, d_out)
        self.K = nn.Linear(d_in, d_out)
        self.V = nn.Linear(d_in, d_out)

    def forward(self, x):
        queries = self.Q(x)
        keys = self.K(x)
        values = self.V(x)
        scores = torch.bmm(queries, keys.transpose(1, 2))
        scores = scores / (self.d_out ** 0.5)
        attention = F.softmax(scores, dim=2)
        hidden_states = torch.bmm(attention, values)
        return hidden_states        

In [106]:
SOS_token = 0
EOS_token = 0

index2words = {
    SOS_token: 'SOS',
    EOS_token: 'EOS',
}

words = "How are you doing ? I am good and you ?"
words_list = set(words.lower().split())
for word in words_list:
    index2words[len(index2words)] = word

words2index = {w: i for i, w in index2words.items()}

In [107]:
def convert2tensors(sentence):
    words_list = sentence.lower().split(' ')
    indexes = [words2index[word] for word in words_list]
    return torch.tensor(indexes, dtype=torch.long).view(1, -1)

In [110]:
sentence = "How are you doing ? I am good and you ?"
convert2tensors(sentence)

tensor([[3, 6, 1, 7, 2, 8, 4, 9, 5, 1, 2]])

In [124]:
HIDDEN_SIZE = 10
VOCAB_SIZE = len(words2index)

embedding = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
attention = Attention(HIDDEN_SIZE, HIDDEN_SIZE)

sentence = "How are you doing ?"
input_tensor = convert2tensors(sentence)
embedded = embedding(input_tensor)
hidden_states = attention(embedded)

# Multihead Attention

In [125]:
class MultiheadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.out = nn.Linear(hidden_size, hidden_size)
        self.heads = nn.ModuleList([
            Attention(hidden_size, hidden_size // num_heads)
            for _ in range(num_heads)
        ])

    def forward(self, x):
        outputs = [head(x) for head in self.heads]
        outputs = torch.cat(outputs, dim=2)
        hidden_states = self.out(outputs)
        return hidden_states

In [127]:
NUM_HEADS = 3
HIDDEN_SIZE = 12
VOCAB_SIZE = len(words2index)

multi_att = MultiheadAttention(HIDDEN_SIZE, NUM_HEADS)
embedding = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)

sentence = "How are you doing ?"
input_tensor = convert2tensors(sentence)
embedded = embedding(input_tensor)
hidden_states = multi_att(embedded)

In [128]:
hidden_states.shape

torch.Size([1, 5, 12])

In [129]:
hidden_states

tensor([[[ 0.0108,  0.0522,  0.1712,  0.2370, -0.0738, -0.2954,  0.0350,
          -0.0551,  0.0423, -0.1061, -0.0453, -0.0933],
         [ 0.0239,  0.0556,  0.1297,  0.1967, -0.0386, -0.2766,  0.0329,
          -0.0894,  0.0359, -0.1271, -0.0508, -0.1351],
         [-0.0068,  0.1204,  0.0525,  0.1229,  0.0398, -0.2572,  0.0457,
          -0.1234,  0.0786, -0.1508, -0.0979, -0.2468],
         [-0.0199,  0.0468,  0.1574,  0.2117, -0.0436, -0.2889,  0.0429,
          -0.0565,  0.0578, -0.1142, -0.0644, -0.1291],
         [-0.0412,  0.0683,  0.1340,  0.1760, -0.0055, -0.2682,  0.0511,
          -0.0569,  0.0727, -0.1113, -0.0643, -0.1850]]],
       grad_fn=<ViewBackward0>)

In [137]:
class MultiheadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()

        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.qkv_linear = nn.Linear(hidden_size, hidden_size * 3)
        self.out = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        batch_size, seq_length, hidden_size = x.size()

        # (batch_size, seq_length, hidden_size * 3)
        qkv = self.qkv_linear(x)
        # (batch_size, seq_length, num_heads, head_dim * 3)
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
        # (batch_size, num_heads, seq_length, head_dim * 3)
        qkv = qkv.transpose(1, 2)
        # (batch_size, num_heads, seq_length, head_dim)
        queries, keys, values = qkv.chunk(3, dim=-1)

        # (batch_size, num_heads, seq_length, seq_length)
        scores = torch.matmul(queries, keys.transpose(2, 3))
        # (batch_size, num_heads, seq_length, seq_length)
        scores = scores / (self.head_dim ** 0.5)
        # (batch_size, num_heads, seq_length, seq_length)
        attention = F.softmax(scores, dim=-1)
        # (batch_size, num_heads, seq_length, head_dim)
        context = torch.matmul(attention, values)
        # (batch_size, seq_length, num_heads, head_dim)
        context = context.transpose(1, 2)
        # (batch_size, seq_length, hidden_size)
        context = context.reshape(batch_size, seq_length, hidden_size)
        # (batch_size, seq_length, hidden_size)
        output = self.out(context)
        return output
        
        
        

In [138]:
NUM_HEADS = 3
HIDDEN_SIZE = 12
VOCAB_SIZE = len(words2index)

multi_att = MultiheadAttention(HIDDEN_SIZE, NUM_HEADS)
embedding = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)

sentence = "How are you doing ?"
input_tensor = convert2tensors(sentence)
embedded = embedding(input_tensor)
hidden_states = multi_att(embedded)

In [139]:
hidden_states.size()

torch.Size([1, 5, 12])