In [1]:
import tensorflow as tf

In [98]:
# gating signal (query) before upsampling from layer {i+1}{j-1}
# skip channel @ 0
query = tf.keras.layers.Input((32, 32))
value = tf.keras.layers.Input((128, 16))

In [90]:
attention = tf.keras.layers.AdditiveAttention()([query, value])
attention

ValueError: Exception encountered when calling layer "additive_attention" (type AdditiveAttention).

Dimensions must be equal, but are 32 and 16 for '{{node additive_attention/add}} = AddV2[T=DT_FLOAT](additive_attention/ExpandDims, additive_attention/ExpandDims_1)' with input shapes: [?,16,32,1,32], [?,16,1,128,16].

Call arguments received:
  • inputs=['tf.Tensor(shape=(None, 16, 32, 32), dtype=float32)', 'tf.Tensor(shape=(None, 16, 128, 16), dtype=float32)']
  • mask=None
  • training=False
  • return_attention_scores=False

In [99]:
# require expend as and multiply and add function
import tensorflow as tf
from tensorflow.keras.layers import Conv1D, Activation, UpSampling1D, BatchNormalization, Add, Multiply
from tensorflow.keras import backend as K


class AdditiveAttentionGateLayer(tf.keras.layers.Layer):
    def __init__(self, x_res_kernel=3, **kwargs):
        super().__init__(**kwargs)
        # resampling kernel_size for value signal x
        self.x_res_kernel = x_res_kernel
    
    def call(self, x, g, F_int, *args, **kwargs):
        # expected shapes
        # x -> (b, l_x, c_x)
        # g -> (b, l_g, c_g)
        x_shape = K.int_shape(x)
        g_shape = K.int_shape(g)
        
        # reshape gating signal to F_int filters
        # phi_g = W_g*g -> shape: (b, l_g, F_int)
        phi_g = Conv1D(filters=F_int,
                      kernel_size=1,
                      strides=1,
                      padding='same')(g)
        phi_g = BatchNormalization()(phi_g)
        
        # resampling value (x) signal to shape of gating signal filters
        # theta_x = W_x*x -> shape: (b, l_g, F_int)
        theta_x = Conv1D(filters=F_int,
                     kernel_size=self.x_res_kernel,
                     strides=(x_shape[1] // g_shape[1]),
                     padding='same')(x)
        theta_x = BatchNormalization()(theta_x)
        
        sum_xg = Add()([phi_g, theta_x])
        act_of_sum_xg = Activation('relu')(sum_xg)
        
        # calculate alpha as the sigmoid activation of psi * act_of_sum_xg [+ b_psi]
        psi = Conv1D(filters=1, kernel_size=1, padding='same')(act_of_sum_xg)
        alpha = Activation('sigmoid')(psi)
        alpha_shape = K.int_shape(alpha)
        
        # upsample alpha and repeat vector along channel axis to match the shape of x
        upsampled_alpha = UpSampling1D((x_shape[1] // alpha_shape[1]))(alpha)
        repeated_alpha = K.repeat_elements(upsampled_alpha, x_shape[2], axis=2)
        
        # multiply x with attention map alpha
        gated_x = Multiply()([x, repeated_alpha])
        
        return gated_x    

In [106]:
class AdditiveAttentionGateLayer(tf.keras.layers.Layer):
    def __init__(self, x_res_kernel=3, **kwargs):
        """init method of AdditiveAttentionGateLayer

        Args:
            x_res_kernel (int, optional): Size of the x resampling kernel. Defaults to 3.
        """

        super().__init__(**kwargs)
        # resampling kernel_size for value signal x
        self.x_res_kernel = x_res_kernel
    
    def call(self, x: tf.Tensor, g: tf.Tensor, F_int: int, *args, **kwargs):
        """call-function of the Layer

        Args:
            x (tf.Tensor): _description_
            g (tf.Tensor): _description_
            F_int (int): _description_

        Returns:
            _type_: _description_
        """

        # expected shapes
        # x -> (b, l_x, c_x)
        # g -> (b, l_g, c_g)
        x_shape = K.int_shape(x)
        g_shape = K.int_shape(g)
        
        # reshape gating signal to F_int filters
        # phi_g = W_g*g -> shape: (b, l_g, F_int)
        phi_g = Conv1D(filters=F_int,
                      kernel_size=1,
                      strides=1,
                      padding='same',
                      name='Conv1D/phi_g')(g)
        phi_g = BatchNormalization(name='BN/phi_g')(phi_g)
        
        # resampling value (x) signal to shape of gating signal filters
        # theta_x = W_x*x -> shape: (b, l_g, F_int)
        theta_x = Conv1D(filters=F_int,
                     kernel_size=self.x_res_kernel,
                     strides=(x_shape[1] // g_shape[1]),
                     padding='same',
                     name='Conv1D/theta_x')(x)
        theta_x = BatchNormalization(name='BN/theta_x')(theta_x)
        
        sum_phi_theta = Add(name='Add/sum_phi_theta')([phi_g, theta_x])
        act_of_sum_xg = Activation('relu', name='Activation/sum_phi_theta')(sum_phi_theta)
        
        # calculate alpha as the sigmoid activation of psi * act_of_sum_xg [+ b_psi]
        psi = Conv1D(filters=1, kernel_size=1, padding='same', name='Conv1D/psi')(act_of_sum_xg)
        alpha = Activation('sigmoid', name='Activation/alpha')(psi)
        
        # upsample alpha -> shape: (bs, l, 1), therefore add and drop temporary "channel" dimension
        # and repeat vector along channel axis to match the shape of x
        reshaped_alpha = alpha[..., tf.newaxis]
        upsampled_alpha = tf.image.resize(reshaped_alpha, size=(x_shape[1] , 1), method=ResizeMethod.BILINEAR)
        repeated_alpha = K.repeat_elements(upsampled_alpha, x_shape[2], axis=2)
        repeated_alpha = K.squeeze(repeated_alpha, axis=3)
        
        # multiply x with attention map alpha
        x_gated = Multiply(name='Multiply/x_gated')([x, repeated_alpha])
        
        return x_gated

In [107]:
AdditiveAttentionGateLayer(3, name='Fu/nny_Name')(value, query, 32)

K.int_shape(x) (None, 128, 16)
(None, 32, 1)
K.int_shape(upsampled_alpha) (None, 128, 1, 1)
K.int_shape(x) (None, 128, 16)
K.int_shape(repeated_alpha) (None, 128, 16)


<KerasTensor: shape=(None, 128, 16) dtype=float32 (created by layer 'Fu/nny_Name')>

https://github.com/ozan-oktay/Attention-Gated-Networks/blob/master/models/layers/grid_attention_layer.py

In [1]:
from tensorflow.keras.layers import Conv2D, Activation, UpSampling2D, BatchNormalization
from tensorflow.keras import backend as K


def AttnGatingBlock(x, g, inter_shape):
    # https://github.com/robinvvinod/unet/blob/master/layers2D.py
    
    shape_x = K.int_shape(x)
    shape_g = K.int_shape(g)

    # Getting the gating signal to the same number of filters as the inter_shape
    phi_g = Conv2D(filters=inter_shape,
                   kernel_size=1,
                   strides=1,
                   padding='same')(g)

    # Getting the x signal to the same shape as the gating signal
    theta_x = Conv2D(filters=inter_shape,
                     kernel_size=3,
                     strides=(shape_x[1] // shape_g[1],
                              shape_x[2] // shape_g[2]),
                     padding='same')(x)

    # Element-wise addition of the gating and x signals
    add_xg = add([phi_g, theta_x])
    add_xg = Activation('relu')(add_xg)

    # 1x1x1 convolution
    psi = Conv2D(filters=1, kernel_size=1, padding='same')(add_xg)
    psi = Activation('sigmoid')(psi)
    shape_sigmoid = K.int_shape(psi)

    # Upsampling psi back to the original dimensions of x signal
    upsample_sigmoid_xg = UpSampling2D(size=(shape_x[1] // shape_sigmoid[1],
                                             shape_x[2] //
                                             shape_sigmoid[2]))(psi)

    # Expanding the filter axis to the number of filters in the original x signal
    upsample_sigmoid_xg = expend_as(upsample_sigmoid_xg, shape_x[3])

    # Element-wise multiplication of attention coefficients back onto original x signal
    attn_coefficients = multiply([upsample_sigmoid_xg, x])

    # Final 1x1x1 convolution to consolidate attention signal to original x dimensions
    output = Conv2D(filters=shape_x[3],
                    kernel_size=1,
                    strides=1,
                    padding='same')(attn_coefficients)
    output = BatchNormalization()(output)
    return output

In [10]:
from tensorflow.keras.layers import Conv3D, Activation, UpSampling3D
from tensorflow.keras import backend as K

def AttnGatingBlock(x, g, inter_shape):
    # https://github.com/robinvvinod/unet/blob/master/layers2D.py
    
    shape_x = K.int_shape(x)
    shape_g = K.int_shape(g)

    # Getting the gating signal to the same number of filters as the inter_shape
    phi_g = Conv3D(filters=inter_shape,
                   kernel_size=1,
                   strides=1,
                   padding='same')(g)

    # Getting the x signal to the same shape as the gating signal
    theta_x = Conv3D(filters=inter_shape,
                     kernel_size=3,
                     strides=(shape_x[1] // shape_g[1],
                              shape_x[2] // shape_g[2],
                              shape_x[3] // shape_g[3]),
                     padding='same')(x)

    # Element-wise addition of the gating and x signals
    add_xg = add([phi_g, theta_x])
    add_xg = Activation('relu')(add_xg)

    # 1x1x1 convolution
    psi = Conv3D(filters=1, kernel_size=1, padding='same')(add_xg)
    psi = Activation('sigmoid')(psi)
    shape_sigmoid = K.int_shape(psi)

    # Upsampling psi back to the original dimensions of x signal
    upsample_sigmoid_xg = UpSampling3D(
        size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2],
              shape_x[3] // shape_sigmoid[3]))(psi)

    # Expanding the filter axis to the number of filters in the original x signal
    upsample_sigmoid_xg = expend_as(upsample_sigmoid_xg, shape_x[4])

    # Element-wise multiplication of attention coefficients back onto original x signal
    attn_coefficients = multiply([upsample_sigmoid_xg, x])

    # Final 1x1x1 convolution to consolidate attention signal to original x dimensions
    output = Conv3D(filters=shape_x[4],
                    kernel_size=1,
                    strides=1,
                    padding='same')(attn_coefficients)
    output = BatchNormalization()(output)
    return output

In [None]:
# Variable-length int sequences.
query_input = tf.keras.Input(shape=(None,), dtype='int32')
value_input = tf.keras.Input(shape=(None,), dtype='int32')

# Embedding lookup.
token_embedding = tf.keras.layers.Embedding(max_tokens, dimension)
# Query embeddings of shape [batch_size, Tq, dimension].
query_embeddings = token_embedding(query_input)
# Value embeddings of shape [batch_size, Tv, dimension].
value_embeddings = token_embedding(value_input)

# CNN layer.
cnn_layer = tf.keras.layers.Conv1D(
    filters=100,
    kernel_size=4,
    # Use 'same' padding so outputs have the same shape as inputs.
    padding='same')
# Query encoding of shape [batch_size, Tq, filters].
query_seq_encoding = cnn_layer(query_embeddings)
# Value encoding of shape [batch_size, Tv, filters].
value_seq_encoding = cnn_layer(value_embeddings)

# Query-value attention of shape [batch_size, Tq, filters].
query_value_attention_seq = tf.keras.layers.AdditiveAttention()(
    [query_seq_encoding, value_seq_encoding])

# Reduce over the sequence axis to produce encodings of shape
# [batch_size, filters].
query_encoding = tf.keras.layers.GlobalAveragePooling1D()(
    query_seq_encoding)
query_value_attention = tf.keras.layers.GlobalAveragePooling1D()(
    query_value_attention_seq)

# Concatenate query and document encodings to produce a DNN input layer.
input_layer = tf.keras.layers.Concatenate()(
    [query_encoding, query_value_attention])

# Add DNN layers, and create Model.
# ...

In [74]:
import numpy as np
from tensorflow.image import ResizeMethod

# batch_size = 4
# ts_length = 8
base_line = np.array([0, 0, 1, 1, 1, 0, 0, 1])
np_images = np.array([base_line for _ in range(4)])

# add width and channel dimension
# np_images = np_images.reshape((4,8,1,1))

tf_images = tf.constant(np_images)
tf_images = tf_images[..., tf.newaxis, tf.newaxis]
tf_images

<tf.Tensor: shape=(4, 8, 1, 1), dtype=int32, numpy=
array([[[[0]],

        [[0]],

        [[1]],

        [[1]],

        [[1]],

        [[0]],

        [[0]],

        [[1]]],


       [[[0]],

        [[0]],

        [[1]],

        [[1]],

        [[1]],

        [[0]],

        [[0]],

        [[1]]],


       [[[0]],

        [[0]],

        [[1]],

        [[1]],

        [[1]],

        [[0]],

        [[0]],

        [[1]]],


       [[[0]],

        [[0]],

        [[1]],

        [[1]],

        [[1]],

        [[0]],

        [[0]],

        [[1]]]])>

In [75]:
result = tf.image.resize(tf_images, size=(16,1), method=ResizeMethod.BILINEAR)
result

<tf.Tensor: shape=(4, 16, 1, 1), dtype=float32, numpy=
array([[[[0.  ]],

        [[0.  ]],

        [[0.  ]],

        [[0.25]],

        [[0.75]],

        [[1.  ]],

        [[1.  ]],

        [[1.  ]],

        [[1.  ]],

        [[0.75]],

        [[0.25]],

        [[0.  ]],

        [[0.  ]],

        [[0.25]],

        [[0.75]],

        [[1.  ]]],


       [[[0.  ]],

        [[0.  ]],

        [[0.  ]],

        [[0.25]],

        [[0.75]],

        [[1.  ]],

        [[1.  ]],

        [[1.  ]],

        [[1.  ]],

        [[0.75]],

        [[0.25]],

        [[0.  ]],

        [[0.  ]],

        [[0.25]],

        [[0.75]],

        [[1.  ]]],


       [[[0.  ]],

        [[0.  ]],

        [[0.  ]],

        [[0.25]],

        [[0.75]],

        [[1.  ]],

        [[1.  ]],

        [[1.  ]],

        [[1.  ]],

        [[0.75]],

        [[0.25]],

        [[0.  ]],

        [[0.  ]],

        [[0.25]],

        [[0.75]],

        [[1.  ]]],


       [[[0.  ]],

        