In [1]:
import tensorflow as tf
import keras
class Attention(keras.layers.Layer):

    def __init__(self, key_dim=None, **kwargs):
        
        self.key_dim = key_dim
        super(Attention, self).__init__(**kwargs)

        
    def build(self, input_shape):
        
         # Weights initializer function
        w_initializer = keras.initializers.glorot_uniform()

        # Biases initializer function
        b_initializer = keras.initializers.Zeros()
        
        #Matrix to extract the keys
        self.key_extract = self.add_weight(name='feature_extract', 
                                      shape=(int(input_shape[2]),int(self.key_dim)),
                                      initializer=w_initializer,
                                      trainable=True)
        #Key Bias
        self.key_bias = self.add_weight(name='feaure_bias', 
                                      shape=(int(1),int(self.key_dim)),
                                      initializer=b_initializer,
                                      trainable=True)
        
        #The Query representing the class
        self.Query = self.add_weight(name='Query', 
                                      shape=(int(self.key_dim),int(1)),
                                      initializer=w_initializer,
                                      trainable=True)

        super(Attention, self).build(input_shape) 

        
    def call(self, x):
        
        #Extract the Keys
        keys=tf.tensordot(x,self.key_extract,axes=[2,0])+self.key_bias
        
        #Calculate the similarity between keys and the Query
        similar_logits=tf.tensordot(keys,self.Query,axes=[2,0])
        
        #Normalize it to be between 0 and 1 and sum to 1
        attention_weights = tf.nn.softmax(similar_logits,axis=1)
        
        #Use these Weights to aggregate
        weighted_input = tf.matmul(x, attention_weights, transpose_a=True)

        
        return weighted_input
    def compute_output_shape(self, input_shape):
        return (input_shape[0],input_shape[2],int(1))


Using TensorFlow backend.


In [0]:
import tensorflow as tf
import keras
class Self_Attention(keras.layers.Layer):

    def __init__(self, key_dim=None, **kwargs):
        
        self.key_dim = key_dim
        super(Self_Attention, self).__init__(**kwargs)

        
    def build(self, input_shape):
        
         # Weights initializer function
        w_initializer = keras.initializers.glorot_uniform()

        # Biases initializer function
        b_initializer = keras.initializers.Zeros()
        
        #Matrix to extract the keys
        self.key_extract = self.add_weight(name='feature_extract', 
                                      shape=(int(input_shape[2]),int(self.key_dim)),
                                      initializer=w_initializer,
                                      trainable=True)
        #Key Bias
        self.key_bias = self.add_weight(name='feaure_bias', 
                                      shape=(int(1),int(self.key_dim)),
                                      initializer=b_initializer,
                                      trainable=True)
        
        #The Query representing the class
        self.query_extract = self.add_weight(name='q_extract', 
                                      shape=(int(input_shape[2]),int(self.key_dim)),
                                      initializer=w_initializer,
                                      trainable=True)
        self.query_bias = self.add_weight(name='q_bias', 
                              shape=(int(1),int(self.key_dim)),
                              initializer=b_initializer,
                              trainable=True)

        super(Self_Attention, self).build(input_shape) 

        
    def call(self, x):
        
        #Extract the Keys
        keys=tf.tensordot(x,self.key_extract,axes=[2,0])+self.key_bias
        #Extract the Keys
        query=tf.tensordot(x,self.query_extract,axes=[2,0])+self.query_bias
        
        #Calculate the similarity between keys and the Query
        similar_logits=tf.matmul(query,keys,transpose_b=True)
        
        #Normalize it to be between 0 and 1 and sum to 1
        attention_weights = tf.nn.softmax(similar_logits,axis=1)
        
        #Use these Weights to aggregate
        weighted_input = tf.matmul(attention_weights, x)

        
        return weighted_input

#    def compute_output_shape(self, input_shape):
#        return (input_shape[0],input_shape[1],input_shape[2])

In [0]:
import tensorflow as tf
import keras
def Multi_Attention(in_feature,N_Heads,key_dim=100): 
  splits=Lambda(lambda x : tf.split(x,axis=2,num_or_size_splits=N_Heads))(in_feature)
  feature_List=[]
  for elem in splits:
    feature_List.append(Self_Attention(key_dim)(elem))
  return Concatenate(axis=2)(feature_List)

In [0]:
from keras.preprocessing import sequence
from keras.models import Sequential,Model
from keras.layers import Dense, Dropout, Activation, Input,LSTM,Lambda,Concatenate
from keras.layers import Embedding
from keras.layers import Conv1D, GlobalMaxPooling1D
from keras.datasets import imdb

In [0]:
Inpu=Input(shape=(3,))

In [0]:
Embeddings=Embedding(output_dim=100,input_dim=10000)(Inpu)

In [0]:
Features=Multi_Attention(Embeddings,N_Heads=10)

In [0]:
Aggregation=Attention(100)(Features)

In [0]:
Aggregation=keras.layers.Flatten()(Aggregation)

In [0]:
Prediction=Dense(3,activation="softmax")(Aggregation)

In [0]:
model=Model(inputs=Inpu,outputs=Prediction)

In [34]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         (None, 3)                 0         
_________________________________________________________________
embedding_3 (Embedding)      (None, 3, 100)            1000000   
_________________________________________________________________
model_3 (Model)              multiple                  22000     
_________________________________________________________________
attention_3 (Attention)      (None, 100, 1)            10200     
_________________________________________________________________
flatten_3 (Flatten)          (None, 100)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 3)                 303       
Total params: 1,032,503
Trainable params: 1,032,503
Non-trainable params: 0
_________________________________________________________________


In [0]:
##Example "sub. Model"

def Create_Feature_Model(in_features,N_Heads):
  Inpu=Input(shape=(None,in_features))
  Features=Multi_Attention(Inpu,N_Heads=N_Heads)
  return Model(inputs=Inpu,outputs=Features,name="Multi_Head_Attention",)

In [0]:
multi_self_a=Create_Feature_Model(in_features=100,N_Heads=10)