In [None]:
import bert
from bert import run_pretraining
from bert import optimization
import tensorflow as tf
import tensorflow_hub as hub
from datetime import datetime
from sklearn import metrics
logger = tf.get_logger()
logger.propagate = False

## Pretrained weights
The pretrained weights will be taken from tfhub.dev.

In [None]:
bert_model_hub = "https://tfhub.dev/google/small_bert/bert_uncased_L-4_H-256_A-4/1"
output_dir = "finetuned_weights/in_task_pretraining/"
max_seq_len = 256  # The pre-training task will typically contain 2 segments of text i.e. 2* 128
max_predictions_per_seq = 40  # The MLM task will have at most these many masked tokens per example
train_files = ["datasets/in_task_pretraining/train/mlm_max_seq_len_256.tfrecord"]
dev_files = ["datasets/in_task_pretraining/dev/mlm_max_seq_len_256.tfrecord"]
tf.gfile.MakeDirs(output_dir)

# BERT + MLM and NSP layers
Take the pretrained BERT model and set it up for masked language modelling and next sentence prediction (MLM + NSP)

In [None]:
def get_masked_lm_prediction(sequence_output, masked_lm_positions):
    """This implementation follows the original from
    https://github.com/google-research/bert/blob/master/run_pretraining.py#L240
    But makes it work with the tfhub model
    """
    # Reuse the embedding matrix in the output projection
    # The word embedding table is the first trainable tensor that gets created
    word_embeddings = tf.trainable_variables()[0]
    print("Re-useing embedding matrix: ", word_embeddings)
    assert "word_embeddings" in word_embeddings.name
    
    # Slice out the masked positions
    input_tensor = run_pretraining.gather_indexes(sequence_output, masked_lm_positions)
        
    # Use the embedding matrix to make the prediction and add a bias term
    output_bias = tf.get_variable("output_bias", shape=[word_embeddings.shape.as_list()[0]], initializer=tf.zeros_initializer())
    logits = tf.matmul(input_tensor, word_embeddings, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)
   
    return logits
    
    
def get_next_sentence_prediction(pooled_output):
    """This implementation follows the original from
    https://github.com/google-research/bert/blob/master/run_pretraining.py#L285
    But makes it work with the tfhub model
    """

    hidden_size = pooled_output.shape.as_list()[-1]
    with tf.variable_scope("next_sentence_prediction"):
        A = tf.get_variable("weights", [hidden_size, 2], initializer=tf.glorot_uniform_initializer()) 
        bias = tf.get_variable("bias", [2], initializer=tf.zeros_initializer())
        logits = tf.nn.xw_plus_b(pooled_output, A, bias)
    
    return logits


def create_model(is_predicting, inputs):
    input_ids, input_mask, segment_ids, masked_lm_positions = inputs
    bert_module = hub.Module(bert_model_hub, trainable=True)
    bert_inputs = dict(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids)
    bert_outputs = bert_module(inputs=bert_inputs, signature="tokens", as_dict=True)
    

    # Use "sequence_output" for the MLM task and add a layer on top
    masked_lm_logits = get_masked_lm_prediction(bert_outputs["sequence_output"], masked_lm_positions)
    # Use "pooled_output" for NSP task and add a layer on top
    next_sentence_logits = get_next_sentence_prediction(bert_outputs["pooled_output"])
    
    
    return masked_lm_logits, next_sentence_logits

It is convenient to wrap this model into a tensorflow estimator which automates the training loop for us. 

In [None]:
def model_fn_builder(learning_rate, num_train_steps, num_warmup_steps):
    def model_fn(features, labels, mode, params):
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        
        # Additional inputs for masked language modelling and next sentence prediction
        masked_lm_positions = features["masked_lm_positions"]
        
        # Flatten the label information to treat batch as 1 sequence
        masked_lm_weights = tf.reshape(features["masked_lm_weights"], [-1])
        masked_lm_ids = tf.reshape(features["masked_lm_ids"], [-1])
        next_sentence_labels = tf.reshape(features["next_sentence_labels"], [-1])

        is_predicting = (mode == tf.estimator.ModeKeys.PREDICT)

        # TRAIN and EVAL
        if not is_predicting:
            
            # Model definition
            inputs = [input_ids, input_mask, segment_ids, masked_lm_positions]
            masked_lm_logits, next_sentence_logits = create_model(is_predicting, inputs)
            masked_lm_predictions = tf.argmax(masked_lm_logits, axis=-1)
            next_sentence_predictions = tf.argmax(next_sentence_logits, axis=-1)
            
            # Losses
            masked_lm_loss = tf.keras.losses.sparse_categorical_crossentropy(
                                    masked_lm_ids, masked_lm_logits, from_logits=True)
            masked_lm_loss = masked_lm_loss*masked_lm_weights
            masked_lm_loss = tf.reduce_mean(masked_lm_loss)
            
            next_sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
                                    next_sentence_labels, next_sentence_logits, from_logits=True)
            next_sentence_loss = tf.reduce_mean(next_sentence_loss)
            
            # The training loss is the sum of the masked language model and next sentence prediciton losses
            loss = masked_lm_loss + next_sentence_loss 
            train_op = bert.optimization.create_optimizer(loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu=False)
            
            
            # Summaries
            tf.summary.scalar("masked_lm_cross_entropy_loss", masked_lm_loss)
            tf.summary.scalar("next_sentence_cross_entropy_loss", next_sentence_loss)
            masked_lm_accuracy, accuracy_op_0 = tf.metrics.accuracy(next_sentence_labels, next_sentence_predictions)
            next_sentence_accuracy, accuracy_op_1 = tf.metrics.accuracy(next_sentence_labels, next_sentence_predictions)
            
            
            with tf.control_dependencies([accuracy_op_0, accuracy_op_1]):
                tf.summary.scalar("masked_lm_accuracy", masked_lm_accuracy)
                tf.summary.scalar("next_sentence_accuracy", next_sentence_accuracy)
                
            # Extract the learning rate from the graph for logging
            for o in tf.get_default_graph().get_operations():
                if "PolynomialDecay" == o.name:
                    print(o.name)
                    lr = o.values()[0]                    
            tf.summary.scalar("learning_rate", lr)
            
            if mode == tf.estimator.ModeKeys.TRAIN:
                return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
            else:
                # Calculate evaluation metrics. 
                eval_metrics = {}
                eval_metrics["masked_lm_accuracy"] = tf.metrics.accuracy(next_sentence_labels, next_sentence_predictions)
                eval_metrics["next_sentence_accuracy"] = tf.metrics.accuracy(next_sentence_labels, next_sentence_predictions)
                return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metrics)
        else:
            inputs = [input_ids, input_mask, segment_ids, masked_lm_positions]
            masked_lm_logits, next_sentence_logits = create_model(is_predicting, inputs)
            predictions = {'masked_lm_logts': masked_lm_logits, 'masked_lm_ids' : masked_lm_ids}
            return tf.estimator.EstimatorSpec(mode, predictions=predictions)

    # Return the actual model function in the closure
    return model_fn


## Training the model

In [None]:
# Compute train and warmup steps from batch size
batch_size = 64
learning_rate = 5e-5
num_train_steps = 800000//batch_size # there are about 800k examples in the dataset
num_warmup_steps = 0

# Specify output directory and number of checkpoint steps to save
run_config = tf.estimator.RunConfig(model_dir=output_dir, save_summary_steps=10,
                                    save_checkpoints_steps=500, keep_checkpoint_max=2)

model_fn = model_fn_builder( learning_rate=learning_rate, num_train_steps=num_train_steps,
                            num_warmup_steps=num_warmup_steps)

estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config, params={"batch_size": batch_size})

# Load the in-task pretraining data
train_input_fn = bert.run_pretraining.input_fn_builder(train_files, max_seq_len,
                                                       max_predictions_per_seq, is_training=True)

print(f"Training for {num_train_steps} steps")
current_time = datetime.now()
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
print("Training took time ", datetime.now() - current_time)