In [6]:
import tensorflow as tf

In [8]:
class EA(tf.keras.layers.Layer):
    def __init__(self, channel_in:int, double_normalization:bool = True):
        super(EA, self).__init__()
        self.channel_in = channel_in
        self.double_normalization = double_normalization
        
        self.query_linear = tf.keras.layers.Dense(self.channel_in)
        self.M_k = tf.keras.layers.Dense(64, use_bias = False)
        self.M_v = tf.keras.layers.Dense(self.channel_in, use_bias = False)
        
    def call(self, X):
        X = self.query_linear(X)
        X = self.M_k(X)
        X = tf.nn.softmax(X, axis = 1)
        if self.double_normalization:
            X = X / (1e-9 + tf.reduce_sum(X, axis = -1, keepdims=True)) 
        X = self.M_v(X)
        return X

In [11]:
ea = EA(512)

ea(tf.random.truncated_normal(shape = [16,224*224,512]))

<tf.Tensor: shape=(16, 50176, 512), dtype=float32, numpy=
array([[[-0.02060293,  0.01195729, -0.00103426, ...,  0.00756243,
          0.00128338, -0.01149979],
        [-0.01484309,  0.01611494, -0.00801398, ...,  0.00659138,
         -0.00814292,  0.00331159],
        [-0.00568072,  0.01222262,  0.00312036, ...,  0.01746604,
         -0.01203682,  0.00945936],
        ...,
        [-0.02103423,  0.00386796, -0.00616206, ...,  0.01989715,
         -0.00283405, -0.01433011],
        [-0.00876631,  0.01146064, -0.00395893, ...,  0.0249698 ,
         -0.00787579, -0.01132022],
        [ 0.0126468 ,  0.00051271,  0.0255545 , ...,  0.01408626,
         -0.02686938, -0.00952071]],

       [[-0.00188944,  0.00270604, -0.00572482, ...,  0.01414756,
         -0.01240578,  0.00533654],
        [-0.00711065,  0.01863265, -0.00748485, ...,  0.01167414,
         -0.00468139, -0.010774  ],
        [-0.02110499,  0.01313176,  0.01994543, ...,  0.00863925,
         -0.01002276,  0.01685289],
        .

In [28]:
class MHEA(tf.keras.layers.Layer):
    def __init__(self, channel_in:int, n_channels:int, num_heads:int, double_normalization:bool = True):
        super(MHEA, self).__init__()
        self.channel_in = channel_in
        self.n_channels = n_channels
        self.num_heads = num_heads
        self.double_normalization = double_normalization
        
        self.query_linear = tf.keras.layers.Dense(self.n_channels)
        self.P = tf.keras.layers.Permute((2,1,3))
        self.M_k = tf.keras.layers.Dense(64, use_bias = False)
        self.M_v = tf.keras.layers.Dense(self.n_channels // self.num_heads, use_bias = False)
        self.W_o = tf.keras.layers.Dense(self.channel_in)
        
    def call(self, X):
        y = self.query_linear(X)
        B, N, C = tf.shape(y)
        y = tf.reshape(y, [B, N, self.num_heads, C//self.num_heads])
        y = self.P(y)
        
        attn = self.M_k(y)
        attn = tf.nn.softmax(attn, axis = 2)
        if self.double_normalization:
            attn = attn / (1e-9 + tf.reduce_sum(attn, axis = -1, keepdims=True))
        
        y = self.M_v(attn)
        y = self.P(y)
        y = tf.reshape(y, [B, N, C])
        y = self.W_o(y)
        return y

In [29]:
mhea = MHEA(2048, 768, 8)
mhea(tf.random.truncated_normal(shape = [16,8*8,2048]))

<tf.Tensor: shape=(16, 64, 2048), dtype=float32, numpy=
array([[[-1.05365822e-02, -1.25911860e-02,  1.95513833e-02, ...,
         -4.30652639e-03, -1.65949315e-02, -2.00314075e-03],
        [ 3.92880710e-03, -5.85939698e-02, -6.49548136e-04, ...,
          4.47908835e-03,  2.23287493e-02,  6.87561231e-03],
        [-2.51609236e-02,  3.76725779e-03,  1.23794777e-02, ...,
         -6.65128184e-03, -1.53609393e-02,  1.11376327e-02],
        ...,
        [ 1.29869962e-02, -1.56708099e-02, -5.61497873e-03, ...,
          1.79542210e-02, -5.17591182e-03,  4.27605473e-02],
        [-6.80662133e-03, -3.63170542e-02,  6.15066802e-03, ...,
          1.73358023e-02, -1.92838125e-02,  1.96269080e-02],
        [ 7.39938579e-03,  7.69017404e-03,  2.67918967e-03, ...,
         -9.98527650e-03, -1.05643971e-02,  3.19369771e-02]],

       [[ 2.01098137e-02,  1.74857560e-04, -1.41844042e-02, ...,
         -2.27681622e-02, -2.40693311e-03, -2.97437119e-03],
        [-1.23568298e-02,  7.83501100e-03,  2.0