<a href="https://colab.research.google.com/github/jojivk/The-Ramp/blob/master/bertmodel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
import gin
import tensorflow as tf
import tensorflow_hub as hub

class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
  """Returns layer that computes custom loss and metrics for pretraining"""
  
  def __init__(self, vocab_size, **kwargs):
    super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
    self._vocab_size = vocab_size
    self.config = { 'vocab_size':vocab_size,}

  def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
                   lm_example_loss, sentence_output, sentence_labels,
                   next_sentence_loss):
    """Add metrics"""
    masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(lm_labels,
                                                                      lm_output)
    numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
    denominator = tf.reduce_sum(lm_label_weights) + 1e-5
    masked_lm_accuracy = numerator/denominator
    self.add_metric(masked_lm_accuracy, name='masked_lm_accuracy',
                    aggregation='mean')
    if sentence_labels is not None:
      next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
                               sentence_labesl, sentence_output)
      self.add_metric(next_sentence_accuracy, 
                      name = 'next_sentence_accuracy',
                      aggregation='mean')
      
      if next_sentence_loss is not None:
        self.add_metric(next_sentence_loss,
                        name='next_sentence_loss',
                        aggregation='mean')
  
  def call(self,
           lm_output_logits,
           sentence_output_logits,
           lm_label_ids,
           lm_label_weights,
           sentence_labels=None):
    """Implements call() for the layer"""
    lm_label_weights = tf.cast(lm_label_weights, tf.float32)
    lm_output_logits = tf.cast(lm_output_logits, tf.float32)

    lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
        lm_label_ids, lm_output_logits, from_logits=True)
    lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
    lm_denominator_loss = tf.reduce_sum(lm_label_weights)
    mask_label_loss =tf.math.divide_no_nan(lm_numerator_loss,
                                           lm_denominator_loss)
    
    if sentence_labels is not None:
      sentence_output_logits = tf.cast(sentence_output_logits, tf.float32)
      sentence_loss =tf.keras.losses.sparse_categorical_crossentropy(
          sentence_labels, sentence_output_logits, from_logits=True
      )
      sentence_loss = tf.reduce_mean(sentence_loss)
      loss = mask_label_loss + sentence_loss
    else :
      sentence_loss = None
      loss = mask_label_loss

    batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
    final_loss = tf.fill(batch_shape, loss)

    self._add_metrics(lm_output_logits, lm_label_ids,
                      lm_label_weights, mask_label_loss,
                      sentence_output_logits, sentence_labels
                      sentence_loss)
    return final_loss


  


In [4]:
@gin.configurable
def get_transformer_encoder(bert_config,
                            transformer_encoder_cls=None,
                            output_range=None):
  
  """Gets a transformer encoder object

  Args:
        bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
    sequence_length: [Deprecated].
    transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
      default BERT encoder implementation.
    output_range: the sequence output range, [0, output_range). Default setting
      is to return the entire sequence output.
  Returns:
    A encoder object.
  """

  if transformer_encoder_cls is not None:
    embedding_cfg = dict(
        vocab_size=bert-config.vocab_size,
        type_vocab_size = bert_config.type_vocab_size,
        hidden_size = bert_config.hidden_size,
        max_seq_length = bert_config.max_position_embeddings,
        inititalizer=tf.keras.inititalizers.TrucatedNormal(
            stddev-bert_config.inititalizer_range),
            dropput_rate=bert_config.hidden_dropout_prob,
    )
    hidden_cfg = dict(
        num_attention_heads=bert_config.num_attention_heads,
        intermediate_size = bert_config.intermediate_size,
        intermediate_activation = tf_utils.get_activation(bert_config.hidden_act),
        dropout_rate = bert_config.hidden_dropout_prob,
        attention_dropout_rate= bert_config.attention_probs_dropout_prob,
        kernel_initializer=tf.kers.initializers.TruncatedNormal(
            stddev=bert_config.initailizer_range,
        )
    )
    kwargs = dict(
        embedding_cfg = enbedding_cfg,
        hidden_cfg=hidden_cfg,
        num_hidden_instances=bert_config.num_hidden_layers,
        pooled_output_dim = bert_config.hidden_size,
        pooled_layer_intializer=tf.keras.initializers.TrucatedNormal(
            stddev=bert_config.inititalizer_range
        )
    )

    # Relies on gin configuration to define the Transformer
    # encoder arguments
    return transformer_encoder_cls(**kwargs)

  kwargs = dict(
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
      activation=tf_utils.get_activation(bert_config.hidden_act),
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
      embedding_width=bert_config.embedding_size,
      initializer=tf.keras.initializers.TruncatedNormal(
          stddev=bert_config.initializer_range))
  
  if isinstance(bert_config, albert_configs.AlbertConfig):
    return networks.AlbertEncoder(**kwargs)
  else:
    assert isinstance(bert_config, configs.BertConfig)
    kwargs['output_range'] = output_range
    return networks.BertEncoder(**kwargs)

  

NameError: ignored