<a href="https://colab.research.google.com/github/contextrcnn2/Context-R-CNN/blob/main/library1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
!pip install tf_slim
import tf_slim as slim




In [None]:
_PADDING_VALUE = -100000

def filter_weight_value(weights, values, valid_mask):
  v_batch_size, v_context_size, _ = values.shape 
  w_batch_size, _, w_context_size = weights.shape
  m_batch_size, m_context_size = valid_mask.shape
  if v_batch_size != m_batch_size or w_batch_size != v_batch_size:
    raise ValueError("please make the first dimensions same")

  if w_context_size != v_context_size:
    raise ValueError("Please make the third dimension of weights same as"
                     " the second dimension of values.")
  if w_context_size != m_context_size:
    raise ValueError("Please make sure the third dimension of the weights"
                     " matches the second dimension of the valid_mask.")
  valid_mask = valid_mask[..., tf.newaxis]
  very_negative_mask = tf.ones(
      weights.shape, dtype=weights.dtype) *_PADDING_VALUE
  valid_weight_mask = tf.tile(tf.transpose(valid_mask, perm=[0, 2, 1]),
                              [1, weights.shape[1], 1])
  weights = tf.where(valid_weight_mask,
                     x=weights, y=very_negative_mask)

  values *= tf.cast(valid_mask, values.dtype)

  return weights, values
  



In [None]:
def compute_valid_mask(num_valid_elements, num_elements):
  batch_size = num_valid_elements.shape[0]
  element_idxs = tf.range(num_elements, dtype=tf.int32)
  batch_element_idxs = tf.tile(element_idxs[tf.newaxis, ...], [batch_size, 1])
  num_valid_elements = num_valid_elements[..., tf.newaxis]
  valid_mask = tf.less(batch_element_idxs, num_valid_elements)
  return valid_mask

In [None]:
def project_features(features, projection_dimension, is_training, normalize):
  batch_norm_params = {
      "is_training": is_training,
      "epsilon": 0.001,
      "decay": 0.97,
      "center": True,
      "scale": True}
  batch_size, _, num_features = features.shape
  features = tf.reshape(features, [-1, num_features])
  projected_features = slim.fully_connected(
      features,
      num_outputs=projection_dimension,
      activation_fn=tf.nn.relu6,
      normalizer_fn=slim.batch_norm,
      normalizer_params=batch_norm_params)

  projected_features = tf.reshape(projected_features,
                                  [batch_size, -1, projection_dimension])
  if normalize:
    projected_features = tf.math.l2_normalize(projected_features, axis=-1)
  return projected_features


In [None]:
def attention_block(input_features, context_features, bottleneck_dimension,
                    output_dimension, attention_temperature,
                    keys_values_valid_mask, queries_valid_mask,
                    is_training, block_name="AttentionBlock"):
 with tf.variable_scope(block_name):
    queries = project_features(
        input_features, bottleneck_dimension, is_training, normalize=True)
    keys = project_features(
        context_features, bottleneck_dimension, is_training, normalize=True)
    values = project_features(
        context_features, bottleneck_dimension, is_training, normalize=True)
    queries *= tf.cast(queries_valid_mask[..., tf.newaxis], queries.dtype)
    keys *= tf.cast(keys_values_valid_mask[..., tf.newaxis], keys.dtype)
    weights = tf.matmul(queries, keys, transpose_b=True)
    weights, values = filter_weight_value(weights, values,
                                          keys_values_valid_mask)
    weights = tf.identity(tf.nn.softmax(weights / attention_temperature),
                          name=block_name+"AttentionWeights")
    features = tf.matmul(weights, values)

    output_features = project_features(
      features, output_dimension, is_training, normalize=False)
    return output_features

In [None]:
def _compute_box_context_attention(box_features, num_proposals,
                                   context_features, valid_context_size,
                                   bottleneck_dimension,
                                   attention_temperature, is_training,
                                   max_num_proposals,
                                   use_self_attention=False,
                                   use_long_term_attention=True,
                                   self_attention_in_sequence=False,
                                   num_attention_heads=1,
                                   num_attention_layers=1):
  _, context_size, _ = context_features.shape
  context_valid_mask = compute_valid_mask(valid_context_size, context_size)

  total_proposals, height, width, channels = box_features.shape

  batch_size = total_proposals // max_num_proposals
  box_features = tf.reshape(
      box_features,
      [batch_size,
       max_num_proposals,
       height,
       width,
       channels])
  box_features = tf.reduce_mean(box_features, [2, 3])
  box_valid_mask = compute_valid_mask(
      num_proposals,
      box_features.shape[1])

  if use_self_attention:
    self_attention_box_features = attention_block(
        box_features, box_features, bottleneck_dimension, channels.value,
        attention_temperature, keys_values_valid_mask=box_valid_mask,
        queries_valid_mask=box_valid_mask, is_training=is_training,
        block_name="SelfAttentionBlock")

  if use_long_term_attention:
    if use_self_attention and self_attention_in_sequence:
      input_features = tf.add(self_attention_box_features, box_features)
      input_features = tf.divide(input_features, 2)
    else:
      input_features = box_features
    original_input_features = input_features
    for jdx in range(num_attention_layers):
      layer_features = tf.zeros_like(input_features)
      for idx in range(num_attention_heads):
        block_name = "AttentionBlock" + str(idx) + "_AttentionLayer" +str(jdx)
        attention_features = attention_block(
            input_features,
            context_features,
            bottleneck_dimension,
            channels.value,
            attention_temperature,
            keys_values_valid_mask=context_valid_mask,
            queries_valid_mask=box_valid_mask,
            is_training=is_training,
            block_name=block_name)
        layer_features = tf.add(layer_features, attention_features)
      layer_features = tf.divide(layer_features, num_attention_heads)
      input_features = tf.add(input_features, layer_features)
    output_features = tf.add(input_features, original_input_features)
    if not self_attention_in_sequence and use_self_attention:
      output_features = tf.add(self_attention_box_features, output_features)
  elif use_self_attention:
    output_features = self_attention_box_features
  else:
    output_features = tf.zeros(self_attention_box_features.shape)
  output_features = output_features[:, :, tf.newaxis, tf.newaxis, :]
  return output_features