Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

In [1]:
import os
import sys
nlp_path = os.path.abspath('../../')
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.dataset.squad import load_pandas_df
from utils_nlp.models.bert.common import Language, Tokenizer
from utils_nlp.models.bert.question_answering import BERTQAExtractor
from utils_nlp.models.bert.qa_utils import postprocess_answers, evaluate_qa

In [2]:
SQUAD_VERSION = "v1.1" 
CACHE_DIR = "./temp"

LANGUAGE = Language.ENGLISHLARGEWWM
DO_LOWER_CASE = True

MAX_SEQ_LENGTH = 384
NUM_EPOCHS = 2
BATCH_SIZE = 16
LEARNING_RATE = 3e-5
WARMUP = 0.1

DOC_TEXT_COL = "doc_text"
QUESTION_TEXT_COL = "question_text"
ANSWER_START_COL = "answer_start"
ANSWER_TEXT_COL = "answer_text"
QA_ID_COL = "qa_id"
IS_IMPOSSIBLE_COL = "is_impossible"

In [3]:
train_df = load_pandas_df(local_cache_path=".", squad_version="v1.1", file_split="train")
dev_df = load_pandas_df(local_cache_path=".", squad_version="v1.1", file_split="dev")

In [4]:
train_df.head()

Unnamed: 0,doc_text,question_text,answer_start,answer_text,qa_id,is_impossible
0,"Architecturally, the school has a Catholic cha...",To whom did the Virgin Mary allegedly appear i...,515,Saint Bernadette Soubirous,5733be284776f41900661182,False
1,"Architecturally, the school has a Catholic cha...",What is in front of the Notre Dame Main Building?,188,a copper statue of Christ,5733be284776f4190066117f,False
2,"Architecturally, the school has a Catholic cha...",The Basilica of the Sacred heart at Notre Dame...,279,the Main Building,5733be284776f41900661180,False
3,"Architecturally, the school has a Catholic cha...",What is the Grotto at Notre Dame?,381,a Marian place of prayer and reflection,5733be284776f41900661181,False
4,"Architecturally, the school has a Catholic cha...",What sits on top of the Main Building at Notre...,92,a golden statue of the Virgin Mary,5733be284776f4190066117e,False


In [5]:
dev_df.head()

Unnamed: 0,doc_text,question_text,answer_start,answer_text,qa_id,is_impossible
0,Super Bowl 50 was an American football game to...,Which NFL team represented the AFC at Super Bo...,"[177, 177, 177]","[Denver Broncos, Denver Broncos, Denver Broncos]",56be4db0acb8001400a502ec,False
1,Super Bowl 50 was an American football game to...,Which NFL team represented the NFC at Super Bo...,"[249, 249, 249]","[Carolina Panthers, Carolina Panthers, Carolin...",56be4db0acb8001400a502ed,False
2,Super Bowl 50 was an American football game to...,Where did Super Bowl 50 take place?,"[403, 355, 355]","[Santa Clara, California, Levi's Stadium, Levi...",56be4db0acb8001400a502ee,False
3,Super Bowl 50 was an American football game to...,Which NFL team won Super Bowl 50?,"[177, 177, 177]","[Denver Broncos, Denver Broncos, Denver Broncos]",56be4db0acb8001400a502ef,False
4,Super Bowl 50 was an American football game to...,What color was used to emphasize the 50th anni...,"[488, 488, 521]","[gold, gold, gold]",56be4db0acb8001400a502f0,False


In [6]:
tokenizer = Tokenizer(language=LANGUAGE, to_lower=DO_LOWER_CASE, cache_dir=CACHE_DIR)

In [7]:
train_features, qa_examples = tokenizer.tokenize_qa(
    doc_text=train_df[DOC_TEXT_COL], 
    question_text=train_df[QUESTION_TEXT_COL], 
    answer_start=train_df[ANSWER_START_COL], 
    answer_text=train_df[ANSWER_TEXT_COL],
    qa_id=train_df[QA_ID_COL],
    is_impossible=train_df[IS_IMPOSSIBLE_COL],
    is_training=True,
    max_len=MAX_SEQ_LENGTH)

In [8]:
dev_features, dev_examples = tokenizer.tokenize_qa(
    doc_text=dev_df[DOC_TEXT_COL], 
    question_text=dev_df[QUESTION_TEXT_COL], 
    answer_start=dev_df[ANSWER_START_COL], 
    answer_text=dev_df[ANSWER_TEXT_COL],
    qa_id=dev_df[QA_ID_COL],
    is_impossible=dev_df[IS_IMPOSSIBLE_COL],
    is_training=False,
    max_len=MAX_SEQ_LENGTH)

In [9]:
sample_feature = dev_features[0]
for f in type(sample_feature)._fields:
    print(f)
    print(getattr(sample_feature, f))
    print()

unique_id
1000000000

example_index
0

tokens
['[CLS]', 'which', 'nfl', 'team', 'represented', 'the', 'afc', 'at', 'super', 'bowl', '50', '?', '[SEP]', 'super', 'bowl', '50', 'was', 'an', 'american', 'football', 'game', 'to', 'determine', 'the', 'champion', 'of', 'the', 'national', 'football', 'league', '(', 'nfl', ')', 'for', 'the', '2015', 'season', '.', 'the', 'american', 'football', 'conference', '(', 'afc', ')', 'champion', 'denver', 'broncos', 'defeated', 'the', 'national', 'football', 'conference', '(', 'nfc', ')', 'champion', 'carolina', 'panthers', '24', '–', '10', 'to', 'earn', 'their', 'third', 'super', 'bowl', 'title', '.', 'the', 'game', 'was', 'played', 'on', 'february', '7', ',', '2016', ',', 'at', 'levi', "'", 's', 'stadium', 'in', 'the', 'san', 'francisco', 'bay', 'area', 'at', 'santa', 'clara', ',', 'california', '.', 'as', 'this', 'was', 'the', '50th', 'super', 'bowl', ',', 'the', 'league', 'emphasized', 'the', '"', 'golden', 'anniversary', '"', 'with', 'various', 'g

In [10]:
# train_features = train_features[0:63]

In [11]:
qa_extractor = BERTQAExtractor(language=LANGUAGE, cache_dir=CACHE_DIR)

In [None]:
qa_extractor.fit(train_features,
                 num_epochs=NUM_EPOCHS,
                 batch_size=BATCH_SIZE,
                 lr=LEARNING_RATE,
                 warmup_proportion=WARMUP,
                 model_output_dir=CACHE_DIR)

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 1/5541 [00:08<13:36:53,  8.85s/it][A
Iteration:   0%|          | 2/5541 [00:09<10:00:52,  6.51s/it][A
Iteration:   0%|          | 3/5541 [00:10<7:29:29,  4.87s/it] [A
Iteration:   0%|          | 4/5541 [00:12<5:43:38,  3.72s/it][A
Iteration:   0%|          | 5/5541 [00:13<4:30:02,  2.93s/it][A
Iteration:   0%|          | 6/5541 [00:14<3:37:55,  2.36s/it][A
Iteration:   0%|          | 7/5541 [00:15<3:01:39,  1.97s/it][A
Iteration:   0%|          | 8/5541 [00:18<3:51:36,  2.51s/it][A
Iteration:   0%|          | 9/5541 [00:19<3:11:09,  2.07s/it][A
Iteration:   0%|          | 10/5541 [00:21<2:43:07,  1.77s/it][A
Iteration:   0%|          | 11/5541 [00:22<2:23:10,  1.55s/it][A
Iteration:   0%|          | 12/5541 [00:23<2:09:31,  1.41s/it][A
Iteration:   0%|          | 13/5541 [00:24<1:59:51,  1.30s/it][A
Iteration:   0%|          | 14/5541 [00:25<1:52:59,  1.23s/it][A
Iteration:   0%|          | 15/5541 

Iteration:   2%|▏         | 119/5541 [02:21<1:35:08,  1.05s/it][A
Iteration:   2%|▏         | 120/5541 [02:22<1:35:00,  1.05s/it][A
Iteration:   2%|▏         | 121/5541 [02:23<1:34:58,  1.05s/it][A
Iteration:   2%|▏         | 122/5541 [02:24<1:34:51,  1.05s/it][A
Iteration:   2%|▏         | 123/5541 [02:25<1:34:44,  1.05s/it][A
Iteration:   2%|▏         | 124/5541 [02:26<1:34:54,  1.05s/it][A
Iteration:   2%|▏         | 125/5541 [02:27<1:34:48,  1.05s/it][A
Iteration:   2%|▏         | 126/5541 [02:28<1:34:50,  1.05s/it][A
Iteration:   2%|▏         | 127/5541 [02:29<1:34:43,  1.05s/it][A
Iteration:   2%|▏         | 128/5541 [02:30<1:34:32,  1.05s/it][A
Iteration:   2%|▏         | 129/5541 [02:31<1:34:34,  1.05s/it][A
Iteration:   2%|▏         | 130/5541 [02:32<1:34:27,  1.05s/it][A
Iteration:   2%|▏         | 131/5541 [02:33<1:34:23,  1.05s/it][A
Iteration:   2%|▏         | 132/5541 [02:34<1:34:26,  1.05s/it][A
Iteration:   2%|▏         | 133/5541 [02:35<1:34:19,  1.05s/it

Iteration:   4%|▍         | 241/5541 [04:37<1:33:13,  1.06s/it][A
Iteration:   4%|▍         | 242/5541 [04:38<1:33:06,  1.05s/it][A
Iteration:   4%|▍         | 243/5541 [04:39<1:33:05,  1.05s/it][A
Iteration:   4%|▍         | 244/5541 [04:40<1:33:03,  1.05s/it][A
Iteration:   4%|▍         | 245/5541 [04:41<1:32:58,  1.05s/it][A
Iteration:   4%|▍         | 246/5541 [04:42<1:33:07,  1.06s/it][A
Iteration:   4%|▍         | 247/5541 [04:43<1:32:57,  1.05s/it][A
Iteration:   4%|▍         | 248/5541 [04:44<1:32:51,  1.05s/it][A
Iteration:   4%|▍         | 249/5541 [04:46<1:32:47,  1.05s/it][A
Iteration:   5%|▍         | 250/5541 [04:47<1:32:34,  1.05s/it][A
Iteration:   5%|▍         | 251/5541 [04:48<1:33:04,  1.06s/it][A
Iteration:   5%|▍         | 252/5541 [04:49<1:33:03,  1.06s/it][A
Iteration:   5%|▍         | 253/5541 [04:50<1:32:53,  1.05s/it][A
Iteration:   5%|▍         | 254/5541 [04:51<1:32:52,  1.05s/it][A
Iteration:   5%|▍         | 255/5541 [04:52<1:32:45,  1.05s/it

Iteration:   7%|▋         | 363/5541 [06:51<1:30:37,  1.05s/it][A
Iteration:   7%|▋         | 364/5541 [06:52<1:30:35,  1.05s/it][A
Iteration:   7%|▋         | 365/5541 [06:53<1:30:35,  1.05s/it][A
Iteration:   7%|▋         | 366/5541 [06:54<1:30:40,  1.05s/it][A
Iteration:   7%|▋         | 367/5541 [06:55<1:30:34,  1.05s/it][A
Iteration:   7%|▋         | 368/5541 [06:56<1:30:43,  1.05s/it][A
Iteration:   7%|▋         | 369/5541 [07:00<2:40:07,  1.86s/it][A
Iteration:   7%|▋         | 370/5541 [07:01<2:19:14,  1.62s/it][A
Iteration:   7%|▋         | 371/5541 [07:02<2:04:46,  1.45s/it][A
Iteration:   7%|▋         | 372/5541 [07:03<1:54:27,  1.33s/it][A
Iteration:   7%|▋         | 373/5541 [07:04<1:47:19,  1.25s/it][A
Iteration:   7%|▋         | 374/5541 [07:05<1:42:17,  1.19s/it][A
Iteration:   7%|▋         | 375/5541 [07:06<1:38:51,  1.15s/it][A
Iteration:   7%|▋         | 376/5541 [07:07<1:36:31,  1.12s/it][A
Iteration:   7%|▋         | 377/5541 [07:08<1:34:40,  1.10s/it

Iteration:   9%|▉         | 485/5541 [09:08<1:28:30,  1.05s/it][A
Iteration:   9%|▉         | 486/5541 [09:09<1:28:25,  1.05s/it][A
Iteration:   9%|▉         | 487/5541 [09:10<1:28:19,  1.05s/it][A
Iteration:   9%|▉         | 488/5541 [09:11<1:28:19,  1.05s/it][A
Iteration:   9%|▉         | 489/5541 [09:12<1:28:12,  1.05s/it][A
Iteration:   9%|▉         | 490/5541 [09:13<1:28:16,  1.05s/it][A
Iteration:   9%|▉         | 491/5541 [09:14<1:28:12,  1.05s/it][A
Iteration:   9%|▉         | 492/5541 [09:15<1:28:14,  1.05s/it][A
Iteration:   9%|▉         | 493/5541 [09:16<1:28:21,  1.05s/it][A
Iteration:   9%|▉         | 494/5541 [09:17<1:28:16,  1.05s/it][A
Iteration:   9%|▉         | 495/5541 [09:18<1:28:18,  1.05s/it][A
Iteration:   9%|▉         | 496/5541 [09:19<1:28:19,  1.05s/it][A
Iteration:   9%|▉         | 497/5541 [09:20<1:28:19,  1.05s/it][A
Iteration:   9%|▉         | 498/5541 [09:21<1:28:28,  1.05s/it][A
Iteration:   9%|▉         | 499/5541 [09:22<1:28:27,  1.05s/it

Iteration:  11%|█         | 607/5541 [11:24<1:26:43,  1.05s/it][A
Iteration:  11%|█         | 608/5541 [11:25<1:26:36,  1.05s/it][A
Iteration:  11%|█         | 609/5541 [11:26<1:26:27,  1.05s/it][A
Iteration:  11%|█         | 610/5541 [11:28<1:26:32,  1.05s/it][A
Iteration:  11%|█         | 611/5541 [11:29<1:26:22,  1.05s/it][A
Iteration:  11%|█         | 612/5541 [11:30<1:26:23,  1.05s/it][A
Iteration:  11%|█         | 613/5541 [11:31<1:26:16,  1.05s/it][A
Iteration:  11%|█         | 614/5541 [11:32<1:26:10,  1.05s/it][A
Iteration:  11%|█         | 615/5541 [11:33<1:26:13,  1.05s/it][A
Iteration:  11%|█         | 616/5541 [11:34<1:26:05,  1.05s/it][A
Iteration:  11%|█         | 617/5541 [11:35<1:26:02,  1.05s/it][A
Iteration:  11%|█         | 618/5541 [11:36<1:26:02,  1.05s/it][A
Iteration:  11%|█         | 619/5541 [11:37<1:25:56,  1.05s/it][A
Iteration:  11%|█         | 620/5541 [11:38<1:26:06,  1.05s/it][A
Iteration:  11%|█         | 621/5541 [11:42<2:33:54,  1.88s/it

Iteration:  13%|█▎        | 729/5541 [13:41<1:24:31,  1.05s/it][A
Iteration:  13%|█▎        | 730/5541 [13:42<1:24:28,  1.05s/it][A
Iteration:  13%|█▎        | 731/5541 [13:43<1:24:23,  1.05s/it][A
Iteration:  13%|█▎        | 732/5541 [13:44<1:24:40,  1.06s/it][A
Iteration:  13%|█▎        | 733/5541 [13:45<1:24:30,  1.05s/it][A
Iteration:  13%|█▎        | 734/5541 [13:46<1:24:35,  1.06s/it][A
Iteration:  13%|█▎        | 735/5541 [13:47<1:24:31,  1.06s/it][A
Iteration:  13%|█▎        | 736/5541 [13:48<1:24:16,  1.05s/it][A
Iteration:  13%|█▎        | 737/5541 [13:49<1:24:15,  1.05s/it][A
Iteration:  13%|█▎        | 738/5541 [13:50<1:24:08,  1.05s/it][A
Iteration:  13%|█▎        | 739/5541 [13:51<1:24:04,  1.05s/it][A
Iteration:  13%|█▎        | 740/5541 [13:52<1:24:10,  1.05s/it][A
Iteration:  13%|█▎        | 741/5541 [13:53<1:24:00,  1.05s/it][A
Iteration:  13%|█▎        | 742/5541 [13:54<1:24:00,  1.05s/it][A
Iteration:  13%|█▎        | 743/5541 [13:55<1:23:57,  1.05s/it

Iteration:  15%|█▌        | 851/5541 [15:54<1:22:14,  1.05s/it][A
Iteration:  15%|█▌        | 852/5541 [15:55<1:22:15,  1.05s/it][A
Iteration:  15%|█▌        | 853/5541 [15:56<1:22:13,  1.05s/it][A
Iteration:  15%|█▌        | 854/5541 [15:58<1:22:17,  1.05s/it][A
Iteration:  15%|█▌        | 855/5541 [15:59<1:22:10,  1.05s/it][A
Iteration:  15%|█▌        | 856/5541 [16:00<1:22:07,  1.05s/it][A
Iteration:  15%|█▌        | 857/5541 [16:01<1:22:11,  1.05s/it][A
Iteration:  15%|█▌        | 858/5541 [16:02<1:22:04,  1.05s/it][A
Iteration:  16%|█▌        | 859/5541 [16:03<1:22:08,  1.05s/it][A
Iteration:  16%|█▌        | 860/5541 [16:04<1:22:18,  1.05s/it][A
Iteration:  16%|█▌        | 861/5541 [16:05<1:22:26,  1.06s/it][A
Iteration:  16%|█▌        | 862/5541 [16:06<1:22:25,  1.06s/it][A
Iteration:  16%|█▌        | 863/5541 [16:07<1:22:11,  1.05s/it][A
Iteration:  16%|█▌        | 864/5541 [16:08<1:22:03,  1.05s/it][A
Iteration:  16%|█▌        | 865/5541 [16:09<1:21:56,  1.05s/it

Iteration:  18%|█▊        | 973/5541 [18:11<2:22:30,  1.87s/it][A
Iteration:  18%|█▊        | 974/5541 [18:12<2:03:42,  1.63s/it][A
Iteration:  18%|█▊        | 975/5541 [18:13<1:50:23,  1.45s/it][A
Iteration:  18%|█▊        | 976/5541 [18:14<1:41:42,  1.34s/it][A
Iteration:  18%|█▊        | 977/5541 [18:15<1:35:14,  1.25s/it][A
Iteration:  18%|█▊        | 978/5541 [18:16<1:30:53,  1.20s/it][A
Iteration:  18%|█▊        | 979/5541 [18:17<1:27:35,  1.15s/it][A
Iteration:  18%|█▊        | 980/5541 [18:18<1:25:47,  1.13s/it][A
Iteration:  18%|█▊        | 981/5541 [18:19<1:24:04,  1.11s/it][A
Iteration:  18%|█▊        | 982/5541 [18:20<1:22:53,  1.09s/it][A
Iteration:  18%|█▊        | 983/5541 [18:21<1:21:56,  1.08s/it][A
Iteration:  18%|█▊        | 984/5541 [18:23<1:21:30,  1.07s/it][A
Iteration:  18%|█▊        | 985/5541 [18:24<1:21:07,  1.07s/it][A
Iteration:  18%|█▊        | 986/5541 [18:25<1:20:47,  1.06s/it][A
Iteration:  18%|█▊        | 987/5541 [18:26<1:20:38,  1.06s/it

Iteration:  20%|█▉        | 1093/5541 [20:23<1:18:09,  1.05s/it][A
Iteration:  20%|█▉        | 1094/5541 [20:24<1:18:06,  1.05s/it][A
Iteration:  20%|█▉        | 1095/5541 [20:25<1:17:56,  1.05s/it][A
Iteration:  20%|█▉        | 1096/5541 [20:26<1:17:58,  1.05s/it][A
Iteration:  20%|█▉        | 1097/5541 [20:27<1:17:44,  1.05s/it][A
Iteration:  20%|█▉        | 1098/5541 [20:28<1:17:43,  1.05s/it][A
Iteration:  20%|█▉        | 1099/5541 [20:29<1:17:38,  1.05s/it][A
Iteration:  20%|█▉        | 1100/5541 [20:30<1:17:31,  1.05s/it][A
Iteration:  20%|█▉        | 1101/5541 [20:31<1:17:35,  1.05s/it][A
Iteration:  20%|█▉        | 1102/5541 [20:32<1:17:28,  1.05s/it][A
Iteration:  20%|█▉        | 1103/5541 [20:33<1:17:27,  1.05s/it][A
Iteration:  20%|█▉        | 1104/5541 [20:34<1:17:30,  1.05s/it][A
Iteration:  20%|█▉        | 1105/5541 [20:35<1:17:27,  1.05s/it][A
Iteration:  20%|█▉        | 1106/5541 [20:39<2:17:37,  1.86s/it][A
Iteration:  20%|█▉        | 1107/5541 [20:40<1:5

Iteration:  22%|██▏       | 1213/5541 [22:37<1:15:50,  1.05s/it][A
Iteration:  22%|██▏       | 1214/5541 [22:38<1:15:46,  1.05s/it][A
Iteration:  22%|██▏       | 1215/5541 [22:39<1:15:54,  1.05s/it][A
Iteration:  22%|██▏       | 1216/5541 [22:40<1:15:52,  1.05s/it][A
Iteration:  22%|██▏       | 1217/5541 [22:41<1:15:49,  1.05s/it][A
Iteration:  22%|██▏       | 1218/5541 [22:42<1:15:52,  1.05s/it][A
Iteration:  22%|██▏       | 1219/5541 [22:43<1:15:49,  1.05s/it][A
Iteration:  22%|██▏       | 1220/5541 [22:44<1:15:50,  1.05s/it][A
Iteration:  22%|██▏       | 1221/5541 [22:45<1:15:50,  1.05s/it][A
Iteration:  22%|██▏       | 1222/5541 [22:46<1:15:48,  1.05s/it][A
Iteration:  22%|██▏       | 1223/5541 [22:50<2:19:33,  1.94s/it][A
Iteration:  22%|██▏       | 1224/5541 [22:51<2:00:26,  1.67s/it][A
Iteration:  22%|██▏       | 1225/5541 [22:52<1:46:56,  1.49s/it][A
Iteration:  22%|██▏       | 1226/5541 [22:53<1:37:33,  1.36s/it][A
Iteration:  22%|██▏       | 1227/5541 [22:54<1:3

Iteration:  24%|██▍       | 1333/5541 [24:52<1:13:40,  1.05s/it][A
Iteration:  24%|██▍       | 1334/5541 [24:53<1:13:35,  1.05s/it][A
Iteration:  24%|██▍       | 1335/5541 [24:54<1:13:39,  1.05s/it][A
Iteration:  24%|██▍       | 1336/5541 [24:55<1:13:30,  1.05s/it][A
Iteration:  24%|██▍       | 1337/5541 [24:56<1:13:32,  1.05s/it][A
Iteration:  24%|██▍       | 1338/5541 [24:57<1:13:34,  1.05s/it][A
Iteration:  24%|██▍       | 1339/5541 [24:58<1:13:31,  1.05s/it][A
Iteration:  24%|██▍       | 1340/5541 [25:02<2:09:32,  1.85s/it][A
Iteration:  24%|██▍       | 1341/5541 [25:03<1:52:38,  1.61s/it][A
Iteration:  24%|██▍       | 1342/5541 [25:04<1:40:53,  1.44s/it][A
Iteration:  24%|██▍       | 1343/5541 [25:05<1:32:41,  1.32s/it][A
Iteration:  24%|██▍       | 1344/5541 [25:06<1:26:50,  1.24s/it][A
Iteration:  24%|██▍       | 1345/5541 [25:07<1:22:54,  1.19s/it][A
Iteration:  24%|██▍       | 1346/5541 [25:08<1:20:03,  1.15s/it][A
Iteration:  24%|██▍       | 1347/5541 [25:09<1:1

Iteration:  26%|██▌       | 1453/5541 [27:06<1:11:43,  1.05s/it][A
Iteration:  26%|██▌       | 1454/5541 [27:07<1:11:38,  1.05s/it][A
Iteration:  26%|██▋       | 1455/5541 [27:08<1:11:36,  1.05s/it][A
Iteration:  26%|██▋       | 1456/5541 [27:09<1:11:27,  1.05s/it][A
Iteration:  26%|██▋       | 1457/5541 [27:10<1:11:28,  1.05s/it][A
Iteration:  26%|██▋       | 1458/5541 [27:11<1:11:20,  1.05s/it][A
Iteration:  26%|██▋       | 1459/5541 [27:12<1:11:16,  1.05s/it][A
Iteration:  26%|██▋       | 1460/5541 [27:14<1:11:16,  1.05s/it][A
Iteration:  26%|██▋       | 1461/5541 [27:15<1:11:09,  1.05s/it][A
Iteration:  26%|██▋       | 1462/5541 [27:16<1:11:16,  1.05s/it][A
Iteration:  26%|██▋       | 1463/5541 [27:17<1:11:13,  1.05s/it][A
Iteration:  26%|██▋       | 1464/5541 [27:18<1:11:08,  1.05s/it][A
Iteration:  26%|██▋       | 1465/5541 [27:19<1:11:14,  1.05s/it][A
Iteration:  26%|██▋       | 1466/5541 [27:20<1:11:12,  1.05s/it][A
Iteration:  26%|██▋       | 1467/5541 [27:21<1:1

Iteration:  28%|██▊       | 1573/5541 [29:20<1:22:17,  1.24s/it][A
Iteration:  28%|██▊       | 1574/5541 [29:21<1:18:34,  1.19s/it][A
Iteration:  28%|██▊       | 1575/5541 [29:22<1:15:52,  1.15s/it][A
Iteration:  28%|██▊       | 1576/5541 [29:24<1:13:59,  1.12s/it][A
Iteration:  28%|██▊       | 1577/5541 [29:25<1:12:41,  1.10s/it][A
Iteration:  28%|██▊       | 1578/5541 [29:26<1:11:43,  1.09s/it][A
Iteration:  28%|██▊       | 1579/5541 [29:27<1:11:07,  1.08s/it][A
Iteration:  29%|██▊       | 1580/5541 [29:28<1:10:36,  1.07s/it][A
Iteration:  29%|██▊       | 1581/5541 [29:29<1:10:14,  1.06s/it][A
Iteration:  29%|██▊       | 1582/5541 [29:30<1:10:02,  1.06s/it][A
Iteration:  29%|██▊       | 1583/5541 [29:31<1:09:38,  1.06s/it][A
Iteration:  29%|██▊       | 1584/5541 [29:32<1:09:29,  1.05s/it][A
Iteration:  29%|██▊       | 1585/5541 [29:33<1:09:19,  1.05s/it][A
Iteration:  29%|██▊       | 1586/5541 [29:34<1:09:10,  1.05s/it][A
Iteration:  29%|██▊       | 1587/5541 [29:35<1:0

Iteration:  31%|███       | 1693/5541 [31:32<1:07:14,  1.05s/it][A
Iteration:  31%|███       | 1694/5541 [31:33<1:07:13,  1.05s/it][A
Iteration:  31%|███       | 1695/5541 [31:34<1:07:08,  1.05s/it][A
Iteration:  31%|███       | 1696/5541 [31:35<1:07:12,  1.05s/it][A
Iteration:  31%|███       | 1697/5541 [31:36<1:07:10,  1.05s/it][A
Iteration:  31%|███       | 1698/5541 [31:37<1:07:08,  1.05s/it][A
Iteration:  31%|███       | 1699/5541 [31:38<1:07:12,  1.05s/it][A
Iteration:  31%|███       | 1700/5541 [31:39<1:07:09,  1.05s/it][A
Iteration:  31%|███       | 1701/5541 [31:40<1:07:11,  1.05s/it][A
Iteration:  31%|███       | 1702/5541 [31:41<1:07:14,  1.05s/it][A
Iteration:  31%|███       | 1703/5541 [31:42<1:07:13,  1.05s/it][A
Iteration:  31%|███       | 1704/5541 [31:43<1:07:18,  1.05s/it][A
Iteration:  31%|███       | 1705/5541 [31:44<1:07:12,  1.05s/it][A
Iteration:  31%|███       | 1706/5541 [31:45<1:07:16,  1.05s/it][A
Iteration:  31%|███       | 1707/5541 [31:46<1:0

Iteration:  33%|███▎      | 1813/5541 [33:43<1:05:15,  1.05s/it][A
Iteration:  33%|███▎      | 1814/5541 [33:44<1:05:09,  1.05s/it][A
Iteration:  33%|███▎      | 1815/5541 [33:45<1:05:07,  1.05s/it][A
Iteration:  33%|███▎      | 1816/5541 [33:46<1:05:07,  1.05s/it][A
Iteration:  33%|███▎      | 1817/5541 [33:47<1:05:02,  1.05s/it][A
Iteration:  33%|███▎      | 1818/5541 [33:51<1:50:46,  1.79s/it][A
Iteration:  33%|███▎      | 1819/5541 [33:52<1:37:02,  1.56s/it][A
Iteration:  33%|███▎      | 1820/5541 [33:53<1:27:22,  1.41s/it][A
Iteration:  33%|███▎      | 1821/5541 [33:54<1:20:42,  1.30s/it][A
Iteration:  33%|███▎      | 1822/5541 [33:55<1:15:57,  1.23s/it][A
Iteration:  33%|███▎      | 1823/5541 [33:56<1:12:43,  1.17s/it][A
Iteration:  33%|███▎      | 1824/5541 [33:57<1:10:25,  1.14s/it][A
Iteration:  33%|███▎      | 1825/5541 [33:58<1:08:43,  1.11s/it][A
Iteration:  33%|███▎      | 1826/5541 [33:59<1:07:38,  1.09s/it][A
Iteration:  33%|███▎      | 1827/5541 [34:00<1:0

Iteration:  35%|███▍      | 1933/5541 [35:56<1:03:21,  1.05s/it][A
Iteration:  35%|███▍      | 1934/5541 [35:57<1:03:15,  1.05s/it][A
Iteration:  35%|███▍      | 1935/5541 [35:58<1:03:18,  1.05s/it][A
Iteration:  35%|███▍      | 1936/5541 [35:59<1:03:12,  1.05s/it][A
Iteration:  35%|███▍      | 1937/5541 [36:00<1:03:10,  1.05s/it][A
Iteration:  35%|███▍      | 1938/5541 [36:02<1:03:10,  1.05s/it][A
Iteration:  35%|███▍      | 1939/5541 [36:03<1:03:04,  1.05s/it][A
Iteration:  35%|███▌      | 1940/5541 [36:04<1:03:05,  1.05s/it][A
Iteration:  35%|███▌      | 1941/5541 [36:05<1:03:00,  1.05s/it][A
Iteration:  35%|███▌      | 1942/5541 [36:06<1:02:56,  1.05s/it][A
Iteration:  35%|███▌      | 1943/5541 [36:07<1:02:56,  1.05s/it][A
Iteration:  35%|███▌      | 1944/5541 [36:08<1:03:11,  1.05s/it][A
Iteration:  35%|███▌      | 1945/5541 [36:11<1:47:27,  1.79s/it][A
Iteration:  35%|███▌      | 1946/5541 [36:12<1:34:02,  1.57s/it][A
Iteration:  35%|███▌      | 1947/5541 [36:13<1:2

Iteration:  37%|███▋      | 2053/5541 [38:10<1:01:24,  1.06s/it][A
Iteration:  37%|███▋      | 2054/5541 [38:11<1:01:13,  1.05s/it][A
Iteration:  37%|███▋      | 2055/5541 [38:12<1:01:08,  1.05s/it][A
Iteration:  37%|███▋      | 2056/5541 [38:13<1:01:01,  1.05s/it][A
Iteration:  37%|███▋      | 2057/5541 [38:14<1:01:02,  1.05s/it][A
Iteration:  37%|███▋      | 2058/5541 [38:15<1:01:02,  1.05s/it][A
Iteration:  37%|███▋      | 2059/5541 [38:16<1:01:06,  1.05s/it][A
Iteration:  37%|███▋      | 2060/5541 [38:17<1:01:11,  1.05s/it][A
Iteration:  37%|███▋      | 2061/5541 [38:18<1:01:04,  1.05s/it][A
Iteration:  37%|███▋      | 2062/5541 [38:19<1:01:08,  1.05s/it][A
Iteration:  37%|███▋      | 2063/5541 [38:20<1:01:06,  1.05s/it][A
Iteration:  37%|███▋      | 2064/5541 [38:21<1:01:11,  1.06s/it][A
Iteration:  37%|███▋      | 2065/5541 [38:22<1:01:14,  1.06s/it][A
Iteration:  37%|███▋      | 2066/5541 [38:23<1:01:06,  1.06s/it][A
Iteration:  37%|███▋      | 2067/5541 [38:24<1:0

Iteration:  39%|███▉      | 2174/5541 [40:24<1:38:08,  1.75s/it][A
Iteration:  39%|███▉      | 2175/5541 [40:25<1:26:16,  1.54s/it][A
Iteration:  39%|███▉      | 2176/5541 [40:26<1:17:56,  1.39s/it][A
Iteration:  39%|███▉      | 2177/5541 [40:27<1:12:11,  1.29s/it][A
Iteration:  39%|███▉      | 2178/5541 [40:28<1:08:01,  1.21s/it][A
Iteration:  39%|███▉      | 2179/5541 [40:29<1:05:13,  1.16s/it][A
Iteration:  39%|███▉      | 2180/5541 [40:30<1:03:13,  1.13s/it][A
Iteration:  39%|███▉      | 2181/5541 [40:31<1:01:45,  1.10s/it][A
Iteration:  39%|███▉      | 2182/5541 [40:32<1:00:51,  1.09s/it][A
Iteration:  39%|███▉      | 2183/5541 [40:33<1:00:07,  1.07s/it][A
Iteration:  39%|███▉      | 2184/5541 [40:34<59:40,  1.07s/it]  [A
Iteration:  39%|███▉      | 2185/5541 [40:35<59:24,  1.06s/it][A
Iteration:  39%|███▉      | 2186/5541 [40:36<59:07,  1.06s/it][A
Iteration:  39%|███▉      | 2187/5541 [40:37<59:01,  1.06s/it][A
Iteration:  39%|███▉      | 2188/5541 [40:39<58:54,  1

Iteration:  41%|████▏     | 2297/5541 [42:40<1:01:17,  1.13s/it][A
Iteration:  41%|████▏     | 2298/5541 [42:41<59:53,  1.11s/it]  [A
Iteration:  41%|████▏     | 2299/5541 [42:42<58:59,  1.09s/it][A
Iteration:  42%|████▏     | 2300/5541 [42:43<58:14,  1.08s/it][A
Iteration:  42%|████▏     | 2301/5541 [42:44<57:44,  1.07s/it][A
Iteration:  42%|████▏     | 2302/5541 [42:45<57:25,  1.06s/it][A
Iteration:  42%|████▏     | 2303/5541 [42:46<57:10,  1.06s/it][A
Iteration:  42%|████▏     | 2304/5541 [42:47<57:05,  1.06s/it][A
Iteration:  42%|████▏     | 2305/5541 [42:48<56:54,  1.06s/it][A
Iteration:  42%|████▏     | 2306/5541 [42:50<56:45,  1.05s/it][A
Iteration:  42%|████▏     | 2307/5541 [42:51<56:42,  1.05s/it][A
Iteration:  42%|████▏     | 2308/5541 [42:52<56:35,  1.05s/it][A
Iteration:  42%|████▏     | 2309/5541 [42:53<56:35,  1.05s/it][A
Iteration:  42%|████▏     | 2310/5541 [42:54<56:31,  1.05s/it][A
Iteration:  42%|████▏     | 2311/5541 [42:55<56:25,  1.05s/it][A
Iterat

In [None]:
qa_results = qa_extractor.predict(dev_features)

In [None]:
final_answers = postprocess_answers(dev_examples, 
                                    dev_features, 
                                    qa_results, 
                                    do_lower_case=DO_LOWER_CASE)

In [None]:
# final_answers

In [None]:
# import json
# with open('/home/hlu/models/wwm_uncased_fintuned_squad/predictions_.json') as f:
#     preds = json.load(f)

In [None]:
evaluate_qa(qa_ids=dev_df['qa_id'], 
            actuals=dev_df['answer_text'], 
            preds=final_answers)