In [1]:
import collections
import tokenization
import six
import copy
import math
import numpy as np
import tensorflow as tf

from utils import tf_utils
from albert import AlbertConfig, AlbertModel

In [2]:
class ALBertQALayer(tf.keras.layers.Layer):
    """Layer computing position and is_possible for question answering task."""

    def __init__(self, hidden_size, start_n_top, end_n_top, initializer, dropout, **kwargs):
        """Constructs Summarization layer.
        Args:
          hidden_size: Int, the hidden size.
          start_n_top: Beam size for span start.
          end_n_top: Beam size for span end.
          initializer: Initializer used for parameters.
          dropout: float, dropout rate.
          **kwargs: Other parameters.
        """
        super(ALBertQALayer, self).__init__(**kwargs)
        self.hidden_size = hidden_size
        self.start_n_top = start_n_top
        self.end_n_top = end_n_top
        self.initializer = initializer
        self.dropout = dropout

    def build(self, unused_input_shapes):
        """Implements build() for the layer."""
        self.start_logits_proj_layer = tf.keras.layers.Dense(
            units=1, kernel_initializer=self.initializer, name='start_logits/dense')
        self.end_logits_proj_layer0 = tf.keras.layers.Dense(
            units=self.hidden_size,
            kernel_initializer=self.initializer,
            activation=tf.nn.tanh,
            name='end_logits/dense_0')
        self.end_logits_proj_layer1 = tf.keras.layers.Dense(
            units=1, kernel_initializer=self.initializer, name='end_logits/dense_1')
        self.end_logits_layer_norm = tf.keras.layers.LayerNormalization(
            axis=-1, epsilon=1e-12, name='end_logits/LayerNorm')
        self.answer_class_proj_layer0 = tf.keras.layers.Dense(
            units=self.hidden_size,
            kernel_initializer=self.initializer,
            activation=tf.nn.tanh,
            name='answer_class/dense_0')
        self.answer_class_proj_layer1 = tf.keras.layers.Dense(
            units=1,
            kernel_initializer=self.initializer,
            use_bias=False,
            name='answer_class/dense_1')
        self.ans_feature_dropout = tf.keras.layers.Dropout(rate=self.dropout)
        super(ALBertQALayer, self).build(unused_input_shapes)

    def __call__(self,
                 sequence_output,
                 p_mask,
                 cls_index,
                 start_positions=None,
                 **kwargs):
        inputs = tf_utils.pack_inputs(
            [sequence_output, p_mask, cls_index, start_positions])
        return super(ALBertQALayer, self).__call__(inputs, **kwargs)


    def call(self, inputs, **kwargs):
        """Implements call() for the layer."""
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        sequence_output = unpacked_inputs[0]
        p_mask = unpacked_inputs[1]
        cls_index = unpacked_inputs[2]
        start_positions = unpacked_inputs[3]

        _, seq_len, _ = sequence_output.shape.as_list()
        sequence_output = tf.transpose(sequence_output, [1, 0, 2])

        start_logits = self.start_logits_proj_layer(sequence_output)
        start_logits = tf.transpose(tf.squeeze(start_logits, -1), [1, 0])
        start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask
        start_log_probs = tf.nn.log_softmax(start_logits_masked, -1)

        if kwargs.get("training", False):
            # during training, compute the end logits based on the
            # ground truth of the start position
            start_positions = tf.reshape(start_positions, [-1])
            start_index = tf.one_hot(start_positions, depth=seq_len, axis=-1,
                                     dtype=tf.float32)
            start_features = tf.einsum(
                'lbh,bl->bh', sequence_output, start_index)
            start_features = tf.tile(start_features[None], [seq_len, 1, 1])
            end_logits = self.end_logits_proj_layer0(
                tf.concat([sequence_output, start_features], axis=-1))

            end_logits = self.end_logits_layer_norm(end_logits)

            end_logits = self.end_logits_proj_layer1(end_logits)
            end_logits = tf.transpose(tf.squeeze(end_logits, -1), [1, 0])
            end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask
            end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
        else:
            start_top_log_probs, start_top_index = tf.nn.top_k(
                start_log_probs, k=self.start_n_top)
            start_index = tf.one_hot(
                start_top_index, depth=seq_len, axis=-1, dtype=tf.float32)
            start_features = tf.einsum(
                'lbh,bkl->bkh', sequence_output, start_index)
            end_input = tf.tile(sequence_output[:, :, None], [
                                1, 1, self.start_n_top, 1])
            start_features = tf.tile(start_features[None], [seq_len, 1, 1, 1])
            end_input = tf.concat([end_input, start_features], axis=-1)
            end_logits = self.end_logits_proj_layer0(end_input)
            end_logits = tf.reshape(end_logits, [seq_len, -1, self.hidden_size])
            end_logits = self.end_logits_layer_norm(end_logits)

            end_logits = tf.reshape(end_logits,
                                    [seq_len, -1, self.start_n_top, self.hidden_size])

            end_logits = self.end_logits_proj_layer1(end_logits)
            end_logits = tf.reshape(
                end_logits, [seq_len, -1, self.start_n_top])
            end_logits = tf.transpose(end_logits, [1, 2, 0])
            end_logits_masked = end_logits * (
                1 - p_mask[:, None]) - 1e30 * p_mask[:, None]
            end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
            end_top_log_probs, end_top_index = tf.nn.top_k(
                end_log_probs, k=self.end_n_top)
            end_top_log_probs = tf.reshape(end_top_log_probs,
                                           [-1, self.start_n_top * self.end_n_top])
            end_top_index = tf.reshape(end_top_index,
                                       [-1, self.start_n_top * self.end_n_top])

        # an additional layer to predict answerability

        # get the representation of CLS
        cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32)
        cls_feature = tf.einsum('lbh,bl->bh', sequence_output, cls_index)

        # get the representation of START
        start_p = tf.nn.softmax(start_logits_masked,
                                axis=-1, name='softmax_start')
        start_feature = tf.einsum('lbh,bl->bh', sequence_output, start_p)

        ans_feature = tf.concat([start_feature, cls_feature], -1)
        ans_feature = self.answer_class_proj_layer0(ans_feature)
        ans_feature = self.ans_feature_dropout(
            ans_feature, training=kwargs.get('training', False))
        cls_logits = self.answer_class_proj_layer1(ans_feature)
        cls_logits = tf.squeeze(cls_logits, -1)

        if kwargs.get("training", False):
            return (start_log_probs, end_log_probs, cls_logits)
        else:
            return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)


class ALBertQAModel(tf.keras.Model):

    def __init__(self, albert_config, max_seq_length, init_checkpoint, start_n_top, end_n_top, dropout=0.1, **kwargs):
        super(ALBertQAModel, self).__init__(**kwargs)
        self.albert_config = copy.deepcopy(albert_config)
        self.initializer = tf.keras.initializers.TruncatedNormal(
            stddev=self.albert_config.initializer_range)
        float_type = tf.float32

        input_word_ids = tf.keras.layers.Input(
            shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
        input_mask = tf.keras.layers.Input(
            shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
        input_type_ids = tf.keras.layers.Input(
            shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')

        albert_layer = AlbertModel(config=albert_config, float_type=float_type)

        _, sequence_output = albert_layer(
            input_word_ids, input_mask, input_type_ids)

        self.albert_model = tf.keras.Model(inputs=[input_word_ids, input_mask, input_type_ids],
                                           outputs=[sequence_output])
        if init_checkpoint != None:
            self.albert_model.load_weights(init_checkpoint)

        self.qalayer = ALBertQALayer(self.albert_config.hidden_size, start_n_top, end_n_top,
                                     self.initializer, dropout)

    def call(self, inputs, **kwargs):
        # unpacked_inputs = tf_utils.unpack_inputs(inputs)
        unique_ids = inputs["unique_ids"]
        input_word_ids = inputs["input_ids"]
        input_mask = inputs["input_mask"]
        segment_ids = inputs["segment_ids"]
        cls_index = tf.reshape(inputs["cls_index"], [-1])
        p_mask = inputs["p_mask"]
        if kwargs.get('training',False):
            start_positions = inputs["start_positions"]
        else:
            start_positions = None
        sequence_output = self.albert_model(
            [input_word_ids, input_mask, segment_ids], **kwargs)
        output = self.qalayer(
            sequence_output, p_mask, cls_index, start_positions, **kwargs)
        return (unique_ids,) + output

In [3]:
do_lower_case = True
albert_config_file = "xxlarge/config.json"
spm_model_file = "xxlarge/vocab/30k-clean.model"
max_seq_length = 384
null_score_diff_threshold = 0.0

model_dir = "squad_out_v2.0"

In [4]:
# We can have documents that are longer than the maximum sequence length.
    # To deal with this we do a sliding window approach, where we take chunks
    # of the up to our max length with a stride of `doc_stride`.
_DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
            "DocSpan", ["start", "length"])

RawResultV2 = collections.namedtuple(
    "RawResultV2",
    ["unique_id", "start_top_log_probs", "start_top_index",
     "end_top_log_probs", "end_top_index", "cls_logits"])

_PrelimPredictionV2 = collections.namedtuple(  # pylint: disable=invalid-name
    "PrelimPrediction",
    ["feature_index", "start_index", "end_index",
     "start_log_prob", "end_log_prob"])

_NbestPredictionV2 = collections.namedtuple(  # pylint: disable=invalid-name
    "NbestPrediction", ["text", "start_log_prob", "end_log_prob","start_index","end_index"])

In [5]:
tokenizer = tokenization.FullTokenizer(
      vocab_file=None,spm_model_file=spm_model_file, do_lower_case=do_lower_case)

In [6]:
class SquadExample(object):
  """A single training/test example for simple sequence classification.
     For examples without an answer, the start and end position are -1.
  """

  def __init__(self,
               qas_id,
               question_text,
               paragraph_text,
               orig_answer_text=None,
               start_position=None,
               end_position=None,
               is_impossible=False):
    self.qas_id = qas_id
    self.question_text = question_text
    self.paragraph_text = paragraph_text
    self.orig_answer_text = orig_answer_text
    self.start_position = start_position
    self.end_position = end_position
    self.is_impossible = is_impossible

  def __str__(self):
    return self.__repr__()

  def __repr__(self):
    s = ""
    s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
    s += ", question_text: %s" % (
        tokenization.printable_text(self.question_text))
    s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text))
    if self.start_position:
      s += ", start_position: %d" % (self.start_position)
    if self.start_position:
      s += ", end_position: %d" % (self.end_position)
    if self.start_position:
      s += ", is_impossible: %r" % (self.is_impossible)
    return s

In [7]:
class InputFeatures(object):
  """A single set of features of data."""

  def __init__(self,
               unique_id,
               example_index,
               doc_span_index,
               tok_start_to_orig_index,
               tok_end_to_orig_index,
               token_is_max_context,
               tokens,
               input_ids,
               input_mask,
               segment_ids,
               paragraph_len,
               p_mask=None,
               cls_index=None,
               start_position=None,
               end_position=None,
               is_impossible=None):
    self.unique_id = unique_id
    self.example_index = example_index
    self.doc_span_index = doc_span_index
    self.tok_start_to_orig_index = tok_start_to_orig_index
    self.tok_end_to_orig_index = tok_end_to_orig_index
    self.token_is_max_context = token_is_max_context
    self.tokens = tokens
    self.input_ids = input_ids
    self.input_mask = input_mask
    self.segment_ids = segment_ids
    self.paragraph_len = paragraph_len
    self.start_position = start_position
    self.end_position = end_position
    self.is_impossible = is_impossible
    self.p_mask = p_mask
    self.cls_index = cls_index


In [8]:
def convert_to_squad(question: str, paragraph: str):
    qas_id = "1234"
    orig_answer_text = None
    start_position = None
    is_impossible = None
    examples = []
    example = SquadExample(
            qas_id=qas_id,
            question_text=question,
            paragraph_text=paragraph,
            orig_answer_text=orig_answer_text,
            start_position=start_position,
            is_impossible=is_impossible)
    examples.append(example)
    return examples

In [9]:
def _check_is_max_context(doc_spans, cur_span_index, position):
  """Check if this is the 'max context' doc span for the token."""

  # Because of the sliding window approach taken to scoring documents, a single
  # token can appear in multiple documents. E.g.
  #  Doc: the man went to the store and bought a gallon of milk
  #  Span A: the man went to the
  #  Span B: to the store and bought
  #  Span C: and bought a gallon of
  #  ...
  #
  # Now the word 'bought' will have two scores from spans B and C. We only
  # want to consider the score with "maximum context", which we define as
  # the *minimum* of its left and right context (the *sum* of left and
  # right context will always be the same, of course).
  #
  # In the example the maximum context for 'bought' would be span C since
  # it has 1 left context and 3 right context, while span B has 4 left context
  # and 0 right context.
  best_score = None
  best_span_index = None
  for (span_index, doc_span) in enumerate(doc_spans):
    end = doc_span.start + doc_span.length - 1
    if position < doc_span.start:
      continue
    if position > end:
      continue
    num_left_context = position - doc_span.start
    num_right_context = end - position
    score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
    if best_score is None or score > best_score:
      best_score = score
      best_span_index = span_index

  return cur_span_index == best_span_index

In [10]:
def _convert_index(index, pos, m=None, is_start=True):
  """Converts index."""
  if index[pos] is not None:
    return index[pos]
  n = len(index)
  rear = pos
  while rear < n - 1 and index[rear] is None:
    rear += 1
  front = pos
  while front > 0 and index[front] is None:
    front -= 1
  assert index[front] is not None or index[rear] is not None
  if index[front] is None:
    if index[rear] >= 1:
      if is_start:
        return 0
      else:
        return index[rear] - 1
    return index[rear]
  if index[rear] is None:
    if m is not None and index[front] < m - 1:
      if is_start:
        return index[front] + 1
      else:
        return m - 1
    return index[front]
  if is_start:
    if index[rear] > index[front] + 1:
      return index[front] + 1
    else:
      return index[rear]
  else:
    if index[rear] > index[front] + 1:
      return index[rear] - 1
    else:
      return index[front]

In [11]:
def convert_examples_to_features(examples, tokenizer, max_seq_length,
                                 doc_stride, max_query_length, is_training=False):
  """Loads a data file into a list of `InputBatch`s."""

  cnt_pos, cnt_neg = 0, 0
  base_id = 1000000000
  unique_id = base_id
  max_n, max_m = 1024, 1024
  f = np.zeros((max_n, max_m), dtype=np.float32)
  
  features = []

  for (example_index, example) in enumerate(examples):

    query_tokens = tokenization.encode_ids(
        tokenizer.sp_model,
        tokenization.preprocess_text(
            example.question_text, lower=do_lower_case))

    if len(query_tokens) > max_query_length:
      query_tokens = query_tokens[0:max_query_length]

    paragraph_text = example.paragraph_text
    para_tokens = tokenization.encode_pieces(
        tokenizer.sp_model,
        tokenization.preprocess_text(
            example.paragraph_text, lower=do_lower_case),
        return_unicode=False)

    para_tokens_ = []
    for para_token in para_tokens:
      if type(para_token) == bytes:
        para_token = para_token.decode("utf-8")
      para_tokens_.append(para_token)
    para_tokens = para_tokens_

    chartok_to_tok_index = []
    tok_start_to_chartok_index = []
    tok_end_to_chartok_index = []
    char_cnt = 0
    for i, token in enumerate(para_tokens):
      chartok_to_tok_index.extend([i] * len(token))
      tok_start_to_chartok_index.append(char_cnt)
      char_cnt += len(token)
      tok_end_to_chartok_index.append(char_cnt - 1)

    tok_cat_text = "".join(para_tokens).replace(tokenization.SPIECE_UNDERLINE.decode("utf-8"), " ")
    n, m = len(paragraph_text), len(tok_cat_text)

    if n > max_n or m > max_m:
      max_n = max(n, max_n)
      max_m = max(m, max_m)
      f = np.zeros((max_n, max_m), dtype=np.float32)

    g = {}

    def _lcs_match(max_dist, n=n, m=m):
      """Longest-common-substring algorithm."""
      f.fill(0)
      g.clear()

      ### longest common sub sequence
      # f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j))
      for i in range(n):

        # note(zhiliny):
        # unlike standard LCS, this is specifically optimized for the setting
        # because the mismatch between sentence pieces and original text will
        # be small
        for j in range(i - max_dist, i + max_dist):
          if j >= m or j < 0: continue

          if i > 0:
            g[(i, j)] = 0
            f[i, j] = f[i - 1, j]

          if j > 0 and f[i, j - 1] > f[i, j]:
            g[(i, j)] = 1
            f[i, j] = f[i, j - 1]

          f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0
          if (tokenization.preprocess_text(
              paragraph_text[i], lower=do_lower_case,
              remove_space=False) == tok_cat_text[j]
              and f_prev + 1 > f[i, j]):
            g[(i, j)] = 2
            f[i, j] = f_prev + 1

    max_dist = abs(n - m) + 5
    for _ in range(2):
      _lcs_match(max_dist)
      if f[n - 1, m - 1] > 0.8 * n: break
      max_dist *= 2

    orig_to_chartok_index = [None] * n
    chartok_to_orig_index = [None] * m
    i, j = n - 1, m - 1
    while i >= 0 and j >= 0:
      if (i, j) not in g: break
      if g[(i, j)] == 2:
        orig_to_chartok_index[i] = j
        chartok_to_orig_index[j] = i
        i, j = i - 1, j - 1
      elif g[(i, j)] == 1:
        j = j - 1
      else:
        i = i - 1

    if (all(v is None for v in orig_to_chartok_index) or
        f[n - 1, m - 1] < 0.8 * n):
      logging.info("MISMATCH DETECTED!")
      continue

    tok_start_to_orig_index = []
    tok_end_to_orig_index = []
    for i in range(len(para_tokens)):
      start_chartok_pos = tok_start_to_chartok_index[i]
      end_chartok_pos = tok_end_to_chartok_index[i]
      start_orig_pos = _convert_index(chartok_to_orig_index, start_chartok_pos,
                                      n, is_start=True)
      end_orig_pos = _convert_index(chartok_to_orig_index, end_chartok_pos,
                                    n, is_start=False)

      tok_start_to_orig_index.append(start_orig_pos)
      tok_end_to_orig_index.append(end_orig_pos)

    if not is_training:
      tok_start_position = tok_end_position = None

    if is_training and example.is_impossible:
      tok_start_position = 0
      tok_end_position = 0

    if is_training and not example.is_impossible:
      start_position = example.start_position
      end_position = start_position + len(example.orig_answer_text) - 1

      start_chartok_pos = _convert_index(orig_to_chartok_index, start_position,
                                         is_start=True)
      tok_start_position = chartok_to_tok_index[start_chartok_pos]

      end_chartok_pos = _convert_index(orig_to_chartok_index, end_position,
                                       is_start=False)
      tok_end_position = chartok_to_tok_index[end_chartok_pos]
      assert tok_start_position <= tok_end_position

    def _piece_to_id(x):
      if six.PY2 and isinstance(x, six.text_type):
        x = six.ensure_binary(x, "utf-8")
      return tokenizer.sp_model.PieceToId(x)

    all_doc_tokens = list(map(_piece_to_id, para_tokens))

    # The -3 accounts for [CLS], [SEP] and [SEP]
    max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

    doc_spans = []
    start_offset = 0
    while start_offset < len(all_doc_tokens):
      length = len(all_doc_tokens) - start_offset
      if length > max_tokens_for_doc:
        length = max_tokens_for_doc
      doc_spans.append(_DocSpan(start=start_offset, length=length))
      if start_offset + length == len(all_doc_tokens):
        break
      start_offset += min(length, doc_stride)

    for (doc_span_index, doc_span) in enumerate(doc_spans):
      tokens = []
      token_is_max_context = {}
      segment_ids = []
      p_mask = []
      cls_index = 0

      cur_tok_start_to_orig_index = []
      cur_tok_end_to_orig_index = []

      tokens.append(tokenizer.sp_model.PieceToId("[CLS]"))
      segment_ids.append(0)
      p_mask.append(0)
      for token in query_tokens:
        tokens.append(token)
        segment_ids.append(0)
        p_mask.append(1)
      tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
      segment_ids.append(0)
      p_mask.append(1)

      for i in range(doc_span.length):
        split_token_index = doc_span.start + i

        cur_tok_start_to_orig_index.append(
            tok_start_to_orig_index[split_token_index])
        cur_tok_end_to_orig_index.append(
            tok_end_to_orig_index[split_token_index])

        is_max_context = _check_is_max_context(doc_spans, doc_span_index,
                                               split_token_index)
        token_is_max_context[len(tokens)] = is_max_context
        tokens.append(all_doc_tokens[split_token_index])
        segment_ids.append(1)
        p_mask.append(0)
      tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
      segment_ids.append(1)
      p_mask.append(1)

      paragraph_len = len(tokens)
      input_ids = tokens

      # The mask has 1 for real tokens and 0 for padding tokens. Only real
      # tokens are attended to.
      input_mask = [1] * len(input_ids)

      # Zero-pad up to the sequence length.
      while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)
        p_mask.append(1)

      assert len(input_ids) == max_seq_length
      assert len(input_mask) == max_seq_length
      assert len(segment_ids) == max_seq_length

      span_is_impossible = example.is_impossible
      start_position = None
      end_position = None
      if is_training and not span_is_impossible:
        # For training, if our document chunk does not contain an annotation
        # we throw it out, since there is nothing to predict.
        doc_start = doc_span.start
        doc_end = doc_span.start + doc_span.length - 1
        out_of_span = False
        if not (tok_start_position >= doc_start and
                tok_end_position <= doc_end):
          out_of_span = True
        if out_of_span:
          # continue
          start_position = 0
          end_position = 0
          span_is_impossible = True
        else:
          doc_offset = len(query_tokens) + 2
          start_position = tok_start_position - doc_start + doc_offset
          end_position = tok_end_position - doc_start + doc_offset

      if is_training and span_is_impossible:
        start_position = cls_index
        end_position = cls_index

      if is_training:
        feat_example_index = None
      else:
        feat_example_index = example_index

      feature = InputFeatures(
          unique_id=unique_id,
          example_index=feat_example_index,
          doc_span_index=doc_span_index,
          tok_start_to_orig_index=cur_tok_start_to_orig_index,
          tok_end_to_orig_index=cur_tok_end_to_orig_index,
          token_is_max_context=token_is_max_context,
          tokens=[tokenizer.sp_model.IdToPiece(x) for x in tokens],
          input_ids=input_ids,
          input_mask=input_mask,
          segment_ids=segment_ids,
          paragraph_len=paragraph_len,
          start_position=start_position,
          end_position=end_position,
          is_impossible=span_is_impossible,
          p_mask=p_mask,
          cls_index=cls_index)

      # Run callback
      features.append(feature)
    
  return features

In [12]:
albert_config = AlbertConfig.from_json_file(albert_config_file)

In [13]:
squad_model = ALBertQAModel(albert_config, max_seq_length, init_checkpoint=None, start_n_top=5, end_n_top=5, dropout=0.)

In [14]:
squad_model.load_weights("squad_out_v2.0/model")

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f8d04c18da0>

In [15]:
# checkpoint_path = tf.train.latest_checkpoint(model_dir)
# checkpoint = tf.train.Checkpoint(model=squad_model)
# checkpoint.restore(checkpoint_path).expect_partial()

In [16]:
def get_raw_results_v2(predictions):
    """Converts multi-replica predictions to RawResult."""
    for unique_ids, start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits in zip(predictions['unique_ids'],
                                                    predictions['start_top_log_probs'],
                                                    predictions['start_top_index'],
                                                    predictions['end_top_log_probs'],
                                                    predictions['end_top_index'],
                                                    predictions['cls_logits']):
        return RawResultV2(
                unique_id=unique_ids.numpy(),
                start_top_log_probs=start_top_log_probs.numpy().tolist(),
                start_top_index=start_top_index.numpy().tolist(),
                end_top_log_probs=end_top_log_probs.numpy().tolist(),
                end_top_index=end_top_index.numpy().tolist(),
                cls_logits=cls_logits.numpy().tolist()
                )

In [17]:
def accumulate_predictions_v2(result_dict, cls_dict, all_examples,
                              all_features, all_results, n_best_size,
                              max_answer_length, start_n_top, end_n_top):
  """accumulate predictions for each positions in a dictionary."""

  example_index_to_features = collections.defaultdict(list)
  for feature in all_features:
    example_index_to_features[feature.example_index].append(feature)

  unique_id_to_result = {}
  for result in all_results:
    unique_id_to_result[result.unique_id] = result

  for (example_index, example) in enumerate(all_examples):
    if example_index not in result_dict:
      result_dict[example_index] = {}
    features = example_index_to_features[example_index]

    # keep track of the minimum score of null start+end of position 0
    score_null = 1000000  # large and positive

    for (feature_index, feature) in enumerate(features):
      if feature.unique_id not in result_dict[example_index]:
        result_dict[example_index][feature.unique_id] = {}
      result = unique_id_to_result[feature.unique_id]
      cur_null_score = result.cls_logits

      # if we could have irrelevant answers, get the min score of irrelevant
      score_null = min(score_null, cur_null_score)

      doc_offset = feature.tokens.index("[SEP]") + 1
      for i in range(start_n_top):
        for j in range(end_n_top):
          start_log_prob = result.start_top_log_probs[i]
          start_index = result.start_top_index[i]

          j_index = i * end_n_top + j

          end_log_prob = result.end_top_log_probs[j_index]
          end_index = result.end_top_index[j_index]
          # We could hypothetically create invalid predictions, e.g., predict
          # that the start of the span is in the question. We throw out all
          # invalid predictions.
          if start_index - doc_offset >= len(feature.tok_start_to_orig_index):
            continue
          if start_index - doc_offset < 0:
            continue
          if end_index - doc_offset >= len(feature.tok_end_to_orig_index):
            continue
          if not feature.token_is_max_context.get(start_index, False):
            continue
          if end_index < start_index:
            continue
          length = end_index - start_index + 1
          if length > max_answer_length:
            continue
          start_idx = start_index - doc_offset
          end_idx = end_index - doc_offset
          if (start_idx, end_idx) not in result_dict[example_index][feature.unique_id]:
            result_dict[example_index][feature.unique_id][(start_idx, end_idx)] = []
          result_dict[example_index][feature.unique_id][(start_idx, end_idx)].append((start_log_prob, end_log_prob))
    if example_index not in cls_dict:
      cls_dict[example_index] = []
    cls_dict[example_index].append(score_null)
  return (result_dict,cls_dict)

In [18]:
squad_model.weights.__len__()

25

In [19]:
def _compute_softmax(scores):
  """Compute softmax probability over raw logits."""
  if not scores:
    return []

  max_score = None
  for score in scores:
    if max_score is None or score > max_score:
      max_score = score

  exp_scores = []
  total_sum = 0.0
  for score in scores:
    x = math.exp(score - max_score)
    exp_scores.append(x)
    total_sum += x

  probs = []
  for score in exp_scores:
    probs.append(score / total_sum)
  return probs

In [20]:
def write_predictions_v2(all_examples, all_features, all_results, n_best_size=20,
                      max_answer_length=30,start_n_top=5, end_n_top=5):
  """Writes final predictions to the json file and log-odds of null if needed."""
  
  result_dict, cls_dict = {},{}
  result_dict, cls_dict = accumulate_predictions_v2(result_dict,cls_dict,all_examples,
                                    all_features,all_results,n_best_size,max_answer_length,
                                    start_n_top,end_n_top)
                                    
  example_index_to_features = collections.defaultdict(list)
  for feature in all_features:
    example_index_to_features[feature.example_index].append(feature)

  unique_id_to_result = {}
  for result in all_results:
    unique_id_to_result[result.unique_id] = result

  all_predictions = collections.OrderedDict()
  all_nbest_json = collections.OrderedDict()
  scores_diff_json = collections.OrderedDict()

  for (example_index, example) in enumerate(all_examples):
    features = example_index_to_features[example_index]

    prelim_predictions = []
    # keep track of the minimum score of null start+end of position 0
    # score_null = 1000000  # large and positive

    for (feature_index, feature) in enumerate(features):
      for ((start_idx, end_idx), logprobs) in \
        result_dict[example_index][feature.unique_id].items():
        start_log_prob = 0
        end_log_prob = 0
        for logprob in logprobs:
          start_log_prob += logprob[0]
          end_log_prob += logprob[1]
        prelim_predictions.append(
            _PrelimPredictionV2(
                feature_index=feature_index,
                start_index=start_idx,
                end_index=end_idx,
                start_log_prob=start_log_prob / len(logprobs),
                end_log_prob=end_log_prob / len(logprobs)))

    prelim_predictions = sorted(
        prelim_predictions,
        key=lambda x: (x.start_log_prob + x.end_log_prob),
        reverse=True)

    seen_predictions = {}
    nbest = []
    for pred in prelim_predictions:
      if len(nbest) >= n_best_size:
        break
      feature = features[pred.feature_index]

      tok_start_to_orig_index = feature.tok_start_to_orig_index
      tok_end_to_orig_index = feature.tok_end_to_orig_index
      start_orig_pos = tok_start_to_orig_index[pred.start_index]
      end_orig_pos = tok_end_to_orig_index[pred.end_index]

      paragraph_text = example.paragraph_text
      final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()

      if final_text in seen_predictions:
        continue

      seen_predictions[final_text] = True

      nbest.append(
          _NbestPredictionV2(
              text=final_text,
              start_log_prob=pred.start_log_prob,
              end_log_prob=pred.end_log_prob,
              start_index=pred.start_index,
              end_index=pred.end_index))

    # In very rare edge cases we could have no valid predictions. So we
    # just create a nonce prediction in this case to avoid failure.
    if not nbest:
      nbest.append(
          _NbestPredictionV2(
              text="",
              start_log_prob=-1e6,
              end_log_prob=-1e6,
              start_index=-1,
              end_index=-1))

    total_scores = []
    best_non_null_entry = None
    for entry in nbest:
      total_scores.append(entry.start_log_prob + entry.end_log_prob)
      if not best_non_null_entry:
        best_non_null_entry = entry

    probs = _compute_softmax(total_scores)

    nbest_json = []
    for (i, entry) in enumerate(nbest):
      output = collections.OrderedDict()
      output["text"] = entry.text
      output["probability"] = probs[i]
      output["start_log_prob"] = entry.start_log_prob
      output["end_log_prob"] = entry.end_log_prob
      output["start_index"] = entry.start_index
      output["end_index"] = entry.end_index + 1
      nbest_json.append(output)

    assert len(nbest_json) >= 1
    assert best_non_null_entry is not None

    score_diff = sum(cls_dict[example_index]) / len(cls_dict[example_index])
    scores_diff_json[example.qas_id] = score_diff
    # predict null answers when null threshold is provided
    if null_score_diff_threshold is None or score_diff < null_score_diff_threshold:
      all_predictions[example.qas_id] = best_non_null_entry.text
    else:
      all_predictions[example.qas_id] = ""

    all_nbest_json[example.qas_id] = nbest_json
    assert len(nbest_json) >= 1
  return all_predictions,all_nbest_json    

In [21]:
doc = "Closely related fields in theoretical computer science are analysis of algorithms and computability theory. A key distinction between analysis of algorithms and computational complexity theory is that the former is devoted to analyzing the amount of resources needed by a particular algorithm to solve a problem, whereas the latter asks a more general question about all possible algorithms that could be used to solve the same problem. More precisely, it tries to classify problems that can or cannot be solved with appropriately restricted resources. In turn, imposing restrictions on the available resources is what distinguishes computational complexity from computability theory: the latter theory asks what kind of problems can, in principle, be solved algorithmically."

q = 'What two fields of theoretical computer science closely mirror computational complexity theory?'

examples = convert_to_squad(q,doc)

features = convert_examples_to_features(examples,tokenizer,max_seq_length=384,doc_stride=128,max_query_length=64)

In [22]:
all_results = []
for X in features:
    x = {"unique_ids" : tf.convert_to_tensor([X.unique_id],dtype=tf.int64),
         "input_ids" : tf.convert_to_tensor([X.input_ids],dtype=tf.int64),
         "input_mask" : tf.convert_to_tensor([X.input_mask],dtype=tf.int64),
         "segment_ids" : tf.convert_to_tensor([X.segment_ids],dtype=tf.int64),
         "cls_index" : tf.convert_to_tensor([X.cls_index],dtype=tf.int64),
         "p_mask" : tf.convert_to_tensor([X.p_mask],dtype=tf.float32)
    }
    y = squad_model(x,training=False)
    unique_ids, start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = y
    predictions = dict(unique_ids=unique_ids,
            start_top_log_probs=start_top_log_probs,
            start_top_index=start_top_index,
            end_top_log_probs=end_top_log_probs,
            end_top_index=end_top_index,
            cls_logits=cls_logits)
    all_results.append(get_raw_results_v2(predictions))

In [23]:
write_predictions_v2(examples,features,all_results)

(OrderedDict([('1234', 'analysis of algorithms and computability theory')]),
 OrderedDict([('1234',
               [OrderedDict([('text',
                              'analysis of algorithms and computability theory'),
                             ('probability', 0.9944961576562703),
                             ('start_log_prob', -0.012674023397266865),
                             ('end_log_prob', -0.0032195420935750008),
                             ('start_index', 8),
                             ('end_index', 16)]),
                OrderedDict([('text',
                              'analysis of algorithms and computability theory.'),
                             ('probability', 0.0029161772185747484),
                             ('start_log_prob', -0.012674023397266865),
                             ('end_log_prob', -5.835182189941406),
                             ('start_index', 8),
                             ('end_index', 17)]),
                OrderedDict([('text', 'analy