In this exercise, you will be implementing `Multi-Head Attention` to solve a toy exercise in sequence modeling. The concept of `Multi-Head Attention` is from the famous paper called ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762), which introduced the Transformer model. Please read the paper carefully and answer the questions below. Understanding the concepts described in this paper will help understanding many modern models in the Neural Networks field and it's also necessary if you choose to work on the NLP project later. 

If you have troubles understanding the paper you can read the [illustrated transformer blog](https://jalammar.github.io/illustrated-transformer/) first. 

i) The biggest benefit of using Transformers instead of RNN and convolution-based models is the possibility to parallelize computations during training. Why parallelization is not possible with RNN and Convolution-based models for sequence processing, but possible with Transformers? *Note*: parallelization can be applied only to the Encoder part of the Trasnformer. (0.5 points)  

ii) In explaining the concept of `self-attention` the paper mentions 3 matrices `Q`, `K` and `V` which serve as an input to self-attention mechanism sublayer. Explain how these matrices are computed in the encoder and in the decoder. What role each of these matrices play? (1 point)  

iii) How is Multi-Head Attention better than Single-Head Attention? (0.5 points)

i) The feed-forward layer does not have those dependencies, however, and thus the various paths can be executed in parallel while flowing through the feed-forward layer
iii)It gives the attention layer multiple “representation subspaces”. As we’ll see next, with multi-headed attention we have not only one, but multiple sets of Query/Key/Value weight matrices (the Transformer uses eight attention heads, so we end up with eight sets for each encoder/decoder). Each of these sets is randomly initialized. Then, after training, each set is used to project the input embeddings (or vectors from lower encoders/decoders) into a different representation subspace.

### Task description
Given an input sequence `XY[0-5]+` (two digits X and Y followed by a sequence of digits in the range from 0 to 5 inclusive), the task is to count the number of occurrences of X and Y in the remaining substring and then calculate the difference #X - #Y.

Example:  
Input: `1214211`  
Output: `2`  
Explanation: there are 3 `1`'s and 1 `2` in the sequence `14211`, `3-1=2`  
  
The model must learn this relationship between the symbols of the sequence and predict the output. This task can be solved with a multi-head attention network.

$\color{red}{\textbf{Note}}$: In all your implementations, you're allowed to use only basic PyTorch operations. No APIs from external libraries such as Huggingface transformers should be used to solve any part of the exercise.

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from IPython.display import Image
from IPython.core.display import HTML 

torch.manual_seed(0)

<torch._C.Generator at 0x104752730>

In [2]:
SEQ_LEN = 5
VOCAB_SIZE = 6
NUM_TRAINING_STEPS = 25000
BATCH_SIZE = 64

#### Data generation function 
Fill the code to calculate the ground truth outpu for the random sequence and store it in `gts`.    

Why are we offseting the ground truth? In other words, why do we need grouth truth to be non-negative?

$\color{red}{\text{Ans:}}$

In [3]:
# This function generates data samples as described at the beginning of the
# script
def get_data_sample(batch_size=1):
    random_seq = torch.randint(low=0, high=VOCAB_SIZE - 1,
                               size=[batch_size, SEQ_LEN + 2])                         
    ############################################################################
    # TODO: Calculate the ground truth output for the random sequence and store
    # it in 'gts'.
    ############################################################################
    gts = torch.empty(batch_size)
    for batch in range(batch_size):
        a = random_seq[batch][0]
        b = random_seq[batch][1]
        gts[batch] = ((torch.bincount(random_seq[batch])[a].item()-1)-(torch.bincount(random_seq[batch])[b].item()-1))
        
    
    # Ensure that GT is non-negative
    ############################################################################
    # TODO: Why is this needed?
    ############################################################################
    gts += SEQ_LEN
    return random_seq, gts.type(torch.LongTensor)

#### Scaled Dot-Product Attention
Implement a naive version of the Attention mechanism in the following class. Please do not deviate from the given structure. If you have ideas about how to optimize the implementation you can however note them in a comment or provide an additional implementation.  
For implementation, refer to Section 3.2.1 and Figure 2 (left) in the paper. Keep the parameters to the forward pass trainable.

In [4]:
class Attention(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, q, k, v,d_k):
        # q, k, and v are batch-first
        # TODO: implement
       
        mul_i = torch.matmul(q,torch.transpose(k,1,2))

        temp = mul_i/math.sqrt(d_k)

        a = torch.nn.Softmax(dim=-1)
        b = a(temp)
        return torch.matmul(b,v)

#### Multi-Head Attention
Implement Multi-Head Attention mechanism on top of the Single-Head Attention mechanism in the following class. Please do not deviate from the given structure. If you have ideas about how to optimize the implementation you can however note them in a comment or provide an additional implementation.  
For implementation, refer to Section 3.2.2 and Figure 2 (right) in the paper. Keep the parameters to the forward pass trainable.

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dim_r = self.embed_dim // self.num_heads   # to evenly split q, k, and v across heads.
        self.attention = Attention()

        self.q_linear_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.k_linear_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_linear_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.final_linear_proj = nn.Linear(self.embed_dim, self.embed_dim)
        
        # xavier initialization for linear layer weights
        nn.init.xavier_uniform_(self.q_linear_proj.weight)
        nn.init.xavier_uniform_(self.k_linear_proj.weight)
        nn.init.xavier_uniform_(self.v_linear_proj.weight)
        nn.init.xavier_uniform_(self.final_linear_proj.weight)

    def forward(self, q, k, v):
        # q, k, and v are batch-first

        ########################################################################
        # TODO: Implement multi-head attention as described in Section 3.2.2
        # of the paper.
        ########################################################################
        # shapes of q, k, v are [bsize, SEQ_LEN + 2, hidden_dim]
        bsize = k.shape[0]
        
        
        q = self.q_linear_proj(q)
        k = self.k_linear_proj(k)
        v = self.v_linear_proj(v)
        for i in range(self.num_heads):
            if i == 0:
                prev_z = self.attention(q[:,:,i*self.dim_r:(i*self.dim_r)+self.dim_r],k[:,:,i*self.dim_r:(i*self.dim_r)+self.dim_r],v[:,:,i*self.dim_r:(i*self.dim_r)+self.dim_r],self.dim_r)
            else:
                output_z = torch.cat((prev_z,self.attention(q[:,:,i*self.dim_r:(i*self.dim_r)+self.dim_r],k[:,:,i*self.dim_r:(i*self.dim_r)+self.dim_r],v[:,:,i*self.dim_r:(i*self.dim_r)+self.dim_r],self.dim_r)),dim=2)
                prev_z = output_z
        return self.final_linear_proj(output_z)

#### Encoding Layer
Implement the Encoding Layer of the network.  
Refer the following figure from the paper for the architecture of the Encoding layer. 

In [6]:
Image(url='https://i.stack.imgur.com/eAKQu.png')

In [7]:
class EncodingLayer(nn.Module):
    def __init__(self, num_hidden, num_heads):
        super().__init__()
        self.mh = MultiHeadAttention(embed_dim=num_hidden, num_heads=num_heads)
        # TODO: add necessary member variables
        self.num_hidden = num_hidden
        self.num_heads = num_heads
        self.normalize_1= nn.LayerNorm([SEQ_LEN+2, num_hidden])
        self.feed_forward = nn.Linear(num_hidden, num_hidden)
        self.normalize_2 = nn.LayerNorm([SEQ_LEN+2, num_hidden])

    def forward(self, x):
        res = x
        x = self.mh(x, x, x)
        
        x = res+x
        
        x=self.normalize_1(x)
      
        res = x
        x=  self.feed_forward(x)
        x = res+x
        x=self.normalize_2(x)
        return x

####  Network definition
Implement the forward pass of the complete network.
The network must do the following:
1. calculate embeddings of the input (with size equal to `num_hidden`)
2. perform positional encoding
3. perform forward pass of a single Encoding layer
4. perform forward pass of a single Decoder layer
5. apply fully connected layer on the output

Because we are dealing with a simple task, the whole Decoder layer can be replaced with a single MultiHeadAttention block. Since our task is not sequence-to-sequence (Seq2Seq), but rather the classification of a sequence, the query (`Q` matrix) for the MultiHeadAttention block can be another learnable parameter (`nn.Parameter`) instead of processed output embedding.

In the forward pass we must add a (trainable) positional encoding of our input embedding. Why is this needed? Can you think of another similar task where the positional encoding would not be necessary?

$\color{red}{\text{Ans:}}$

In [8]:
# Network definition
class Net(nn.Module):
    def __init__(self, num_encoding_layers=1, num_hidden=64, num_heads=4):
        super().__init__()
        # q = torch.empty([1, num_hidden])
        q = torch.empty([BATCH_SIZE,SEQ_LEN+2, num_hidden])
        nn.init.normal_(q)
        self.q = nn.Parameter(q, requires_grad=True)
        # self.layers =  nn.ModuleList([torch.copy.deepcopy(EncodingLayer()) for i in range(num_encoding_layers)])
        self.encoding =EncodingLayer(num_hidden,num_heads)
        # TODO: implement
        self.embd = nn.Embedding(6, num_hidden).requires_grad_(False)
        self.num_hidden = num_hidden
        self.decoding = MultiHeadAttention(embed_dim=num_hidden, num_heads=num_heads)
        self.last_fwd =  nn.Linear((SEQ_LEN+2)*num_hidden, 11)
        self.sftmx = torch.nn.Softmax(dim=-1)
        
    def forward(self, x):
        # TODO: implement
        
        x = self.embd(x)
        pe = torch.zeros(SEQ_LEN+2, self.num_hidden)
        for pos in range(SEQ_LEN+2):
            for i in range(0, self.num_hidden, 2):
                pe[pos, i] = \
                math.sin(pos / (10000 ** ((2 * i)/self.num_hidden)))
                pe[pos, i + 1] = \
                math.cos(pos / (10000 ** ((2 * (i + 1))/self.num_hidden)))
        x = x+pe
        
        x= self.encoding(x)
        x = self.decoding(self.q,x,x)

        x = self.last_fwd(x.view(x.size(0), -1))
        return x

        

#### Training
Don't edit the following 2 cells. They must run without errors if you implemented the model correctly.  
The model should converge to nearly 100% accuracy after ~4.5k steps.

In [9]:
# Instantiate the network, loss function, and optimizer
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.005, momentum=0.9)

In [10]:
# Train the network
num = 0
for i in range(NUM_TRAINING_STEPS):
    
    inputs, labels = get_data_sample(BATCH_SIZE)
    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    # print(f'labels:{labels}')
    # print(f'output:{torch.argmax(outputs, axis=-1)}')
    accuracy = (torch.argmax(outputs, axis=-1) == labels).float().mean()

    if i % 100 == 0:
        for name,p in net.named_parameters():
            if p.requires_grad == True:
                num+=1
        print('[%d/%d] loss: %.3f, accuracy: %.3f' %
              (i , NUM_TRAINING_STEPS - 1, loss.item(), accuracy.item()))
    if i == NUM_TRAINING_STEPS - 1:
        print('Final accuracy: %.3f, expected %.3f' %
              (accuracy.item(), 1.0))

[0/24999] loss: 2.357, accuracy: 0.141
[100/24999] loss: 1.804, accuracy: 0.312
[200/24999] loss: 1.602, accuracy: 0.438
[300/24999] loss: 1.796, accuracy: 0.281
[400/24999] loss: 1.602, accuracy: 0.391
[500/24999] loss: 1.679, accuracy: 0.344
[600/24999] loss: 1.499, accuracy: 0.562
[700/24999] loss: 1.640, accuracy: 0.469
[800/24999] loss: 1.367, accuracy: 0.516
[900/24999] loss: 1.613, accuracy: 0.312
[1000/24999] loss: 1.530, accuracy: 0.500
[1100/24999] loss: 1.526, accuracy: 0.438
[1200/24999] loss: 1.560, accuracy: 0.422
[1300/24999] loss: 1.469, accuracy: 0.516
[1400/24999] loss: 1.510, accuracy: 0.484
[1500/24999] loss: 1.338, accuracy: 0.562
[1600/24999] loss: 1.323, accuracy: 0.531
[1700/24999] loss: 1.571, accuracy: 0.484
[1800/24999] loss: 1.534, accuracy: 0.391
[1900/24999] loss: 1.554, accuracy: 0.406
[2000/24999] loss: 1.595, accuracy: 0.453
[2100/24999] loss: 1.588, accuracy: 0.375
[2200/24999] loss: 1.501, accuracy: 0.469
[2300/24999] loss: 1.615, accuracy: 0.406
[240