In [3]:
import tensorflow as tf

class SpatialRNNCell(tf.keras.layers.Layer):
    def __init__(self, direction, **kwargs):
        super(SpatialRNNCell, self).__init__(**kwargs)
        self.direction = direction
        # In the literature, alpha is initialized as an identity matrix, but for simplicity, 
        # we're initializing it as a scalar which will have an equivalent effect of scaling the outputs.
        self.alpha = self.add_weight(shape=(1,), initializer="ones", trainable=True)

    def call(self, inputs):
        # Identify the shape of the input feature map.
        feature_map_shape = tf.shape(inputs)
        height = feature_map_shape[1]
        width = feature_map_shape[2]

        # Determine the number of recurrent steps based on the direction.
        num_steps = width if self.direction in ['left', 'right'] else height
        
        # Initialize the output with the input values.
        outputs = inputs
        
        # Perform the recurrent translation.
        for _ in range(num_steps):
            outputs = tf.roll(outputs, shift={'right': (0, 1), 'left': (0, -1),
                                              'up': (-1, 0), 'down': (1, 0)}[self.direction], axis=(1, 2))
            # Apply the alpha weight and ReLU activation at each step.
            outputs = self.alpha * outputs
            outputs = tf.nn.relu(outputs)
        
        return outputs


class AttentionEstimatorNetwork(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(AttentionEstimatorNetwork, self).__init__(**kwargs)
        self.conv1 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')
        self.conv2 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')
        self.conv3 = tf.keras.layers.Conv2D(4, (1, 1), activation=None, padding='same')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        return self.conv3(x)


class DirectionAwareSpatialContextModule(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(DirectionAwareSpatialContextModule, self).__init__(**kwargs)
        self.attention_network = AttentionEstimatorNetwork()
        self.spatial_rnn_cells = {
            direction: SpatialRNNCell(direction=direction)
            for direction in ['right', 'left', 'up', 'down']
        }
        # Layer for the hidden-to-hidden translation
        self.hidden_to_hidden_conv = tf.keras.layers.Conv2D(32, (1, 1), activation='relu')

    def call(self, inputs):
        # Compute attention weights
        W = self.attention_network(inputs)
        W_split = tf.split(W, num_or_size_splits=4, axis=-1)

        # Apply spatial RNN in four directions for the first round
        context_features_1st_round = {
            direction: rnn_cell(inputs)
            for direction, rnn_cell in self.spatial_rnn_cells.items()
        }

        # Multiply first-round context features with attention weights
        weighted_context_features_1st_round = [
            context_features_1st_round[dir] * W_split[i]
            for i, dir in enumerate(['right', 'left', 'up', 'down'])
        ]

        # Concatenate features and apply 1x1 convolution to reduce dimensionality
        concatenated_features_1st_round = tf.concat(weighted_context_features_1st_round, axis=-1)
        reduced_features_1st_round = self.hidden_to_hidden_conv(concatenated_features_1st_round)

        # Apply spatial RNN in four directions for the second round
        context_features_2nd_round = {
            direction: rnn_cell(reduced_features_1st_round)
            for direction, rnn_cell in self.spatial_rnn_cells.items()
        }

        # Multiply second-round context features with the SAME attention weights
        weighted_context_features_2nd_round = [
            context_features_2nd_round[dir] * W_split[i]
            for i, dir in enumerate(['right', 'left', 'up', 'down'])
        ]

        # Concatenate features from the second round
        concatenated_features_2nd_round = tf.concat(weighted_context_features_2nd_round, axis=-1)

        # Final 1x1 convolution followed by ReLU as per the literature
        output_dsc_features = self.hidden_to_hidden_conv(concatenated_features_2nd_round)

        return output_dsc_features
