## Keras custom MultiHead-Attention

source paper: https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf

Other code repo:

https://github.com/Lsdefine/attention-is-all-you-need-keras/blob/master/transformer.py

https://github.com/benjamintaiwo/Attention/blob/master/transformer.py

In [1]:
from tensorflow.keras import backend as K
from tensorflow.keras import initializers, regularizers, activations
from tensorflow.keras.initializers import Zeros, Ones
from tensorflow.keras.layers import Layer,Embedding
import numpy as np

K.clear_session()

In [2]:
class MultiHeadSelf_Attention(Layer):
    def __init__(self, num_heads=6, activation='relu', initializer='glorot_normal', regularizer=None):
        super(MultiHeadSelf_Attention, self).__init__()
        self.num_heads = num_heads

        self.activation = activations.get(activation)
        self.initializer = initializers.get(initializer)
        self.regularizer = regularizers.get(regularizer)
        

    def build(self, input_shape):
        ## This implementation assume that Q,K,V have the same input! (ENCODER)

        emb_size = input_shape[-1].value

        if emb_size % self.num_heads != 0 :
            raise TypeError("Number of dimension of the embedding must be divided by the number of heads")
        
        #feature dimension and dq (query dimension) = dk (key dimension) = dv (value dimension)
        dk = emb_size//self.num_heads
        self.scaling_factor = np.sqrt(dk)
        
        #Weight's for the linear projection for all the heads
        self.multihead_Wq = []
        self.multihead_Wk = []
        self.multihead_Wv = []
        
        for head_i in range(self.num_heads):
            #define q,k,v weight's per HEAD
            self.multihead_Wq.append(self.add_variable(name = "Wq_{}".format(head_i),
                                                       shape = [emb_size,dk],
                                                       initializer = self.initializer,
                                                       regularizer = self.regularizer,)
                                    )
            self.multihead_Wk.append(self.add_variable(name = "Wk_{}".format(head_i),
                                                       shape = [emb_size,dk],
                                                       initializer = self.initializer,
                                                       regularizer = self.regularizer,)
                                    )
            self.multihead_Wv.append(self.add_variable(name = "Wv_{}".format(head_i),
                                                       shape = [emb_size,dk],
                                                       initializer = self.initializer,
                                                       regularizer = self.regularizer,)
                                    )
        
        #Output weight
        self.Wo = self.add_variable(name = "Wo",
                                    shape = [dk*self.num_heads,emb_size], #can be replaced by emb_size,emb_size
                                    initializer = self.initializer,
                                    regularizer = self.regularizer,)
        
        
        
        super(MultiHeadSelf_Attention, self).build(input_shape)
    
    def call(self, x):
        #Q,K,V are the same input!
        q = k = v = x
        
        #head output's
        heads = []
        
        for head_i in range(self.num_heads):
            #dim of projection (batch, input_size, dk)
            q_projection = K.dot(q, self.multihead_Wq[head_i])
            q_projection = self.activation(q_projection)
            k_projection = K.dot(k, self.multihead_Wk[head_i])
            k_projection = self.activation(k_projection)
            v_projection = K.dot(v, self.multihead_Wv[head_i])
            v_projection = self.activation(v_projection)
            
            ### START SCALED DOT PRODUCT
            #dim (batch, input_size, input_size)
            #transpose 3d matrix
            k_transpose = K.permute_dimensions(k_projection, (0,2,1))

            q_k_sim = K.batch_dot(q_projection, k_transpose)/self.scaling_factor

            #apply the softmax row wise related to the (input,input) matrix
            #dim (batch, input_size, input_size)
            softmax_weight = K.softmax(q_k_sim, axis =2)
            #dim (batch,input_size,dk)
            heads.append(K.batch_dot(softmax_weight,v_projection))
            ### END SCALED DOT PRODUCT
        
        #concatenation allong the row dimension related to the matrix (input_size,dk)
        
        multihead_concat = K.concatenate(heads, axis=2)
        return K.dot(multihead_concat,self.Wo)

In [2]:
class LayerNormalization(Layer):
    def __init__(self, eps=1e-6, **kwargs):
        self.eps = eps
        super().__init__(**kwargs)
    def build(self, input_shape):
        self.gamma = self.add_weight(name='gamma', shape=input_shape[-1:], initializer=Ones())
        self.beta = self.add_weight(name='beta', shape=input_shape[-1:], initializer=Zeros())
        super().build(input_shape)
    def call(self, x):
        mean = K.mean(x, axis=-1, keepdims=True)
        std = K.std(x, axis=-1, keepdims=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta


In [None]:
class PositionFFN(Layer):
    

In [5]:
### TEST

from tensorflow.keras.models import Sequential

model = Sequential()
model.add(Embedding(1000, 768))
model.add(MultiHeadSelf_Attention(num_heads=12))



model.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_2 (Embedding)      (None, None, 768)         768000    
_________________________________________________________________
multi_head_self__attention_2 (None, None, 768)         2359296   
_________________________________________________________________
multi_head_self__attention_3 (None, None, 768)         2359296   
_________________________________________________________________
multi_head_self__attention_4 (None, None, 768)         2359296   
_________________________________________________________________
multi_head_self__attention_5 (None, None, 768)         2359296   
_________________________________________________________________
multi_head_self__attention_6 (None, None, 768)         2359296   
_________________________________________________________________
multi_head_self__attention_7 (None, None, 768)         2359296   
__________