<a href="https://colab.research.google.com/github/contextrcnn2/Context-R-CNN/blob/main/library2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf

_PADDING_VALUE = -100000

class FreezableBatchNorm(tf.keras.layers.BatchNormalization):
  def __init__(self, training=None, **kwargs):
    super(FreezableBatchNorm, self).__init__(**kwargs)
    self._training = training

  def call(self, inputs, training=None):
    if self._training is False: 
      training = self._training
    return super(FreezableBatchNorm, self).call(inputs, training=training)

class ContextProjection(tf.keras.layers.Layer):
  def __init__(self, projection_dimension, **kwargs):
    self.batch_norm = FreezableBatchNorm(
        epsilon=0.001,
        center=True,
        scale=True,
        momentum=0.97,
        trainable=True)
    self.projection = tf.keras.layers.Dense(units=projection_dimension,
                                            use_bias=True)
    self.projection_dimension = projection_dimension
    super(ContextProjection, self).__init__(**kwargs)

  def build(self, input_shape):
    self.projection.build(input_shape)
    self.batch_norm.build(input_shape[:1] + [self.projection_dimension])

  def call(self, input_features, is_training=False):
    return tf.nn.relu6(self.batch_norm(self.projection(input_features),
                                       is_training))

In [None]:
class AttentionBlock(tf.keras.layers.Layer):
  def __init__(self, bottleneck_dimension, attention_temperature,
               output_dimension=None, is_training=False,
               name='AttentionBlock', max_num_proposals=100,**kwargs):
    
    self._key_proj = ContextProjection(bottleneck_dimension)
    self._val_proj = ContextProjection(bottleneck_dimension)
    self._query_proj = ContextProjection(bottleneck_dimension)
    self._feature_proj = None
    self._attention_temperature = attention_temperature
    self._bottleneck_dimension = bottleneck_dimension
    self._is_training = is_training
    self._output_dimension = output_dimension
    self._max_num_proposals = max_num_proposals
    if self._output_dimension:
      self._feature_proj = ContextProjection(self._output_dimension)
    super(AttentionBlock, self).__init__(name=name, **kwargs)

  def build(self, input_shapes):
    
    if not self._feature_proj:
      self._output_dimension = input_shapes[-1]
      self._feature_proj = ContextProjection(self._output_dimension)

  def call(self, box_features, context_features, valid_context_size,
           num_proposals):
   
    _, context_size, _ = context_features.shape
    keys_values_valid_mask = compute_valid_mask(
        valid_context_size, context_size)

    total_proposals, height, width, channels = box_features.shape
    batch_size = total_proposals // self._max_num_proposals
    box_features = tf.reshape(
        box_features,
        [batch_size,
         self._max_num_proposals,
         height,
         width,
         channels])

    box_features = tf.reduce_mean(box_features, [2, 3])

    queries_valid_mask = compute_valid_mask(num_proposals,
                                            box_features.shape[1])
    queries = project_features(
        box_features, self._bottleneck_dimension, self._is_training,
        self._query_proj, normalize=True)
    keys = project_features(
        context_features, self._bottleneck_dimension, self._is_training,
        self._key_proj, normalize=True)
    values = project_features(
        context_features, self._bottleneck_dimension, self._is_training,
        self._val_proj, normalize=True)

    keys *= tf.cast(keys_values_valid_mask[..., tf.newaxis], keys.dtype)
    queries *= tf.cast(queries_valid_mask[..., tf.newaxis], queries.dtype)

    weights = tf.matmul(queries, keys, transpose_b=True)
    weights, values = filter_weight_value(weights, values,
                                          keys_values_valid_mask)
    weights = tf.nn.softmax(weights / self._attention_temperature)

    features = tf.matmul(weights, values)
    output_features = project_features(
        features, self._output_dimension, self._is_training,
        self._feature_proj, normalize=False)
    output_features = output_features[:, :, tf.newaxis, tf.newaxis, :]
    return output_features

In [None]:
def filter_weight_value(weights, values, valid_mask):
  w_batch_size, _, w_context_size = weights.shape
  v_batch_size, v_context_size, _ = values.shape
  m_batch_size, m_context_size = valid_mask.shape
  if w_batch_size != v_batch_size or v_batch_size != m_batch_size:
    raise ValueError('Please make sure the first dimension of the input'
                     ' tensors are the same.')

  if w_context_size != v_context_size:
    raise ValueError('Please make sure the third dimension of weights matches'
                     ' the second dimension of values.')

  if w_context_size != m_context_size:
    raise ValueError('Please make sure the third dimension of the weights'
                     ' matches the second dimension of the valid_mask.')

  valid_mask = valid_mask[..., tf.newaxis]
  weights += tf.transpose(
      tf.cast(tf.math.logical_not(valid_mask), weights.dtype) *
      _PADDING_VALUE,
      perm=[0, 2, 1])
  values *= tf.cast(valid_mask, values.dtype)

  return weights, values



In [None]:
def project_features(features, bottleneck_dimension, is_training,
                     layer, normalize=True):
  shape_arr = features.shape
  batch_size, _, num_features = shape_arr
  features = tf.reshape(features, [-1, num_features])

  projected_features = layer(features, is_training)

  projected_features = tf.reshape(projected_features,
                                  [batch_size, -1, bottleneck_dimension])

  if normalize:
    projected_features = tf.keras.backend.l2_normalize(projected_features,
                                                       axis=-1)

  return projected_features


In [None]:
def compute_valid_mask(num_valid_elements, num_elements):
  batch_size = num_valid_elements.shape[0]
  element_idxs = tf.range(num_elements, dtype=tf.int32)
  batch_element_idxs = tf.tile(element_idxs[tf.newaxis, ...], [batch_size, 1])
  num_valid_elements = num_valid_elements[..., tf.newaxis]
  valid_mask = tf.less(batch_element_idxs, num_valid_elements)
  return valid_mask