In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf

from scripts import nqa_utils
from scripts.nqa_utils import AnswerType
from scripts import bert_modeling as modeling
from scripts import bert_optimization
from scripts import albert_optimization
from scripts import albert

import tqdm
import json
import absl
import sys
import os

NQA Utils Loaded!


In [2]:
### Define Flags ###

def del_all_flags(FLAGS):
    flags_dict = FLAGS._flags()
    keys_list = [keys for keys in flags_dict]
    for keys in keys_list:
        FLAGS.__delattr__(keys)

del_all_flags(absl.flags.FLAGS)

flags = absl.flags

flags.DEFINE_string(
    "model", "albert",
    "The name of model to use. Choose from ['bert', 'albert'].")

flags.DEFINE_string(
    "config_file", "models/albert_xxl/config.json",
    "The config json file corresponding to the pre-trained BERT/ALBERT model. "
    "This specifies the model architecture.")

flags.DEFINE_string("vocab_file", "models/albert_xxl/vocab/modified-30k-clean.model",
                    "The vocabulary file that the ALBERT/BERT model was trained on.")

flags.DEFINE_string(
    "output_dir", "output/",
    "The output directory where the model checkpoints will be written.")

flags.DEFINE_string("train_precomputed_file", "data/albert_train.tf_record",
                    "Precomputed tf records for training.")

flags.DEFINE_integer("train_num_precomputed", -1,
                     "Number of precomputed tf records for training.")

flags.DEFINE_string(
    "output_checkpoint_file", "tf2_albert_finetuned.ckpt",
    "Where to save finetuned checkpoints to.")

flags.DEFINE_string(
    "output_predictions_file", "predictions.json",
    "Where to print predictions in NQ prediction format, to be passed to"
    "natural_questions.nq_eval.")

flags.DEFINE_string(
    "log_dir", "logs/",
    "Where logs, specifically Tensorboard logs, will be saved to.")

flags.DEFINE_integer(
    "log_freq", 128,
    "How many samples between each training log update.")

flags.DEFINE_string(
    "init_checkpoint", "models/bert_joint_baseline/tf2_bert_joint.ckpt",
    "Initial checkpoint (usually from a pre-trained BERT model).")

flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text. Should be True for uncased "
    "models and False for cased models.")

# This should be changed to 512 at some point,
# as training was done with that value, it may
# not make a big difference though
flags.DEFINE_integer(
    "max_seq_length", 384,
    "The maximum total input sequence length after WordPiece tokenization. "
    "Sequences longer than this will be truncated, and sequences shorter "
    "than this will be padded.")

flags.DEFINE_integer(
    "doc_stride", 128,
    "When splitting up a long document into chunks, how much stride to "
    "take between chunks.")

flags.DEFINE_integer(
    "max_query_length", 64,
    "The maximum number of tokens for the question. Questions longer than "
    "this will be truncated to this length.")

flags.DEFINE_bool("do_train", True, "Whether to run training.")

flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")

flags.DEFINE_integer("train_batch_size", 1, "Total batch size for training.")

flags.DEFINE_integer("predict_batch_size", 8,
                     "Total batch size for predictions.")

flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")

flags.DEFINE_integer("num_train_epochs", 3,
                   "Total number of training epochs to perform.")

flags.DEFINE_float(
    "warmup_proportion", 0.1,
    "Proportion of training to perform linear learning rate warmup for. "
    "E.g., 0.1 = 10% of training.")

flags.DEFINE_integer("save_checkpoints_steps", 10000,
                     "How often to save the model checkpoint.")

flags.DEFINE_integer("iterations_per_loop", 1000,
                     "How many steps to make in each estimator call.")

flags.DEFINE_integer(
    "n_best_size", 20,
    "The total number of n-best predictions to generate in the "
    "nbest_predictions.json output file.")

flags.DEFINE_integer(
    "verbosity", 1, "How verbose our error messages should be")

flags.DEFINE_integer(
    "max_answer_length", 30,
    "The maximum length of an answer that can be generated. This is needed "
    "because the start and end predictions are not conditioned on one another.")

flags.DEFINE_float(
    "include_unknowns", -1.0,
    "If positive, probability of including answers of type `UNKNOWN`.")

flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")

flags.DEFINE_string("tpu_name", None, "Name of the TPU to use.")

flags.DEFINE_string("tpu_zone", None, "Which zone the TPU is in.")

flags.DEFINE_bool("use_one_hot_embeddings", False, "Whether to use use_one_hot_embeddings")

absl.flags.DEFINE_string(
    "gcp_project", None,
    "[Optional] Project name for the Cloud TPU-enabled project. If not "
    "specified, we will attempt to automatically detect the GCE project from "
    "metadata.")

flags.DEFINE_bool(
    "verbose_logging", False,
    "If true, all of the warnings related to data processing will be printed. "
    "A number of warnings are expected for a normal NQ evaluation.")

# TODO(Edan): Look at nested contents too at some point
# Around 5% of long answers are nested, and around 50% of questions have
# long answers
# This means that this setting alone restricts us from a correct answer
# around 2.5% of the time
flags.DEFINE_boolean(
    "skip_nested_contexts", True,
    "Completely ignore context that are not top level nodes in the page.")

flags.DEFINE_integer("task_id", 0,
                     "Train and dev shard to read from and write to.")

flags.DEFINE_integer("max_contexts", 48,
                     "Maximum number of contexts to output for an example.")

flags.DEFINE_integer(
    "max_position", 50,
    "Maximum context position for which to generate special tokens.")

## Custom flags

flags.DEFINE_integer(
    "n_examples", -1,
    "Number of examples to read from files. Only applicable during testing")

flags.DEFINE_string(
    "train_file", "data/simplified-nq-train.jsonl",
    "NQ json for training. E.g., dev-v1.1.jsonl.gz or test-v1.1.jsonl.gz")

## Special flags - do not change

flags.DEFINE_string(
    "predict_file", "data/simplified-nq-test.jsonl",
    "NQ json for predictions. E.g., dev-v1.1.jsonl.gz or test-v1.1.jsonl.gz")
flags.DEFINE_boolean("logtostderr", True, "Logs to stderr")
flags.DEFINE_boolean("undefok", True, "it's okay to be undefined")
flags.DEFINE_string('f', '', 'kernel')
flags.DEFINE_string('HistoryManager.hist_file', '', 'kernel')

FLAGS = flags.FLAGS
FLAGS(sys.argv) # Parse the flags

VOCAB_SIZE = 30209

In [3]:
def blocks(f, size=65536):
    while True:
        b = f.read(size)
        if not b:
            break
        yield b

n_records = 0
for record in tf.compat.v1.python_io.tf_record_iterator(FLAGS.train_precomputed_file):
    n_records += 1

# with open(FLAGS.train_file, 'r') as f:
#     n_train_examples = sum([bl.count('\n') for bl in blocks(f)])

# print('# Training Examples:', n_train_examples)
print('# Training Records:', n_records)

if FLAGS.do_train and FLAGS.train_num_precomputed != n_records:
    print('Changing the number of precomuted records listed to use all avaliable data.')
    FLAGS.train_num_precomputed = n_records

Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
# Training Records: 457137
Changing the number of precomuted records listed to use all avaliable data.


In [4]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [5]:
### Create Generator for Training Data ###

train_filenames = tf.io.gfile.glob(FLAGS.train_precomputed_file)

name_to_features = {
    "input_ids": tf.io.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
    "input_mask": tf.io.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
    "segment_ids": tf.io.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
}
if FLAGS.do_train:
    name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64)
    name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64)
    name_to_features["answer_types"] = tf.io.FixedLenFeature([], tf.int64)

def decode_record(record, name_to_features):
    """Decodes a record to a TensorFlow example."""
    example = tf.io.parse_single_example(serialized=record, features=name_to_features)

    # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
    # So cast all int64 to int32.
    for name in list(example.keys()):
        t = example[name]
        if t.dtype == tf.int64:
            t = tf.cast(t, dtype=tf.int32)
        example[name] = t

    output = ({
        'input_ids': example['input_ids'],
        'input_mask': example['input_mask'],
        'segment_ids': example['segment_ids']
    },
    {
        'tf_op_layer_start_logits': example['start_positions'],
        'tf_op_layer_end_logits': example['end_positions'],
        'ans_type_logits': example['answer_types']
    })

    return output

def data_generator(batch_size=32, seed=42, valid_frac=0.05):
    """The actual input function."""

    # For training, we want a lot of parallel reading and shuffling.
    # For eval, we want no shuffling and parallel reading doesn't matter.
    dataset = tf.data.TFRecordDataset(train_filenames)
    dataset = dataset.map(lambda r: decode_record(r, name_to_features))

    if valid_frac <= 0:
        dataset = dataset.shuffle(buffer_size=20000, seed=seed, reshuffle_each_iteration=True)
        dataset = dataset.batch(batch_size=batch_size, drop_remainder=False)
        dataset = dataset.repeat()
        return dataset, None

    train_size = int(FLAGS.train_num_precomputed * (1.0 - valid_frac))

    train_dataset = dataset.take(train_size)
    valid_dataset = dataset.skip(train_size)

    train_dataset = train_dataset.batch(batch_size=batch_size, drop_remainder=False)
    valid_dataset = valid_dataset.batch(batch_size=batch_size, drop_remainder=False)

#     train_dataset = train_dataset.shuffle(buffer_size=20000, seed=seed, reshuffle_each_iteration=True)
#     valid_dataset = valid_dataset.shuffle(buffer_size=5000, seed=seed, reshuffle_each_iteration=True)

    return train_dataset, valid_dataset

### Train the Model ###

valid_frac = 0.02
train_dataset, valid_dataset = data_generator(batch_size=2000, valid_frac=valid_frac)
n_valid = np.ceil(FLAGS.train_num_precomputed * valid_frac)

In [6]:
print(int(FLAGS.train_num_precomputed * (1.0 - .02)))
print(int(FLAGS.train_num_precomputed * .02))
print(FLAGS.train_num_precomputed)

447994
9142
457137


In [7]:
ds = iter(train_dataset)

In [27]:
a = next(ds)
a = a[0]['input_ids'].numpy()

In [35]:
uq = set([])

for ids in a[:1000]:
    curr_q = tuple(ids[:np.argmax(ids == 3)])
    if not curr_q in uq:
        uq.add(curr_q)
#     print(curr_q)

In [46]:
ds = iter(train_dataset)

example_idx = 0
uq = set([])

while True:
    try:
        a = next(ds)
    except:
        break
        
    a = a[0]['input_ids'].numpy()
    for ids in a:
        curr_q = tuple(ids[:np.argmax(ids == 3)])
        if not curr_q in uq:
            example_idx += 1
            uq.add(curr_q)
            
print(len(uq))

206221


In [47]:
ds = iter(valid_dataset)

example_idx = 0
uq = set([])

while True:
    try:
        a = next(ds)
    except:
        break
        
    a = a[0]['input_ids'].numpy()
    for ids in a:
        curr_q = tuple(ids[:np.argmax(ids == 3)])
        if not curr_q in uq:
            example_idx += 1
            uq.add(curr_q)
            
print(len(uq))

8936


In [16]:
target_ids = [13, 5, 836, 13, 6, 1488, 713, 13, 8, 400, 4436, 13, 8, 9427, 25983, 13, 8, 713, 13, 5, 64, 1322, 13, 15, 24305, 13, 6, 13246, 268, 3526, 1996, 22651, 12421, 34, 1119, 709, 34, 17119, 17, 18650, 30164, 30000, 201, 16, 4194, 25, 21, 1155, 4906, 5270, 1211, 34, 8683, 103, 25113, 467, 19, 1089, 13, 9, 14, 8798, 25, 15197, 9102, 7570, 242, 248, 13, 1, 4194, 13, 22, 22, 22030, 13, 15, 40, 909, 883, 37, 6162, 49, 23656, 72, 17707, 1549, 16, 23291, 17, 5713, 856, 37, 40, 274, 348, 13, 9, 24, 16309, 13, 21811, 509, 75, 21, 23808, 133, 20721, 27, 21, 22145, 19, 14, 1819, 2144, 29, 21, 6192, 5825, 377, 1098, 5197, 13, 9, 30165, 30001, 14, 1211, 63, 945, 91, 119, 652, 507, 3298, 3497, 13, 9, 32, 23, 4682, 34, 35, 639, 355, 479, 3107, 1947, 115, 142, 2217, 34, 680, 251, 7721, 836, 13, 15, 56, 467, 32, 19, 299, 1089, 13, 9, 14, 1063, 1322, 230, 14, 169, 19717, 1507, 26, 3209, 14, 249, 159, 13, 9, 32, 23, 67, 2519, 26, 16361, 603, 13, 22, 18, 836, 11137, 973, 13, 15, 113, 32, 23, 1571, 69, 34, 1314, 7477, 1358, 13, 9, 30166, 30002, 14, 484, 4064, 644, 22, 22164, 121, 4194, 23, 2519, 19, 14, 484, 16361, 615, 16, 14, 3013, 1009, 2535, 1746, 8828, 99, 18, 13, 15, 113, 32, 23, 1571, 69, 34, 8219, 1334, 1944, 13, 9, 14, 1211, 230, 14, 973, 1607, 18031, 1507, 13, 15, 21, 180, 1180, 1211, 450, 13, 9, 19, 894, 13, 15, 32, 230, 14, 2282, 13, 118, 1819, 189, 450, 26, 2233, 19, 246, 3035, 3209, 26, 122, 1089, 13, 8, 8, 973, 13, 9, 19, 563, 32, 23, 4876, 77, 21, 1580, 171, 1012, 34, 13, 1712, 1358, 29, 21, 8413, 34, 684, 12507, 62, 13, 9, 30167, 30003, 14, 1211, 2661, 29, 21, 1945, 37, 14, 1314, 13, 15, 56, 25, 40, 9315, 141]

In [12]:
target_ids_2 = [34, 680, 251, 7721, 836, 13, 15, 56, 467, 32, 19, 299, 1089, 13, 9, 14, 1063, 1322, 230, 14, 169, 19717, 1507, 26, 3209, 14, 249, 159, 13, 9, 32, 23, 67, 2519, 26, 16361, 603, 13, 22, 18, 836, 11137, 973, 13, 15, 113, 32, 23, 1571, 69, 34, 1314, 7477, 1358, 13, 9, 30166, 30002, 14, 484, 4064, 644, 22, 22164, 121, 4194, 23, 2519, 19, 14, 484, 16361, 615, 16, 14, 3013, 1009, 2535, 1746, 8828, 99, 18, 13, 15, 113, 32, 23, 1571, 69, 34, 8219, 1334, 1944, 13, 9, 14, 1211, 230, 14, 973, 1607, 18031, 1507, 13, 15, 21, 180, 1180, 1211, 450, 13, 9, 19, 894, 13, 15, 32, 230, 14, 2282, 13, 118, 1819, 189, 450, 26, 2233, 19, 246, 3035, 3209, 26, 122, 1089, 13, 8, 8, 973, 13, 9, 19, 563, 32, 23, 4876, 77, 21, 1580, 171, 1012, 34, 13, 1712, 1358, 29, 21, 8413, 34, 684, 12507, 62, 13, 9, 30167, 30003, 14, 1211, 2661, 29, 21, 1945, 37, 14, 1314, 13, 15, 56, 25, 40, 9315, 141, 16, 32, 13, 9, 15200, 13, 15, 14, 1945, 4359, 2894, 6079, 963, 13, 9, 32, 2589, 20, 4088, 17, 16525, 53, 16, 14, 360, 13, 22, 18, 407, 6876, 13, 45, 14, 23490, 16, 1770, 13, 9, 30168, 30004, 201, 16, 4194, 25, 17640, 77, 132, 4501, 13, 45, 30169, 30005, 19, 14, 64, 1050, 13, 15, 14, 407, 925, 13, 15, 34, 14, 204, 16, 15197, 9102, 22030, 13, 15, 40, 3035, 1155, 13, 15, 27704, 18, 88, 33, 4292, 19, 739, 13, 9, 33, 321, 258, 18, 21, 6655, 19, 6162, 49, 23656, 13, 9, 14, 23663, 1927, 14, 190, 29, 21, 3109, 24578, 8937, 17, 109, 3260, 16, 2383, 6182, 13, 9, 30170, 30006]

In [10]:
def contains(small, big):
    for i in range(len(big)-len(small)+1):
        for j in range(len(small)):
            if big[i+j] != small[j]:
                break
        else:
            return i, i+len(small)
    return False

In [17]:
ds = iter(train_dataset)

b = 0
while True:
    try:
        a = next(ds)
    except:
        break
        
    a = a[0]['input_ids'].numpy()
    for i, ids in enumerate(a):
        if contains(target_ids, ids):
            print('Found:', b*2000 + i)
    b += 1
    
    if b % 5 == 0:
        print(b * 2000)

10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
Found: 118294
120000
130000
140000
150000
160000
170000
180000
190000
200000
210000
220000
230000
240000
250000
260000
270000
280000
290000
300000
310000
320000
330000
340000
350000
360000
370000
Found: 376612
380000
390000
400000
410000
420000
430000
440000


In [13]:
ds = iter(train_dataset)

b = 0
while True:
    try:
        a = next(ds)
    except:
        break
        
    a = a[0]['input_ids'].numpy()
    for i, ids in enumerate(a):
        if contains(target_ids_2, ids):
            print('Found:', b*2000 + i)
    b += 1
    
    if b % 5 == 0:
        print(b * 2000)

10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
Found: 117050
120000
130000
140000
150000
160000
170000
180000
190000
200000
210000
220000
230000
240000
250000
260000
270000
280000
290000
300000
310000
320000
330000
340000
350000
360000
370000
380000
390000
400000
410000
420000
430000
440000


In [15]:
ds = iter(valid_dataset)

b = 0
while True:
    try:
        a = next(ds)
    except:
        break
        
    a = a[0]['input_ids'].numpy()
    for i, ids in enumerate(a):
        if contains(target_ids_2, ids):
            print('Found:', b*2000 + i)
    b += 1
    
    if b % 5 == 0:
        print(b * 2000)

10000


In [8]:
ds = iter(train_dataset)
for i in range(188):
    next(ds)

data = next(ds)

In [11]:
nd = {k: x[612] for k, x in data[0].items()}
print(contains(target_ids, nd['input_ids']))
print(nd['input_ids'])

(162, 182)
tf.Tensor(
[    2 30204    98    25    14  5825    19   201    16  4194     3 30150
 30207 30151 30050   201    16  4194   201    16  4194  1227  1314  8683
   103 25113   501   581   201    16  4194   475   836   816   486  5975
  3209  5916   680   251  7721   836  3176  1231   299   547    13    15
  1089    13     5   836    13     6  1488   713    13     8   400  4436
    13     8  9427 25983    13     8   713    13     5    64  1322    13
    15 24305    13     6 13246   268  3526  1996 22651 12421    34  1119
   709    34 17119    17 18650 30164 30000   201    16  4194    25    21
  1155  4906  5270  1211    34  8683   103 25113   467    19  1089    13
     9    14  8798    25 15197  9102  7570   242   248    13     1  4194
    13    22    22 22030    13    15    40   909   883    37  6162    49
 23656    72 17707  1549    16 23291    17  5713   856    37    40   274
   348    13     9    24 16309    13 21811   509    75    21 23808   133
 20721    27    21 22145    1