Credit: https://medium.com/tensorflow/a-transformer-chatbot-tutorial-with-tensorflow-2-0-88bf59e66fe2

# __Transformer__

In [None]:
import tensorflow as tf

## <font color=pink>__Attention__</font>

Like many sequence-to-sequence models, Transformer also consist of encoder and decoder. However, instead of recurrent or convolution layers, Transformer uses multi-head attention layers, which consist of multiple scaled dot-product attention.

<div>
<img src="attention_workflow.png" width="600"/>
</div>

### <font color=lightblue>Scaled dot product attention:</font>

The scaled dot-product attention function takes three inputs: $Q (query), K (key), V (value)$. The equation used to calculate the attention weights is:

<div>
<img src="attention_weight_equation.png" width="600" style="background-color: white"/>
</div>

As the softmax normalization being applied on the $key$, its values decide the amount of importance given to the $query$. The output represents the multiplication of the attention weights and $value$. This ensures that the words we want to focus on are kept as is and the irrelevant words are flushed out.



In [2]:
def scaled_dot_product_attention(query, key, value, mask):
    matmul_qk = tf.matmul(query, key, transpose_b=True)

    depth = tf.cast(tf.shape(key)[-1], tf.float32)
    logits = matmul_qk / tf.math.sqrt(depth)

    # add the mask zero out padding tokens.
    if mask is not None:
        logits += (mask * -1e9)

    attention_weights = tf.nn.softmax(logits, axis=-1)

    return tf.matmul(attention_weights, value)

### <font color=lightblue>Multi-head Attention Layer</font>


The Sequential models allow us to build models very quickly by simply stacking layers on top of each other; however, for more complicated and non-sequential models, the Functional API and Model subclassing are needed. The $tf.keras$ API allows us to mix and match different API styles. My favourite feature of Model subclassing is the capability for debugging. I can set a breakpoint in the $call()$ method and observe the values for each layer’s inputs and outputs like a numpy array, and this makes debugging a lot simpler.

Here, we are using Model subclassing to implement our $MultiHeadAttention$ layer.

Multi-head attention consists of four parts:

- Linear layers and split into heads.
- Scaled dot-product attention.
- Concatenation of heads.
- Final linear layer.

Each multi-head attention block takes a dictionary as input, which consist of query, key and value. Notice that when using Model subclassing with Functional API, the input(s) has to be kept as a single argument, hence we have to wrap query, key and value as a dictionary.

The input are then put through dense layers and split up into multiple heads. scaled_dot_product_attention() defined above is applied to each head (broadcasted for efficiency). An appropriate mask must be used in the attention step. The attention output for each head is then concatenated and put through a final dense layer.

Instead of one single attention head, query, key, and value are split into multiple heads because it allows the model to jointly attend to information at different positions from different representational spaces. After the split each head has a reduced dimensionality, so the total computation cost is the same as a single head attention with full dimensionality.

In [None]:
class MultiHeadAttention(tf.keras.layers.Layer):

    def __init__(self, d_model, num_heads, name="multi_head_attention"):
        super(MultiHeadAttention, self).__init__(name=name)
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.query_dense = tf.keras.layers.Dense(units=d_model)
        self.key_dense = tf.keras.layers.Dense(units=d_model)
        self.value_dense = tf.keras.layers.Dense(units=d_model)

        self.dense = tf.keras.layers.Dense(units=d_model)

    def split_heads(self, inputs, batch_size):
        inputs = tf.reshape(
            inputs, shape=(batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(inputs, perm=[0, 2, 1, 3])

    def call(self, inputs):
        query, key, value, mask = inputs['query'], inputs['key'], inputs[
            'value'], inputs['mask']
        batch_size = tf.shape(query)[0]

        # linear layers
        query = self.query_dense(query)
        key = self.key_dense(key)
        value = self.value_dense(value)

        # split heads
        query = self.split_heads(query, batch_size)
        key = self.split_heads(key, batch_size)
        value = self.split_heads(value, batch_size)

        scaled_attention = scaled_dot_product_attention(query, key, value, mask)

        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])

        concat_attention = tf.reshape(scaled_attention,
                                    (batch_size, -1, self.d_model))

        outputs = self.dense(concat_attention)

        return outputs