In [1]:
import keras.backend as K
from keras.layers import Layer, Dense, TimeDistributed, Concatenate, InputSpec, Wrapper, RNN
import numpy as np
import tensorflow as tf

Using TensorFlow backend.


In [2]:
class ScaledDotProductAttention(Layer):
    def __init__(self,**kwargs):
        super(ScaledDotProductAttention,self).__init__(**kwargs)
    def call(self,x,mask=None):
        """
            Attention(Q,K,V)=softmax(Q*K^T / sqrt(d_k))*V
        """
        q,k,v=x
        d_k=q.shape.as_list()[2]
        weights=K.batch_dot(q,k,axes=[2,2])
        
        if mask is not None:
            weights+=-1e10*(1-mask)
        
        weights=K.softmax(weights/np.sqrt(d_k))
        output=K.batch_dot(weights,v)
        return output,weights
    
    def build(self,input_shape):
        super(ScaledDotProductAttention,self).build(input_shape)
        

In [4]:
class MultiHeadAttention(Layer):
    def __init__(self,h,**kwargs):
        self.h=h
        super(MultiHeadAttention,self).__init__(**kwargs)
    def build(self,input_shape):
        d_k,d_v=input_shape[1][-1]
        d_model=self.h*d_k
        self._q_layers=[]
        self._k_layers=[]
        self._v_layers=[]
        self.sdpa_layer=ScaledDotProductAttention()
        self._output=TimeDistributed(Dense(d_model))
        for _ in range(self.h):
            self._q_layers.append(TimeDistributed(Dense(d_k,activation="relu",use_bias=False)))
            self._k_layers.append(TimeDistributed(Dense(d_k,activation="relu",use_bias=False)))
            self._v_layers.append(TimeDistributed(Dense(d_v,activation="relu",use_bias=False)))
        super(MultiHeadAttention, self).build(input_shape)
    
    def call(self,x,mask=None):
        """
            MultiHeadAttention(q,k,v)=concat(head_1,...head_h)*W_0
            head_i=Attention(q*W_q_i,k*W_k_i,v*W_v_i)
        """
        q,k,v=x
        outputs=[]
        attentions=[]
        for i in range(self.h):
            qi=self._q_layers[i](q)
            ki=self._k_layers[i](k)
            vi=self._v_layers[i](v)
            output,attention=self.sdpa_layer([qi,ki,vi],mask=mask)
            outputs.append(output)
            attentions.append(attention)
        
        concatenated_outputs=Concatenate()(outputs)
        concatenated_attentions=Concatenate()(attentions)
        output=self._output(concatenated_outputs)
        return [output,concatenated_attentions]
        