In [2]:
import keras.backend as K
from keras.layers import Layer, Dense, TimeDistributed, Concatenate, InputSpec, Wrapper, RNN,Conv1D,Lambda,Add
import numpy as np
import tensorflow as tf

Using TensorFlow backend.


In [3]:
class ScaledDotProductAttention(Layer):
    def __init__(self,**kwargs):
        super(ScaledDotProductAttention,self).__init__(**kwargs)
    def call(self,x,mask=None):
        """
            Attention(Q,K,V)=softmax(Q*K^T / sqrt(d_k))*V
        """
        q,k,v=x
        d_k=q.shape.as_list()[2]
        weights=K.batch_dot(q,k,axes=[2,2])
        
        if mask is not None:
            weights+=-1e10*(1-mask)
        
        weights=K.softmax(weights/np.sqrt(d_k))
        output=K.batch_dot(weights,v)
        return output,weights
    
    def build(self,input_shape):
        super(ScaledDotProductAttention,self).build(input_shape)
        

In [4]:
class MultiHeadAttention(Layer):
    def __init__(self,h,d_k,**kwargs):
        self.h=h
        self.d_k=d_k
        self.d_v=d_k
        self.d_model=self.h*d_k
        self._q_layers=[]
        self._k_layers=[]
        self._v_layers=[]
        self.sdpa_layer=ScaledDotProductAttention()
        self._output=TimeDistributed(Dense(self.d_model))
        for _ in range(self.h):
            self._q_layers.append(TimeDistributed(Dense(self.d_k,activation="relu",use_bias=False)))
            self._k_layers.append(TimeDistributed(Dense(self.d_k,activation="relu",use_bias=False)))
            self._v_layers.append(TimeDistributed(Dense(self.d_v,activation="relu",use_bias=False)))
            
        super(MultiHeadAttention,self).__init__(**kwargs)
    def build(self,input_shape):
        
        super(MultiHeadAttention, self).build(input_shape)
    
    def call(self,x,mask=None):
        """
            MultiHeadAttention(q,k,v)=concat(head_1,...head_h)*W_0
            head_i=Attention(q*W_q_i,k*W_k_i,v*W_v_i)
        """
        q,k,v=x
        outputs=[]
        attentions=[]
        for i in range(self.h):
            qi=self._q_layers[i](q)
            ki=self._k_layers[i](k)
            vi=self._v_layers[i](v)
            output,attention=self.sdpa_layer([qi,ki,vi],mask=mask)
            outputs.append(output)
            attentions.append(attention)
        
        concatenated_outputs=Concatenate()(outputs)
        concatenated_attentions=Concatenate()(attentions)
        output=self._output(concatenated_outputs)
        return [output,concatenated_attentions]
        

In [5]:
class PositionWiseFeedForward(Layer):
    def __init__(self,d_model=512,d_ff=2048,**kwargs):
        self.d_model=d_model,
        self.d_ff=d_ff
        self.conv1=Conv1D(self.d_ff,kernel_size=1,activation='relu')
        self.conv2=Conv1D(self.d_model,kernel_size=1)
        super(PositionWiseFeedForward,self).__init__(**kwargs)
        
    def build(self,input_shape):
        super(PositionWiseFeedForward,self).build(input_shape)
        
    def call(self,x):
        temp_x=self.conv1(x)
        return self.conv2(temp_x)
        

In [6]:
class LayerNormalization(Layer):
    def __init__(self,**kwargs):
        super(LayerNormalization,self).__init__(**kwargs)
    
    def build(self,input_shape):
        self.w=self.add_weight(name='normalization_weights',shape=(input_shape[-1],),initializer=Ones(),trainable=True)
        self.b=self.add_weight(name='bias',shape=(input_shape[-1],),initializer=Zeros(),trainable=True)
        super(LayerNormalization,self).build(input_shape)
    
    def call(self,x):
        mean=K.mean(x,axis=-1)
        std=K.std(x,axis=-1)
        output=self.g*(x-mean)/(std+1e-8)+self.b
        return output

In [7]:
class EncoderLayer(Layer):
    def __init__(self,h=8,d_k=64,d_hidden=2048,**kwargs):
        self.h=h
        self.d_k=64
        self.d_model=self.h*self.d_k
        self.d_hidden=d_hidden
        self.mha=MultiHeadAttention(self.h,self.d_k)
        self.ln_1=LayerNormalization()
        self.add_1=Add()
        self.ffwd=PositionWiseFeedForward(d_model=self.d_model,d_ff=self.d_hidden)
        self.ln_2=LayerNormalization()
        self.add_2=Add()
        super(EncoderLayer,self).__init__(**kwargs)
        
        
    def call(self,x):
        y,_=self.mha([x,x,x])
        y=self.add_1([x,y])
        y=self.ln_1(y)
        
        x=self.ffwd(y)
        x=self.add_2([x,y])
        y=self.ln_2(x)
        
        return y
        
        

In [8]:
class DecoderLayer(Layer):
    def __init__(self,h=8,d_k=64,d_hidden=2048,**kwargs):
        self.h=h
        self.d_k=64
        self.d_model=self.h*self.d_k
        self.d_hidden=d_hidden
        self.mha_1=MultiHeadAttention(self.h,self.d_k)
        self.ln_1=LayerNormalization()
        self.add_1=Add()
        self.mha_2=MultiHeadAttention(self.h,self.d_k)
        self.ln_2=LayerNormalization()
        self.add_2=Add()
        self.ffwd=PositionWiseFeedForward(d_model=self.d_model,d_ff=self.d_hidden)
        self.ln_3=LayerNormalization()
        self.add_3=Add()
        super(EncoderLayer,self).__init__(**kwargs)
    
    def call(self,x,encoder_output):
        
        y,s_attn=self.mha_1([x,x,x])
        y=self.add_1([x,y])
        y=self.ln_1(y)
        
        x,enc_attn=self.mha_2([encoder_output,encoder_output,y])
        x=self.add_2([x,y])
        x=self.ln_2(x)
        
        y=self.ffwd(x)
        y=self.add_3([x,y])
        y=self.ln_3(y)
        
        return [y,s_attn,enc_attn]
        

In [9]:
class Encoder(Layer):
    def __init__(self,n=6,h=8,d_k=64,d_hidden=2048,**kwargs):
        self.n=n
        self.h=h
        self.d_k=d_k
        self.d_hidden=d_hidden
        self.layers=[]
        for i in range(n):
            layers.append(EncoderLayer(h=self.h,d_k=self.d_k,d_hidden=self.d_hidden))
        super(Encoder,self).__init__(**kwargs)
    
    def call(self,x):
        for layer in self.layers:
            x=layer(x)
        return x

In [10]:
class Decoder(Layer):
    def __init__(self,n=6,h=8,d_k=64,d_hidden=2048,**kwargs):
        self.n=n
        self.h=h
        self.d_k=d_k
        self.d_hidden=d_hidden
        self.layers=[]
        for i in range(n):
            layers.append(DecoderLayer(h=self.h,d_k=self.d_k,d_hidden=self.d_hidden))
        super(Encoder,self).__init__(**kwargs)
        
    def call(self,x,encoder_output):
        s_attns=[]
        enc_attns=[]
        for layer in self.layers:
            x,s_attn,enc_attn=layer(x,encoder_output)
            s_attns.append(s_attn)
            enc_attns.append(enc_attn)
        return [x,s_attns,enc_attns]