## <font color='darkblue'>Preface</font>
<b><font size='3ptx'>In this tutorial, you will discover how to implement multi-head attention from scratch in TensorFlow and Keras.</font></b>

([article source](https://machinelearningmastery.com/how-to-implement-multi-head-attention-from-scratch-in-tensorflow-and-keras/)) We have already familiarised ourselves with the theory behind the [**Transformer model**](https://machinelearningmastery.com/the-transformer-model/) and its [**attention mechanism**](https://machinelearningmastery.com/the-transformer-model/), and we have already started our journey of implementing a complete model by seeing how to [implement the scaled-dot product attention](https://machinelearningmastery.com/how-to-implement-scaled-dot-product-attention-from-scratch-in-tensorflow-and-keras). **We shall now progress one step further into our journey by encapsulating the scaled-dot product attention into a multi-head attention mechanism, of which it is a core component. Our end goal remains the application of the complete model to Natural Language Processing** (NLP).

After completing this tutorial, you will know:
* The layers that form part of the multi-head attention mechanism.
* How to implement the multi-head attention mechanism from scratch.  

### <font color='darkgreen'>Tutorial Overview</font>
This tutorial is divided into three parts; they are:
1. <font size='3ptx'><b><a href='#sect1'>Recap of the Transformer Architecture</a></b></font>
    * <b><a href='#sect1_1'>The Transformer Multi-Head Attention</a></b>
    * <b><a href='#sect1_2'>Implementing Multi-Head Attention From Scratch</a></b>
2. <font size='3ptx'><b><a href='#sect2'>Testing Out the Code</a></b></font>

### <font color='darkgreen'>Prerequisites</font>
For this tutorial, we assume that you are already familiar with:
* [The concept of attention](https://machinelearningmastery.com/what-is-attention/)
* [The Transfomer attention mechanism](https://machinelearningmastery.com/the-transformer-attention-mechanism)
* [The Transformer model](https://machinelearningmastery.com/the-transformer-model/)
* [The scaled dot-product attention](https://machinelearningmastery.com/how-to-implement-scaled-dot-product-attention-from-scratch-in-tensorflow-and-keras)

<a id='sect1'></a>
## <font color='darkblue'>Recap of the Transformer Architecture</font>
<b><font size='3ptx'>Recall having seen that the Transformer architecture follows an encoder-decoder structure</font></b>

The encoder, on the left-hand side, is tasked with mapping an input sequence to a sequence of continuous representations; the decoder, on the right-hand side, receives the output of the encoder together with the decoder output at the previous time step, to generate an output sequence ([image source: Attention is all you need](https://arxiv.org/abs/1706.03762)).

![Transformer articture](images/1.PNG)

<br/>

In generating an output sequence, the Transformer does not rely on recurrence and convolutions.

We had seen that the decoder part of the Transformer shares many similarities in its architecture with the encoder. <b>One of the core mechanisms that both the encoder and decoder share is the <font color='darkblue'>multi-head attention mechanism</font></b>. 

<a id='sect1_1'></a>
### <font color='darkgreen'>The Transformer Multi-Head Attention</font>
Each multi-head attention block is made up of four consecutive levels:
* **On the first level**, three linear (dense) layers that each receives the queries, keys or values. 
* **On the second level**, a scaled dot-product attention function. The operations performed on both first and second levels are repeated `h` times and performed in parallel, according to the number of heads composing the multi-head attention block. 
* **On the third level**, a concatenation operation that joins the outputs of the different heads.
* **On the fourth level**, a final linear (dense) layer that produces the output. 

![Multi-Head Attention](images/2.PNG)

[Recall](https://machinelearningmastery.com/the-transformer-attention-mechanism/) as well the important components that will serve as building blocks for our implementation of the multi-head attention:
* The **queries**, **keys** and **values**: These are the inputs to each multi-head attention block. In the encoder stage, they each carry the same input sequence after this has been embedded and augmented by positional information. Similarly on the decoder side, the queries, keys and values fed into the first attention block represent the same target sequence, after this would have also been embedded and augmented by positional information. The second attention block of the decoder receives the encoder output in the form of keys and values, and the normalized output of the first decoder attention block as the queries. The dimensionality of the queries and keys is denoted by $d_k$, whereas the dimensionality of the values is denoted by $d_v$.
* The **projection matrices**: When applied to the queries, keys and values, these projection matrices generate different subspace representations of each. Each attention head then works on one of these projected versions of the queries, keys and values. An additional projection matrix is also applied to the output of the multi-head attention block, after the outputs of each individual head would have been concatenated together. The projection matrices are learned during training.

<br/>

Let’s now see how to implement the multi-head attention from scratch in TensorFlow and Keras.

<a id='sect1_2'></a>
### <font color='darkgreen'>Implementing Multi-Head Attention From Scratch</font>
Let’s start by creating the class, <b><font color='blue'>MultiHeadAttention</font></b>, that inherits form the [**Layer**](https://keras.io/api/layers/base_layer/) base class in Keras, and initialize several instance attributes that we shall be working with (<font color='brown'>attribute descriptions may be found in the comments</font>):
```python
class MultiHeadAttention(Layer):
    def __init__(self, h, d_k, d_v, d_model, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.attention = DotProductAttention()  # Scaled dot product attention 
        self.heads = h  # Number of attention heads to use
        self.d_k = d_k  # Dimensionality of the linearly projected queries and keys
        self.d_v = d_v  # Dimensionality of the linearly projected values
        self.W_q = Dense(d_k)  # Learned projection matrix for the queries
        self.W_k = Dense(d_k)  # Learned projection matrix for the keys
        self.W_v = Dense(d_v)  # Learned projection matrix for the values
        self.W_o = Dense(d_model)  # Learned projection matrix for the multi-head output
        ...
```

<br/>

Here note that we have also created an instance of the <b><font color='blue'>DotProductAttention</font></b> class that we had implemented earlier, and assigned its output to the variable attention. Recall that we had implemented the <b><font color='blue'>DotProductAttention</font></b> class as follows:
```python
from tensorflow import matmul, math, cast, float32
from tensorflow.keras.layers import Layer
from keras.backend import softmax

# Implementing the Scaled-Dot Product Attention
class DotProductAttention(Layer):
    def __init__(self, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)

    def call(self, queries, keys, values, d_k, mask=None):
        # Scoring the queries against the keys after transposing the latter, and scaling
        scores = matmul(queries, keys, transpose_b=True) / math.sqrt(cast(d_k, float32))

        # Apply mask to the attention scores
        if mask is not None:
            scores += -1e9 * mask

        # Computing the weights by a softmax operation
        weights = softmax(scores)

        # Computing the attention by a weighted sum of the value vectors
        return matmul(weights, values)
```

<br/>

Next, we will be reshaping the linearly projected queries, keys and values in such a manner as to allow the attention heads to be computed in parallel. 

The queries, keys and values will be fed as input into the multi-head attention block having a shape of (<font color='brown'>batch size, sequence length, model dimensionality</font>), where <b>the `batch size` is a hyperparameter of the training process, the `sequence length` defines the maximum length of the input/output phrases, and the `model dimensionality` is the dimensionality of the outputs produced by all sub-layers of the model</b>. They are then passed through the respective dense layer to be linearly projected to a shape of `(batch size, sequence length, queries/keys/values dimensionality)`.

The linearly projected queries, keys and values will be rearranged into `(batch size, number of heads, sequence length, depth)`, by first reshaping them into `(batch size, sequence length, number of heads, depth)` and then transposing the second and third dimensions. For this purpose, we will create the class method, <font color='blue'>reshape_tensor</font>, as follows:
```python
def reshape_tensor(self, x, heads, flag):
    if flag:
        # Tensor shape after reshaping and transposing: (batch_size, heads, seq_length, -1)
        x = reshape(x, shape=(shape(x)[0], shape(x)[1], heads, -1))
        x = transpose(x, perm=(0, 2, 1, 3))
    else:
        # Reverting the reshaping and transposing operations: (batch_size, seq_length, d_model)
        x = transpose(x, perm=(0, 2, 1, 3))
        x = reshape(x, shape=(shape(x)[0], shape(x)[1], -1))
    return x
```

<br/>

The <font color='blue'>reshape_tensor</font> method receives the linearly projected queries, keys or values as input (<font color='brown'>while setting the flag to True</font>) to be rearranged as previously explained. Once the multi-head attention output has been generated, this is also fed into the same function (<font color='brown'>this time setting the flag to False</font>) to perform a reverse operation, effectively concatenating the results of all heads together. 

Hence, the next step is to feed the linearly projected queries, keys and values into the <font color='blue'>reshape_tensor</font> method to be rearranged, and then proceed to feed them into the scaled dot-product attention function. In order to do so, we will create another class method, <font color='blue'>call</font>, as follows:
```python
def call(self, queries, keys, values, mask=None):
    # Rearrange the queries to be able to compute all heads in parallel
    q_reshaped = self.reshape_tensor(self.W_q(queries), self.heads, True)
    # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)

    # Rearrange the keys to be able to compute all heads in parallel
    k_reshaped = self.reshape_tensor(self.W_k(keys), self.heads, True)
    # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)

    # Rearrange the values to be able to compute all heads in parallel
    v_reshaped = self.reshape_tensor(self.W_v(values), self.heads, True)
    # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)

    # Compute the multi-head attention output using the reshaped queries, keys and values
    o_reshaped = self.attention(q_reshaped, k_reshaped, v_reshaped, self.d_k, mask)
    # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)
    ...
```

<br/>

Note that the <font color='blue'>reshape_tensor</font> method can also receive a mask (<font color='brown'>whose value defaults to None</font>) as input, in addition to the queries, keys and values. 

[Recall](https://machinelearningmastery.com/the-transformer-model/) that the Transformer model introduces a <b>look-ahead mask</b> to prevent the decoder from attending to succeeding words, such that the prediction for a particular word can only depend on known outputs for the words that come before it. Furthermore, since the word embeddings are zero-padded to a specific sequence length, a `padding mask` needs to be introduced too in order to prevent the zero values from being processed along with the input. These look-ahead and padding masks can be passed on to the scaled-dot product attention through the mask argument.

Once we have generated the multi-head attention output from all the attention heads, the final steps are to concatenate back all outputs together into a tensor of shape, `(batch size, sequence length, values dimensionality)`, and passing the result through one final dense layer. For this purpose, we will be adding the next two lines of code to the <font color='blue'>call</font> method:
```python
...
# Rearrange back the output into concatenated form
output = self.reshape_tensor(o_reshaped, self.heads, False)
# Resulting tensor shape: (batch_size, input_seq_length, d_v)

# Apply one final linear projection to the output to generate the multi-head attention
# Resulting tensor shape: (batch_size, input_seq_length, d_model)
return self.W_o(output)
```

<br/>

Putting everything together, we have the following implementation of the multi-head attention:

In [2]:
from tensorflow import math, matmul, reshape, shape, transpose, cast, float32
from tensorflow.keras.layers import Dense, Layer
from keras.backend import softmax

# Implementing the Scaled-Dot Product Attention
class DotProductAttention(Layer):
    def __init__(self, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)

    def call(self, queries, keys, values, d_k, mask=None):
        # Scoring the queries against the keys after transposing the latter, and scaling
        scores = matmul(queries, keys, transpose_b=True) / math.sqrt(cast(d_k, float32))

        # Apply mask to the attention scores
        if mask is not None:
            scores += -1e9 * mask

        # Computing the weights by a softmax operation
        weights = softmax(scores)

        # Computing the attention by a weighted sum of the value vectors
        return matmul(weights, values)

# Implementing the Multi-Head Attention
class MultiHeadAttention(Layer):
    def __init__(self, h, d_k, d_v, d_model, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.attention = DotProductAttention()  # Scaled dot product attention
        self.heads = h  # Number of attention heads to use
        self.d_k = d_k  # Dimensionality of the linearly projected queries and keys
        self.d_v = d_v  # Dimensionality of the linearly projected values
        self.d_model = d_model  # Dimensionality of the model
        self.W_q = Dense(d_k)  # Learned projection matrix for the queries
        self.W_k = Dense(d_k)  # Learned projection matrix for the keys
        self.W_v = Dense(d_v)  # Learned projection matrix for the values
        self.W_o = Dense(d_model)  # Learned projection matrix for the multi-head output

    def reshape_tensor(self, x, heads, flag):
        if flag:
            # Tensor shape after reshaping and transposing: (batch_size, heads, seq_length, -1)
            x = reshape(x, shape=(shape(x)[0], shape(x)[1], heads, -1))
            x = transpose(x, perm=(0, 2, 1, 3))
        else:
            # Reverting the reshaping and transposing operations: (batch_size, seq_length, d_k)
            x = transpose(x, perm=(0, 2, 1, 3))
            x = reshape(x, shape=(shape(x)[0], shape(x)[1], self.d_k))
        return x

    def call(self, queries, keys, values, mask=None):
        # Rearrange the queries to be able to compute all heads in parallel
        q_reshaped = self.reshape_tensor(self.W_q(queries), self.heads, True)
        # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)

        # Rearrange the keys to be able to compute all heads in parallel
        k_reshaped = self.reshape_tensor(self.W_k(keys), self.heads, True)
        # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)

        # Rearrange the values to be able to compute all heads in parallel
        v_reshaped = self.reshape_tensor(self.W_v(values), self.heads, True)
        # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)

        # Compute the multi-head attention output using the reshaped queries, keys and values
        o_reshaped = self.attention(q_reshaped, k_reshaped, v_reshaped, self.d_k, mask)
        # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)

        # Rearrange back the output into concatenated form
        output = self.reshape_tensor(o_reshaped, self.heads, False)
        # Resulting tensor shape: (batch_size, input_seq_length, d_v)

        # Apply one final linear projection to the output to generate the multi-head attention
        # Resulting tensor shape: (batch_size, input_seq_length, d_model)
        return self.W_o(output)

<a id='sect2'></a>
## <font color='darkblue'>Testing Out the Code</font>
We will be working with the parameter values specified in the paper, [Attention Is All You Need, by Vaswani et al. (2017)](https://arxiv.org/abs/1706.03762):

```python
h = 8  # Number of self-attention heads
d_k = 64  # Dimensionality of the linearly projected queries and keys
d_v = 64  # Dimensionality of the linearly projected values
d_model = 512  # Dimensionality of the model sub-layers' outputs
batch_size = 64  # Batch size from the training process
...
```

<br/>

As for the sequence length, and the queries, keys and values, we will be working with dummy data for the time being until we arrive to the stage of [training the complete Transformer model](https://machinelearningmastery.com/training-the-transformer-model) in a separate tutorial, at which point we will be using actual sentences:

```python
...
input_seq_length = 5  # Maximum length of the input sequence

queries = random.random((batch_size, input_seq_length, d_k))
keys = random.random((batch_size, input_seq_length, d_k))
values = random.random((batch_size, input_seq_length, d_v))
...
```

<br/>

<b>In the complete Transformer model, values for the sequence length, and the queries, keys and values will be obtained through a process of word tokenization and embedding</b>. We will be covering this in a separate tutorial. 

Returning back to our testing procedure, the next step is to create a new instance of the <b><font color='blue'>MultiHeadAttention</font></b> class, assigning its output to the multihead_attention variable:

```python
...
multihead_attention = MultiHeadAttention(h, d_k, d_v, d_model)
...
```

<br/>

Since the <b><font color='blue'>MultiHeadAttention</font></b> class inherits from the [**Layer**](https://keras.io/api/layers/base_layer/) base class, the <font color='blue'>call()</font> method of the former will be automatically invoked by the magic \_\_call()__ method of the latter. The final step is to pass in the input arguments and printing the result:

```python
...
print(multihead_attention(queries, keys, values))
```

<br/>

Tying everything together produces the following code listing:

In [4]:
from numpy import random

input_seq_length = 5  # Maximum length of the input sequence
h = 8  # Number of self-attention heads
d_k = 64  # Dimensionality of the linearly projected queries and keys
d_v = 64  # Dimensionality of the linearly projected values
d_model = 512  # Dimensionality of the model sub-layers' outputs
batch_size = 64  # Batch size from the training process

queries = random.random((batch_size, input_seq_length, d_k))
keys = random.random((batch_size, input_seq_length, d_k))
values = random.random((batch_size, input_seq_length, d_v))

multihead_attention = MultiHeadAttention(h, d_k, d_v, d_model)
print(multihead_attention(queries, keys, values))

tf.Tensor(
[[[ 0.31182265  0.19591606 -0.1745555  ... -0.0810058   0.3352888
   -0.09868204]
  [ 0.31025675  0.19259396 -0.17539161 ... -0.08174489  0.3322592
   -0.09594787]
  [ 0.31222063  0.1926596  -0.17362075 ... -0.08272956  0.3357062
   -0.09639519]
  [ 0.30701864  0.19324327 -0.1784897  ... -0.08069471  0.3368063
   -0.09762746]
  [ 0.3132949   0.19455978 -0.17680576 ... -0.08266836  0.33526513
   -0.09668439]]

 [[ 0.30330384  0.31986487 -0.41438392 ... -0.18172678  0.38058582
   -0.21451142]
  [ 0.2987469   0.32107183 -0.4159746  ... -0.1853449   0.3806158
   -0.2141388 ]
  [ 0.3031524   0.31956658 -0.41565073 ... -0.18619305  0.3800994
   -0.21355376]
  [ 0.30263156  0.32314113 -0.41567734 ... -0.18400706  0.37728742
   -0.21264756]
  [ 0.30143106  0.319786   -0.41592628 ... -0.18569006  0.3804273
   -0.21345107]]

 [[ 0.38233978  0.2469757  -0.2247107  ... -0.35052028  0.38104048
   -0.20756163]
  [ 0.38205683  0.24726364 -0.22410625 ... -0.35173076  0.38397768
   -0.206451

Running this code produces an output of shape, `(batch size, sequence length, model dimensionality)`. Note that you will likely see a different output due to the random initialization of the queries, keys and values, and the parameter values of the dense layers.

## <font color='darkblue'>Further Reading</font>
This section provides more resources on the topic if you are looking to go deeper.

* **Books**
    * [Advanced Deep Learning with Python, 2019.](https://www.amazon.com/Advanced-Deep-Learning-Python-next-generation/dp/178995617X)
    * [Transformers for Natural Language Processing, 2021.](https://www.amazon.com/Transformers-Natural-Language-Processing-architectures/dp/1800565798)
* **Papers**
    * [Attention Is All You Need, 2017.](https://arxiv.org/abs/1706.03762)
* **Others**
    * [【機器學習2021】自注意力機制 (Self-attention) (上)](https://www.youtube.com/watch?v=hYdO9CscNes)
    * [【機器學習2021】自注意力機制 (Self-attention) (下)](https://www.youtube.com/watch?v=gmsMY5kc-zw)