<a href="https://colab.research.google.com/github/framoni/jup-notebooks/blob/master/BERT_SQuAD_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tuning BERT on SQuAD v2 dataset for question answering

This notebook reproduces from scratch and without concerns the fine-tuning of BERT on the SQuAD v2 dataset for question answering tasks using the models uploaded on TensorFlow Hub. Just run the notebook and don't worry about details!



## Import libraries

In [0]:
!pip install bert-tensorflow

import bert
from bert import modeling
from bert import optimization
from bert import run_squad
from bert import tokenization
import collections
import datetime
import json
import math
import os
import random
import six
import tensorflow as tf
import tensorflow_hub as hub

## Set root folder

In [0]:
# if use_drive = False, upload SQuAD "train-v2.0.json" "dev-v2.0.json" on Colab
# if use_drive = True, upload those files in a folder named "BERT SQuAD" in your Google Drive

use_drive = True

if use_drive:
  from google.colab import drive
  drive.mount('/content/drive')
  root  = '/content/drive/My Drive/BERT SQuAD v2'
else:
  root = ''

## Function to read samples

In [0]:
def read_examples(input_file, is_training):
  """
  This function is an extension of run_squad.read_squad_examples to allow for input of external text files as paragraphs
  """

  with tf.gfile.Open(input_file, "r") as reader:
    input_data = json.load(reader)["data"]

  def is_whitespace(c):
    if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
      return True
    return False

  examples = []
  for entry in input_data:
    for paragraph in entry["paragraphs"]:
      paragraph_text = paragraph["context"]
      if os.path.isfile(paragraph_text):
            with open(paragraph_text, "r", encoding="utf-8") as f:
                paragraph_text = f.read()
      doc_tokens = []
      char_to_word_offset = []
      prev_is_whitespace = True
      for c in paragraph_text:
        if is_whitespace(c):
          prev_is_whitespace = True
        else:
          if prev_is_whitespace:
            doc_tokens.append(c)
          else:
            doc_tokens[-1] += c
          prev_is_whitespace = False
        char_to_word_offset.append(len(doc_tokens) - 1)
      for qa in paragraph["qas"]:
        qas_id = qa["id"]
        question_text = qa["question"]
        start_position = None
        end_position = None
        orig_answer_text = None
        is_impossible = False
        if is_training:
          is_impossible = qa["is_impossible"]
          if (len(qa["answers"]) != 1) and (not is_impossible):
            raise ValueError(
                "For training, each question should have exactly 1 answer.")
          if not is_impossible:
            answer = qa["answers"][0]
            orig_answer_text = answer["text"]
            answer_offset = answer["answer_start"]
            answer_length = len(orig_answer_text)
            start_position = char_to_word_offset[answer_offset]
            end_position = char_to_word_offset[answer_offset + answer_length - 1]
            actual_text = " ".join(
                doc_tokens[start_position:(end_position + 1)])
            cleaned_answer_text = " ".join(
                tokenization.whitespace_tokenize(orig_answer_text))
            if actual_text.find(cleaned_answer_text) == -1:
              tf.logging.warning("Could not find answer: '%s' vs. '%s'",
                                 actual_text, cleaned_answer_text)
              continue
          else:
            start_position = -1
            end_position = -1
            orig_answer_text = ""

        example = run_squad.SquadExample(
            qas_id=qas_id,
            question_text=question_text,
            doc_tokens=doc_tokens,
            orig_answer_text=orig_answer_text,
            start_position=start_position,
            end_position=end_position,
            is_impossible=is_impossible)
        examples.append(example)

  return examples

## Settings and model parameters


In [0]:
# BERT-large

# bert_model_hub = 'https://tfhub.dev/google/bert_uncased_L-24_H-1024_A-16/1'

# BERT-small

bert_model_hub = 'https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1'

output_dir = 'output'
do_delete = True # whether do delete the content of output folder if it already exists
input_file = os.path.join(root, 'train-v2.0.json')

batch_size = 32
predict_batch_size = 8
learning_rate = 5e-5
num_train_epochs = 3.0
warmup_proportion = 0.1

max_seq_length = 128
max_query_length = 64
max_answer_length = 30

doc_stride = 128 # stride of the window sliding through the length of the document
null_score_diff_threshold = 0.0 # threshold on the difference (null_score - best_non_null) over which to predict null
n_best_size = 3 # number of n-best predictions to generate in the nbest_predictions.json output file

tf.logging.set_verbosity(tf.logging.WARN) # set TF verbosity

save_checkpoints_steps = 1000
save_summary_steps = 100

In [0]:
train_examples = read_examples(input_file=input_file, is_training=True)

# compute training steps

num_train_steps = int(len(train_examples) / batch_size * num_train_epochs)
num_warmup_steps = int(num_train_steps * warmup_proportion)

del train_examples

## Create output folder

In [0]:
if do_delete:
  try:
    tf.gfile.DeleteRecursively(output_dir)
  except:
    pass
tf.gfile.MakeDirs(output_dir)

## Create BERT tokenizer

In [0]:
def create_tokenizer_from_hub_module():
  """
  Get the vocab file and casing info from the Hub module
  """
  
  with tf.Graph().as_default():
    bert_module = hub.Module(bert_model_hub)
    tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
    with tf.Session() as sess:
      vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"], tokenization_info["do_lower_case"]])
      
  return [do_lower_case, bert.tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)]

[do_lower_case, tokenizer] = create_tokenizer_from_hub_module()

## Load training data

In [0]:
train_examples = read_examples(input_file=input_file, is_training=True)

# pre-shuffle the input to avoid having to make a very large shuffle buffer in the "input_fn"

rng = random.Random(12345)
rng.shuffle(train_examples)

# write to a temporary file to avoid storing very large constant tensors in memory

train_writer = run_squad.FeatureWriter(filename=os.path.join(output_dir, "train.tf_record"), is_training=True)

run_squad.convert_examples_to_features(
    examples=train_examples,
    tokenizer=tokenizer,
    max_seq_length=max_seq_length,
    doc_stride=doc_stride,
    max_query_length=max_query_length,
    is_training=True,
    output_fn=train_writer.process_feature)

train_writer.close()

del train_examples

## Functions to create and build the model

In [0]:
def model_fn_builder(learning_rate, num_train_steps, num_warmup_steps):
  ''' 
  Creates a model function using the passed parameters for learning_rate, etc.
  '''
  
  def model_fn(features, labels, mode, params):

    unique_ids = features["unique_ids"]
    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    
    is_predicting = (mode == tf.estimator.ModeKeys.PREDICT)
    
    if not is_predicting:
    
      start_positions = features["start_positions"]
      end_positions = features["end_positions"]
      
      (total_loss, start_logits, end_logits) = create_model(
          is_predicting=is_predicting, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, 
          start_positions=start_positions, end_positions=end_positions)
      
      train_op = optimization.create_optimizer(
          total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu=False)
      
      output_spec = tf.estimator.EstimatorSpec(mode=mode, loss=total_loss, train_op=train_op)
      
    elif mode == tf.estimator.ModeKeys.PREDICT:
      
      start_positions = None
      end_positions = None
      
      (start_logits, end_logits) = create_model(
          is_predicting=is_predicting, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, 
          start_positions=start_positions, end_positions=end_positions)
      
      predictions = {
          "unique_ids": unique_ids,
          "start_logits": start_logits,
          "end_logits": end_logits,
      }
      
      output_spec = tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
      
    else:
      raise ValueError(
          "Only TRAIN and PREDICT modes are supported: %s" % (mode))

    return output_spec
  
  return model_fn

def create_model(is_predicting, input_ids, input_mask, segment_ids, start_positions, end_positions):
  """
  Create a BERT model
  """

  bert_module = hub.Module(bert_model_hub, trainable=True)
  bert_inputs = dict(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids)
  bert_outputs = bert_module(inputs=bert_inputs, signature="tokens", as_dict=True)
  
  output_layer = bert_outputs["sequence_output"]
  output_layer_shape = modeling.get_shape_list(output_layer, expected_rank=3)
  
  batch_size = output_layer_shape[0]
  seq_length = output_layer_shape[1]
  hidden_size = output_layer_shape[2]

  output_weights = tf.get_variable("cls/squad/output_weights", [2, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02))
  output_bias = tf.get_variable("cls/squad/output_bias", [2], initializer=tf.zeros_initializer())
  
  with tf.variable_scope("total_loss"):
  
    # no dropout for question answering

    output_layer_matrix = tf.reshape(output_layer, [batch_size * seq_length, hidden_size])
  
    logits = tf.matmul(output_layer_matrix, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)
    logits = tf.reshape(logits, [batch_size, seq_length, 2])
    logits = tf.transpose(logits, [2, 0, 1])

    unstacked_logits = tf.unstack(logits, axis=0)

    (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])
    
    if is_predicting:
      return (start_logits, end_logits)
  
    def compute_loss(logits, positions):
      one_hot_positions = tf.one_hot(positions, depth=seq_length, dtype=tf.float32)
      log_probs = tf.nn.log_softmax(logits, axis=-1)
      loss = -tf.reduce_mean(tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
      return loss
  
    start_loss = compute_loss(start_logits, start_positions)
    end_loss = compute_loss(end_logits, end_positions)
    total_loss = (start_loss + end_loss) / 2.0

    return (total_loss, start_logits, end_logits)

## Build the model

In [0]:
# pass configuration to the model

run_config = tf.estimator.RunConfig(
    model_dir=output_dir,
    save_summary_steps=save_summary_steps,
    save_checkpoints_steps=save_checkpoints_steps)

model_fn = model_fn_builder(
    learning_rate=learning_rate,
    num_train_steps=num_train_steps,
    num_warmup_steps=num_warmup_steps)

estimator = tf.estimator.Estimator(
  model_fn=model_fn,
  config=run_config,
  params={"batch_size": batch_size})

## Train the model

In [0]:
# create an input function for training

train_input_fn = run_squad.input_fn_builder(
    input_file=train_writer.filename,
    seq_length=max_seq_length,
    is_training=True,
    drop_remainder=False)

# train the model

current_time = datetime.datetime.now()
print('Starting fine-tuning of BERT on SQuAD v2...')
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
print("Fine-tuning took time ", datetime.datetime.now() - current_time)

## Export best model to Google Drive

In [0]:
# TODO: determine automatically the best model

!mv "/content/output/model.ckpt-xxxxx.data-00000-of-00001" "/content/drive/My Drive/BERT SQuAD v2/output/"
!mv "/content/output/model.ckpt-xxxxx.index" "/content/drive/My Drive/BERT SQuAD v2/output/"
!mv "/content/output/model.ckpt-xxxxx.meta" "/content/drive/My Drive/BERT SQuAD v2/output/"
!mv "/content/output/graph.pbtxt" "/content/drive/My Drive/BERT SQuAD v2/output/"
!mv "/content/output/events.out.tfevents.yyyyyyyyyy.zzzzzzzzzzzz" "/content/drive/My Drive/BERT SQuAD v2/output/"
!mv "/content/output/checkpoint" "/content/drive/My Drive/BERT SQuAD v2/output/"

## Function to write predictions

In [0]:
def write_predictions(all_examples, all_features, all_results, n_best_size,
                      max_answer_length, do_lower_case, output_prediction_file,
                      output_nbest_file, output_null_log_odds_file):
  """
  Write final predictions to the json file and log-odds of null if needed
  """
  
  tf.logging.info("Writing predictions to: %s" % (output_prediction_file))
  tf.logging.info("Writing nbest to: %s" % (output_nbest_file))

  example_index_to_features = collections.defaultdict(list)
  for feature in all_features:
    example_index_to_features[feature.example_index].append(feature)

  unique_id_to_result = {}
  for result in all_results:
    unique_id_to_result[result.unique_id] = result

  _PrelimPrediction = collections.namedtuple(
      "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])

  all_predictions = collections.OrderedDict()
  all_nbest_json = collections.OrderedDict()
  scores_diff_json = collections.OrderedDict()

  for (example_index, example) in enumerate(all_examples):
    features = example_index_to_features[example_index]

    prelim_predictions = []
    # keep track of the minimum score of null start+end of position 0
    score_null = 1000000  # large and positive
    min_null_feature_index = 0  # the paragraph slice with min mull score
    null_start_logit = 0  # the start logit at the slice with min null score
    null_end_logit = 0  # the end logit at the slice with min null score
    for (feature_index, feature) in enumerate(features):
      result = unique_id_to_result[feature.unique_id]
      start_indexes = _get_best_indexes(result.start_logits, n_best_size)
      end_indexes = _get_best_indexes(result.end_logits, n_best_size)
      # if we could have irrelevant answers, get the min score of irrelevant
      feature_null_score = result.start_logits[0] + result.end_logits[0]
      if feature_null_score < score_null:
        score_null = feature_null_score
        min_null_feature_index = feature_index
        null_start_logit = result.start_logits[0]
        null_end_logit = result.end_logits[0]
      for start_index in start_indexes:
        for end_index in end_indexes:
          if start_index >= len(feature.tokens):
            continue
          if end_index >= len(feature.tokens):
            continue
          if start_index not in feature.token_to_orig_map:
            continue
          if end_index not in feature.token_to_orig_map:
            continue
          if not feature.token_is_max_context.get(start_index, False):
            continue
          if end_index < start_index:
            continue
          length = end_index - start_index + 1
          if length > max_answer_length:
            continue
          prelim_predictions.append(
              _PrelimPrediction(
                  feature_index=feature_index,
                  start_index=start_index,
                  end_index=end_index,
                  start_logit=result.start_logits[start_index],
                  end_logit=result.end_logits[end_index]))

    prelim_predictions.append(
        _PrelimPrediction(
            feature_index=min_null_feature_index,
            start_index=0,
            end_index=0,
            start_logit=null_start_logit,
            end_logit=null_end_logit))
    prelim_predictions = sorted(
        prelim_predictions,
        key=lambda x: (x.start_logit + x.end_logit),
        reverse=True)

    _NbestPrediction = collections.namedtuple(
        "NbestPrediction", ["text", "start_logit", "end_logit"])

    seen_predictions = {}
    nbest = []
    for pred in prelim_predictions:
      if len(nbest) >= n_best_size:
        break
      feature = features[pred.feature_index]
      if pred.start_index > 0:  # this is a non-null prediction
        tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
        orig_doc_start = feature.token_to_orig_map[pred.start_index]
        orig_doc_end = feature.token_to_orig_map[pred.end_index]
        orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
        tok_text = " ".join(tok_tokens)

        # de-tokenize WordPieces that have been split off
        tok_text = tok_text.replace(" ##", "")
        tok_text = tok_text.replace("##", "")

        # clean whitespace
        tok_text = tok_text.strip()
        tok_text = " ".join(tok_text.split())
        orig_text = " ".join(orig_tokens)

        final_text = get_final_text(tok_text, orig_text, do_lower_case)
        if final_text in seen_predictions:
          continue

        seen_predictions[final_text] = True
      else:
        final_text = ""
        seen_predictions[final_text] = True

      nbest.append(
          _NbestPrediction(
              text=final_text,
              start_logit=pred.start_logit,
              end_logit=pred.end_logit))

    # if we didn't inlude the empty option in the n-best, inlcude it
    if "" not in seen_predictions:
      nbest.append(
          _NbestPrediction(
              text="", start_logit=null_start_logit,
              end_logit=null_end_logit))
    # in very rare edge cases we could have no valid predictions. So we
    # just create a nonce prediction in this case to avoid failure
    if not nbest:
      nbest.append(
          _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))

    assert len(nbest) >= 1

    total_scores = []
    best_non_null_entry = None
    for entry in nbest:
      total_scores.append(entry.start_logit + entry.end_logit)
      if not best_non_null_entry:
        if entry.text:
          best_non_null_entry = entry

    probs = _compute_softmax(total_scores)

    nbest_json = []
    for (i, entry) in enumerate(nbest):
      output = collections.OrderedDict()
      output["text"] = entry.text
      output["probability"] = probs[i]
      output["start_logit"] = entry.start_logit
      output["end_logit"] = entry.end_logit
      nbest_json.append(output)

    assert len(nbest_json) >= 1

    # predict "" iff the null score - the score of best non-null > threshold
    score_diff = score_null - best_non_null_entry.start_logit - (
        best_non_null_entry.end_logit)
    scores_diff_json[example.qas_id] = score_diff
    if score_diff > null_score_diff_threshold:
      all_predictions[example.qas_id] = ""
    else:
      all_predictions[example.qas_id] = best_non_null_entry.text

    all_nbest_json[example.qas_id] = nbest_json

  with tf.gfile.GFile(output_prediction_file, "w") as writer:
    writer.write(json.dumps(all_predictions, indent=4) + "\n")

  with tf.gfile.GFile(output_nbest_file, "w") as writer:
    writer.write(json.dumps(all_nbest_json, indent=4) + "\n")

  with tf.gfile.GFile(output_null_log_odds_file, "w") as writer:
    writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
    
def _get_best_indexes(logits, n_best_size):
  """
  Get the n-best logits from a list
  """
  index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)

  best_indexes = []
  for i in range(len(index_and_score)):
    if i >= n_best_size:
      break
    best_indexes.append(index_and_score[i][0])
  return best_indexes

def _compute_softmax(scores):
  """
  Compute softmax probability over raw logits
  """
  
  if not scores:
    return []

  max_score = None
  for score in scores:
    if max_score is None or score > max_score:
      max_score = score

  exp_scores = []
  total_sum = 0.0
  for score in scores:
    x = math.exp(score - max_score)
    exp_scores.append(x)
    total_sum += x

  probs = []
  for score in exp_scores:
    probs.append(score / total_sum)
  return probs

def get_final_text(pred_text, orig_text, do_lower_case):
  """
  Project the tokenized prediction back to the original text
  """

  def _strip_spaces(text):
    ns_chars = []
    ns_to_s_map = collections.OrderedDict()
    for (i, c) in enumerate(text):
      if c == " ":
        continue
      ns_to_s_map[len(ns_chars)] = i
      ns_chars.append(c)
    ns_text = "".join(ns_chars)
    return (ns_text, ns_to_s_map)

  tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)

  tok_text = " ".join(tokenizer.tokenize(orig_text))

  start_position = tok_text.find(pred_text)
  if start_position == -1:
    tf.logging.info(
        "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
    return orig_text
  end_position = start_position + len(pred_text) - 1

  (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
  (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)

  if len(orig_ns_text) != len(tok_ns_text):
    tf.logging.info("Length not equal after stripping spaces: '%s' vs '%s'",
                    orig_ns_text, tok_ns_text)
    return orig_text

  tok_s_to_ns_map = {}
  for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
    tok_s_to_ns_map[tok_index] = i

  orig_start_position = None
  if start_position in tok_s_to_ns_map:
    ns_start_position = tok_s_to_ns_map[start_position]
    if ns_start_position in orig_ns_to_s_map:
      orig_start_position = orig_ns_to_s_map[ns_start_position]

  if orig_start_position is None:
    tf.logging.info("Couldn't map start position")
    return orig_text

  orig_end_position = None
  if end_position in tok_s_to_ns_map:
    ns_end_position = tok_s_to_ns_map[end_position]
    if ns_end_position in orig_ns_to_s_map:
      orig_end_position = orig_ns_to_s_map[ns_end_position]

  if orig_end_position is None:
    tf.logging.info("Couldn't map end position")
    return orig_text

  output_text = orig_text[orig_start_position:(orig_end_position + 1)]
  return output_text

## Make predictions on SQuAD v2.0 dev set

In [0]:
predict_file = os.path.join(root, 'dev-v2.0.json')
checkpoint_path = os.path.join(root, 'output/model.ckpt-xxxxx')
eval_examples = read_examples(input_file=predict_file, is_training=False)

eval_writer = run_squad.FeatureWriter(
    filename=os.path.join(output_dir, "eval.tf_record"),
    is_training=False)
eval_features = []

def append_feature(feature):
    eval_features.append(feature)
    eval_writer.process_feature(feature)

run_squad.convert_examples_to_features(
    examples=eval_examples,
    tokenizer=tokenizer,
    max_seq_length=max_seq_length,
    doc_stride=doc_stride,
    max_query_length=max_query_length,
    is_training=False,
    output_fn=append_feature)
eval_writer.close()

tf.logging.info("***** Running predictions *****")
tf.logging.info("  Num orig examples = %d", len(eval_examples))
tf.logging.info("  Num split examples = %d", len(eval_features))
tf.logging.info("  Batch size = %d", predict_batch_size)

predict_input_fn = run_squad.input_fn_builder(
    input_file=eval_writer.filename,
    seq_length=max_seq_length,
    is_training=False,
    drop_remainder=False)

RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"])
all_results = []
for result in estimator.predict(predict_input_fn, yield_single_examples=True, checkpoint_path=checkpoint_path):
  if len(all_results) % 1000 == 0:
    tf.logging.info("Processing example: %d" % (len(all_results)))
  unique_id = int(result["unique_ids"])
  start_logits = [float(x) for x in result["start_logits"].flat]
  end_logits = [float(x) for x in result["end_logits"].flat]
  all_results.append(RawResult(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits))

output_prediction_file = os.path.join(root, "predictions.json")
output_nbest_file = os.path.join(root, "nbest_predictions.json")
output_null_log_odds_file = os.path.join(root, "null_odds.json")

run_squad.write_predictions(eval_examples, eval_features, all_results,
                  n_best_size, max_answer_length,
                  do_lower_case, output_prediction_file,
                  output_nbest_file, output_null_log_odds_file)