In [1]:
from __future__ import absolute_import, division, print_function

import argparse
import logging
import os
import random
import glob

import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from tensorboardX import SummaryWriter

from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
                                  BertForQuestionAnswering, BertTokenizer,
                                  XLMConfig, XLMForQuestionAnswering,
                                  XLMTokenizer, XLNetConfig,
                                  XLNetForQuestionAnswering,
                                  XLNetTokenizer,
                                  DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)

from pytorch_transformers import AdamW, WarmupLinearSchedule

from utils_squad import (read_squad_examples, convert_examples_to_features,
                         RawResult, write_predictions,
                         RawResultExtended, write_predictions_extended)
from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad

#import tensorflow as tf

#from keras.backend.tensorflow_backend import set_session
#config = tf.ConfigProto()
#config.gpu_options.allow_growth = True
#config.gpu_options.visible_device_list = "0" #only the gpu 0 is allowed

# config.gpu_options.per_process_gpu_memory_fraction = 0.01

#set_session(tf.Session(config=config))

In [2]:
logger = logging.getLogger(__name__)

ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
                  for conf in (BertConfig, XLNetConfig, XLMConfig)), ())

In [3]:
MODEL_CLASSES = {
    'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
    'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
    'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
    'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
}

In [4]:
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

def to_list(tensor):
    return tensor.detach().cpu().tolist()

In [5]:
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {'input_ids':       batch[0],
                      'attention_mask':  batch[1], 
                      'token_type_ids':  None if args.model_type == 'xlm' else batch[2],  
                      'start_positions': batch[3], 
                      'end_positions':   batch[4]}
            if args.model_type in ['xlnet', 'xlm']:
                inputs.update({'cls_index': batch[5],
                               'p_mask':       batch[6]})
            outputs = model(**inputs)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step

In [6]:
def evaluate(args, model, tokenizer, prefix=""):
    dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)

    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly
    eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    all_results = []
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)
        with torch.no_grad():
            inputs = {'input_ids':      batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': None if args.model_type == 'xlm' else batch[2]  # XLM don't use segment_ids
                      }
            example_indices = batch[3]
            if args.model_type in ['xlnet', 'xlm']:
                inputs.update({'cls_index': batch[4],
                               'p_mask':    batch[5]})
            outputs = model(**inputs)

        for i, example_index in enumerate(example_indices):
            eval_feature = features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
            if args.model_type in ['xlnet', 'xlm']:
                # XLNet uses a more complex post-processing procedure
                result = RawResultExtended(unique_id            = unique_id,
                                           start_top_log_probs  = to_list(outputs[0][i]),
                                           start_top_index      = to_list(outputs[1][i]),
                                           end_top_log_probs    = to_list(outputs[2][i]),
                                           end_top_index        = to_list(outputs[3][i]),
                                           cls_logits           = to_list(outputs[4][i]))
            else:
                result = RawResult(unique_id    = unique_id,
                                   start_logits = to_list(outputs[0][i]),
                                   end_logits   = to_list(outputs[1][i]))
            all_results.append(result)

    # Compute predictions
    output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
    output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
    if args.version_2_with_negative:
        output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
    else:
        output_null_log_odds_file = None

    if args.model_type in ['xlnet', 'xlm']:
        # XLNet uses a more complex post-processing procedure
        write_predictions_extended(examples, features, all_results, args.n_best_size,
                        args.max_answer_length, output_prediction_file,
                        output_nbest_file, output_null_log_odds_file, args.predict_file,
                        model.config.start_n_top, model.config.end_n_top,
                        args.version_2_with_negative, tokenizer, args.verbose_logging)
    else:
        write_predictions(examples, features, all_results, args.n_best_size,
                        args.max_answer_length, args.do_lower_case, output_prediction_file,
                        output_nbest_file, output_null_log_odds_file, args.verbose_logging,
                        args.version_2_with_negative, args.null_score_diff_threshold)

    # Evaluate with the official SQuAD script
    evaluate_options = EVAL_OPTS(data_file=args.predict_file,
                                 pred_file=output_prediction_file,
                                 na_prob_file=output_null_log_odds_file)
    results = evaluate_on_squad(evaluate_options)
    return results

In [7]:
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    # Load data features from cache or dataset file
    input_file = args.predict_file if evaluate else args.train_file
    cached_features_file = 'cached_features_file'
    test_features_file = 'test_features_file'
    if os.path.exists(cached_features_file) and input_file == args.train_file:
        logger.info("Loading features from cached file %s", cached_features_file)
        examples = read_squad_examples(input_file=input_file,
                                            is_training=not evaluate,
                                            version_2_with_negative=args.version_2_with_negative)
        features = torch.load(cached_features_file)
    elif os.path.exists(test_features_file) and input_file == args.predict_file:
        logger.info("Loading features from cached file %s", cached_features_file)
        examples = read_squad_examples(input_file=input_file,
                                            is_training=not evaluate,
                                            version_2_with_negative=args.version_2_with_negative)
        features = torch.load(test_features_file)
        
    else:
        logger.info("Creating features from dataset file at %s", input_file)
        examples = read_squad_examples(input_file=input_file,
                                            is_training=not evaluate,
                                            version_2_with_negative=args.version_2_with_negative)
        features = convert_examples_to_features(examples=examples,
                                            tokenizer=tokenizer,
                                            max_seq_length=args.max_seq_length,
                                            doc_stride=args.doc_stride,
                                            max_query_length=args.max_query_length,
                                            is_training=not evaluate)
        # save train features into cached_features_file
        if args.local_rank in [-1, 0] and input_file == args.train_file:
            logger.info("Saving features into cached file %s", cached_features_file)
            torch.save(features, cached_features_file)
        # save test features into test_features_file
        if args.local_rank in [-1, 0] and input_file == args.predict_file:
            logger.info("Saving features into test file %s", test_features_file)
            torch.save(features, test_features_file)
    if args.local_rank == 0 and not evaluate:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
    all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
    if evaluate:
        all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                all_example_index, all_cls_index, all_p_mask)
    else:
        all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
        all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                all_start_positions, all_end_positions,
                                all_cls_index, all_p_mask)

    if output_examples:
        return dataset, examples, features
    return dataset

In [8]:
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--train_file", default=None, type=str, required=True,
                        help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument("--predict_file", default=None, type=str, required=True,
                        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
    parser.add_argument("--model_type", default=None, type=str, required=True,
                        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
                        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model checkpoints and predictions will be written.")

    ## Other parameters
    parser.add_argument("--config_name", default="", type=str,
                        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--tokenizer_name", default="", type=str,
                        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument("--cache_dir", default="", type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")

    parser.add_argument('--version_2_with_negative', action='store_true',
                        help='If true, the SQuAD examples contain some that do not have an answer.')
    parser.add_argument('--null_score_diff_threshold', type=float, default=0.0,
                        help="If null_score - best_non_null is greater than the threshold predict null.")

    parser.add_argument("--max_seq_length", default=384, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
                             "longer than this will be truncated, and sequences shorter than this will be padded.")
    parser.add_argument("--doc_stride", default=128, type=int,
                        help="When splitting up a long document into chunks, how much stride to take between chunks.")
    parser.add_argument("--max_query_length", default=64, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--do_train", action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--evaluate_during_training", action='store_true',
                        help="Rul evaluation during training at each logging step.")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument("--learning_rate", default=5e-5, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--weight_decay", default=0.0, type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--max_steps", default=-1, type=int,
                        help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
    parser.add_argument("--warmup_steps", default=0, type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--n_best_size", default=20, type=int,
                        help="The total number of n-best predictions to generate in the nbest_predictions.json output file.")
    parser.add_argument("--max_answer_length", default=30, type=int,
                        help="The maximum length of an answer that can be generated. This is needed because the start "
                             "and end predictions are not conditioned on one another.")
    parser.add_argument("--verbose_logging", action='store_true',
                        help="If true, all of the warnings related to data processing will be printed. "
                             "A number of warnings are expected for a normal SQuAD evaluation.")

    parser.add_argument('--logging_steps', type=int, default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps', type=int, default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument("--eval_all_checkpoints", action='store_true',
                        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
    parser.add_argument("--no_cuda", action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--overwrite_output_dir', action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument('--overwrite_cache', action='store_true',
                        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed', type=int, default=42,
                        help="random seed for initialization")

    parser.add_argument("--local_rank", type=int, default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    #args = parser.parse_args()

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
        raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))

    # Setup distant debugging if needed
    #if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        #import ptvsd
        #print("Waiting for debugger attach")
        #ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        #ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    #if args.local_rank == -1 or args.no_cuda:
        #device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        #args.n_gpu = torch.cuda.device_count()
    #else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        #torch.cuda.set_device(args.local_rank)
        #device = torch.device("cuda", args.local_rank)
        #torch.distributed.init_process_group(backend='nccl')
        #args.n_gpu = 1
    args.device = device

    # Setup logging
    #logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        #datefmt = '%m/%d/%Y %H:%M:%S',
                        #level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    #logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
                    #args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)

    # Set seed
    #set_seed(args)

    # Load pretrained model and tokenizer
    #if args.local_rank not in [-1, 0]:
        #torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(args.model_name)
    tokenizer = tokenizer_class.from_pretrained(args.model_name)
    model = model_class.from_pretrained(args.model_name, config=config)
    #if args.local_rank == 0:
        #torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
        train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)


    # Save the trained model and the tokenizer
    if args.do_train: #and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))

        # Load a trained model and vocabulary that you have fine-tuned
        model = model_class.from_pretrained(args.output_dir)
        tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
        model.to(args.device)


    # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
    results = {}
    if args.do_eval: #and args.local_rank in [-1, 0]:
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce model loading logs

        logger.info("Evaluate the following checkpoints: %s", checkpoints)

        for checkpoint in checkpoints:
            # Reload the model
            global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
            model = model_class.from_pretrained(checkpoint)
            model.to(args.device)

            # Evaluate
            result = evaluate(args, model, tokenizer, prefix=global_step)

            result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items())
            results.update(result)

    logger.info("Results: {}".format(results))

    return results

In [9]:
class struct():
    def __init__(self):
        self.train_file = 'train-v2.0.json'
        self.predict_file = 'dev-v2.0.json'
        # exam the new dataset
        #self.predict_file = 'test.json'
        self.model_type = 'bert'
        self.model_name = 'bert-large-uncased-whole-word-masking'
        self.task_name = 'MRPC'
        self.do_train = True
        self.do_eval = True
        self.do_lower_case = True
        self.data_dir = 'GLUE_DIR/MRPC/'
        self.max_seq_length = 384
        self.per_gpu_eval_batch_size = 12
        self.per_gpu_train_batch_size = 12
        # the default is 2, we make it quicker
        self.num_train_epochs = 2.0
        self.learning_rate = 3e-5
        self.output_dir = 'tmp/mrpc_output/'
        self.overwrite_output_dir = True
        self.overwrite_cache = True   
        self.local_rank = -1
        self.version_2_with_negative = True
        self.doc_stride = 128
        self.max_query_length = 64
        self.n_gpu = 1
        self.max_steps = -1
        self.gradient_accumulation_steps = 1
        self.weight_decay = 0
        self.adam_epsilon = 1e-8
        self.max_grad_norm = 1
        self.warmup_steps = 0
        self.n_best_size = 20
        self.max_answer_length = 30
        self.verbose_logging = True
        self.logging_steps = 50
        self.save_steps = 50000
        self.fp16 = False
        self.fp16_opt_level = 'O1'
        self.seed = 42
        self.no_cuda = False
        self.evaluate_during_training = False
        

In [10]:
args = struct()

In [11]:
args.predict_file

'dev-v2.0.json'

In [12]:
input_file = args.train_file
input_file

'train-v2.0.json'

In [13]:
read_squad_examples(input_file=input_file,
                            is_training=not evaluate,
                            version_2_with_negative=False)

[qas_id: 56be85543aeaaa14008c9063, question_text: When did Beyonce start becoming popular?, doc_tokens: [Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny's Child. Managed by her father, Mathew Knowles, the group became one of the world's best-selling girl groups of all time. Their hiatus saw the release of Beyoncé's debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles "Crazy in Love" and "Baby Boy".],
 qas_id: 56be85543aeaaa14008c9065, question_text: What areas did Beyonce compete in when she was growing up?, doc_tokens: [Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4

In [13]:
# set up device
if args.local_rank == -1 or args.no_cuda:
    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = torch.cuda.device_count()
else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.cuda.set_device(args.local_rank)
    device = torch.device("cuda", args.local_rank)
    torch.distributed.init_process_group(backend='nccl')
    args.n_gpu = 1
args.device = device

In [15]:
# prepare model and dataset
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.model_name)
tokenizer = tokenizer_class.from_pretrained(args.model_name)
model = model_class.from_pretrained(args.model_name, config=config)

model.to(args.device)
logger.info("Training/evaluation parameters %s", args)
    
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)

In [14]:
args.fp16

False

In [17]:
# start training
global_step, tr_loss = train(args, train_dataset, model, tokenizer)

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 1/5498 [00:05<9:04:08,  5.94s/it][A
Iteration:   0%|          | 2/5498 [00:06<6:49:13,  4.47s/it][A
Iteration:   0%|          | 3/5498 [00:08<5:14:47,  3.44s/it][A
Iteration:   0%|          | 4/5498 [00:09<4:08:37,  2.72s/it][A
Iteration:   0%|          | 5/5498 [00:10<3:22:28,  2.21s/it][A
Iteration:   0%|          | 6/5498 [00:11<2:50:00,  1.86s/it][A
Iteration:   0%|          | 7/5498 [00:12<2:27:57,  1.62s/it][A
Iteration:   0%|          | 8/5498 [00:13<2:12:05,  1.44s/it][A
Iteration:   0%|          | 9/5498 [00:14<2:00:56,  1.32s/it][A
Iteration:   0%|          | 10/5498 [00:15<1:53:03,  1.24s/it][A
Iteration:   0%|          | 11/5498 [00:16<1:47:38,  1.18s/it][A
Iteration:   0%|          | 12/5498 [00:17<1:43:49,  1.14s/it][A
Iteration:   0%|          | 13/5498 [00:18<1:40:59,  1.10s/it][A
Iteration:   0%|          | 14/5498 [00:19<1:39:27,  1.09s/it][A
Iteration:   0%|          | 15/5498 [00

Iteration:   4%|▍         | 241/5498 [04:18<1:32:11,  1.05s/it][A
Iteration:   4%|▍         | 242/5498 [04:19<1:32:09,  1.05s/it][A
Iteration:   4%|▍         | 243/5498 [04:20<1:32:17,  1.05s/it][A
Iteration:   4%|▍         | 244/5498 [04:21<1:32:07,  1.05s/it][A
Iteration:   4%|▍         | 245/5498 [04:22<1:32:10,  1.05s/it][A
Iteration:   4%|▍         | 246/5498 [04:23<1:32:03,  1.05s/it][A
Iteration:   4%|▍         | 247/5498 [04:24<1:31:54,  1.05s/it][A
Iteration:   5%|▍         | 248/5498 [04:25<1:31:48,  1.05s/it][A
Iteration:   5%|▍         | 249/5498 [04:26<1:31:46,  1.05s/it][A
Iteration:   5%|▍         | 250/5498 [04:27<1:31:49,  1.05s/it][A
Iteration:   5%|▍         | 251/5498 [04:28<1:31:52,  1.05s/it][A
Iteration:   5%|▍         | 252/5498 [04:29<1:31:56,  1.05s/it][A
Iteration:   5%|▍         | 253/5498 [04:30<1:31:50,  1.05s/it][A
Iteration:   5%|▍         | 254/5498 [04:31<1:31:42,  1.05s/it][A
Iteration:   5%|▍         | 255/5498 [04:32<1:31:44,  1.05s/it

Iteration:   9%|▉         | 485/5498 [08:35<1:27:51,  1.05s/it][A
Iteration:   9%|▉         | 486/5498 [08:36<1:27:46,  1.05s/it][A
Iteration:   9%|▉         | 487/5498 [08:38<1:27:50,  1.05s/it][A
Iteration:   9%|▉         | 488/5498 [08:39<1:27:52,  1.05s/it][A
Iteration:   9%|▉         | 489/5498 [08:40<1:27:49,  1.05s/it][A
Iteration:   9%|▉         | 490/5498 [08:41<1:27:46,  1.05s/it][A
Iteration:   9%|▉         | 491/5498 [08:42<1:27:48,  1.05s/it][A
Iteration:   9%|▉         | 492/5498 [08:43<1:27:43,  1.05s/it][A
Iteration:   9%|▉         | 493/5498 [08:44<1:27:40,  1.05s/it][A
Iteration:   9%|▉         | 494/5498 [08:45<1:27:40,  1.05s/it][A
Iteration:   9%|▉         | 495/5498 [08:46<1:27:38,  1.05s/it][A
Iteration:   9%|▉         | 496/5498 [08:47<1:27:37,  1.05s/it][A
Iteration:   9%|▉         | 497/5498 [08:48<1:27:48,  1.05s/it][A
Iteration:   9%|▉         | 498/5498 [08:49<1:27:51,  1.05s/it][A
Iteration:   9%|▉         | 499/5498 [08:50<1:27:48,  1.05s/it

Iteration:  13%|█▎        | 729/5498 [12:53<1:23:44,  1.05s/it][A
Iteration:  13%|█▎        | 730/5498 [12:54<1:23:41,  1.05s/it][A
Iteration:  13%|█▎        | 731/5498 [12:55<1:23:43,  1.05s/it][A
Iteration:  13%|█▎        | 732/5498 [12:56<1:23:36,  1.05s/it][A
Iteration:  13%|█▎        | 733/5498 [12:57<1:23:33,  1.05s/it][A
Iteration:  13%|█▎        | 734/5498 [12:58<1:23:36,  1.05s/it][A
Iteration:  13%|█▎        | 735/5498 [12:59<1:23:32,  1.05s/it][A
Iteration:  13%|█▎        | 736/5498 [13:00<1:23:28,  1.05s/it][A
Iteration:  13%|█▎        | 737/5498 [13:01<1:23:34,  1.05s/it][A
Iteration:  13%|█▎        | 738/5498 [13:02<1:23:35,  1.05s/it][A
Iteration:  13%|█▎        | 739/5498 [13:04<1:23:37,  1.05s/it][A
Iteration:  13%|█▎        | 740/5498 [13:05<1:23:39,  1.05s/it][A
Iteration:  13%|█▎        | 741/5498 [13:06<1:23:32,  1.05s/it][A
Iteration:  13%|█▎        | 742/5498 [13:07<1:23:25,  1.05s/it][A
Iteration:  14%|█▎        | 743/5498 [13:08<1:23:25,  1.05s/it

Iteration:  18%|█▊        | 973/5498 [17:11<1:19:09,  1.05s/it][A
Iteration:  18%|█▊        | 974/5498 [17:12<1:19:17,  1.05s/it][A
Iteration:  18%|█▊        | 975/5498 [17:13<1:19:18,  1.05s/it][A
Iteration:  18%|█▊        | 976/5498 [17:14<1:19:17,  1.05s/it][A
Iteration:  18%|█▊        | 977/5498 [17:15<1:19:15,  1.05s/it][A
Iteration:  18%|█▊        | 978/5498 [17:16<1:19:15,  1.05s/it][A
Iteration:  18%|█▊        | 979/5498 [17:17<1:19:08,  1.05s/it][A
Iteration:  18%|█▊        | 980/5498 [17:18<1:19:12,  1.05s/it][A
Iteration:  18%|█▊        | 981/5498 [17:19<1:19:11,  1.05s/it][A
Iteration:  18%|█▊        | 982/5498 [17:20<1:19:12,  1.05s/it][A
Iteration:  18%|█▊        | 983/5498 [17:21<1:19:16,  1.05s/it][A
Iteration:  18%|█▊        | 984/5498 [17:22<1:19:22,  1.06s/it][A
Iteration:  18%|█▊        | 985/5498 [17:23<1:19:25,  1.06s/it][A
Iteration:  18%|█▊        | 986/5498 [17:24<1:19:39,  1.06s/it][A
Iteration:  18%|█▊        | 987/5498 [17:25<1:19:35,  1.06s/it

Iteration:  22%|██▏       | 1213/5498 [21:24<1:15:10,  1.05s/it][A
Iteration:  22%|██▏       | 1214/5498 [21:25<1:15:10,  1.05s/it][A
Iteration:  22%|██▏       | 1215/5498 [21:26<1:15:08,  1.05s/it][A
Iteration:  22%|██▏       | 1216/5498 [21:27<1:15:03,  1.05s/it][A
Iteration:  22%|██▏       | 1217/5498 [21:28<1:15:06,  1.05s/it][A
Iteration:  22%|██▏       | 1218/5498 [21:29<1:14:58,  1.05s/it][A
Iteration:  22%|██▏       | 1219/5498 [21:30<1:15:16,  1.06s/it][A
Iteration:  22%|██▏       | 1220/5498 [21:31<1:15:17,  1.06s/it][A
Iteration:  22%|██▏       | 1221/5498 [21:32<1:15:18,  1.06s/it][A
Iteration:  22%|██▏       | 1222/5498 [21:33<1:15:04,  1.05s/it][A
Iteration:  22%|██▏       | 1223/5498 [21:34<1:14:55,  1.05s/it][A
Iteration:  22%|██▏       | 1224/5498 [21:35<1:14:50,  1.05s/it][A
Iteration:  22%|██▏       | 1225/5498 [21:36<1:14:48,  1.05s/it][A
Iteration:  22%|██▏       | 1226/5498 [21:38<1:14:46,  1.05s/it][A
Iteration:  22%|██▏       | 1227/5498 [21:39<1:1

Iteration:  26%|██▋       | 1453/5498 [25:37<1:11:36,  1.06s/it][A
Iteration:  26%|██▋       | 1454/5498 [25:38<1:11:17,  1.06s/it][A
Iteration:  26%|██▋       | 1455/5498 [25:39<1:11:00,  1.05s/it][A
Iteration:  26%|██▋       | 1456/5498 [25:40<1:10:51,  1.05s/it][A
Iteration:  27%|██▋       | 1457/5498 [25:41<1:10:40,  1.05s/it][A
Iteration:  27%|██▋       | 1458/5498 [25:42<1:10:53,  1.05s/it][A
Iteration:  27%|██▋       | 1459/5498 [25:43<1:11:07,  1.06s/it][A
Iteration:  27%|██▋       | 1460/5498 [25:44<1:10:54,  1.05s/it][A
Iteration:  27%|██▋       | 1461/5498 [25:45<1:10:56,  1.05s/it][A
Iteration:  27%|██▋       | 1462/5498 [25:46<1:10:52,  1.05s/it][A
Iteration:  27%|██▋       | 1463/5498 [25:48<1:10:47,  1.05s/it][A
Iteration:  27%|██▋       | 1464/5498 [25:49<1:10:47,  1.05s/it][A
Iteration:  27%|██▋       | 1465/5498 [25:50<1:10:40,  1.05s/it][A
Iteration:  27%|██▋       | 1466/5498 [25:51<1:10:36,  1.05s/it][A
Iteration:  27%|██▋       | 1467/5498 [25:52<1:1

Iteration:  31%|███       | 1693/5498 [29:50<1:06:49,  1.05s/it][A
Iteration:  31%|███       | 1694/5498 [29:51<1:06:39,  1.05s/it][A
Iteration:  31%|███       | 1695/5498 [29:52<1:06:44,  1.05s/it][A
Iteration:  31%|███       | 1696/5498 [29:53<1:06:37,  1.05s/it][A
Iteration:  31%|███       | 1697/5498 [29:54<1:06:36,  1.05s/it][A
Iteration:  31%|███       | 1698/5498 [29:55<1:06:43,  1.05s/it][A
Iteration:  31%|███       | 1699/5498 [29:56<1:06:35,  1.05s/it][A
Iteration:  31%|███       | 1700/5498 [29:57<1:06:34,  1.05s/it][A
Iteration:  31%|███       | 1701/5498 [29:59<1:06:36,  1.05s/it][A
Iteration:  31%|███       | 1702/5498 [30:00<1:06:37,  1.05s/it][A
Iteration:  31%|███       | 1703/5498 [30:01<1:10:17,  1.11s/it][A
Iteration:  31%|███       | 1704/5498 [30:02<1:09:08,  1.09s/it][A
Iteration:  31%|███       | 1705/5498 [30:03<1:08:21,  1.08s/it][A
Iteration:  31%|███       | 1706/5498 [30:04<1:07:41,  1.07s/it][A
Iteration:  31%|███       | 1707/5498 [30:05<1:0

Iteration:  35%|███▌      | 1933/5498 [34:04<1:02:43,  1.06s/it][A
Iteration:  35%|███▌      | 1934/5498 [34:05<1:02:37,  1.05s/it][A
Iteration:  35%|███▌      | 1935/5498 [34:06<1:02:34,  1.05s/it][A
Iteration:  35%|███▌      | 1936/5498 [34:07<1:02:35,  1.05s/it][A
Iteration:  35%|███▌      | 1937/5498 [34:08<1:02:36,  1.05s/it][A
Iteration:  35%|███▌      | 1938/5498 [34:09<1:02:36,  1.06s/it][A
Iteration:  35%|███▌      | 1939/5498 [34:10<1:02:35,  1.06s/it][A
Iteration:  35%|███▌      | 1940/5498 [34:11<1:02:31,  1.05s/it][A
Iteration:  35%|███▌      | 1941/5498 [34:12<1:02:26,  1.05s/it][A
Iteration:  35%|███▌      | 1942/5498 [34:13<1:02:25,  1.05s/it][A
Iteration:  35%|███▌      | 1943/5498 [34:14<1:02:26,  1.05s/it][A
Iteration:  35%|███▌      | 1944/5498 [34:15<1:02:23,  1.05s/it][A
Iteration:  35%|███▌      | 1945/5498 [34:16<1:02:22,  1.05s/it][A
Iteration:  35%|███▌      | 1946/5498 [34:17<1:02:20,  1.05s/it][A
Iteration:  35%|███▌      | 1947/5498 [34:18<1:0

Iteration:  40%|███▉      | 2176/5498 [38:20<58:15,  1.05s/it][A
Iteration:  40%|███▉      | 2177/5498 [38:21<58:14,  1.05s/it][A
Iteration:  40%|███▉      | 2178/5498 [38:22<58:10,  1.05s/it][A
Iteration:  40%|███▉      | 2179/5498 [38:23<58:09,  1.05s/it][A
Iteration:  40%|███▉      | 2180/5498 [38:24<58:07,  1.05s/it][A
Iteration:  40%|███▉      | 2181/5498 [38:25<58:03,  1.05s/it][A
Iteration:  40%|███▉      | 2182/5498 [38:26<58:11,  1.05s/it][A
Iteration:  40%|███▉      | 2183/5498 [38:27<58:01,  1.05s/it][A
Iteration:  40%|███▉      | 2184/5498 [38:29<57:59,  1.05s/it][A
Iteration:  40%|███▉      | 2185/5498 [38:30<57:58,  1.05s/it][A
Iteration:  40%|███▉      | 2186/5498 [38:31<57:58,  1.05s/it][A
Iteration:  40%|███▉      | 2187/5498 [38:32<58:14,  1.06s/it][A
Iteration:  40%|███▉      | 2188/5498 [38:33<58:18,  1.06s/it][A
Iteration:  40%|███▉      | 2189/5498 [38:34<58:16,  1.06s/it][A
Iteration:  40%|███▉      | 2190/5498 [38:35<58:18,  1.06s/it][A
Iteration:

Iteration:  44%|████▍     | 2424/5498 [42:42<54:51,  1.07s/it][A
Iteration:  44%|████▍     | 2425/5498 [42:43<54:31,  1.06s/it][A
Iteration:  44%|████▍     | 2426/5498 [42:44<54:16,  1.06s/it][A
Iteration:  44%|████▍     | 2427/5498 [42:45<54:06,  1.06s/it][A
Iteration:  44%|████▍     | 2428/5498 [42:46<54:03,  1.06s/it][A
Iteration:  44%|████▍     | 2429/5498 [42:47<53:58,  1.06s/it][A
Iteration:  44%|████▍     | 2430/5498 [42:48<53:52,  1.05s/it][A
Iteration:  44%|████▍     | 2431/5498 [42:49<53:49,  1.05s/it][A
Iteration:  44%|████▍     | 2432/5498 [42:50<53:49,  1.05s/it][A
Iteration:  44%|████▍     | 2433/5498 [42:51<53:45,  1.05s/it][A
Iteration:  44%|████▍     | 2434/5498 [42:52<53:42,  1.05s/it][A
Iteration:  44%|████▍     | 2435/5498 [42:53<53:39,  1.05s/it][A
Iteration:  44%|████▍     | 2436/5498 [42:54<53:39,  1.05s/it][A
Iteration:  44%|████▍     | 2437/5498 [42:55<53:41,  1.05s/it][A
Iteration:  44%|████▍     | 2438/5498 [42:57<53:40,  1.05s/it][A
Iteration:

Iteration:  49%|████▊     | 2672/5498 [47:04<49:54,  1.06s/it][A
Iteration:  49%|████▊     | 2673/5498 [47:05<49:55,  1.06s/it][A
Iteration:  49%|████▊     | 2674/5498 [47:06<49:55,  1.06s/it][A
Iteration:  49%|████▊     | 2675/5498 [47:07<52:46,  1.12s/it][A
Iteration:  49%|████▊     | 2676/5498 [47:08<51:51,  1.10s/it][A
Iteration:  49%|████▊     | 2677/5498 [47:09<51:15,  1.09s/it][A
Iteration:  49%|████▊     | 2678/5498 [47:10<50:51,  1.08s/it][A
Iteration:  49%|████▊     | 2679/5498 [47:11<50:29,  1.07s/it][A
Iteration:  49%|████▊     | 2680/5498 [47:12<50:15,  1.07s/it][A
Iteration:  49%|████▉     | 2681/5498 [47:13<50:05,  1.07s/it][A
Iteration:  49%|████▉     | 2682/5498 [47:14<49:58,  1.06s/it][A
Iteration:  49%|████▉     | 2683/5498 [47:15<49:51,  1.06s/it][A
Iteration:  49%|████▉     | 2684/5498 [47:16<49:47,  1.06s/it][A
Iteration:  49%|████▉     | 2685/5498 [47:17<49:44,  1.06s/it][A
Iteration:  49%|████▉     | 2686/5498 [47:19<49:40,  1.06s/it][A
Iteration:

Iteration:  53%|█████▎    | 2920/5498 [51:26<45:09,  1.05s/it][A
Iteration:  53%|█████▎    | 2921/5498 [51:27<45:05,  1.05s/it][A
Iteration:  53%|█████▎    | 2922/5498 [51:28<45:08,  1.05s/it][A
Iteration:  53%|█████▎    | 2923/5498 [51:29<45:15,  1.05s/it][A
Iteration:  53%|█████▎    | 2924/5498 [51:30<45:12,  1.05s/it][A
Iteration:  53%|█████▎    | 2925/5498 [51:31<45:08,  1.05s/it][A
Iteration:  53%|█████▎    | 2926/5498 [51:32<45:08,  1.05s/it][A
Iteration:  53%|█████▎    | 2927/5498 [51:33<45:05,  1.05s/it][A
Iteration:  53%|█████▎    | 2928/5498 [51:34<45:06,  1.05s/it][A
Iteration:  53%|█████▎    | 2929/5498 [51:35<47:37,  1.11s/it][A
Iteration:  53%|█████▎    | 2930/5498 [51:36<46:50,  1.09s/it][A
Iteration:  53%|█████▎    | 2931/5498 [51:37<46:18,  1.08s/it][A
Iteration:  53%|█████▎    | 2932/5498 [51:38<45:52,  1.07s/it][A
Iteration:  53%|█████▎    | 2933/5498 [51:39<45:33,  1.07s/it][A
Iteration:  53%|█████▎    | 2934/5498 [51:41<45:20,  1.06s/it][A
Iteration:

Iteration:  58%|█████▊    | 3168/5498 [55:47<40:46,  1.05s/it][A
Iteration:  58%|█████▊    | 3169/5498 [55:48<40:44,  1.05s/it][A
Iteration:  58%|█████▊    | 3170/5498 [55:49<40:43,  1.05s/it][A
Iteration:  58%|█████▊    | 3171/5498 [55:50<40:44,  1.05s/it][A
Iteration:  58%|█████▊    | 3172/5498 [55:52<40:43,  1.05s/it][A
Iteration:  58%|█████▊    | 3173/5498 [55:53<40:40,  1.05s/it][A
Iteration:  58%|█████▊    | 3174/5498 [55:54<40:40,  1.05s/it][A
Iteration:  58%|█████▊    | 3175/5498 [55:55<40:41,  1.05s/it][A
Iteration:  58%|█████▊    | 3176/5498 [55:56<40:39,  1.05s/it][A
Iteration:  58%|█████▊    | 3177/5498 [55:57<40:45,  1.05s/it][A
Iteration:  58%|█████▊    | 3178/5498 [55:58<40:46,  1.05s/it][A
Iteration:  58%|█████▊    | 3179/5498 [55:59<40:43,  1.05s/it][A
Iteration:  58%|█████▊    | 3180/5498 [56:00<40:42,  1.05s/it][A
Iteration:  58%|█████▊    | 3181/5498 [56:01<40:45,  1.06s/it][A
Iteration:  58%|█████▊    | 3182/5498 [56:02<40:43,  1.06s/it][A
Iteration:

Iteration:  62%|██████▏   | 3415/5498 [1:00:08<36:33,  1.05s/it][A
Iteration:  62%|██████▏   | 3416/5498 [1:00:09<36:28,  1.05s/it][A
Iteration:  62%|██████▏   | 3417/5498 [1:00:10<36:27,  1.05s/it][A
Iteration:  62%|██████▏   | 3418/5498 [1:00:11<36:26,  1.05s/it][A
Iteration:  62%|██████▏   | 3419/5498 [1:00:12<36:27,  1.05s/it][A
Iteration:  62%|██████▏   | 3420/5498 [1:00:13<36:25,  1.05s/it][A
Iteration:  62%|██████▏   | 3421/5498 [1:00:14<36:24,  1.05s/it][A
Iteration:  62%|██████▏   | 3422/5498 [1:00:15<36:23,  1.05s/it][A
Iteration:  62%|██████▏   | 3423/5498 [1:00:17<36:22,  1.05s/it][A
Iteration:  62%|██████▏   | 3424/5498 [1:00:18<36:20,  1.05s/it][A
Iteration:  62%|██████▏   | 3425/5498 [1:00:19<36:20,  1.05s/it][A
Iteration:  62%|██████▏   | 3426/5498 [1:00:20<36:20,  1.05s/it][A
Iteration:  62%|██████▏   | 3427/5498 [1:00:21<36:18,  1.05s/it][A
Iteration:  62%|██████▏   | 3428/5498 [1:00:22<36:14,  1.05s/it][A
Iteration:  62%|██████▏   | 3429/5498 [1:00:23<3

Iteration:  66%|██████▋   | 3655/5498 [1:04:21<32:21,  1.05s/it][A
Iteration:  66%|██████▋   | 3656/5498 [1:04:22<32:17,  1.05s/it][A
Iteration:  67%|██████▋   | 3657/5498 [1:04:23<32:18,  1.05s/it][A
Iteration:  67%|██████▋   | 3658/5498 [1:04:24<32:17,  1.05s/it][A
Iteration:  67%|██████▋   | 3659/5498 [1:04:26<32:15,  1.05s/it][A
Iteration:  67%|██████▋   | 3660/5498 [1:04:27<32:15,  1.05s/it][A
Iteration:  67%|██████▋   | 3661/5498 [1:04:28<32:13,  1.05s/it][A
Iteration:  67%|██████▋   | 3662/5498 [1:04:29<32:12,  1.05s/it][A
Iteration:  67%|██████▋   | 3663/5498 [1:04:30<32:11,  1.05s/it][A
Iteration:  67%|██████▋   | 3664/5498 [1:04:31<32:10,  1.05s/it][A
Iteration:  67%|██████▋   | 3665/5498 [1:04:32<32:09,  1.05s/it][A
Iteration:  67%|██████▋   | 3666/5498 [1:04:33<32:06,  1.05s/it][A
Iteration:  67%|██████▋   | 3667/5498 [1:04:34<32:09,  1.05s/it][A
Iteration:  67%|██████▋   | 3668/5498 [1:04:35<32:05,  1.05s/it][A
Iteration:  67%|██████▋   | 3669/5498 [1:04:36<3

Iteration:  71%|███████   | 3895/5498 [1:08:35<28:41,  1.07s/it][A
Iteration:  71%|███████   | 3896/5498 [1:08:36<28:31,  1.07s/it][A
Iteration:  71%|███████   | 3897/5498 [1:08:37<28:26,  1.07s/it][A
Iteration:  71%|███████   | 3898/5498 [1:08:38<28:19,  1.06s/it][A
Iteration:  71%|███████   | 3899/5498 [1:08:39<28:16,  1.06s/it][A
Iteration:  71%|███████   | 3900/5498 [1:08:40<28:14,  1.06s/it][A
Iteration:  71%|███████   | 3901/5498 [1:08:41<28:14,  1.06s/it][A
Iteration:  71%|███████   | 3902/5498 [1:08:42<28:12,  1.06s/it][A
Iteration:  71%|███████   | 3903/5498 [1:08:43<28:09,  1.06s/it][A
Iteration:  71%|███████   | 3904/5498 [1:08:44<28:08,  1.06s/it][A
Iteration:  71%|███████   | 3905/5498 [1:08:45<28:06,  1.06s/it][A
Iteration:  71%|███████   | 3906/5498 [1:08:47<28:05,  1.06s/it][A
Iteration:  71%|███████   | 3907/5498 [1:08:48<28:04,  1.06s/it][A
Iteration:  71%|███████   | 3908/5498 [1:08:49<28:02,  1.06s/it][A
Iteration:  71%|███████   | 3909/5498 [1:08:50<2

Iteration:  75%|███████▌  | 4135/5498 [1:12:48<23:55,  1.05s/it][A
Iteration:  75%|███████▌  | 4136/5498 [1:12:49<23:55,  1.05s/it][A
Iteration:  75%|███████▌  | 4137/5498 [1:12:50<23:55,  1.05s/it][A
Iteration:  75%|███████▌  | 4138/5498 [1:12:51<23:53,  1.05s/it][A
Iteration:  75%|███████▌  | 4139/5498 [1:12:52<23:51,  1.05s/it][A
Iteration:  75%|███████▌  | 4140/5498 [1:12:53<23:48,  1.05s/it][A
Iteration:  75%|███████▌  | 4141/5498 [1:12:54<23:48,  1.05s/it][A
Iteration:  75%|███████▌  | 4142/5498 [1:12:56<23:49,  1.05s/it][A
Iteration:  75%|███████▌  | 4143/5498 [1:12:57<23:45,  1.05s/it][A
Iteration:  75%|███████▌  | 4144/5498 [1:12:58<23:43,  1.05s/it][A
Iteration:  75%|███████▌  | 4145/5498 [1:12:59<25:02,  1.11s/it][A
Iteration:  75%|███████▌  | 4146/5498 [1:13:00<24:38,  1.09s/it][A
Iteration:  75%|███████▌  | 4147/5498 [1:13:01<24:21,  1.08s/it][A
Iteration:  75%|███████▌  | 4148/5498 [1:13:02<24:08,  1.07s/it][A
Iteration:  75%|███████▌  | 4149/5498 [1:13:03<2

Iteration:  80%|███████▉  | 4375/5498 [1:17:02<19:39,  1.05s/it][A
Iteration:  80%|███████▉  | 4376/5498 [1:17:03<19:39,  1.05s/it][A
Iteration:  80%|███████▉  | 4377/5498 [1:17:04<19:39,  1.05s/it][A
Iteration:  80%|███████▉  | 4378/5498 [1:17:05<19:38,  1.05s/it][A
Iteration:  80%|███████▉  | 4379/5498 [1:17:06<19:35,  1.05s/it][A
Iteration:  80%|███████▉  | 4380/5498 [1:17:07<19:37,  1.05s/it][A
Iteration:  80%|███████▉  | 4381/5498 [1:17:08<19:36,  1.05s/it][A
Iteration:  80%|███████▉  | 4382/5498 [1:17:09<19:36,  1.05s/it][A
Iteration:  80%|███████▉  | 4383/5498 [1:17:10<19:32,  1.05s/it][A
Iteration:  80%|███████▉  | 4384/5498 [1:17:11<19:31,  1.05s/it][A
Iteration:  80%|███████▉  | 4385/5498 [1:17:12<19:29,  1.05s/it][A
Iteration:  80%|███████▉  | 4386/5498 [1:17:13<19:28,  1.05s/it][A
Iteration:  80%|███████▉  | 4387/5498 [1:17:14<19:29,  1.05s/it][A
Iteration:  80%|███████▉  | 4388/5498 [1:17:15<19:28,  1.05s/it][A
Iteration:  80%|███████▉  | 4389/5498 [1:17:16<1

Iteration:  84%|████████▍ | 4615/5498 [1:21:15<15:30,  1.05s/it][A
Iteration:  84%|████████▍ | 4616/5498 [1:21:16<15:28,  1.05s/it][A
Iteration:  84%|████████▍ | 4617/5498 [1:21:17<15:26,  1.05s/it][A
Iteration:  84%|████████▍ | 4618/5498 [1:21:18<15:25,  1.05s/it][A
Iteration:  84%|████████▍ | 4619/5498 [1:21:19<15:23,  1.05s/it][A
Iteration:  84%|████████▍ | 4620/5498 [1:21:20<15:22,  1.05s/it][A
Iteration:  84%|████████▍ | 4621/5498 [1:21:21<15:21,  1.05s/it][A
Iteration:  84%|████████▍ | 4622/5498 [1:21:22<15:21,  1.05s/it][A
Iteration:  84%|████████▍ | 4623/5498 [1:21:23<15:20,  1.05s/it][A
Iteration:  84%|████████▍ | 4624/5498 [1:21:24<15:18,  1.05s/it][A
Iteration:  84%|████████▍ | 4625/5498 [1:21:25<15:17,  1.05s/it][A
Iteration:  84%|████████▍ | 4626/5498 [1:21:27<15:15,  1.05s/it][A
Iteration:  84%|████████▍ | 4627/5498 [1:21:28<15:15,  1.05s/it][A
Iteration:  84%|████████▍ | 4628/5498 [1:21:29<15:13,  1.05s/it][A
Iteration:  84%|████████▍ | 4629/5498 [1:21:30<1

Iteration:  88%|████████▊ | 4855/5498 [1:25:28<11:15,  1.05s/it][A
Iteration:  88%|████████▊ | 4856/5498 [1:25:29<11:14,  1.05s/it][A
Iteration:  88%|████████▊ | 4857/5498 [1:25:30<11:14,  1.05s/it][A
Iteration:  88%|████████▊ | 4858/5498 [1:25:31<11:12,  1.05s/it][A
Iteration:  88%|████████▊ | 4859/5498 [1:25:32<11:12,  1.05s/it][A
Iteration:  88%|████████▊ | 4860/5498 [1:25:33<11:10,  1.05s/it][A
Iteration:  88%|████████▊ | 4861/5498 [1:25:35<11:09,  1.05s/it][A
Iteration:  88%|████████▊ | 4862/5498 [1:25:36<11:08,  1.05s/it][A
Iteration:  88%|████████▊ | 4863/5498 [1:25:37<11:07,  1.05s/it][A
Iteration:  88%|████████▊ | 4864/5498 [1:25:38<11:06,  1.05s/it][A
Iteration:  88%|████████▊ | 4865/5498 [1:25:39<11:05,  1.05s/it][A
Iteration:  89%|████████▊ | 4866/5498 [1:25:40<11:04,  1.05s/it][A
Iteration:  89%|████████▊ | 4867/5498 [1:25:41<11:02,  1.05s/it][A
Iteration:  89%|████████▊ | 4868/5498 [1:25:42<11:01,  1.05s/it][A
Iteration:  89%|████████▊ | 4869/5498 [1:25:43<1

Iteration:  93%|█████████▎| 5095/5498 [1:29:42<07:06,  1.06s/it][A
Iteration:  93%|█████████▎| 5096/5498 [1:29:43<07:05,  1.06s/it][A
Iteration:  93%|█████████▎| 5097/5498 [1:29:44<07:03,  1.06s/it][A
Iteration:  93%|█████████▎| 5098/5498 [1:29:45<07:01,  1.05s/it][A
Iteration:  93%|█████████▎| 5099/5498 [1:29:46<07:00,  1.05s/it][A
Iteration:  93%|█████████▎| 5100/5498 [1:29:47<06:59,  1.05s/it][A
Iteration:  93%|█████████▎| 5101/5498 [1:29:48<06:57,  1.05s/it][A
Iteration:  93%|█████████▎| 5102/5498 [1:29:49<06:56,  1.05s/it][A
Iteration:  93%|█████████▎| 5103/5498 [1:29:50<06:54,  1.05s/it][A
Iteration:  93%|█████████▎| 5104/5498 [1:29:51<06:53,  1.05s/it][A
Iteration:  93%|█████████▎| 5105/5498 [1:29:52<06:53,  1.05s/it][A
Iteration:  93%|█████████▎| 5106/5498 [1:29:53<06:51,  1.05s/it][A
Iteration:  93%|█████████▎| 5107/5498 [1:29:54<06:50,  1.05s/it][A
Iteration:  93%|█████████▎| 5108/5498 [1:29:55<06:49,  1.05s/it][A
Iteration:  93%|█████████▎| 5109/5498 [1:29:56<0

Iteration:  97%|█████████▋| 5335/5498 [1:33:55<02:51,  1.05s/it][A
Iteration:  97%|█████████▋| 5336/5498 [1:33:56<02:50,  1.05s/it][A
Iteration:  97%|█████████▋| 5337/5498 [1:33:57<02:49,  1.05s/it][A
Iteration:  97%|█████████▋| 5338/5498 [1:33:58<02:48,  1.05s/it][A
Iteration:  97%|█████████▋| 5339/5498 [1:33:59<02:46,  1.05s/it][A
Iteration:  97%|█████████▋| 5340/5498 [1:34:00<02:45,  1.05s/it][A
Iteration:  97%|█████████▋| 5341/5498 [1:34:01<02:44,  1.05s/it][A
Iteration:  97%|█████████▋| 5342/5498 [1:34:02<02:43,  1.05s/it][A
Iteration:  97%|█████████▋| 5343/5498 [1:34:03<02:52,  1.11s/it][A
Iteration:  97%|█████████▋| 5344/5498 [1:34:04<02:48,  1.10s/it][A
Iteration:  97%|█████████▋| 5345/5498 [1:34:05<02:45,  1.08s/it][A
Iteration:  97%|█████████▋| 5346/5498 [1:34:06<02:42,  1.07s/it][A
Iteration:  97%|█████████▋| 5347/5498 [1:34:07<02:41,  1.07s/it][A
Iteration:  97%|█████████▋| 5348/5498 [1:34:08<02:39,  1.07s/it][A
Iteration:  97%|█████████▋| 5349/5498 [1:34:09<0

Iteration:   1%|▏         | 79/5498 [01:23<1:35:01,  1.05s/it][A
Iteration:   1%|▏         | 80/5498 [01:24<1:35:00,  1.05s/it][A
Iteration:   1%|▏         | 81/5498 [01:25<1:34:58,  1.05s/it][A
Iteration:   1%|▏         | 82/5498 [01:26<1:34:53,  1.05s/it][A
Iteration:   2%|▏         | 83/5498 [01:27<1:35:14,  1.06s/it][A
Iteration:   2%|▏         | 84/5498 [01:28<1:35:21,  1.06s/it][A
Iteration:   2%|▏         | 85/5498 [01:29<1:35:29,  1.06s/it][A
Iteration:   2%|▏         | 86/5498 [01:30<1:35:28,  1.06s/it][A
Iteration:   2%|▏         | 87/5498 [01:31<1:35:28,  1.06s/it][A
Iteration:   2%|▏         | 88/5498 [01:32<1:35:24,  1.06s/it][A
Iteration:   2%|▏         | 89/5498 [01:33<1:35:27,  1.06s/it][A
Iteration:   2%|▏         | 90/5498 [01:34<1:35:23,  1.06s/it][A
Iteration:   2%|▏         | 91/5498 [01:35<1:35:19,  1.06s/it][A
Iteration:   2%|▏         | 92/5498 [01:37<1:35:21,  1.06s/it][A
Iteration:   2%|▏         | 93/5498 [01:38<1:35:18,  1.06s/it][A
Iteration:

Iteration:   6%|▌         | 323/5498 [05:40<1:30:40,  1.05s/it][A
Iteration:   6%|▌         | 324/5498 [05:41<1:30:40,  1.05s/it][A
Iteration:   6%|▌         | 325/5498 [05:42<1:30:40,  1.05s/it][A
Iteration:   6%|▌         | 326/5498 [05:44<1:30:44,  1.05s/it][A
Iteration:   6%|▌         | 327/5498 [05:45<1:30:43,  1.05s/it][A
Iteration:   6%|▌         | 328/5498 [05:46<1:30:41,  1.05s/it][A
Iteration:   6%|▌         | 329/5498 [05:47<1:30:47,  1.05s/it][A
Iteration:   6%|▌         | 330/5498 [05:48<1:30:40,  1.05s/it][A
Iteration:   6%|▌         | 331/5498 [05:49<1:30:46,  1.05s/it][A
Iteration:   6%|▌         | 332/5498 [05:50<1:30:42,  1.05s/it][A
Iteration:   6%|▌         | 333/5498 [05:51<1:30:44,  1.05s/it][A
Iteration:   6%|▌         | 334/5498 [05:52<1:30:41,  1.05s/it][A
Iteration:   6%|▌         | 335/5498 [05:53<1:30:35,  1.05s/it][A
Iteration:   6%|▌         | 336/5498 [05:54<1:30:34,  1.05s/it][A
Iteration:   6%|▌         | 337/5498 [05:55<1:30:54,  1.06s/it

Iteration:  10%|█         | 567/5498 [09:58<1:26:52,  1.06s/it][A
Iteration:  10%|█         | 568/5498 [09:59<1:26:52,  1.06s/it][A
Iteration:  10%|█         | 569/5498 [10:00<1:26:52,  1.06s/it][A
Iteration:  10%|█         | 570/5498 [10:01<1:26:45,  1.06s/it][A
Iteration:  10%|█         | 571/5498 [10:02<1:26:36,  1.05s/it][A
Iteration:  10%|█         | 572/5498 [10:03<1:26:31,  1.05s/it][A
Iteration:  10%|█         | 573/5498 [10:04<1:26:31,  1.05s/it][A
Iteration:  10%|█         | 574/5498 [10:05<1:26:34,  1.05s/it][A
Iteration:  10%|█         | 575/5498 [10:06<1:26:31,  1.05s/it][A
Iteration:  10%|█         | 576/5498 [10:07<1:26:32,  1.05s/it][A
Iteration:  10%|█         | 577/5498 [10:08<1:26:18,  1.05s/it][A
Iteration:  11%|█         | 578/5498 [10:09<1:26:12,  1.05s/it][A
Iteration:  11%|█         | 579/5498 [10:11<1:26:16,  1.05s/it][A
Iteration:  11%|█         | 580/5498 [10:12<1:26:14,  1.05s/it][A
Iteration:  11%|█         | 581/5498 [10:13<1:26:08,  1.05s/it

Iteration:  15%|█▍        | 811/5498 [14:16<1:22:12,  1.05s/it][A
Iteration:  15%|█▍        | 812/5498 [14:17<1:22:10,  1.05s/it][A
Iteration:  15%|█▍        | 813/5498 [14:18<1:22:10,  1.05s/it][A
Iteration:  15%|█▍        | 814/5498 [14:19<1:22:06,  1.05s/it][A
Iteration:  15%|█▍        | 815/5498 [14:20<1:22:01,  1.05s/it][A
Iteration:  15%|█▍        | 816/5498 [14:21<1:22:02,  1.05s/it][A
Iteration:  15%|█▍        | 817/5498 [14:22<1:21:58,  1.05s/it][A
Iteration:  15%|█▍        | 818/5498 [14:23<1:21:56,  1.05s/it][A
Iteration:  15%|█▍        | 819/5498 [14:24<1:21:59,  1.05s/it][A
Iteration:  15%|█▍        | 820/5498 [14:25<1:21:59,  1.05s/it][A
Iteration:  15%|█▍        | 821/5498 [14:26<1:21:54,  1.05s/it][A
Iteration:  15%|█▍        | 822/5498 [14:27<1:21:45,  1.05s/it][A
Iteration:  15%|█▍        | 823/5498 [14:28<1:21:54,  1.05s/it][A
Iteration:  15%|█▍        | 824/5498 [14:29<1:21:54,  1.05s/it][A
Iteration:  15%|█▌        | 825/5498 [14:30<1:21:48,  1.05s/it

Iteration:  19%|█▉        | 1054/5498 [18:32<1:18:05,  1.05s/it][A
Iteration:  19%|█▉        | 1055/5498 [18:33<1:17:58,  1.05s/it][A
Iteration:  19%|█▉        | 1056/5498 [18:34<1:17:52,  1.05s/it][A
Iteration:  19%|█▉        | 1057/5498 [18:35<1:18:00,  1.05s/it][A
Iteration:  19%|█▉        | 1058/5498 [18:36<1:17:57,  1.05s/it][A
Iteration:  19%|█▉        | 1059/5498 [18:37<1:18:08,  1.06s/it][A
Iteration:  19%|█▉        | 1060/5498 [18:39<1:17:59,  1.05s/it][A
Iteration:  19%|█▉        | 1061/5498 [18:40<1:17:51,  1.05s/it][A
Iteration:  19%|█▉        | 1062/5498 [18:41<1:17:52,  1.05s/it][A
Iteration:  19%|█▉        | 1063/5498 [18:42<1:17:48,  1.05s/it][A
Iteration:  19%|█▉        | 1064/5498 [18:43<1:17:46,  1.05s/it][A
Iteration:  19%|█▉        | 1065/5498 [18:44<1:17:40,  1.05s/it][A
Iteration:  19%|█▉        | 1066/5498 [18:45<1:17:51,  1.05s/it][A
Iteration:  19%|█▉        | 1067/5498 [18:46<1:17:43,  1.05s/it][A
Iteration:  19%|█▉        | 1068/5498 [18:47<1:1

Iteration:  24%|██▎       | 1294/5498 [22:46<1:13:58,  1.06s/it][A
Iteration:  24%|██▎       | 1295/5498 [22:47<1:13:48,  1.05s/it][A
Iteration:  24%|██▎       | 1296/5498 [22:48<1:13:47,  1.05s/it][A
Iteration:  24%|██▎       | 1297/5498 [22:49<1:13:36,  1.05s/it][A
Iteration:  24%|██▎       | 1298/5498 [22:50<1:13:44,  1.05s/it][A
Iteration:  24%|██▎       | 1299/5498 [22:51<1:13:46,  1.05s/it][A
Iteration:  24%|██▎       | 1300/5498 [22:52<1:13:44,  1.05s/it][A
Iteration:  24%|██▎       | 1301/5498 [22:53<1:13:44,  1.05s/it][A
Iteration:  24%|██▎       | 1302/5498 [22:54<1:13:34,  1.05s/it][A
Iteration:  24%|██▎       | 1303/5498 [22:55<1:13:32,  1.05s/it][A
Iteration:  24%|██▎       | 1304/5498 [22:56<1:13:38,  1.05s/it][A
Iteration:  24%|██▎       | 1305/5498 [22:57<1:13:29,  1.05s/it][A
Iteration:  24%|██▍       | 1306/5498 [22:58<1:13:21,  1.05s/it][A
Iteration:  24%|██▍       | 1307/5498 [22:59<1:13:14,  1.05s/it][A
Iteration:  24%|██▍       | 1308/5498 [23:00<1:1

Iteration:  28%|██▊       | 1534/5498 [26:59<1:09:34,  1.05s/it][A
Iteration:  28%|██▊       | 1535/5498 [27:00<1:09:30,  1.05s/it][A
Iteration:  28%|██▊       | 1536/5498 [27:01<1:09:24,  1.05s/it][A
Iteration:  28%|██▊       | 1537/5498 [27:02<1:09:25,  1.05s/it][A
Iteration:  28%|██▊       | 1538/5498 [27:03<1:09:31,  1.05s/it][A
Iteration:  28%|██▊       | 1539/5498 [27:04<1:13:18,  1.11s/it][A
Iteration:  28%|██▊       | 1540/5498 [27:05<1:12:08,  1.09s/it][A
Iteration:  28%|██▊       | 1541/5498 [27:06<1:11:13,  1.08s/it][A
Iteration:  28%|██▊       | 1542/5498 [27:07<1:10:38,  1.07s/it][A
Iteration:  28%|██▊       | 1543/5498 [27:08<1:10:20,  1.07s/it][A
Iteration:  28%|██▊       | 1544/5498 [27:09<1:10:08,  1.06s/it][A
Iteration:  28%|██▊       | 1545/5498 [27:10<1:10:11,  1.07s/it][A
Iteration:  28%|██▊       | 1546/5498 [27:12<1:09:52,  1.06s/it][A
Iteration:  28%|██▊       | 1547/5498 [27:13<1:09:45,  1.06s/it][A
Iteration:  28%|██▊       | 1548/5498 [27:14<1:0

Iteration:  32%|███▏      | 1774/5498 [31:12<1:05:16,  1.05s/it][A
Iteration:  32%|███▏      | 1775/5498 [31:13<1:05:17,  1.05s/it][A
Iteration:  32%|███▏      | 1776/5498 [31:14<1:05:17,  1.05s/it][A
Iteration:  32%|███▏      | 1777/5498 [31:15<1:05:11,  1.05s/it][A
Iteration:  32%|███▏      | 1778/5498 [31:16<1:05:08,  1.05s/it][A
Iteration:  32%|███▏      | 1779/5498 [31:17<1:05:01,  1.05s/it][A
Iteration:  32%|███▏      | 1780/5498 [31:18<1:05:06,  1.05s/it][A
Iteration:  32%|███▏      | 1781/5498 [31:19<1:05:05,  1.05s/it][A
Iteration:  32%|███▏      | 1782/5498 [31:20<1:05:01,  1.05s/it][A
Iteration:  32%|███▏      | 1783/5498 [31:22<1:04:55,  1.05s/it][A
Iteration:  32%|███▏      | 1784/5498 [31:23<1:04:54,  1.05s/it][A
Iteration:  32%|███▏      | 1785/5498 [31:24<1:05:12,  1.05s/it][A
Iteration:  32%|███▏      | 1786/5498 [31:25<1:05:17,  1.06s/it][A
Iteration:  33%|███▎      | 1787/5498 [31:26<1:05:05,  1.05s/it][A
Iteration:  33%|███▎      | 1788/5498 [31:27<1:0

Iteration:  37%|███▋      | 2014/5498 [35:25<1:01:07,  1.05s/it][A
Iteration:  37%|███▋      | 2015/5498 [35:27<1:01:08,  1.05s/it][A
Iteration:  37%|███▋      | 2016/5498 [35:28<1:01:06,  1.05s/it][A
Iteration:  37%|███▋      | 2017/5498 [35:29<1:01:07,  1.05s/it][A
Iteration:  37%|███▋      | 2018/5498 [35:30<1:01:03,  1.05s/it][A
Iteration:  37%|███▋      | 2019/5498 [35:31<1:00:59,  1.05s/it][A
Iteration:  37%|███▋      | 2020/5498 [35:32<1:00:59,  1.05s/it][A
Iteration:  37%|███▋      | 2021/5498 [35:33<1:00:57,  1.05s/it][A
Iteration:  37%|███▋      | 2022/5498 [35:34<1:01:01,  1.05s/it][A
Iteration:  37%|███▋      | 2023/5498 [35:35<1:00:54,  1.05s/it][A
Iteration:  37%|███▋      | 2024/5498 [35:36<1:00:50,  1.05s/it][A
Iteration:  37%|███▋      | 2025/5498 [35:37<1:00:51,  1.05s/it][A
Iteration:  37%|███▋      | 2026/5498 [35:38<1:00:52,  1.05s/it][A
Iteration:  37%|███▋      | 2027/5498 [35:39<1:00:50,  1.05s/it][A
Iteration:  37%|███▋      | 2028/5498 [35:40<1:0

Iteration:  41%|████      | 2259/5498 [39:44<58:18,  1.08s/it][A
Iteration:  41%|████      | 2260/5498 [39:45<57:49,  1.07s/it][A
Iteration:  41%|████      | 2261/5498 [39:46<57:33,  1.07s/it][A
Iteration:  41%|████      | 2262/5498 [39:47<57:17,  1.06s/it][A
Iteration:  41%|████      | 2263/5498 [39:48<57:09,  1.06s/it][A
Iteration:  41%|████      | 2264/5498 [39:49<57:08,  1.06s/it][A
Iteration:  41%|████      | 2265/5498 [39:50<56:57,  1.06s/it][A
Iteration:  41%|████      | 2266/5498 [39:51<56:50,  1.06s/it][A
Iteration:  41%|████      | 2267/5498 [39:52<56:48,  1.06s/it][A
Iteration:  41%|████▏     | 2268/5498 [39:54<56:44,  1.05s/it][A
Iteration:  41%|████▏     | 2269/5498 [39:55<56:41,  1.05s/it][A
Iteration:  41%|████▏     | 2270/5498 [39:56<56:46,  1.06s/it][A
Iteration:  41%|████▏     | 2271/5498 [39:57<56:39,  1.05s/it][A
Iteration:  41%|████▏     | 2272/5498 [39:58<56:32,  1.05s/it][A
Iteration:  41%|████▏     | 2273/5498 [39:59<56:29,  1.05s/it][A
Iteration:

Iteration:  46%|████▌     | 2507/5498 [44:06<52:29,  1.05s/it][A
Iteration:  46%|████▌     | 2508/5498 [44:07<52:32,  1.05s/it][A
Iteration:  46%|████▌     | 2509/5498 [44:08<52:31,  1.05s/it][A
Iteration:  46%|████▌     | 2510/5498 [44:09<52:32,  1.05s/it][A
Iteration:  46%|████▌     | 2511/5498 [44:10<55:28,  1.11s/it][A
Iteration:  46%|████▌     | 2512/5498 [44:11<54:31,  1.10s/it][A
Iteration:  46%|████▌     | 2513/5498 [44:12<53:51,  1.08s/it][A
Iteration:  46%|████▌     | 2514/5498 [44:13<53:25,  1.07s/it][A
Iteration:  46%|████▌     | 2515/5498 [44:15<53:07,  1.07s/it][A
Iteration:  46%|████▌     | 2516/5498 [44:16<52:54,  1.06s/it][A
Iteration:  46%|████▌     | 2517/5498 [44:17<52:41,  1.06s/it][A
Iteration:  46%|████▌     | 2518/5498 [44:18<52:33,  1.06s/it][A
Iteration:  46%|████▌     | 2519/5498 [44:19<52:29,  1.06s/it][A
Iteration:  46%|████▌     | 2520/5498 [44:20<52:26,  1.06s/it][A
Iteration:  46%|████▌     | 2521/5498 [44:21<52:21,  1.06s/it][A
Iteration:

Iteration:  50%|█████     | 2755/5498 [48:28<48:07,  1.05s/it][A
Iteration:  50%|█████     | 2756/5498 [48:29<48:06,  1.05s/it][A
Iteration:  50%|█████     | 2757/5498 [48:30<48:04,  1.05s/it][A
Iteration:  50%|█████     | 2758/5498 [48:31<48:05,  1.05s/it][A
Iteration:  50%|█████     | 2759/5498 [48:32<48:04,  1.05s/it][A
Iteration:  50%|█████     | 2760/5498 [48:33<48:04,  1.05s/it][A
Iteration:  50%|█████     | 2761/5498 [48:34<47:57,  1.05s/it][A
Iteration:  50%|█████     | 2762/5498 [48:35<47:57,  1.05s/it][A
Iteration:  50%|█████     | 2763/5498 [48:36<47:53,  1.05s/it][A
Iteration:  50%|█████     | 2764/5498 [48:37<47:51,  1.05s/it][A
Iteration:  50%|█████     | 2765/5498 [48:38<50:37,  1.11s/it][A
Iteration:  50%|█████     | 2766/5498 [48:40<49:46,  1.09s/it][A
Iteration:  50%|█████     | 2767/5498 [48:41<49:13,  1.08s/it][A
Iteration:  50%|█████     | 2768/5498 [48:42<48:44,  1.07s/it][A
Iteration:  50%|█████     | 2769/5498 [48:43<48:32,  1.07s/it][A
Iteration:

Iteration:  55%|█████▍    | 3003/5498 [52:50<43:36,  1.05s/it][A
Iteration:  55%|█████▍    | 3004/5498 [52:51<43:33,  1.05s/it][A
Iteration:  55%|█████▍    | 3005/5498 [52:52<43:30,  1.05s/it][A
Iteration:  55%|█████▍    | 3006/5498 [52:53<43:35,  1.05s/it][A
Iteration:  55%|█████▍    | 3007/5498 [52:54<43:37,  1.05s/it][A
Iteration:  55%|█████▍    | 3008/5498 [52:55<43:36,  1.05s/it][A
Iteration:  55%|█████▍    | 3009/5498 [52:56<43:36,  1.05s/it][A
Iteration:  55%|█████▍    | 3010/5498 [52:57<43:32,  1.05s/it][A
Iteration:  55%|█████▍    | 3011/5498 [52:58<43:38,  1.05s/it][A
Iteration:  55%|█████▍    | 3012/5498 [52:59<43:36,  1.05s/it][A
Iteration:  55%|█████▍    | 3013/5498 [53:00<43:32,  1.05s/it][A
Iteration:  55%|█████▍    | 3014/5498 [53:01<43:32,  1.05s/it][A
Iteration:  55%|█████▍    | 3015/5498 [53:02<43:36,  1.05s/it][A
Iteration:  55%|█████▍    | 3016/5498 [53:03<43:35,  1.05s/it][A
Iteration:  55%|█████▍    | 3017/5498 [53:04<43:35,  1.05s/it][A
Iteration:

Iteration:  59%|█████▉    | 3251/5498 [57:11<39:22,  1.05s/it][A
Iteration:  59%|█████▉    | 3252/5498 [57:12<39:24,  1.05s/it][A
Iteration:  59%|█████▉    | 3253/5498 [57:14<39:25,  1.05s/it][A
Iteration:  59%|█████▉    | 3254/5498 [57:15<39:25,  1.05s/it][A
Iteration:  59%|█████▉    | 3255/5498 [57:16<39:22,  1.05s/it][A
Iteration:  59%|█████▉    | 3256/5498 [57:17<39:19,  1.05s/it][A
Iteration:  59%|█████▉    | 3257/5498 [57:18<39:15,  1.05s/it][A
Iteration:  59%|█████▉    | 3258/5498 [57:19<39:20,  1.05s/it][A
Iteration:  59%|█████▉    | 3259/5498 [57:20<39:17,  1.05s/it][A
Iteration:  59%|█████▉    | 3260/5498 [57:21<39:13,  1.05s/it][A
Iteration:  59%|█████▉    | 3261/5498 [57:22<39:14,  1.05s/it][A
Iteration:  59%|█████▉    | 3262/5498 [57:23<39:17,  1.05s/it][A
Iteration:  59%|█████▉    | 3263/5498 [57:24<39:14,  1.05s/it][A
Iteration:  59%|█████▉    | 3264/5498 [57:25<39:13,  1.05s/it][A
Iteration:  59%|█████▉    | 3265/5498 [57:26<39:11,  1.05s/it][A
Iteration:

Iteration:  64%|██████▎   | 3496/5498 [1:01:30<35:13,  1.06s/it][A
Iteration:  64%|██████▎   | 3497/5498 [1:01:31<35:13,  1.06s/it][A
Iteration:  64%|██████▎   | 3498/5498 [1:01:32<35:06,  1.05s/it][A
Iteration:  64%|██████▎   | 3499/5498 [1:01:33<35:04,  1.05s/it][A
Iteration:  64%|██████▎   | 3500/5498 [1:01:34<35:03,  1.05s/it][A
Iteration:  64%|██████▎   | 3501/5498 [1:01:35<34:59,  1.05s/it][A
Iteration:  64%|██████▎   | 3502/5498 [1:01:36<34:57,  1.05s/it][A
Iteration:  64%|██████▎   | 3503/5498 [1:01:37<34:57,  1.05s/it][A
Iteration:  64%|██████▎   | 3504/5498 [1:01:38<34:58,  1.05s/it][A
Iteration:  64%|██████▍   | 3505/5498 [1:01:40<35:00,  1.05s/it][A
Iteration:  64%|██████▍   | 3506/5498 [1:01:41<34:58,  1.05s/it][A
Iteration:  64%|██████▍   | 3507/5498 [1:01:42<34:57,  1.05s/it][A
Iteration:  64%|██████▍   | 3508/5498 [1:01:43<34:55,  1.05s/it][A
Iteration:  64%|██████▍   | 3509/5498 [1:01:44<34:51,  1.05s/it][A
Iteration:  64%|██████▍   | 3510/5498 [1:01:45<3

Iteration:  68%|██████▊   | 3736/5498 [1:05:43<31:00,  1.06s/it][A
Iteration:  68%|██████▊   | 3737/5498 [1:05:44<30:57,  1.05s/it][A
Iteration:  68%|██████▊   | 3738/5498 [1:05:45<30:58,  1.06s/it][A
Iteration:  68%|██████▊   | 3739/5498 [1:05:47<30:56,  1.06s/it][A
Iteration:  68%|██████▊   | 3740/5498 [1:05:48<30:54,  1.05s/it][A
Iteration:  68%|██████▊   | 3741/5498 [1:05:49<30:56,  1.06s/it][A
Iteration:  68%|██████▊   | 3742/5498 [1:05:50<30:53,  1.06s/it][A
Iteration:  68%|██████▊   | 3743/5498 [1:05:51<30:49,  1.05s/it][A
Iteration:  68%|██████▊   | 3744/5498 [1:05:52<30:46,  1.05s/it][A
Iteration:  68%|██████▊   | 3745/5498 [1:05:53<30:44,  1.05s/it][A
Iteration:  68%|██████▊   | 3746/5498 [1:05:54<30:43,  1.05s/it][A
Iteration:  68%|██████▊   | 3747/5498 [1:05:55<30:42,  1.05s/it][A
Iteration:  68%|██████▊   | 3748/5498 [1:05:56<30:40,  1.05s/it][A
Iteration:  68%|██████▊   | 3749/5498 [1:05:57<30:40,  1.05s/it][A
Iteration:  68%|██████▊   | 3750/5498 [1:05:58<3

Iteration:  72%|███████▏  | 3976/5498 [1:09:56<26:38,  1.05s/it][A
Iteration:  72%|███████▏  | 3977/5498 [1:09:57<26:37,  1.05s/it][A
Iteration:  72%|███████▏  | 3978/5498 [1:09:58<26:36,  1.05s/it][A
Iteration:  72%|███████▏  | 3979/5498 [1:09:59<26:36,  1.05s/it][A
Iteration:  72%|███████▏  | 3980/5498 [1:10:00<26:35,  1.05s/it][A
Iteration:  72%|███████▏  | 3981/5498 [1:10:02<28:05,  1.11s/it][A
Iteration:  72%|███████▏  | 3982/5498 [1:10:03<27:40,  1.10s/it][A
Iteration:  72%|███████▏  | 3983/5498 [1:10:04<27:19,  1.08s/it][A
Iteration:  72%|███████▏  | 3984/5498 [1:10:05<27:02,  1.07s/it][A
Iteration:  72%|███████▏  | 3985/5498 [1:10:06<26:52,  1.07s/it][A
Iteration:  72%|███████▏  | 3986/5498 [1:10:07<26:44,  1.06s/it][A
Iteration:  73%|███████▎  | 3987/5498 [1:10:08<26:37,  1.06s/it][A
Iteration:  73%|███████▎  | 3988/5498 [1:10:09<26:31,  1.05s/it][A
Iteration:  73%|███████▎  | 3989/5498 [1:10:10<26:29,  1.05s/it][A
Iteration:  73%|███████▎  | 3990/5498 [1:10:11<2

Iteration:  77%|███████▋  | 4216/5498 [1:14:09<22:27,  1.05s/it][A
Iteration:  77%|███████▋  | 4217/5498 [1:14:10<22:28,  1.05s/it][A
Iteration:  77%|███████▋  | 4218/5498 [1:14:12<22:27,  1.05s/it][A
Iteration:  77%|███████▋  | 4219/5498 [1:14:13<22:24,  1.05s/it][A
Iteration:  77%|███████▋  | 4220/5498 [1:14:14<22:24,  1.05s/it][A
Iteration:  77%|███████▋  | 4221/5498 [1:14:15<22:23,  1.05s/it][A
Iteration:  77%|███████▋  | 4222/5498 [1:14:16<22:23,  1.05s/it][A
Iteration:  77%|███████▋  | 4223/5498 [1:14:17<22:21,  1.05s/it][A
Iteration:  77%|███████▋  | 4224/5498 [1:14:18<22:21,  1.05s/it][A
Iteration:  77%|███████▋  | 4225/5498 [1:14:19<22:21,  1.05s/it][A
Iteration:  77%|███████▋  | 4226/5498 [1:14:20<22:21,  1.05s/it][A
Iteration:  77%|███████▋  | 4227/5498 [1:14:21<22:19,  1.05s/it][A
Iteration:  77%|███████▋  | 4228/5498 [1:14:22<22:17,  1.05s/it][A
Iteration:  77%|███████▋  | 4229/5498 [1:14:23<22:15,  1.05s/it][A
Iteration:  77%|███████▋  | 4230/5498 [1:14:24<2

Iteration:  81%|████████  | 4456/5498 [1:18:23<18:15,  1.05s/it][A
Iteration:  81%|████████  | 4457/5498 [1:18:24<18:13,  1.05s/it][A
Iteration:  81%|████████  | 4458/5498 [1:18:25<18:14,  1.05s/it][A
Iteration:  81%|████████  | 4459/5498 [1:18:26<18:11,  1.05s/it][A
Iteration:  81%|████████  | 4460/5498 [1:18:27<18:11,  1.05s/it][A
Iteration:  81%|████████  | 4461/5498 [1:18:28<18:11,  1.05s/it][A
Iteration:  81%|████████  | 4462/5498 [1:18:29<18:09,  1.05s/it][A
Iteration:  81%|████████  | 4463/5498 [1:18:30<18:07,  1.05s/it][A
Iteration:  81%|████████  | 4464/5498 [1:18:31<18:09,  1.05s/it][A
Iteration:  81%|████████  | 4465/5498 [1:18:32<18:06,  1.05s/it][A
Iteration:  81%|████████  | 4466/5498 [1:18:33<18:04,  1.05s/it][A
Iteration:  81%|████████  | 4467/5498 [1:18:34<18:03,  1.05s/it][A
Iteration:  81%|████████▏ | 4468/5498 [1:18:35<18:02,  1.05s/it][A
Iteration:  81%|████████▏ | 4469/5498 [1:18:36<18:05,  1.05s/it][A
Iteration:  81%|████████▏ | 4470/5498 [1:18:37<1

Iteration:  85%|████████▌ | 4696/5498 [1:22:36<14:03,  1.05s/it][A
Iteration:  85%|████████▌ | 4697/5498 [1:22:37<14:02,  1.05s/it][A
Iteration:  85%|████████▌ | 4698/5498 [1:22:38<14:01,  1.05s/it][A
Iteration:  85%|████████▌ | 4699/5498 [1:22:39<13:59,  1.05s/it][A
Iteration:  85%|████████▌ | 4700/5498 [1:22:40<13:57,  1.05s/it][A
Iteration:  86%|████████▌ | 4701/5498 [1:22:41<13:56,  1.05s/it][A
Iteration:  86%|████████▌ | 4702/5498 [1:22:42<13:55,  1.05s/it][A
Iteration:  86%|████████▌ | 4703/5498 [1:22:43<13:56,  1.05s/it][A
Iteration:  86%|████████▌ | 4704/5498 [1:22:44<13:56,  1.05s/it][A
Iteration:  86%|████████▌ | 4705/5498 [1:22:45<13:55,  1.05s/it][A
Iteration:  86%|████████▌ | 4706/5498 [1:22:46<13:57,  1.06s/it][A
Iteration:  86%|████████▌ | 4707/5498 [1:22:48<13:56,  1.06s/it][A
Iteration:  86%|████████▌ | 4708/5498 [1:22:49<13:55,  1.06s/it][A
Iteration:  86%|████████▌ | 4709/5498 [1:22:50<13:54,  1.06s/it][A
Iteration:  86%|████████▌ | 4710/5498 [1:22:51<1

Iteration:  90%|████████▉ | 4936/5498 [1:26:50<09:52,  1.05s/it][A
Iteration:  90%|████████▉ | 4937/5498 [1:26:51<09:51,  1.05s/it][A
Iteration:  90%|████████▉ | 4938/5498 [1:26:52<09:50,  1.05s/it][A
Iteration:  90%|████████▉ | 4939/5498 [1:26:53<09:48,  1.05s/it][A
Iteration:  90%|████████▉ | 4940/5498 [1:26:54<09:47,  1.05s/it][A
Iteration:  90%|████████▉ | 4941/5498 [1:26:55<09:45,  1.05s/it][A
Iteration:  90%|████████▉ | 4942/5498 [1:26:56<09:44,  1.05s/it][A
Iteration:  90%|████████▉ | 4943/5498 [1:26:57<09:43,  1.05s/it][A
Iteration:  90%|████████▉ | 4944/5498 [1:26:58<09:43,  1.05s/it][A
Iteration:  90%|████████▉ | 4945/5498 [1:26:59<09:41,  1.05s/it][A
Iteration:  90%|████████▉ | 4946/5498 [1:27:00<09:40,  1.05s/it][A
Iteration:  90%|████████▉ | 4947/5498 [1:27:01<09:40,  1.05s/it][A
Iteration:  90%|████████▉ | 4948/5498 [1:27:02<09:38,  1.05s/it][A
Iteration:  90%|█████████ | 4949/5498 [1:27:03<09:36,  1.05s/it][A
Iteration:  90%|█████████ | 4950/5498 [1:27:04<0

Iteration:  94%|█████████▍| 5176/5498 [1:31:03<05:39,  1.05s/it][A
Iteration:  94%|█████████▍| 5177/5498 [1:31:04<05:38,  1.05s/it][A
Iteration:  94%|█████████▍| 5178/5498 [1:31:05<05:36,  1.05s/it][A
Iteration:  94%|█████████▍| 5179/5498 [1:31:06<05:55,  1.11s/it][A
Iteration:  94%|█████████▍| 5180/5498 [1:31:07<05:48,  1.10s/it][A
Iteration:  94%|█████████▍| 5181/5498 [1:31:08<05:43,  1.08s/it][A
Iteration:  94%|█████████▍| 5182/5498 [1:31:09<05:39,  1.07s/it][A
Iteration:  94%|█████████▍| 5183/5498 [1:31:10<05:36,  1.07s/it][A
Iteration:  94%|█████████▍| 5184/5498 [1:31:11<05:33,  1.06s/it][A
Iteration:  94%|█████████▍| 5185/5498 [1:31:12<05:31,  1.06s/it][A
Iteration:  94%|█████████▍| 5186/5498 [1:31:14<05:29,  1.06s/it][A
Iteration:  94%|█████████▍| 5187/5498 [1:31:15<05:28,  1.06s/it][A
Iteration:  94%|█████████▍| 5188/5498 [1:31:16<05:27,  1.06s/it][A
Iteration:  94%|█████████▍| 5189/5498 [1:31:17<05:26,  1.06s/it][A
Iteration:  94%|█████████▍| 5190/5498 [1:31:18<0

Iteration:  99%|█████████▊| 5416/5498 [1:35:16<01:26,  1.05s/it][A
Iteration:  99%|█████████▊| 5417/5498 [1:35:17<01:25,  1.05s/it][A
Iteration:  99%|█████████▊| 5418/5498 [1:35:18<01:24,  1.06s/it][A
Iteration:  99%|█████████▊| 5419/5498 [1:35:19<01:23,  1.05s/it][A
Iteration:  99%|█████████▊| 5420/5498 [1:35:20<01:22,  1.05s/it][A
Iteration:  99%|█████████▊| 5421/5498 [1:35:21<01:21,  1.06s/it][A
Iteration:  99%|█████████▊| 5422/5498 [1:35:22<01:20,  1.06s/it][A
Iteration:  99%|█████████▊| 5423/5498 [1:35:23<01:19,  1.05s/it][A
Iteration:  99%|█████████▊| 5424/5498 [1:35:25<01:18,  1.05s/it][A
Iteration:  99%|█████████▊| 5425/5498 [1:35:26<01:16,  1.05s/it][A
Iteration:  99%|█████████▊| 5426/5498 [1:35:27<01:15,  1.05s/it][A
Iteration:  99%|█████████▊| 5427/5498 [1:35:28<01:14,  1.06s/it][A
Iteration:  99%|█████████▊| 5428/5498 [1:35:29<01:13,  1.05s/it][A
Iteration:  99%|█████████▊| 5429/5498 [1:35:30<01:12,  1.05s/it][A
Iteration:  99%|█████████▉| 5430/5498 [1:35:31<0

In [18]:
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

In [21]:
# save the trained model and the tokenizer
#logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
#model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
#model_to_save.save_pretrained(args.output_dir)
#tokenizer.save_pretrained(args.output_dir)

('tmp/mrpc_output/vocab.txt',
 'tmp/mrpc_output/special_tokens_map.json',
 'tmp/mrpc_output/added_tokens.json')

In [151]:
 # Good practice: save your training arguments together with the trained model
#torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))

In [14]:
# Load a trained model and vocabulary that you have fine-tuned
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.model_name)
model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model.to(args.device)

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-12,

In [66]:
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory (too lone)
#checkpoints = [args.output_dir]
#checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
#logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce model loading logs
#logger.info("Evaluate the following checkpoints: %s", checkpoints)

In [22]:
# evaluate the final results
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

In [23]:
# Eval!
prefix=46530
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info("  Num examples = %d", len(dataset))
logger.info("  Batch size = %d", args.eval_batch_size)
all_results = []
for batch in tqdm(eval_dataloader, desc="Evaluating"):
    model.eval()
    batch = tuple(t.to(args.device) for t in batch)
    with torch.no_grad():
        inputs = {'input_ids':      batch[0],
                    'attention_mask': batch[1],
                    'token_type_ids': None if args.model_type == 'xlm' else batch[2]  # XLM don't use segment_ids
                    }
        example_indices = batch[3]
        if args.model_type in ['xlnet', 'xlm']:
            inputs.update({'cls_index': batch[4],
                            'p_mask':    batch[5]})
        outputs = model(**inputs)

    for i, example_index in enumerate(example_indices):
        eval_feature = features[example_index.item()]
        unique_id = int(eval_feature.unique_id)
        if args.model_type in ['xlnet', 'xlm']:
            # XLNet uses a more complex post-processing procedure
            result = RawResultExtended(unique_id            = unique_id,
                                        start_top_log_probs  = to_list(outputs[0][i]),
                                        start_top_index      = to_list(outputs[1][i]),
                                        end_top_log_probs    = to_list(outputs[2][i]),
                                        end_top_index        = to_list(outputs[3][i]),
                                        cls_logits           = to_list(outputs[4][i]))
        else:
            result = RawResult(unique_id    = unique_id,
                                start_logits = to_list(outputs[0][i]),
                                end_logits   = to_list(outputs[1][i]))
        all_results.append(result)
#result = evaluate(args, model, tokenizer, prefix=global_step)

Evaluating: 100%|██████████| 510/510 [05:26<00:00,  1.75it/s]


In [24]:
# Compute predictions
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
# assume some question does not have an answer
if args.version_2_with_negative:
    output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))

In [25]:
args.null_score_diff_threshold = 0
args.version_2_with_negative = True
if args.model_type in ['xlnet', 'xlm']:
    # XLNet uses a more complex post-processing procedure
    write_predictions_extended(examples, features, all_results, args.n_best_size,
                    args.max_answer_length, output_prediction_file,
                    output_nbest_file, output_null_log_odds_file, args.predict_file,
                    model.config.start_n_top, model.config.end_n_top,
                    args.version_2_with_negative, tokenizer, args.verbose_logging)
else:
    write_predictions(examples, features, all_results, args.n_best_size,
                    args.max_answer_length, args.do_lower_case, output_prediction_file,
                    output_nbest_file, output_null_log_odds_file, args.verbose_logging,
                    args.version_2_with_negative, args.null_score_diff_threshold)

In [26]:
# Evaluate with the official SQuAD script
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
                            pred_file=output_prediction_file,
                            na_prob_file=output_null_log_odds_file)
results = evaluate_on_squad(evaluate_options)

{
  "exact": 82.12751621325697,
  "f1": 85.07844733920284,
  "total": 11873,
  "HasAns_exact": 80.1450742240216,
  "HasAns_f1": 86.05539899769825,
  "HasAns_total": 5928,
  "NoAns_exact": 84.10428931875526,
  "NoAns_f1": 84.10428931875526,
  "NoAns_total": 5945,
  "best_exact": 82.89396108818327,
  "best_exact_thresh": -2.081042766571045,
  "best_f1": 85.69798230133034,
  "best_f1_thresh": -2.081042766571045
}


In [219]:
# manually compare the true and predicted answer
import json
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(46530))
filename1 = output_prediction_file
filename2 = 'dev-v2.0.json'
with open(filename1, 'r') as f:
    data_predict = json.load(f)
with open(filename2, 'r') as f:
    data_true = json.load(f)    

In [221]:
data_true['data'][0]

{'title': 'Normans',
 'paragraphs': [{'qas': [{'question': 'In what country is Normandy located?',
     'id': '56ddde6b9a695914005b9628',
     'answers': [{'text': 'France', 'answer_start': 159},
      {'text': 'France', 'answer_start': 159},
      {'text': 'France', 'answer_start': 159},
      {'text': 'France', 'answer_start': 159}],
     'is_impossible': False},
    {'question': 'When were the Normans in Normandy?',
     'id': '56ddde6b9a695914005b9629',
     'answers': [{'text': '10th and 11th centuries', 'answer_start': 94},
      {'text': 'in the 10th and 11th centuries', 'answer_start': 87},
      {'text': '10th and 11th centuries', 'answer_start': 94},
      {'text': '10th and 11th centuries', 'answer_start': 94}],
     'is_impossible': False},
    {'question': 'From which countries did the Norse originate?',
     'id': '56ddde6b9a695914005b962a',
     'answers': [{'text': 'Denmark, Iceland and Norway', 'answer_start': 256},
      {'text': 'Denmark, Iceland and Norway', 'answer

In [149]:
data_predict

{'56ddde6b9a695914005b9628': 'France',
 '56ddde6b9a695914005b9629': '10th and 11th centuries',
 '56ddde6b9a695914005b962a': 'Denmark, Iceland and Norway',
 '56ddde6b9a695914005b962b': 'Rollo',
 '56ddde6b9a695914005b962c': '10th',
 '5ad39d53604f3c001a3fe8d1': '',
 '5ad39d53604f3c001a3fe8d2': '',
 '5ad39d53604f3c001a3fe8d3': '',
 '5ad39d53604f3c001a3fe8d4': '',
 '56dddf4066d3e219004dad5f': 'William the Conqueror',
 '56dddf4066d3e219004dad60': 'Richard I of Normandy',
 '56dddf4066d3e219004dad61': 'Christian',
 '5ad3a266604f3c001a3fea27': '',
 '5ad3a266604f3c001a3fea28': 'The Normans',
 '5ad3a266604f3c001a3fea29': '',
 '5ad3a266604f3c001a3fea2a': 'Richard I of Normandy',
 '5ad3a266604f3c001a3fea2b': '',
 '56dde0379a695914005b9636': 'Norseman, Viking',
 '56dde0379a695914005b9637': '9th century',
 '5ad3ab70604f3c001a3feb89': '',
 '5ad3ab70604f3c001a3feb8a': '',
 '56dde0ba66d3e219004dad75': '911',
 '56dde0ba66d3e219004dad76': '',
 '56dde0ba66d3e219004dad77': 'the river Epte and the Atlantic c

In [192]:
# learning the input and output data format
import json
filename1 = args.train_file
filename2 = args.predict_file
with open(filename1, 'r') as f:
    data_train = json.load(f)
with open(filename2, 'r') as f:
    data_true = json.load(f)    

In [21]:
data_true.keys()

dict_keys(['version', 'data'])

In [28]:
data_true['data'][0]['paragraphs'][0]['qas']

[{'question': 'In what country is Normandy located?',
  'id': '56ddde6b9a695914005b9628',
  'answers': [{'text': 'France', 'answer_start': 159},
   {'text': 'France', 'answer_start': 159},
   {'text': 'France', 'answer_start': 159},
   {'text': 'France', 'answer_start': 159}],
  'is_impossible': False},
 {'question': 'When were the Normans in Normandy?',
  'id': '56ddde6b9a695914005b9629',
  'answers': [{'text': '10th and 11th centuries', 'answer_start': 94},
   {'text': 'in the 10th and 11th centuries', 'answer_start': 87},
   {'text': '10th and 11th centuries', 'answer_start': 94},
   {'text': '10th and 11th centuries', 'answer_start': 94}],
  'is_impossible': False},
 {'question': 'From which countries did the Norse originate?',
  'id': '56ddde6b9a695914005b962a',
  'answers': [{'text': 'Denmark, Iceland and Norway', 'answer_start': 256},
   {'text': 'Denmark, Iceland and Norway', 'answer_start': 256},
   {'text': 'Denmark, Iceland and Norway', 'answer_start': 256},
   {'text': 'Den

In [196]:
data_train['data'][0]['paragraphs'][0]

{'qas': [{'question': 'When did Beyonce start becoming popular?',
   'id': '56be85543aeaaa14008c9063',
   'answers': [{'text': 'in the late 1990s', 'answer_start': 269}],
   'is_impossible': False},
  {'question': 'What areas did Beyonce compete in when she was growing up?',
   'id': '56be85543aeaaa14008c9065',
   'answers': [{'text': 'singing and dancing', 'answer_start': 207}],
   'is_impossible': False},
  {'question': "When did Beyonce leave Destiny's Child and become a solo singer?",
   'id': '56be85543aeaaa14008c9066',
   'answers': [{'text': '2003', 'answer_start': 526}],
   'is_impossible': False},
  {'question': 'In what city and state did Beyonce  grow up? ',
   'id': '56bf6b0f3aeaaa14008c9601',
   'answers': [{'text': 'Houston, Texas', 'answer_start': 166}],
   'is_impossible': False},
  {'question': 'In which decade did Beyonce become famous?',
   'id': '56bf6b0f3aeaaa14008c9602',
   'answers': [{'text': 'late 1990s', 'answer_start': 276}],
   'is_impossible': False},
  {'q

In [16]:
# create similar input data
test_grandson = {}
test_grandson['qas'] = [{'question':'What happened?',
                         'id': '1',
                          'answers': [{'text':'food','answer_start': 0}],
                          'is_impossible': False},
                        {'question':'Where is the delay?',
                         'id': '2',
                          'answers': [{'text':'food','answer_start': 0}],
                          'is_impossible': False},
                       {'question':'Where is the construction?',
                         'id': '3',
                          'answers': [{'text':'food','answer_start': 0}],
                          'is_impossible': False},
                       {'question':'Where is the incident?',
                         'id': '4',
                          'answers': [{'text':'food','answer_start': 0}],
                          'is_impossible': False},
                       {'question':'When is the delay?',
                         'id': '5',
                          'answers': [{'text':'food','answer_start': 0}],
                          'is_impossible': False},
                       {'question':'When is the incident?',
                         'id': '6',
                          'answers': [{'text':'food','answer_start': 0}],
                          'is_impossible': False},
                       {'question':'When is the construction?',
                         'id': '7',
                          'answers': [{'text':'food','answer_start': 0}],
                          'is_impossible': False}]
test_grandson['context'] = 'An accident on 73 SB approaching Fellowship Rd is causing a delay between 295 and the NJ Turnpike.'

In [17]:
test_child = {}
test_child['title'] =  'Traffic'
test_child['paragraphs'] = [test_grandson]

In [18]:
test = {}
test['version'] = 'test'
test['data'] = [test_child]

In [19]:
test

{'version': 'test',
 'data': [{'title': 'Traffic',
   'paragraphs': [{'qas': [{'question': 'What happened?',
       'id': '1',
       'answers': [{'text': 'food', 'answer_start': 0}],
       'is_impossible': False},
      {'question': 'Where is the delay?',
       'id': '2',
       'answers': [{'text': 'food', 'answer_start': 0}],
       'is_impossible': False},
      {'question': 'Where is the construction?',
       'id': '3',
       'answers': [{'text': 'food', 'answer_start': 0}],
       'is_impossible': False},
      {'question': 'Where is the incident?',
       'id': '4',
       'answers': [{'text': 'food', 'answer_start': 0}],
       'is_impossible': False},
      {'question': 'When is the delay?',
       'id': '5',
       'answers': [{'text': 'food', 'answer_start': 0}],
       'is_impossible': False},
      {'question': 'When is the incident?',
       'id': '6',
       'answers': [{'text': 'food', 'answer_start': 0}],
       'is_impossible': False},
      {'question': 'When is 

In [20]:
# save test to json
import json
with open('test.json', 'w') as outfile:
    json.dump(test, outfile)

In [17]:
# we need to change the datasets
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    # Load data features from cache or dataset file
    input_file = args.predict_file if evaluate else args.train_file
    cached_features_file = 'cached_features_file'
    test_features_file = 'test_features_file'
    
    logger.info("Creating features from dataset file at %s", input_file)
    examples = read_squad_examples(input_file=input_file,
                                        is_training=not evaluate,
                                        version_2_with_negative=args.version_2_with_negative)
    features = convert_examples_to_features(examples=examples,
                                        tokenizer=tokenizer,
                                        max_seq_length=args.max_seq_length,
                                        doc_stride=args.doc_stride,
                                        max_query_length=args.max_query_length,
                                            is_training=not evaluate)
    
    if args.local_rank == 0 and not evaluate:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
    all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
    if evaluate:
        all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                all_example_index, all_cls_index, all_p_mask)
    else:
        all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
        all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                all_start_positions, all_end_positions,
                                all_cls_index, all_p_mask)

    if output_examples:
        return dataset, examples, features
    return dataset

In [22]:
# evaluate the QA for the new dataset
# evaluate the final results
args.predict_file = 'test.json'
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

In [23]:
prefix=1
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info("  Num examples = %d", len(dataset))
logger.info("  Batch size = %d", args.eval_batch_size)
all_results = []
for batch in tqdm(eval_dataloader, desc="Evaluating"):
    model.eval()
    batch = tuple(t.to(args.device) for t in batch)
    with torch.no_grad():
        inputs = {'input_ids':      batch[0],
                    'attention_mask': batch[1],
                    'token_type_ids': None if args.model_type == 'xlm' else batch[2]  # XLM don't use segment_ids
                    }
        example_indices = batch[3]
        if args.model_type in ['xlnet', 'xlm']:
            inputs.update({'cls_index': batch[4],
                            'p_mask':    batch[5]})
        outputs = model(**inputs)

    for i, example_index in enumerate(example_indices):
        eval_feature = features[example_index.item()]
        unique_id = int(eval_feature.unique_id)
        if args.model_type in ['xlnet', 'xlm']:
            # XLNet uses a more complex post-processing procedure
            result = RawResultExtended(unique_id            = unique_id,
                                        start_top_log_probs  = to_list(outputs[0][i]),
                                        start_top_index      = to_list(outputs[1][i]),
                                        end_top_log_probs    = to_list(outputs[2][i]),
                                        end_top_index        = to_list(outputs[3][i]),
                                        cls_logits           = to_list(outputs[4][i]))
        else:
            result = RawResult(unique_id    = unique_id,
                                start_logits = to_list(outputs[0][i]),
                                end_logits   = to_list(outputs[1][i]))
        all_results.append(result)

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  4.67it/s]


In [43]:
# Compute predictions
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
# assume some question does not have an answer
if args.version_2_with_negative:
    output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))

In [44]:
args.null_score_diff_threshold = 0
args.version_2_with_negative = True
if args.model_type in ['xlnet', 'xlm']:
    # XLNet uses a more complex post-processing procedure
    write_predictions_extended(examples, features, all_results, args.n_best_size,
                    args.max_answer_length, output_prediction_file,
                    output_nbest_file, output_null_log_odds_file, args.predict_file,
                    model.config.start_n_top, model.config.end_n_top,
                    args.version_2_with_negative, tokenizer, args.verbose_logging)
else:
    write_predictions(examples, features, all_results, args.n_best_size,
                    args.max_answer_length, args.do_lower_case, output_prediction_file,
                    output_nbest_file, output_null_log_odds_file, args.verbose_logging,
                    args.version_2_with_negative, args.null_score_diff_threshold)

In [45]:
# read the answer
import json
filename1 = output_prediction_file
with open(filename1, 'r') as f:
    answer = json.load(f)

In [46]:
answer

{'1': 'An accident on 73 SB approaching Fellowship Rd is causing a delay between 295 and the NJ Turnpike',
 '2': 'between 295 and the NJ Turnpike',
 '3': '',
 '4': 'Fellowship Rd',
 '5': 'between 295 and the NJ Turnpike',
 '6': 'An accident on 73 SB approaching Fellowship Rd',
 '7': ''}

In [18]:
import pandas as pd
data = pd.read_csv('local.csv')

In [19]:
# collect all the information
# collect all answers from tweets
import time
import json
answer_collect = {}
for key in range(len(data)):
#for key in range(3):
    context = data.iloc[key]['text']
    context = context.replace('#','')
    test_grandson = {}
    test_grandson['qas'] = [{'question':'What happened?',
                         'id': '1',
                          'answers': [{'text':'food','answer_start': 0}],
                          'is_impossible': False},
                        {'question':'Where is the delay?',
                         'id': '2',
                          'answers': [{'text':'food','answer_start': 0}],
                          'is_impossible': False}, 
                       {'question':'Where is the construction?',
                         'id': '3',
                          'answers': [{'text':'food','answer_start': 0}],
                          'is_impossible': False},
                       {'question':'Where is the incident?',
                         'id': '4',
                          'answers': [{'text':'food','answer_start': 0}],
                          'is_impossible': False},
                       {'question':'Where is the event?',
                         'id': '5',
                          'answers': [{'text':'food','answer_start': 0}],
                          'is_impossible': False},
                       {'question':'Where is the closure?',
                         'id': '6',
                          'answers': [{'text':'food','answer_start': 0}],
                          'is_impossible': False},
                       {'question':'Cleared or Updated?',
                         'id': '7',
                          'answers': [{'text':'food','answer_start': 0}],
                          'is_impossible': False}]
    test_grandson['context'] = context

    test_child = {}
    test_child['title'] =  'Thai'
    test_child['paragraphs'] = [test_grandson]

    test = {}
    test['version'] = 'test'
    test['data'] = [test_child]
    with open('test.json', 'w') as outfile:
        json.dump(test, outfile)
    #time.sleep(1)
    # evaluate the QA for the new dataset
    # evaluate the final results
    args.predict_file = 'test.json'
    dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly
    eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
    
    prefix=1
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    all_results = []
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)
        with torch.no_grad():
            inputs = {'input_ids':      batch[0],
                        'attention_mask': batch[1],
                        'token_type_ids': None if args.model_type == 'xlm' else batch[2]  # XLM don't use segment_ids
                        }
            example_indices = batch[3]
            if args.model_type in ['xlnet', 'xlm']:
                inputs.update({'cls_index': batch[4],
                                'p_mask':    batch[5]})
            outputs = model(**inputs)
    
        for i, example_index in enumerate(example_indices):
            eval_feature = features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
            if args.model_type in ['xlnet', 'xlm']:
                # XLNet uses a more complex post-processing procedure
                result = RawResultExtended(unique_id            = unique_id,
                                            start_top_log_probs  = to_list(outputs[0][i]),
                                            start_top_index      = to_list(outputs[1][i]),
                                            end_top_log_probs    = to_list(outputs[2][i]),
                                            end_top_index        = to_list(outputs[3][i]),
                                            cls_logits           = to_list(outputs[4][i]))
            else:
                result = RawResult(unique_id    = unique_id,
                                    start_logits = to_list(outputs[0][i]),
                                    end_logits   = to_list(outputs[1][i]))
            all_results.append(result)
            
    # Compute predictions
    output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
    output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
    # assume some question does not have an answer
    if args.version_2_with_negative:
        output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
        
    args.null_score_diff_threshold = 0
    args.version_2_with_negative = True
    if args.model_type in ['xlnet', 'xlm']:
        # XLNet uses a more complex post-processing procedure
        write_predictions_extended(examples, features, all_results, args.n_best_size,
                        args.max_answer_length, output_prediction_file,
                        output_nbest_file, output_null_log_odds_file, args.predict_file,
                        model.config.start_n_top, model.config.end_n_top,
                        args.version_2_with_negative, tokenizer, args.verbose_logging)
    else:
        write_predictions(examples, features, all_results, args.n_best_size,
                        args.max_answer_length, args.do_lower_case, output_prediction_file,
                        output_nbest_file, output_null_log_odds_file, args.verbose_logging,
                        args.version_2_with_negative, args.null_score_diff_threshold)
        
    filename1 = output_prediction_file
    with open(filename1, 'r') as f:
        answer = json.load(f)    
    #time.sleep(1)    
    answer_collect[key] = answer

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.12it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.23it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.22it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.21it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.22it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.29it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.22it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.28it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 10

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.71it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.71it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.85it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  3.28it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.61it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.87it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.66it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.86it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.82it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.08it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.28it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.28it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.32it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
Evaluating: 10

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.26it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.25it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.25it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.28it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.22it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.26it/s]
Evaluating: 10

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.27it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.23it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.26it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.22it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.21it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
Evaluating: 10

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.23it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.28it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.11it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.27it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
Evaluating: 10

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.88it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.96it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.97it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.12it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.07it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.32it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.25it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.23it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.29it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.12it/s]
Evaluating: 10

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.01it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.69it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.03it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.28it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.12it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.11it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 10

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.28it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.96it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.94it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.88it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.25it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.33it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.12it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.26it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 10

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.98it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.98it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.21it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.79it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.95it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.99it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.23it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.29it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 10

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.09it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.26it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.99it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.05it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.35it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.21it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.72it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.98it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.04it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.26it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.35it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.30it/s]
Evaluating: 10

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.08it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.27it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.84it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.06it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.01it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.25it/s]
Evaluating: 10

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.41it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.43it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.21it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.31it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.35it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.08it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.09it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.27it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.11it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.82it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.00it/s]
Evaluating: 10

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.08it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.90it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.10it/s]
Evaluating: 10

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.30it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.22it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.27it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.27it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.22it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.27it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.93it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.12it/s]
Evaluating: 10

In [20]:
answer_collect[2]

{'1': 'Construction on GeorgeWashingtonBridge WB at Lower Trans Manhattan Expressway',
 '2': '',
 '3': 'Lower Trans Manhattan Expressway',
 '4': 'Lower Trans Manhattan Expressway',
 '5': 'Lower Trans Manhattan Expressway',
 '6': '',
 '7': ''}

In [22]:
len(answer_collect)

1867

In [23]:
df = pd.DataFrame(columns=['What happened?','Where is the delay?','Where is the construction?','Where is the incident?',
                           'Where is the event?','Where is the closure?','Cleared or Updated?'])

In [24]:
for i in range(len(answer_collect)):
    df.at[i,'What happened?'] = answer_collect[i]['1']
    df.at[i,'Where is the delay?'] = answer_collect[i]['2']
    df.at[i,'Where is the construction?'] = answer_collect[i]['3']
    df.at[i,'Where is the incident?'] = answer_collect[i]['4']
    df.at[i,'Where is the event?'] = answer_collect[i]['5']
    df.at[i,'Where is the closure?'] = answer_collect[i]['6']
    df.at[i,'Cleared or Updated?'] = answer_collect[i]['7']

In [25]:
df.to_csv('df.csv')