### Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

##### Multihead Attention
- Source
    - [Algorithm Whiteboard : Attenton by Rasa](https://youtu.be/yGTUuEx3GkA)
    - [Attention is all you need](https://arxiv.org/pdf/1706.03762.pdf)



<img src="../assets/scp-mha.png" width="1100" height="200"/>

<img src="../assets/sda.png" width="550" height="700"/> <img src="../assets/mha.png" width="550" height="700"/> 

In [None]:
class MHA(nn.Module):
    """
        Multihead attention block used as attention mechanism in Transformer model
    """
    def __init__(self, num_heads = 8, embedding_dim=256):
        super(MHA, self).__init__()
        
        self.num_heads = num_heads
        
        # fully connected layers 
        self.fc_query = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
        self.fc_key = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
        self.fc_value = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
        self.fc_out = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
        
        
        # each linear
        self.scale = torch.sqrt(torch.tensor([embedding_dim]))
    
    def forward(self, query, key, value):
        """
            keys.shape = queries.shape == values.shape -> [batch_size, seq_len, embedding_dim]
        """
        # pass them to linear layer
        Q = self.fc_query(query)
        K = self.fc_key(key)
        V = self.fc_value(value)
        
        # compute the attention weights and apply softmax to compute the attention values
        attention_weights = torch.einsum('bij,bjk->bik', queries, keys.permute(0, 2, 1))
        
#         attention = F.softmax(attention_weights, dim=1)
#         print(attention_weights.shape, attention.shape)
        
        # contextual embedding is weigted sum of value, attention is weight
        return attention_weights
        
        
        
        

tensor([16.])

In [3]:
mha = MHA()

In [4]:
keys = torch.randn(10, 32, 100)
queries = torch.randn(10, 32, 100)
values = torch.randn(10, 32, 100)

In [5]:
attention_weights = mha(keys, queries, values)

In [6]:
attention_weights.shape

torch.Size([10, 32, 32])

In [9]:
attention_weights.softmax(dim=3).shape

IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

In [10]:
attention_weights

torch.Size([10, 32, 32])

<img src="../assets/transformer.png"/> 

In [43]:
keys.permute?

[0;31mDocstring:[0m
permute(*dims) -> Tensor

Returns a view of the original tensor with its dimensions permuted.

Args:
    *dims (int...): The desired ordering of dimensions

Example:
    >>> x = torch.randn(2, 3, 5)
    >>> x.size()
    torch.Size([2, 3, 5])
    >>> x.permute(2, 0, 1).size()
    torch.Size([5, 2, 3])
[0;31mType:[0m      builtin_function_or_method
