In [1]:
# !wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
# !unzip uncased_L-12_H-768_A-12.zip
# !wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
# !wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json

In [2]:
class SquadExample(object):
  """A single training/test example for simple sequence classification.
     For examples without an answer, the start and end position are -1.
  """

  def __init__(self,
               qas_id,
               question_text,
               doc_tokens,
               orig_answer_text=None,
               start_position=None,
               end_position=None,
               is_impossible=False):
    self.qas_id = qas_id
    self.question_text = question_text
    self.doc_tokens = doc_tokens
    self.orig_answer_text = orig_answer_text
    self.start_position = start_position
    self.end_position = end_position
    self.is_impossible = is_impossible

  def __str__(self):
    return self.__repr__()

  def __repr__(self):
    s = ""
    s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
    s += ", question_text: %s" % (
        tokenization.printable_text(self.question_text))
    s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
    if self.start_position:
      s += ", start_position: %d" % (self.start_position)
    if self.start_position:
      s += ", end_position: %d" % (self.end_position)
    if self.start_position:
      s += ", is_impossible: %r" % (self.is_impossible)
    return s

In [96]:
import tensorflow as tf
import bert
from bert import run_classifier
from bert import optimization
from bert import tokenization
from bert import modeling
from tqdm import tqdm
import json
import math

BERT_VOCAB = 'uncased_L-12_H-768_A-12/vocab.txt'
BERT_INIT_CHKPNT = 'uncased_L-12_H-768_A-12/bert_model.ckpt'
BERT_CONFIG = 'uncased_L-12_H-768_A-12/bert_config.json'

tokenization.validate_case_matches_checkpoint(True,BERT_INIT_CHKPNT)
tokenizer = tokenization.FullTokenizer(
      vocab_file=BERT_VOCAB, do_lower_case=True)

In [80]:
version_2_with_negative = False

def read_squad_examples(input_file, is_training, version_2_with_negative = False):
    """Read a SQuAD json file into a list of SquadExample."""
    with tf.gfile.Open(input_file, 'r') as reader:
        input_data = json.load(reader)['data']

    def is_whitespace(c):
        if c == ' ' or c == '\t' or c == '\r' or c == '\n' or ord(c) == 0x202F:
            return True
        return False

    examples = []
    for entry in input_data:
        for paragraph in entry['paragraphs']:
            paragraph_text = paragraph['context']
            doc_tokens = []
            char_to_word_offset = []
            prev_is_whitespace = True
            for c in paragraph_text:
                if is_whitespace(c):
                    prev_is_whitespace = True
                else:
                    if prev_is_whitespace:
                        doc_tokens.append(c)
                    else:
                        doc_tokens[-1] += c
                    prev_is_whitespace = False
                char_to_word_offset.append(len(doc_tokens) - 1)

            for qa in paragraph['qas']:
                qas_id = qa['id']
                question_text = qa['question']
                start_position = None
                end_position = None
                orig_answer_text = None
                is_impossible = False
                if is_training:

                    if version_2_with_negative:
                        is_impossible = qa['is_impossible']
                    if (len(qa['answers']) != 1) and (not is_impossible):
                        raise ValueError(
                            'For training, each question should have exactly 1 answer.'
                        )
                    if not is_impossible:
                        answer = qa['answers'][0]
                        orig_answer_text = answer['text']
                        answer_offset = answer['answer_start']
                        answer_length = len(orig_answer_text)
                        start_position = char_to_word_offset[answer_offset]
                        end_position = char_to_word_offset[
                            answer_offset + answer_length - 1
                        ]
                        actual_text = ' '.join(
                            doc_tokens[start_position : (end_position + 1)]
                        )
                        cleaned_answer_text = ' '.join(
                            tokenization.whitespace_tokenize(orig_answer_text)
                        )
                        if actual_text.find(cleaned_answer_text) == -1:
                            tf.logging.warning(
                                "Could not find answer: '%s' vs. '%s'",
                                actual_text,
                                cleaned_answer_text,
                            )
                            continue
                    else:
                        start_position = -1
                        end_position = -1
                        orig_answer_text = ''

                example = SquadExample(
                    qas_id = qas_id,
                    question_text = question_text,
                    doc_tokens = doc_tokens,
                    orig_answer_text = orig_answer_text,
                    start_position = start_position,
                    end_position = end_position,
                    is_impossible = is_impossible,
                )
                examples.append(example)

    return examples

In [5]:
squad_train = read_squad_examples('train-v1.1.json', True)
squad_test = read_squad_examples('dev-v1.1.json', False)

In [174]:
import six

def _improve_answer_span(
    doc_tokens, input_start, input_end, tokenizer, orig_answer_text
):
    """Returns tokenized answer spans that better match the annotated answer."""

    # The SQuAD annotations are character based. We first project them to
    # whitespace-tokenized words. But then after WordPiece tokenization, we can
    # often find a "better match". For example:
    #
    #   Question: What year was John Smith born?
    #   Context: The leader was John Smith (1895-1943).
    #   Answer: 1895
    #
    # The original whitespace-tokenized answer will be "(1895-1943).". However
    # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
    # the exact answer, 1895.
    #
    # However, this is not always possible. Consider the following:
    #
    #   Question: What country is the top exporter of electornics?
    #   Context: The Japanese electronics industry is the lagest in the world.
    #   Answer: Japan
    #
    # In this case, the annotator chose "Japan" as a character sub-span of
    # the word "Japanese". Since our WordPiece tokenizer does not split
    # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
    # in SQuAD, but does happen.
    tok_answer_text = ' '.join(tokenizer.tokenize(orig_answer_text))

    for new_start in range(input_start, input_end + 1):
        for new_end in range(input_end, new_start - 1, -1):
            text_span = ' '.join(doc_tokens[new_start : (new_end + 1)])
            if text_span == tok_answer_text:
                return (new_start, new_end)

    return (input_start, input_end)


def _check_is_max_context(doc_spans, cur_span_index, position):
    """Check if this is the 'max context' doc span for the token."""

    # Because of the sliding window approach taken to scoring documents, a single
    # token can appear in multiple documents. E.g.
    #  Doc: the man went to the store and bought a gallon of milk
    #  Span A: the man went to the
    #  Span B: to the store and bought
    #  Span C: and bought a gallon of
    #  ...
    #
    # Now the word 'bought' will have two scores from spans B and C. We only
    # want to consider the score with "maximum context", which we define as
    # the *minimum* of its left and right context (the *sum* of left and
    # right context will always be the same, of course).
    #
    # In the example the maximum context for 'bought' would be span C since
    # it has 1 left context and 3 right context, while span B has 4 left context
    # and 0 right context.
    best_score = None
    best_span_index = None
    for (span_index, doc_span) in enumerate(doc_spans):
        end = doc_span.start + doc_span.length - 1
        if position < doc_span.start:
            continue
        if position > end:
            continue
        num_left_context = position - doc_span.start
        num_right_context = end - position
        score = (
            min(num_left_context, num_right_context) + 0.01 * doc_span.length
        )
        if best_score is None or score > best_score:
            best_score = score
            best_span_index = span_index

    return cur_span_index == best_span_index

def get_final_text(pred_text, orig_text, do_lower_case):
    """Project the tokenized prediction back to the original text."""

    # When we created the data, we kept track of the alignment between original
    # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
    # now `orig_text` contains the span of our original text corresponding to the
    # span that we predicted.
    #
    # However, `orig_text` may contain extra characters that we don't want in
    # our prediction.
    #
    # For example, let's say:
    #   pred_text = steve smith
    #   orig_text = Steve Smith's
    #
    # We don't want to return `orig_text` because it contains the extra "'s".
    #
    # We don't want to return `pred_text` because it's already been normalized
    # (the SQuAD eval script also does punctuation stripping/lower casing but
    # our tokenizer does additional normalization like stripping accent
    # characters).
    #
    # What we really want to return is "Steve Smith".
    #
    # Therefore, we have to apply a semi-complicated alignment heruistic between
    # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
    # can fail in certain cases in which case we just return `orig_text`.

    def _strip_spaces(text):
        ns_chars = []
        ns_to_s_map = collections.OrderedDict()
        for (i, c) in enumerate(text):
            if c == ' ':
                continue
            ns_to_s_map[len(ns_chars)] = i
            ns_chars.append(c)
        ns_text = ''.join(ns_chars)
        return (ns_text, ns_to_s_map)

    # We first tokenize `orig_text`, strip whitespace from the result
    # and `pred_text`, and check if they are the same length. If they are
    # NOT the same length, the heuristic has failed. If they are the same
    # length, we assume the characters are one-to-one aligned.
    tokenizer = tokenization.BasicTokenizer(do_lower_case = do_lower_case)

    tok_text = ' '.join(tokenizer.tokenize(orig_text))

    start_position = tok_text.find(pred_text)
    if start_position == -1:
        return orig_text
    end_position = start_position + len(pred_text) - 1

    (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
    (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)

    if len(orig_ns_text) != len(tok_ns_text):
        return orig_text

    # We then project the characters in `pred_text` back to `orig_text` using
    # the character-to-character alignment.
    tok_s_to_ns_map = {}
    for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
        tok_s_to_ns_map[tok_index] = i

    orig_start_position = None
    if start_position in tok_s_to_ns_map:
        ns_start_position = tok_s_to_ns_map[start_position]
        if ns_start_position in orig_ns_to_s_map:
            orig_start_position = orig_ns_to_s_map[ns_start_position]

    if orig_start_position is None:
        return orig_text

    orig_end_position = None
    if end_position in tok_s_to_ns_map:
        ns_end_position = tok_s_to_ns_map[end_position]
        if ns_end_position in orig_ns_to_s_map:
            orig_end_position = orig_ns_to_s_map[ns_end_position]

    if orig_end_position is None:
        return orig_text

    output_text = orig_text[orig_start_position : (orig_end_position + 1)]
    return output_text


In [116]:
max_seq_length = 384
doc_stride = 128
max_query_length = 64
import collections

def example_feature(examples, is_training = True):
    inputs_ids, input_masks, segment_ids, start_positions, end_positions = [], [], [], [], []
    token_to_orig_maps, token_is_max_contexts, tokenss = [], [], []
    indices = []
    for (example_index, example) in enumerate(examples):
        query_tokens = tokenizer.tokenize(example.question_text)
        
        if len(query_tokens) > max_query_length:
            query_tokens = query_tokens[:max_query_length]
            
        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, token) in enumerate(example.doc_tokens):
            orig_to_tok_index.append(len(all_doc_tokens))
            sub_tokens = tokenizer.tokenize(token)
            for sub_token in sub_tokens:
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)
        
        tok_start_position = None
        tok_end_position = None
        if is_training and example.is_impossible:
            tok_start_position = -1
            tok_end_position = -1
        if is_training and not example.is_impossible:
            tok_start_position = orig_to_tok_index[example.start_position]
            if example.end_position < len(example.doc_tokens) - 1:
                tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
            else:
                tok_end_position = len(all_doc_tokens) - 1
            (tok_start_position, tok_end_position) = _improve_answer_span(
                  all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
                  example.orig_answer_text)
        
        max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
        _DocSpan = collections.namedtuple("DocSpan", ["start", "length"])
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start=start_offset, length=length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, doc_stride)
        
        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_to_orig_map = {}
            token_is_max_context = {}
            segment_id = []
            tokens.append('[CLS]')
            segment_id.append(0)
            for token in query_tokens:
                tokens.append(token)
                segment_id.append(0)
            tokens.append('[SEP]')
            segment_id.append(0)
            
            for i in range(doc_span.length):
                split_token_index = doc_span.start + i
                token_to_orig_map[len(tokens)] = tok_to_orig_index[
                    split_token_index
                ]

                is_max_context = _check_is_max_context(
                    doc_spans, doc_span_index, split_token_index
                )
                token_is_max_context[len(tokens)] = is_max_context
                tokens.append(all_doc_tokens[split_token_index])
                segment_id.append(1)
            tokens.append('[SEP]')
            segment_id.append(1)
            
            input_id = tokenizer.convert_tokens_to_ids(tokens)
            input_mask = [1] * len(input_id)
            
            while len(input_id) < max_seq_length:
                input_id.append(0)
                input_mask.append(0)
                segment_id.append(0)
                
            assert len(input_id) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_id) == max_seq_length
            
            start_position = None
            end_position = None
            if is_training and not example.is_impossible:
                # For training, if our document chunk does not contain an annotation
                # we throw it out, since there is nothing to predict.
                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1
                out_of_span = False
                if not (
                    tok_start_position >= doc_start
                    and tok_end_position <= doc_end
                ):
                    out_of_span = True
                if out_of_span:
                    start_position = 0
                    end_position = 0
                else:
                    doc_offset = len(query_tokens) + 2
                    start_position = tok_start_position - doc_start + doc_offset
                    end_position = tok_end_position - doc_start + doc_offset

            if is_training and example.is_impossible:
                start_position = 0
                end_position = 0
            
            inputs_ids.append(input_id)
            input_masks.append(input_mask)
            segment_ids.append(segment_id)
            start_positions.append(start_position)
            end_positions.append(end_position)
            token_is_max_contexts.append(token_is_max_context)
            token_to_orig_maps.append(token_to_orig_map)
            tokenss.append(tokens)
            indices.append(example_index)
    return (inputs_ids, input_masks, segment_ids, start_positions, 
            end_positions, token_to_orig_maps, token_is_max_contexts, tokenss, indices)

In [117]:
train_inputs_ids, train_input_masks, train_segment_ids, \
train_start_positions, train_end_positions, \
train_token_to_orig_maps, train_token_is_max_contexts, train_tokens, train_indices = example_feature(squad_train)

In [118]:
test_inputs_ids, test_input_masks, test_segment_ids, \
test_start_positions, test_end_positions, \
test_token_to_orig_maps, test_token_is_max_contexts, test_tokens, test_indices = example_feature(squad_test, False)

In [10]:
bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)

In [47]:
epoch = 4
batch_size = 12
warmup_proportion = 0.1
n_best_size = 20
num_train_steps = int(len(train_inputs_ids) / batch_size * epoch)
num_warmup_steps = int(num_train_steps * warmup_proportion)

In [12]:
class Model:
    def __init__(
        self,
        learning_rate = 2e-5,
    ):
        self.X = tf.placeholder(tf.int32, [None, None])
        self.segment_ids = tf.placeholder(tf.int32, [None, None])
        self.input_masks = tf.placeholder(tf.int32, [None, None])
        self.start_positions = tf.placeholder(tf.int32, [None])
        self.end_positions = tf.placeholder(tf.int32, [None])
        
        model = modeling.BertModel(
            config=bert_config,
            is_training=True,
            input_ids=self.X,
            input_mask=self.input_masks,
            token_type_ids=self.segment_ids,
            use_one_hot_embeddings=False)
        
        final_hidden = model.get_sequence_output()
        final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
        batch_size = final_hidden_shape[0]
        seq_length = final_hidden_shape[1]
        hidden_size = final_hidden_shape[2]
        
        output_weights = tf.get_variable(
            "cls/squad/output_weights", [2, hidden_size],
            initializer=tf.truncated_normal_initializer(stddev=0.02))

        output_bias = tf.get_variable(
              "cls/squad/output_bias", [2], initializer=tf.zeros_initializer())

        final_hidden_matrix = tf.reshape(final_hidden,
                                           [batch_size * seq_length, hidden_size])
        logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)

        logits = tf.reshape(logits, [batch_size, seq_length, 2])
        logits = tf.transpose(logits, [2, 0, 1])

        unstacked_logits = tf.unstack(logits, axis=0)

        (self.start_logits, self.end_logits) = (unstacked_logits[0], unstacked_logits[1])
        print(self.start_logits, self.end_logits)
        
        def compute_loss(logits, positions):
            one_hot_positions = tf.one_hot(
                positions, depth=seq_length, dtype=tf.float32)
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            loss = -tf.reduce_mean(
                tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
            return loss
        
        start_loss = compute_loss(self.start_logits, self.start_positions)
        end_loss = compute_loss(self.end_logits, self.end_positions)

        self.cost = (start_loss + end_loss) / 2.0
    
        self.optimizer = optimization.create_optimizer(self.cost, learning_rate, 
                                                       num_train_steps, num_warmup_steps, False)

In [13]:
learning_rate = 2e-5

tf.reset_default_graph()
sess = tf.InteractiveSession()
model = Model(
    learning_rate
)

sess.run(tf.global_variables_initializer())
var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')
saver = tf.train.Saver(var_list = var_lists)
saver.restore(sess, BERT_INIT_CHKPNT)

Instructions for updating:
Colocations handled automatically by placer.

For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Instructions for updating:
Use keras.layers.dense instead.
Tensor("unstack:0", shape=(?, ?), dtype=float32) Tensor("unstack:1", shape=(?, ?), dtype=float32)
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from uncased_L-12_H-768_A-12/bert_model.ckpt


In [14]:
for e in range(epoch):
    train_loss = 0
    pbar = tqdm(
        range(0, len(train_inputs_ids), batch_size), desc = 'train minibatch loop'
    )
    for i in pbar:
        index = min(i + batch_size, len(train_inputs_ids))
        batch_ids = train_inputs_ids[i: index]
        batch_masks = train_input_masks[i: index]
        batch_segment = train_segment_ids[i: index]
        batch_start = train_start_positions[i: index]
        batch_end = train_end_positions[i: index]
        cost, _ = sess.run(
            [model.cost, model.optimizer],
            feed_dict = {
                model.start_positions: batch_start,
                model.end_positions: batch_end,
                model.X: batch_ids,
                model.segment_ids: batch_segment,
                model.input_masks: batch_masks
            },
        )
        pbar.set_postfix(cost = cost)
        train_loss += cost
    train_loss /= len(train_inputs_ids) / batch_size
    print(
        'epoch: %d, training loss: %f\n'
        % (e, train_loss)
    )

train minibatch loop: 100%|██████████| 7387/7387 [46:42<00:00,  2.81it/s, cost=1.21]  
train minibatch loop:   0%|          | 0/7387 [00:00<?, ?it/s]

epoch: 0, training loss: 1.642079



train minibatch loop: 100%|██████████| 7387/7387 [46:43<00:00,  2.80it/s, cost=0.477] 
train minibatch loop:   0%|          | 0/7387 [00:00<?, ?it/s]

epoch: 1, training loss: 0.863868



train minibatch loop: 100%|██████████| 7387/7387 [46:42<00:00,  2.81it/s, cost=0.598]  
train minibatch loop:   0%|          | 0/7387 [00:00<?, ?it/s]

epoch: 2, training loss: 0.631467



train minibatch loop: 100%|██████████| 7387/7387 [46:43<00:00,  2.80it/s, cost=0.073]  

epoch: 3, training loss: 0.509195






In [101]:
len(test_inputs_ids)

10833

In [18]:
batch_ids = test_inputs_ids[:10]
batch_masks = test_input_masks[:10]
batch_segment = test_segment_ids[:10]
batch_start = test_start_positions[:10]
batch_end = test_end_positions[:10]

In [105]:
p = []

In [106]:
p.extend(start_logits.tolist())
p.extend(start_logits.tolist())

In [108]:
np.array(p).shape

(20, 384)

In [110]:
starts, ends = [], []
pbar = tqdm(
        range(0, len(test_inputs_ids), batch_size), desc = 'test minibatch loop'
    )
for i in pbar:
    index = min(i + batch_size, len(test_inputs_ids))
    batch_ids = test_inputs_ids[i: index]
    batch_masks = test_input_masks[i: index]
    batch_segment = test_segment_ids[i: index]
    start_logits, end_logits = sess.run(
                [model.start_logits, model.end_logits],
                feed_dict = {
                    model.X: batch_ids,
                    model.segment_ids: batch_segment,
                    model.input_masks: batch_masks
                },
            )
    starts.extend(start_logits.tolist())
    ends.extend(end_logits.tolist())


test minibatch loop:   0%|          | 0/903 [00:00<?, ?it/s][A
test minibatch loop: 100%|██████████| 903/903 [02:00<00:00,  8.02it/s]


In [57]:
import numpy as np

def _get_best_indexes(logits, n_best_size):
    index_and_score = sorted(
        enumerate(logits), key = lambda x: x[1], reverse = True
    )

    best_indexes = []
    for i in range(len(index_and_score)):
        if i >= n_best_size:
            break
        best_indexes.append(index_and_score[i][0])
    return best_indexes

def _compute_softmax(scores):
    """Compute softmax probability over raw logits."""
    if not scores:
        return []

    max_score = None
    for score in scores:
        if max_score is None or score > max_score:
            max_score = score

    exp_scores = []
    total_sum = 0.0
    for score in scores:
        x = math.exp(score - max_score)
        exp_scores.append(x)
        total_sum += x

    probs = []
    for score in exp_scores:
        probs.append(score / total_sum)
    return probs

In [177]:
def to_predict(
    indices,
    examples,
    start_logits,
    end_logits,
    tokens,
    token_to_orig_maps,
    token_is_max_contexts,
    max_answer_length = 30,
    n_best_size = 20,
    do_lower_case = False,
    null_score_diff_threshold = 0.0,
    output_prediction_file = 'predictions.json',
    output_nbest_file = 'nbest_predictions.json',
    output_null_log_odds_file = 'null_odds.json',
):

    example_index_to_features = collections.defaultdict(list)
    for no, feature in enumerate(indices):
        example_index_to_features[feature].append(no)

    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()
    scores_diff_json = collections.OrderedDict()

    _PrelimPrediction = collections.namedtuple(
        'PrelimPrediction',
        [
            'feature_index',
            'start_index',
            'end_index',
            'start_logit',
            'end_logit',
        ],
    )

    for (example_index, example) in enumerate(examples):
        features = example_index_to_features[example_index]
        prelim_predictions = []
        score_null = 1000000
        min_null_feature_index = 0
        null_start_logit = 0
        null_end_logit = 0
        for (feature_index, i) in enumerate(features):

            start_indexes = _get_best_indexes(start_logits[i], n_best_size)
            end_indexes = _get_best_indexes(end_logits[i], n_best_size)
            if version_2_with_negative:
                feature_null_score = start_logits[i][0] + end_logits[i][0]
                if feature_null_score < score_null:
                    score_null = feature_null_score
                    min_null_feature_index = feature_index
                    null_start_logit = start_logits[i][0]
                    null_end_logit = end_logits[i][0]
            for start_index in start_indexes:
                for end_index in end_indexes:
                    if start_index >= len(tokens[i]):
                        continue
                    if end_index >= len(tokens[i]):
                        continue
                    if start_index not in token_to_orig_maps[i]:
                        continue
                    if end_index not in token_to_orig_maps[i]:
                        continue
                    if not token_is_max_contexts[i].get(start_index, False):
                        continue
                    if end_index < start_index:
                        continue
                    length = end_index - start_index + 1
                    if length > max_answer_length:
                        continue
                    prelim_predictions.append(
                        _PrelimPrediction(
                            feature_index = i,
                            start_index = start_index,
                            end_index = end_index,
                            start_logit = start_logits[i][start_index],
                            end_logit = end_logits[i][end_index],
                        )
                    )
        if version_2_with_negative:
            prelim_predictions.append(
                _PrelimPrediction(
                    feature_index = min_null_feature_index,
                    start_index = 0,
                    end_index = 0,
                    start_logit = null_start_logit,
                    end_logit = null_end_logit,
                )
            )

        prelim_predictions = sorted(
            prelim_predictions,
            key = lambda x: (x.start_logit + x.end_logit),
            reverse = True,
        )

        _NbestPrediction = collections.namedtuple(
            'NbestPrediction', ['text', 'start_logit', 'end_logit']
        )

        seen_predictions = {}
        nbest = []
        for pred in prelim_predictions:
            if len(nbest) >= n_best_size:
                break
            i = pred.feature_index
            if pred.start_index > 0:
                tok_tokens = tokens[i][pred.start_index : (pred.end_index + 1)]
                orig_doc_start = token_to_orig_maps[i][pred.start_index]
                orig_doc_end = token_to_orig_maps[i][pred.end_index]
                orig_tokens = example.doc_tokens[
                    orig_doc_start : (orig_doc_end + 1)
                ]
                tok_text = ' '.join(tok_tokens)
                tok_text = tok_text.replace(' ##', '')
                tok_text = tok_text.replace('##', '')
                tok_text = tok_text.strip()
                tok_text = ' '.join(tok_text.split())
                orig_text = ' '.join(orig_tokens)

                final_text = get_final_text(tok_text, orig_text, do_lower_case)
                if final_text in seen_predictions:
                    continue
                seen_predictions[final_text] = True
            else:
                final_text = ''
                seen_predictions[final_text] = True
            nbest.append(
                _NbestPrediction(
                    text = final_text,
                    start_logit = pred.start_logit,
                    end_logit = pred.end_logit,
                )
            )
        if version_2_with_negative:
            if '' not in seen_predictions:
                nbest.append(
                    _NbestPrediction(
                        text = '',
                        start_logit = null_start_logit,
                        end_logit = null_end_logit,
                    )
                )

        if not nbest:
            nbest.append(
                _NbestPrediction(
                    text = 'empty', start_logit = 0.0, end_logit = 0.0
                )
            )

        assert len(nbest) >= 1
        total_scores = []
        best_non_null_entry = None
        for entry in nbest:
            total_scores.append(entry.start_logit + entry.end_logit)
            if not best_non_null_entry:
                if entry.text:
                    best_non_null_entry = entry

        probs = _compute_softmax(total_scores)

        nbest_json = []
        for (i, entry) in enumerate(nbest):
            output = collections.OrderedDict()
            output['text'] = entry.text
            output['probability'] = probs[i]
            output['start_logit'] = entry.start_logit
            output['end_logit'] = entry.end_logit
            nbest_json.append(output)

        assert len(nbest_json) >= 1
        if not version_2_with_negative:
            all_predictions[example.qas_id] = nbest_json[0]['text']
        else:
            score_diff = (
                score_null
                - best_non_null_entry.start_logit
                - (best_non_null_entry.end_logit)
            )
            scores_diff_json[example.qas_id] = score_diff
            if score_diff > null_score_diff_threshold:
                all_predictions[example.qas_id] = ''
            else:
                all_predictions[example.qas_id] = best_non_null_entry.text
        all_nbest_json[example.qas_id] = nbest_json

    with tf.gfile.GFile(output_prediction_file, 'w') as writer:
        writer.write(json.dumps(all_predictions, indent = 4) + '\n')
    with tf.gfile.GFile(output_nbest_file, 'w') as writer:
        writer.write(json.dumps(all_nbest_json, indent = 4) + '\n')
    if version_2_with_negative:
        with tf.gfile.GFile(output_null_log_odds_file, 'w') as writer:
            writer.write(json.dumps(scores_diff_json, indent = 4) + '\n')


In [178]:
to_predict(test_indices, squad_test, starts, ends,
          test_tokens, test_token_to_orig_maps, test_token_is_max_contexts)

In [184]:
!wget https://raw.githubusercontent.com/allenai/bi-att-flow/master/squad/evaluate-v1.1.py

--2019-07-02 17:46:28--  https://raw.githubusercontent.com/allenai/bi-att-flow/master/squad/evaluate-v1.1.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.8.133
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.8.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3419 (3.3K) [text/plain]
Saving to: ‘evaluate-v1.1.py’


2019-07-02 17:46:29 (99.2 MB/s) - ‘evaluate-v1.1.py’ saved [3419/3419]



In [185]:
!python3 evaluate-v1.1.py dev-v1.1.json predictions.json

{"exact_match": 77.57805108798486, "f1": 86.18327335287402}
