## MultiHeadAttention adapted from 
### https://machinelearningmastery.com/how-to-implement-multi-head-attention-from-scratch-in-tensorflow-and-keras/

## Crucially in the above no explicit mention is made about the relation between Q, K, V projection dimensions and the number of attention heads.

## Below we make that detail explicit by defining these dimensions as multiples of the attention heads count.

## This works by performing one long rectangular matrix multiplications for the Q, K, V projections and then chopping rows into equal size parts as inputs for individual heads. This is why the number of heads must divide d_k, d_v.

In [None]:
import tensorflow as tf
import numpy as np

In [2]:
class DotProductAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)

    def call(self, queries, keys, values, d_k, mask=None):
        scores = tf.matmul(queries, keys, transpose_b=True) / tf.sqrt(tf.cast(d_k, tf.float32))

        # Effectively zeroing entries in softmax
        if mask is not None:
            scores += -1e9 * mask 

        weights = tf.nn.softmax(scores)

        return tf.matmul(weights, values)

In [46]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, h, d_k, d_v, d_model, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)

        self.attention = DotProductAttention()

        self.heads = h
        self.d_k = d_k
        self.d_v = d_v
        self.W_q = tf.keras.layers.Dense(d_k)
        self.W_k = tf.keras.layers.Dense(d_k)
        self.W_v = tf.keras.layers.Dense(d_v)
        # Layer output will produce model's working shape 
        # - the embedding dimensionality
        self.W_o = tf.keras.layers.Dense(d_model)

    def reshape_tensor(self, x, heads, flag):
        if flag:
            # Tensor shape after reshaping and transposing: (batch_size, heads, seq_length, -1)
            x = tf.reshape(x, shape=(tf.shape(x)[0], tf.shape(x)[1], heads, -1))
            x = tf.transpose(x, perm=(0, 2, 1, 3))
        else:
            # Reverting the reshaping and transposing operations: (batch_size, seq_length, d_model)
            x = tf.transpose(x, perm=(0, 2, 1, 3))
            x = tf.reshape(x, shape=(tf.shape(x)[0], tf.shape(x)[1], -1))
        
        return x

    def call(self, queries, keys, values, mask=None):
        q_reshaped = self.reshape_tensor(self.W_q(queries), self.heads, flag=True)
        print(f'Reshaped query partitioned into heads: {q_reshaped.shape}')
        k_reshaped = self.reshape_tensor(self.W_k(keys), self.heads, flag=True)
        v_reshaped = self.reshape_tensor(self.W_v(values), self.heads, flag=True)
        o_reshaped = self.attention(q_reshaped, k_reshaped, v_reshaped, self.d_k, mask)
        output = self.reshape_tensor(o_reshaped, self.heads, flag=False)

        return self.W_o(output)

In [47]:
# Number of self-attention heads
h = 8 

# Dimensionality of the linearly projected queries and keys
# Individual head will get 16-dimensional input
d_k = 16 * h  
d_v = 16 * h 

# Dimensionality of the model sub-layers' outputs
d_model = 512  
batch_size = 1 

In [48]:
# Maximum length of the input sequence
input_seq_length = 5  
 
queries = np.random.random((batch_size, input_seq_length, d_k))
keys = np.random.random((batch_size, input_seq_length, d_k))
values = np.random.random((batch_size, input_seq_length, d_v))

In [49]:
multihead_attention = MultiHeadAttention(h, d_k, d_v, d_model)

In [50]:
print(multihead_attention(queries, keys, values))

Reshaped query partitioned into heads: (1, 8, 5, 16)
tf.Tensor(
[[[-0.06934756 -0.4272185  -0.7013557  ... -0.18155214 -0.04901038
    0.10685071]
  [-0.06753655 -0.4271417  -0.6974434  ... -0.17985624 -0.04874355
    0.10335528]
  [-0.07192618 -0.42755052 -0.70126486 ... -0.1846404  -0.04992645
    0.10626048]
  [-0.07006095 -0.42346823 -0.6988238  ... -0.18033272 -0.04631981
    0.10505362]
  [-0.06831628 -0.42368662 -0.7011584  ... -0.18153885 -0.04925524
    0.10744631]]], shape=(1, 5, 512), dtype=float32)
