# Evaluation

In [1]:
import os
import re
import argparse
import json
import sys
import string
import pandas as pd
import numpy as np
import math
from collections import defaultdict, Counter
import os
import glob

In [2]:


import torch
from transformers import MT5ForConditionalGeneration, MT5Tokenizer, DataCollatorForSeq2Seq, DataCollatorWithPadding
from datasets import Dataset, load_dataset

from pprint import pprint

from tqdm import tqdm

device = "cuda:0"# "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
for i in tqdm(range(10)):
    continue

100%|██████████| 10/10 [00:00<00:00, 104077.02it/s]


In [4]:
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        tx = re.sub(r'\b(a|an|the)\b.', ' ', text)
        tx = tx.replace('pad', '').replace('s', '')
        return tx

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


def evaluate(gold_answers, predictions):
    f1 = exact_match = total = 0

    for ground_truths, prediction in zip(gold_answers, predictions):
        total += 1
        exact_match += metric_max_over_ground_truths(
                    exact_match_score, prediction, ground_truths)
        f1 += metric_max_over_ground_truths(
          f1_score, prediction, ground_truths)

    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total

    return {'exact_match': exact_match, 'f1': f1}

In [5]:
# process the examples in input and target text format and the eos token at the end 
def add_eos_to_examples(example):
    result = {}
    context = example['context']
    question = example['question']
    if 'answers' in example.keys():
        answer = example['answers']['text'][0]
        result['answers'] = example['answers']
    else:
        answer = example['answer']

    result['input_text'] =  'question: %s context: %s' % (question, context)
    result['answer'] = result
    result['target_text'] = '%s' % answer
    
    return result

def convert_to_features(example):

    encoding = {}
    
    input_encoding = tokenizer.encode_plus(example['input_text'],
#                                            pad_to_max_length=True,
                                           truncation=True,
                                           max_length=512,
                                           add_special_tokens=True)
    target_encoding = tokenizer.encode_plus(example['target_text'],
                                            pad_to_max_length=True,
                                              truncation=True,
                                            max_length=16, add_special_tokens=True)

    encoding['input_ids'] = input_encoding['input_ids'] 
#     encoding['attention_mask'] = input_encoding['attention_mask'] 
    encoding['target_ids'] = target_encoding['input_ids'] 
#     encoding['labels_attention_mask'] =  target_encoding['attention_mask'] 

    # print(f"type(encodings['input_ids']: {type(encodings['input_ids'])}")
    return encoding

In [6]:
!ls /ist

ls: cannot access '/ist': No such file or directory


# MLQA

In [68]:
ALL_MLQA_FINETUNED_MODEL_DIR = glob.glob('../checkpoints/t5-base.adafactor.seq2seq.mlqa*')

T5_TOKENIZER_MODEL_DIR = '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-4.max_steps-25000.save_steps-500/'

tokenizer = MT5Tokenizer.from_pretrained(T5_TOKENIZER_MODEL_DIR, max_length=512)
tokenizer


PreTrainedTokenizer(name_or_path='../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-4.max_steps-25000.save_steps-500/', vocab_size=250100, model_max_len=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'})

In [95]:
MLQA_TEST_DATA_PATH = '../data/mlqa/datasets_format/test.json'

mlqa_xx = { 
     'test': json.load(open(MLQA_TEST_DATA_PATH))['data'],
}

In [96]:
mlqa_xx['test'][0]

{'id': 'd019ebb48d4d69d6ae303155',
 'title': 'Pappataci fever',
 'context': 'Pappataci fever is prevalent in the subtropical zone of the Eastern Hemisphere between 20°N and 45°N, particularly in Southern Europe, North Africa, the Balkans, Eastern Mediterranean, Iraq, Iran, Pakistan, Afghanistan and India.The disease is transmitted by the bites of phlebotomine sandflies of the Genus Phlebotomus, in particular, Phlebotomus papatasi, Phlebotomus perniciosus and Phlebotomus perfiliewi. The sandfly becomes infected when biting an infected human in the period between 48 hours before the onset of fever and 24 hours after the end of the fever, and remains infected for its lifetime. Besides this horizontal virus transmission from man to sandfly, the virus can be transmitted in insects transovarially, from an infected female sandfly to its offspring.Pappataci fever is seldom recognised in endemic populations because it is mixed with other febrile illnesses of childhood, but it is more well-known

### Your dataset

In [97]:
# data_files = {"test": "/ist/ist-share/scads/korn/datasets/qa_datasset/MLQA/dev/dev-context-ar-question-vi.json"}
mlqa_test_dataset = mlqa_xx['test']
mlqa_test_dataset[0]

{'id': 'd019ebb48d4d69d6ae303155',
 'title': 'Pappataci fever',
 'context': 'Pappataci fever is prevalent in the subtropical zone of the Eastern Hemisphere between 20°N and 45°N, particularly in Southern Europe, North Africa, the Balkans, Eastern Mediterranean, Iraq, Iran, Pakistan, Afghanistan and India.The disease is transmitted by the bites of phlebotomine sandflies of the Genus Phlebotomus, in particular, Phlebotomus papatasi, Phlebotomus perniciosus and Phlebotomus perfiliewi. The sandfly becomes infected when biting an infected human in the period between 48 hours before the onset of fever and 24 hours after the end of the fever, and remains infected for its lifetime. Besides this horizontal virus transmission from man to sandfly, the virus can be transmitted in insects transovarially, from an infected female sandfly to its offspring.Pappataci fever is seldom recognised in endemic populations because it is mixed with other febrile illnesses of childhood, but it is more well-known

In [98]:
mlqa_test_features = list(map(convert_to_features, map(add_eos_to_examples, mlqa_test_dataset)))
len(mlqa_test_features)



(4199,
 {'input_ids': [7680,
   267,
   259,
   64996,
   461,
   259,
   63446,
   332,
   9016,
   110663,
   1002,
   8125,
   910,
   1459,
   291,
   19730,
   267,
   128108,
   1319,
   769,
   259,
   50646,
   339,
   786,
   56451,
   281,
   287,
   2411,
   75763,
   21102,
   304,
   287,
   259,
   48918,
   13044,
   266,
   111192,
   259,
   4964,
   259,
   102226,
   538,
   305,
   147363,
   538,
   261,
   6962,
   484,
   281,
   259,
   40892,
   3970,
   261,
   6424,
   15750,
   261,
   287,
   87536,
   263,
   261,
   259,
   48918,
   259,
   122307,
   261,
   259,
   36986,
   261,
   19255,
   261,
   24272,
   261,
   73261,
   305,
   4783,
   260,
   2009,
   32664,
   339,
   31012,
   3678,
   455,
   287,
   5485,
   299,
   304,
   690,
   129343,
   476,
   5650,
   14684,
   110663,
   304,
   287,
   114313,
   263,
   2879,
   129343,
   17991,
   263,
   261,
   281,
   6962,
   261,
   2879,
   129343,
   17991,
   263,
   12498,
   26039,


In [99]:
mlqa_references = [ item['answers']['text'] for item in mlqa_xx['test'] ]
len(mlqa_references),  mlqa_references[:2]

(4199, [['remains infected for its lifetime'], ['collenchyma']])

Evaluate MLQA

In [100]:
ALL_MLQA_FINETUNED_MODEL_DIR


['../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500',
 '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-4.max_steps-25000.save_steps-500',
 '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500',
 '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-2.5e-5.max_steps-30000.save_steps-500']

In [101]:
!mkdir -p mlqa_scores

In [103]:

for MLQA_FINETUNED_MODEL_DIRS in ALL_MLQA_FINETUNED_MODEL_DIR:
    
    EACH_CKP_MLQA_FINETUNED_MODEL_DIRS = glob.glob(f'{MLQA_FINETUNED_MODEL_DIRS}/checkpoint-*')
    print('='*50)
    print(f'\n\n\nModel EXP: {MLQA_FINETUNED_MODEL_DIRS}')
    model_exp_name = MLQA_FINETUNED_MODEL_DIRS.split('/')[-1]
    model_exp_name

    BATCH_SIZE = 128
    mlqa_en_scores = []
    for MODEL_DIR in EACH_CKP_MLQA_FINETUNED_MODEL_DIRS:
    # MODEL_DIR = '/ist/ist-share/scads/aires/'
        print(f'MODEL_DIR: {MODEL_DIR}')
        model_ckp = MODEL_DIR.split('-')[-1]

        model = MT5ForConditionalGeneration.from_pretrained(MODEL_DIR).to(device)

        data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer,
                                                padding=True,
                                                max_length=512,
                                               )
        data_loader = torch.utils.data.DataLoader(mlqa_test_features,
                                                  batch_size=BATCH_SIZE,
                                                  collate_fn=data_collator)

        predictions = []
        answers = []
        c = 0
        batched_gt_ans = []
        for i, batch in tqdm(enumerate(data_loader)):
            batch_size = len(batch['input_ids'])

            target_ans = tokenizer.batch_decode(batch['target_ids'], skip_special_tokens=True)
            batched_gt_ans.extend(target_ans)
            # print('batch size', len(batch['input_ids']))
            outs = model.generate(input_ids=batch['input_ids'].to(device), 
                                attention_mask=batch['attention_mask'].to(device),
                                max_length=64,
                                early_stopping=True,
                                num_beams=1,
                                decoder_start_token_id=0)

            answer = [tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
            if c < 5:
    #             print(f'\nRefs and answer \n')
    #             pprint(list(zip(answer[:BATCH_SIZE], references[:BATCH_SIZE])))
    #             print('\n')
                batch_level_eval_results = evaluate(mlqa_references[i*BATCH_SIZE:(i+1)*BATCH_SIZE], answer[:BATCH_SIZE])
                print(' - Batch-level eval:')
                pprint(batch_level_eval_results, indent=4)
                print('\n')
    #             print(f'\n  refs: {mlqa_references[-6:]}')
    #             print(f'\nanswer: {answer[-6:]}')
                c+=1
            answers.extend(answer)
    #         print(f'answers: {answers}')
    #         break
        # break
        predictions = answers

        eval_results = evaluate(mlqa_references, predictions)
        print('Per-epoch eval results')
        print(eval_results)
        print('\n\n')
        mlqa_en_scores.append({
            'model_ckp': model_ckp,
            'model_dir': MODEL_DIR,
            **eval_results,
        })

    with open(f'./mlqa_scores/mlqa_en_scores.{model_exp_name}.json', 'w') as f:
        json.dump(mlqa_en_scores, f, indent=4)
    pprint(mlqa_en_scores, sort_dicts=False)
    print('\n')
    print('-'*60)
    print('\n\n')




Model EXP: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500
MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500/checkpoint-4753


1it [00:03,  3.86s/it]

 - Batch-level eval:
{'exact_match': 65.625, 'f1': 80.63807936854815}




2it [00:09,  4.34s/it]

 - Batch-level eval:
{'exact_match': 57.8125, 'f1': 69.52182904695981}




3it [00:13,  4.29s/it]

 - Batch-level eval:
{'exact_match': 57.03125, 'f1': 74.27730381816828}




4it [00:17,  4.20s/it]

 - Batch-level eval:
{'exact_match': 62.5, 'f1': 75.14909173706728}




5it [00:21,  4.27s/it]

 - Batch-level eval:
{'exact_match': 54.6875, 'f1': 69.06361001256947}




33it [02:29,  4.52s/it]


Per-epoch eval results
{'exact_match': 56.98975946653965, 'f1': 70.30063090841755}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500/checkpoint-19012


1it [00:03,  3.73s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 79.92322649811875}




2it [00:09,  4.23s/it]

 - Batch-level eval:
{'exact_match': 57.8125, 'f1': 69.37767199875913}




3it [00:13,  4.15s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 74.38014070174566}




4it [00:16,  3.96s/it]

 - Batch-level eval:
{'exact_match': 66.40625, 'f1': 79.03979488872729}




5it [00:20,  3.97s/it]

 - Batch-level eval:
{'exact_match': 54.6875, 'f1': 70.39228553148347}




33it [02:20,  4.25s/it]


Per-epoch eval results
{'exact_match': 60.30007144558228, 'f1': 72.77181955997746}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500/checkpoint-9506


1it [00:03,  3.84s/it]

 - Batch-level eval:
{'exact_match': 66.40625, 'f1': 81.72486584595963}




2it [00:07,  3.79s/it]

 - Batch-level eval:
{'exact_match': 57.8125, 'f1': 69.40029247039276}




3it [00:12,  4.27s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 76.20414769698826}




4it [00:16,  4.05s/it]

 - Batch-level eval:
{'exact_match': 67.96875, 'f1': 78.3337002249846}




5it [00:20,  4.08s/it]

 - Batch-level eval:
{'exact_match': 57.8125, 'f1': 71.54730444152801}




33it [02:20,  4.26s/it]


Per-epoch eval results
{'exact_match': 61.038342462491066, 'f1': 73.42493795877951}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500/checkpoint-20000


1it [00:03,  3.75s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 77.8669602816963}




2it [00:07,  3.74s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 72.34840078568857}




3it [00:11,  3.81s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 74.525487682131}




4it [00:15,  3.78s/it]

 - Batch-level eval:
{'exact_match': 65.625, 'f1': 78.04205313851841}




5it [00:19,  3.85s/it]

 - Batch-level eval:
{'exact_match': 52.34375, 'f1': 69.8658154737634}




33it [02:16,  4.13s/it]


Per-epoch eval results
{'exact_match': 60.20481066920696, 'f1': 73.2191860198789}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500/checkpoint-14259


1it [00:03,  3.91s/it]

 - Batch-level eval:
{'exact_match': 62.5, 'f1': 78.76343682869421}




2it [00:07,  3.88s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 69.53018514716042}




3it [00:11,  3.89s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 74.43948377095714}




4it [00:15,  3.92s/it]

 - Batch-level eval:
{'exact_match': 65.625, 'f1': 79.65077724518662}




5it [00:19,  3.94s/it]

 - Batch-level eval:
{'exact_match': 54.6875, 'f1': 71.10341747229808}




33it [02:16,  4.13s/it]


Per-epoch eval results
{'exact_match': 59.70469159323648, 'f1': 72.53781831120766}



[{'model_ckp': '4753',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500/checkpoint-4753',
  'exact_match': 56.98975946653965,
  'f1': 70.30063090841755},
 {'model_ckp': '19012',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500/checkpoint-19012',
  'exact_match': 60.30007144558228,
  'f1': 72.77181955997746},
 {'model_ckp': '9506',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500/checkpoint-9506',
  'exact_match': 61.038342462491066,
  'f1': 73.42493795877951},
 {'model_ckp': '20000',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500/checkpoint-20000',
  'exact_match': 60.20481066920696,
  'f1': 73.219186

1it [00:04,  4.94s/it]

 - Batch-level eval:
{'exact_match': 58.59375, 'f1': 76.85346988081363}




2it [00:09,  4.89s/it]

 - Batch-level eval:
{'exact_match': 55.46875, 'f1': 67.93031882162107}




3it [00:14,  4.90s/it]

 - Batch-level eval:
{'exact_match': 53.125, 'f1': 69.78435246974819}




4it [00:18,  4.67s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 74.71743675924768}




5it [00:22,  4.52s/it]

 - Batch-level eval:
{'exact_match': 53.90625, 'f1': 68.04614707188236}




33it [02:27,  4.48s/it]


Per-epoch eval results
{'exact_match': 54.79876160990712, 'f1': 68.59596041032331}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-4.max_steps-25000.save_steps-500/checkpoint-23765


1it [00:03,  3.74s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 78.17145049634277}




2it [00:07,  3.84s/it]

 - Batch-level eval:
{'exact_match': 55.46875, 'f1': 67.76228310649813}




3it [00:11,  3.84s/it]

 - Batch-level eval:
{'exact_match': 56.25, 'f1': 71.18196676483879}




4it [00:16,  4.12s/it]

 - Batch-level eval:
{'exact_match': 60.9375, 'f1': 72.32655409629099}




5it [00:20,  4.10s/it]

 - Batch-level eval:
{'exact_match': 56.25, 'f1': 69.88255654519197}




33it [02:19,  4.22s/it]


Per-epoch eval results
{'exact_match': 55.53703262681591, 'f1': 68.31767098201894}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-4.max_steps-25000.save_steps-500/checkpoint-25000


1it [00:03,  3.79s/it]

 - Batch-level eval:
{'exact_match': 57.03125, 'f1': 72.28772593317633}




2it [00:07,  3.87s/it]

 - Batch-level eval:
{'exact_match': 56.25, 'f1': 66.22366323769735}




3it [00:12,  4.19s/it]

 - Batch-level eval:
{'exact_match': 53.125, 'f1': 68.538942204355}




4it [00:17,  4.35s/it]

 - Batch-level eval:
{'exact_match': 57.8125, 'f1': 70.80187651692592}




5it [00:21,  4.26s/it]

 - Batch-level eval:
{'exact_match': 58.59375, 'f1': 70.92705178298007}




33it [02:22,  4.33s/it]


Per-epoch eval results
{'exact_match': 55.06072874493927, 'f1': 68.33177766482166}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-4.max_steps-25000.save_steps-500/checkpoint-19012


1it [00:03,  3.42s/it]

 - Batch-level eval:
{'exact_match': 57.8125, 'f1': 74.23217037762078}




2it [00:07,  3.56s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 70.47625296822825}




3it [00:10,  3.59s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 70.93581831350942}




4it [00:14,  3.56s/it]

 - Batch-level eval:
{'exact_match': 57.8125, 'f1': 71.57121072709212}




5it [00:18,  3.69s/it]

 - Batch-level eval:
{'exact_match': 57.8125, 'f1': 71.73699742192147}




33it [02:13,  4.06s/it]


Per-epoch eval results
{'exact_match': 56.918313884258154, 'f1': 69.43475807565589}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-4.max_steps-25000.save_steps-500/checkpoint-9506


1it [00:03,  3.88s/it]

 - Batch-level eval:
{'exact_match': 65.625, 'f1': 78.4775489267677}




2it [00:07,  3.91s/it]

 - Batch-level eval:
{'exact_match': 58.59375, 'f1': 69.77469634979717}




3it [00:12,  4.21s/it]

 - Batch-level eval:
{'exact_match': 53.90625, 'f1': 72.19071964264738}




4it [00:16,  4.20s/it]

 - Batch-level eval:
{'exact_match': 60.9375, 'f1': 74.49588644388741}




5it [00:20,  4.14s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 72.0865276471894}




33it [02:23,  4.33s/it]


Per-epoch eval results
{'exact_match': 56.68016194331984, 'f1': 69.82175165962082}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-4.max_steps-25000.save_steps-500/checkpoint-14259


1it [00:03,  3.51s/it]

 - Batch-level eval:
{'exact_match': 57.03125, 'f1': 71.41307043650797}




2it [00:07,  3.62s/it]

 - Batch-level eval:
{'exact_match': 52.34375, 'f1': 66.01209050085676}




3it [00:11,  3.79s/it]

 - Batch-level eval:
{'exact_match': 54.6875, 'f1': 69.75416096939435}




4it [00:15,  3.85s/it]

 - Batch-level eval:
{'exact_match': 64.0625, 'f1': 78.02741100524965}




5it [00:19,  3.91s/it]

 - Batch-level eval:
{'exact_match': 58.59375, 'f1': 70.44807253102289}




33it [02:14,  4.08s/it]


Per-epoch eval results
{'exact_match': 56.84686830197666, 'f1': 69.32650788028167}



[{'model_ckp': '4753',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-4.max_steps-25000.save_steps-500/checkpoint-4753',
  'exact_match': 54.79876160990712,
  'f1': 68.59596041032331},
 {'model_ckp': '23765',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-4.max_steps-25000.save_steps-500/checkpoint-23765',
  'exact_match': 55.53703262681591,
  'f1': 68.31767098201894},
 {'model_ckp': '25000',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-4.max_steps-25000.save_steps-500/checkpoint-25000',
  'exact_match': 55.06072874493927,
  'f1': 68.33177766482166},
 {'model_ckp': '19012',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-4.max_steps-25000.save_steps-500/checkpoint-19012',
  'exact_match': 56.918313884258154,
  'f1': 69.4347

1it [00:03,  3.92s/it]

 - Batch-level eval:
{'exact_match': 44.53125, 'f1': 55.73939732142856}




2it [00:09,  4.36s/it]

 - Batch-level eval:
{'exact_match': 33.59375, 'f1': 43.4280122010629}




3it [00:14,  4.68s/it]

 - Batch-level eval:
{'exact_match': 35.9375, 'f1': 49.244017291997835}




4it [00:19,  4.77s/it]

 - Batch-level eval:
{'exact_match': 41.40625, 'f1': 53.90829569106281}




5it [00:24,  4.69s/it]

 - Batch-level eval:
{'exact_match': 32.8125, 'f1': 42.50617718678381}




33it [02:32,  4.62s/it]


Per-epoch eval results
{'exact_match': 32.62681590854965, 'f1': 44.15215756179154}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-23765


1it [00:03,  3.86s/it]

 - Batch-level eval:
{'exact_match': 63.28125, 'f1': 78.77213715104342}




2it [00:07,  3.88s/it]

 - Batch-level eval:
{'exact_match': 62.5, 'f1': 73.29344666991496}




3it [00:11,  3.97s/it]

 - Batch-level eval:
{'exact_match': 60.9375, 'f1': 76.91178449952396}




4it [00:16,  4.21s/it]

 - Batch-level eval:
{'exact_match': 64.84375, 'f1': 77.19791385582653}




5it [00:20,  4.15s/it]

 - Batch-level eval:
{'exact_match': 57.03125, 'f1': 71.3352161576813}




33it [02:21,  4.30s/it]


Per-epoch eval results
{'exact_match': 59.823767563705644, 'f1': 72.83228562322316}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-25000


1it [00:03,  3.89s/it]

 - Batch-level eval:
{'exact_match': 64.0625, 'f1': 78.99173481272082}




2it [00:09,  4.35s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 72.26848063730625}




3it [00:13,  4.31s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 77.9128520461078}




4it [00:18,  4.45s/it]

 - Batch-level eval:
{'exact_match': 64.0625, 'f1': 75.84249672635512}




5it [00:22,  4.32s/it]

 - Batch-level eval:
{'exact_match': 53.90625, 'f1': 69.81412950332259}




33it [02:26,  4.45s/it]


Per-epoch eval results
{'exact_match': 60.10954989283162, 'f1': 72.85439604850133}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-19012


1it [00:03,  3.91s/it]

 - Batch-level eval:
{'exact_match': 64.0625, 'f1': 77.56468141233768}




2it [00:09,  4.37s/it]

 - Batch-level eval:
{'exact_match': 57.8125, 'f1': 70.30091434942747}




3it [00:13,  4.44s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 75.03304955693032}




4it [00:18,  4.51s/it]

 - Batch-level eval:
{'exact_match': 63.28125, 'f1': 76.29388561524402}




5it [00:22,  4.36s/it]

 - Batch-level eval:
{'exact_match': 57.8125, 'f1': 70.17258538567552}




33it [02:25,  4.41s/it]


Per-epoch eval results
{'exact_match': 58.871159799952366, 'f1': 71.31343546219146}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-9506


1it [00:03,  3.90s/it]

 - Batch-level eval:
{'exact_match': 57.03125, 'f1': 72.42466517857143}




2it [00:09,  4.35s/it]

 - Batch-level eval:
{'exact_match': 50.0, 'f1': 62.40402381083025}




3it [00:14,  4.67s/it]

 - Batch-level eval:
{'exact_match': 51.5625, 'f1': 65.55966537852086}




4it [00:19,  4.68s/it]

 - Batch-level eval:
{'exact_match': 58.59375, 'f1': 69.623138699782}




5it [00:23,  4.48s/it]

 - Batch-level eval:
{'exact_match': 52.34375, 'f1': 68.02617379073455}




33it [02:31,  4.59s/it]


Per-epoch eval results
{'exact_match': 50.34532031436056, 'f1': 63.20791112460852}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-14259


1it [00:03,  3.91s/it]

 - Batch-level eval:
{'exact_match': 63.28125, 'f1': 76.93650826807999}




2it [00:09,  4.37s/it]

 - Batch-level eval:
{'exact_match': 54.6875, 'f1': 68.59131937013139}




3it [00:13,  4.31s/it]

 - Batch-level eval:
{'exact_match': 58.59375, 'f1': 72.9418301265144}




4it [00:18,  4.42s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 75.82289814172898}




5it [00:22,  4.30s/it]

 - Batch-level eval:
{'exact_match': 54.6875, 'f1': 69.71585461644479}




33it [02:23,  4.34s/it]


Per-epoch eval results
{'exact_match': 56.58490116694451, 'f1': 69.22956717060629}



[{'model_ckp': '4753',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-4753',
  'exact_match': 32.62681590854965,
  'f1': 44.15215756179154},
 {'model_ckp': '23765',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-23765',
  'exact_match': 59.823767563705644,
  'f1': 72.83228562322316},
 {'model_ckp': '25000',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-25000',
  'exact_match': 60.10954989283162,
  'f1': 72.85439604850133},
 {'model_ckp': '19012',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-19012',
  'exact_match': 58.871159799952366,
  'f1': 71.313

1it [00:03,  3.90s/it]

 - Batch-level eval:
{'exact_match': 58.59375, 'f1': 73.05799885878011}




2it [00:09,  4.34s/it]

 - Batch-level eval:
{'exact_match': 53.90625, 'f1': 66.01682456152896}




3it [00:13,  4.29s/it]

 - Batch-level eval:
{'exact_match': 52.34375, 'f1': 66.96560446598568}




4it [00:18,  4.42s/it]

 - Batch-level eval:
{'exact_match': 57.8125, 'f1': 72.0192464754128}




5it [00:22,  4.43s/it]

 - Batch-level eval:
{'exact_match': 49.21875, 'f1': 62.64590559527624}




33it [02:30,  4.55s/it]


Per-epoch eval results
{'exact_match': 51.03596094308168, 'f1': 64.52186701205903}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-2.5e-5.max_steps-30000.save_steps-500/checkpoint-23765


1it [00:03,  3.87s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 79.41035265321366}




2it [00:07,  3.84s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 72.1159534769666}




3it [00:11,  3.95s/it]

 - Batch-level eval:
{'exact_match': 63.28125, 'f1': 78.26520151955101}




4it [00:16,  4.18s/it]

 - Batch-level eval:
{'exact_match': 62.5, 'f1': 76.67070489706187}




5it [00:20,  4.12s/it]

 - Batch-level eval:
{'exact_match': 58.59375, 'f1': 73.5851829306744}




33it [02:23,  4.34s/it]


Per-epoch eval results
{'exact_match': 60.06191950464396, 'f1': 73.63979497073646}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-2.5e-5.max_steps-30000.save_steps-500/checkpoint-28518


1it [00:03,  3.87s/it]

 - Batch-level eval:
{'exact_match': 60.9375, 'f1': 78.57472994259096}




2it [00:09,  4.33s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 74.86701199348028}




3it [00:13,  4.29s/it]

 - Batch-level eval:
{'exact_match': 63.28125, 'f1': 77.54397854968676}




4it [00:17,  4.07s/it]

 - Batch-level eval:
{'exact_match': 64.0625, 'f1': 76.27824197709214}




5it [00:20,  4.04s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 74.93922499606515}




33it [02:22,  4.33s/it]


Per-epoch eval results
{'exact_match': 61.96713503215051, 'f1': 74.49240693744254}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-2.5e-5.max_steps-30000.save_steps-500/checkpoint-30000


1it [00:03,  3.91s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 78.26471010132113}




2it [00:09,  4.36s/it]

 - Batch-level eval:
{'exact_match': 60.9375, 'f1': 73.13276100141678}




3it [00:13,  4.32s/it]

 - Batch-level eval:
{'exact_match': 63.28125, 'f1': 76.82174913859865}




4it [00:17,  4.14s/it]

 - Batch-level eval:
{'exact_match': 63.28125, 'f1': 75.74500785010801}




5it [00:21,  4.10s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 74.60576824813472}




33it [02:21,  4.30s/it]


Per-epoch eval results
{'exact_match': 60.30007144558228, 'f1': 73.5842736188262}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-2.5e-5.max_steps-30000.save_steps-500/checkpoint-19012


1it [00:03,  3.90s/it]

 - Batch-level eval:
{'exact_match': 64.84375, 'f1': 79.67467645202024}




2it [00:09,  4.35s/it]

 - Batch-level eval:
{'exact_match': 58.59375, 'f1': 71.08928674784471}




3it [00:13,  4.31s/it]

 - Batch-level eval:
{'exact_match': 66.40625, 'f1': 78.20907444915768}




4it [00:18,  4.51s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 76.09636033666713}




5it [00:22,  4.36s/it]

 - Batch-level eval:
{'exact_match': 57.03125, 'f1': 72.04471065623831}




33it [02:21,  4.29s/it]


Per-epoch eval results
{'exact_match': 60.65729935698976, 'f1': 73.36494553911567}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-2.5e-5.max_steps-30000.save_steps-500/checkpoint-9506


1it [00:03,  3.89s/it]

 - Batch-level eval:
{'exact_match': 66.40625, 'f1': 80.00419935966812}




2it [00:09,  4.35s/it]

 - Batch-level eval:
{'exact_match': 55.46875, 'f1': 69.63353313063561}




3it [00:13,  4.42s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 75.3549347030342}




4it [00:18,  4.51s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 75.65227933505078}




5it [00:22,  4.36s/it]

 - Batch-level eval:
{'exact_match': 53.90625, 'f1': 69.63563102059621}




33it [02:31,  4.58s/it]


Per-epoch eval results
{'exact_match': 58.27577994760657, 'f1': 70.91729448423729}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-2.5e-5.max_steps-30000.save_steps-500/checkpoint-14259


1it [00:03,  3.91s/it]

 - Batch-level eval:
{'exact_match': 64.0625, 'f1': 79.76238388347765}




2it [00:09,  4.36s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 73.05951328725435}




3it [00:13,  4.31s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 76.45351367221531}




4it [00:18,  4.44s/it]

 - Batch-level eval:
{'exact_match': 64.0625, 'f1': 78.59623813118884}




5it [00:22,  4.31s/it]

 - Batch-level eval:
{'exact_match': 55.46875, 'f1': 69.07832476476051}




33it [02:27,  4.47s/it]


Per-epoch eval results
{'exact_match': 60.10954989283162, 'f1': 72.88491438643344}



[{'model_ckp': '4753',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-2.5e-5.max_steps-30000.save_steps-500/checkpoint-4753',
  'exact_match': 51.03596094308168,
  'f1': 64.52186701205903},
 {'model_ckp': '23765',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-2.5e-5.max_steps-30000.save_steps-500/checkpoint-23765',
  'exact_match': 60.06191950464396,
  'f1': 73.63979497073646},
 {'model_ckp': '28518',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-2.5e-5.max_steps-30000.save_steps-500/checkpoint-28518',
  'exact_match': 61.96713503215051,
  'f1': 74.49240693744254},
 {'model_ckp': '30000',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.mlqa_hparams.bz-8.grad_acc-1.lr-2.5e-5.max_steps-30000.save_steps-500/checkpoint-30000',
  'exact_match': 60.30007144558228,
  'f1': 

# XORQA

In [6]:

T5_TOKENIZER_MODEL_DIR = '../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-4.max_steps-25000.save_steps-500'

tokenizer = MT5Tokenizer.from_pretrained(T5_TOKENIZER_MODEL_DIR, max_length=512)
tokenizer

PreTrainedTokenizer(name_or_path='../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-4.max_steps-25000.save_steps-500', vocab_size=250100, model_max_len=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'})

### Your dataset

In [7]:
XORQA_TEST_DATA_PATH = '../data/xorqa/datasets_format/test.json'

xorqa_xx = { 
     'test': json.load(open(XORQA_TEST_DATA_PATH))['data'],
}

In [8]:
# data_files = {"test": "/ist/ist-share/scads/korn/datasets/qa_datasset/MLQA/dev/dev-context-ar-question-vi.json"}
xorqa_test_dataset = xorqa_xx['test']


xorqa_test_features = list(map(convert_to_features, map(add_eos_to_examples, xorqa_test_dataset)))
len(xorqa_test_features)

xorqa_references = [ item['answers']['text'] for item in xorqa_xx['test'] ]
len(xorqa_references),  xorqa_references[:2]



(3108, [['Bloomsbury'], ['White Star Line']])

In [9]:
ALL_XORQA_FINETUNED_MODEL_DIR = glob.glob('../checkpoints/t5-base.adafactor.seq2seq.xorqa*')

ALL_XORQA_FINETUNED_MODEL_DIR

['../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500',
 '../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500',
 '../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-4.max_steps-25000.save_steps-500']

In [10]:
!mkdir -p xorqa_scores

In [None]:

for XORQA_FINETUNED_MODEL_DIRS in ALL_XORQA_FINETUNED_MODEL_DIR:
    
    EACH_CKP_XORQA_FINETUNED_MODEL_DIRS = glob.glob(f'{XORQA_FINETUNED_MODEL_DIRS}/checkpoint-*')
    print('='*50)
    print(f'\n\n\nModel EXP: {XORQA_FINETUNED_MODEL_DIRS}')
    model_exp_name = XORQA_FINETUNED_MODEL_DIRS.split('/')[-1]
    model_exp_name

    BATCH_SIZE = 128
    xorqa_en_scores = []
    for MODEL_DIR in EACH_CKP_XORQA_FINETUNED_MODEL_DIRS:
    # MODEL_DIR = '/ist/ist-share/scads/aires/'
        print(f'MODEL_DIR: {MODEL_DIR}')
        model_ckp = MODEL_DIR.split('-')[-1]

        model = MT5ForConditionalGeneration.from_pretrained(MODEL_DIR).to(device)

        data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer,
                                                padding=True,
                                                max_length=512,
                                               )
        data_loader = torch.utils.data.DataLoader(xorqa_test_features,
                                                  batch_size=BATCH_SIZE,
                                                  collate_fn=data_collator)

        predictions = []
        answers = []
        c = 0
        batched_gt_ans = []
        for i, batch in tqdm(enumerate(data_loader)):
            batch_size = len(batch['input_ids'])

            target_ans = tokenizer.batch_decode(batch['target_ids'], skip_special_tokens=True)
            batched_gt_ans.extend(target_ans)
            # print('batch size', len(batch['input_ids']))
            outs = model.generate(input_ids=batch['input_ids'].to(device), 
                                attention_mask=batch['attention_mask'].to(device),
                                max_length=64,
                                early_stopping=True,
                                num_beams=1,
                                decoder_start_token_id=0)

            answer = [tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
            if c < 5:
    #             print(f'\nRefs and answer \n')
    #             pprint(list(zip(answer[:BATCH_SIZE], references[:BATCH_SIZE])))
    #             print('\n')
                batch_level_eval_results = evaluate(xorqa_references[i*BATCH_SIZE:(i+1)*BATCH_SIZE],
                                                    answer[:BATCH_SIZE])
                print(' - Batch-level eval:')
                pprint(batch_level_eval_results, indent=4)
                print('\n')
    #             print(f'\n  refs: {xorqa_references[-6:]}')
    #             print(f'\nanswer: {answer[-6:]}')
                c+=1
            answers.extend(answer)
    #         print(f'answers: {answers}')
    #         break
        # break
        predictions = answers

        eval_results = evaluate(xorqa_references, predictions)
        print('Per-epoch eval results')
        print(eval_results)
        print('\n\n')
        xorqa_en_scores.append({
            'model_ckp': model_ckp,
            'model_dir': MODEL_DIR,
            **eval_results,
        })

    with open(f'./xorqa_scores/xorqa_en_scores.{model_exp_name}.json', 'w') as f:
        json.dump(xorqa_en_scores, f, indent=4)
    pprint(xorqa_en_scores, sort_dicts=False)
    print('\n')
    print('-'*60)
    print('\n\n')




Model EXP: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500
MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-12608


1it [00:07,  7.62s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 65.65028207811925}




2it [00:14,  7.27s/it]

 - Batch-level eval:
{'exact_match': 55.46875, 'f1': 65.71490575396827}




3it [00:22,  7.60s/it]

 - Batch-level eval:
{'exact_match': 60.9375, 'f1': 67.0723586309524}




4it [00:29,  7.56s/it]

 - Batch-level eval:
{'exact_match': 62.5, 'f1': 71.20208744349569}




5it [00:38,  7.89s/it]

 - Batch-level eval:
{'exact_match': 55.46875, 'f1': 65.5328807135902}




25it [03:27,  8.29s/it]


Per-epoch eval results
{'exact_match': 54.536679536679536, 'f1': 65.59756396653704}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-3152


1it [00:03,  3.82s/it]

 - Batch-level eval:
{'exact_match': 23.4375, 'f1': 28.66392721861472}




2it [00:09,  4.29s/it]

 - Batch-level eval:
{'exact_match': 28.125, 'f1': 34.33864442516136}




3it [00:15,  5.02s/it]

 - Batch-level eval:
{'exact_match': 32.8125, 'f1': 39.316983865914786}




4it [00:23,  5.74s/it]

 - Batch-level eval:
{'exact_match': 31.25, 'f1': 36.48619462512126}




5it [00:32,  6.70s/it]

 - Batch-level eval:
{'exact_match': 26.5625, 'f1': 32.15012286358794}




25it [03:09,  7.60s/it]


Per-epoch eval results
{'exact_match': 26.319176319176318, 'f1': 33.346876833831345}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-14184


1it [00:05,  5.96s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 65.55316428363304}




2it [00:12,  6.09s/it]

 - Batch-level eval:
{'exact_match': 52.34375, 'f1': 63.402777777777786}




3it [00:19,  6.47s/it]

 - Batch-level eval:
{'exact_match': 62.5, 'f1': 69.4614955357143}




4it [00:26,  6.57s/it]

 - Batch-level eval:
{'exact_match': 64.84375, 'f1': 73.48921803037956}




5it [00:32,  6.33s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 68.07655692214516}




25it [03:10,  7.63s/it]


Per-epoch eval results
{'exact_match': 55.85585585585586, 'f1': 66.5010978035878}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-9456


1it [00:06,  6.69s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 64.84918076714953}




2it [00:13,  6.60s/it]

 - Batch-level eval:
{'exact_match': 52.34375, 'f1': 63.68737599206349}




3it [00:21,  7.22s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 67.6117931547619}




4it [00:28,  7.08s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 69.49735987939313}




5it [00:35,  6.95s/it]

 - Batch-level eval:
{'exact_match': 54.6875, 'f1': 62.36632964941787}




25it [03:17,  7.91s/it]


Per-epoch eval results
{'exact_match': 52.76705276705277, 'f1': 63.63552389874214}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-22064


1it [00:06,  6.49s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 66.78797331141082}




2it [00:12,  6.46s/it]

 - Batch-level eval:
{'exact_match': 55.46875, 'f1': 65.53695436507937}




3it [00:19,  6.51s/it]

 - Batch-level eval:
{'exact_match': 64.0625, 'f1': 69.53657670454545}




4it [00:27,  7.10s/it]

 - Batch-level eval:
{'exact_match': 68.75, 'f1': 76.59843520209672}




5it [00:38,  8.03s/it]

 - Batch-level eval:
{'exact_match': 56.25, 'f1': 67.80167618416371}




25it [03:08,  7.56s/it]


Per-epoch eval results
{'exact_match': 58.88030888030888, 'f1': 69.63617473223172}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-23640


1it [00:05,  5.86s/it]

 - Batch-level eval:
{'exact_match': 62.5, 'f1': 67.16056946525697}




2it [00:12,  6.03s/it]

 - Batch-level eval:
{'exact_match': 57.8125, 'f1': 67.55518353174604}




3it [00:19,  6.36s/it]

 - Batch-level eval:
{'exact_match': 64.0625, 'f1': 70.05368979978354}




4it [00:26,  6.69s/it]

 - Batch-level eval:
{'exact_match': 67.1875, 'f1': 76.02518033196687}




5it [00:35,  7.28s/it]

 - Batch-level eval:
{'exact_match': 58.59375, 'f1': 69.09876321454314}




25it [03:14,  7.76s/it]


Per-epoch eval results
{'exact_match': 58.526383526383526, 'f1': 69.3286575674731}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-20488


1it [00:07,  7.27s/it]

 - Batch-level eval:
{'exact_match': 60.9375, 'f1': 66.67636616855367}




2it [00:13,  7.01s/it]

 - Batch-level eval:
{'exact_match': 56.25, 'f1': 67.9923115079365}




3it [00:22,  7.52s/it]

 - Batch-level eval:
{'exact_match': 63.28125, 'f1': 71.21124751984127}




4it [00:29,  7.43s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 73.54023038607943}




5it [00:38,  7.85s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 69.82339587761786}




25it [03:23,  8.13s/it]


Per-epoch eval results
{'exact_match': 57.78635778635778, 'f1': 69.08052942043315}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-17336


1it [00:07,  7.19s/it]

 - Batch-level eval:
{'exact_match': 62.5, 'f1': 68.31730118839495}




2it [00:13,  6.98s/it]

 - Batch-level eval:
{'exact_match': 54.6875, 'f1': 67.03559027777779}




3it [00:20,  7.03s/it]

 - Batch-level eval:
{'exact_match': 62.5, 'f1': 68.67466517857143}




4it [00:27,  7.07s/it]

 - Batch-level eval:
{'exact_match': 67.96875, 'f1': 77.21565652244307}




5it [00:36,  7.65s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 69.03905280609806}




25it [03:03,  7.35s/it]


Per-epoch eval results
{'exact_match': 57.4002574002574, 'f1': 68.43656333943095}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-11032


1it [00:08,  8.51s/it]

 - Batch-level eval:
{'exact_match': 55.46875, 'f1': 61.94089467732561}




2it [00:14,  7.89s/it]

 - Batch-level eval:
{'exact_match': 53.125, 'f1': 65.07130456349206}




3it [00:22,  7.93s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 67.53366815476191}




4it [00:30,  7.79s/it]

 - Batch-level eval:
{'exact_match': 57.8125, 'f1': 67.0538261653002}




5it [00:39,  8.08s/it]

 - Batch-level eval:
{'exact_match': 57.03125, 'f1': 66.52719889540839}




25it [03:23,  8.15s/it]


Per-epoch eval results
{'exact_match': 52.541827541827544, 'f1': 63.90148988160305}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-25000


1it [00:06,  6.49s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 66.13693164474415}




2it [00:12,  6.47s/it]

 - Batch-level eval:
{'exact_match': 57.03125, 'f1': 68.01091269841271}




3it [00:19,  6.59s/it]

 - Batch-level eval:
{'exact_match': 64.84375, 'f1': 72.07682291666667}




4it [00:27,  6.87s/it]

 - Batch-level eval:
{'exact_match': 70.3125, 'f1': 78.480537474824}




5it [00:35,  7.30s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 69.10300559131842}




25it [03:19,  7.99s/it]


Per-epoch eval results
{'exact_match': 58.397683397683394, 'f1': 69.62096902679285}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-6304


1it [00:08,  8.49s/it]

 - Batch-level eval:
{'exact_match': 54.6875, 'f1': 59.69684141289745}




2it [00:16,  8.47s/it]

 - Batch-level eval:
{'exact_match': 52.34375, 'f1': 62.03085213345708}




3it [00:25,  8.36s/it]

 - Batch-level eval:
{'exact_match': 54.6875, 'f1': 61.158607809762685}




4it [00:31,  7.85s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 67.82992354867355}




5it [00:39,  7.74s/it]

 - Batch-level eval:
{'exact_match': 47.65625, 'f1': 57.053867227580454}




25it [03:06,  7.47s/it]


Per-epoch eval results
{'exact_match': 48.80952380952381, 'f1': 59.56037437873313}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-7880


1it [00:06,  6.94s/it]

 - Batch-level eval:
{'exact_match': 57.03125, 'f1': 62.386073128260634}




2it [00:13,  6.78s/it]

 - Batch-level eval:
{'exact_match': 50.0, 'f1': 61.062534800646205}




3it [00:21,  7.32s/it]

 - Batch-level eval:
{'exact_match': 56.25, 'f1': 65.32545319264068}




4it [00:29,  7.32s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 67.92186092981328}




5it [00:35,  7.10s/it]

 - Batch-level eval:
{'exact_match': 51.5625, 'f1': 60.22617813426636}




25it [03:17,  7.90s/it]


Per-epoch eval results
{'exact_match': 51.22265122265122, 'f1': 62.04036000620208}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-25001


1it [00:06,  6.12s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 66.13693164474415}




2it [00:12,  6.19s/it]

 - Batch-level eval:
{'exact_match': 57.03125, 'f1': 68.37148962148962}




3it [00:19,  6.43s/it]

 - Batch-level eval:
{'exact_match': 64.84375, 'f1': 72.43739983974359}




4it [00:26,  6.69s/it]

 - Batch-level eval:
{'exact_match': 70.3125, 'f1': 78.480537474824}




5it [00:35,  7.33s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 69.10300559131842}




25it [03:18,  7.92s/it]


Per-epoch eval results
{'exact_match': 58.46203346203346, 'f1': 69.6691288058015}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-15760


1it [00:06,  6.44s/it]

 - Batch-level eval:
{'exact_match': 62.5, 'f1': 67.16645984224108}




2it [00:12,  6.42s/it]

 - Batch-level eval:
{'exact_match': 54.6875, 'f1': 66.32874503968254}




3it [00:20,  6.76s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 68.14453125}




4it [00:27,  7.01s/it]

 - Batch-level eval:
{'exact_match': 65.625, 'f1': 75.26253152244307}




5it [00:36,  7.53s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 68.20141099796102}




25it [03:20,  8.01s/it]


Per-epoch eval results
{'exact_match': 56.04890604890605, 'f1': 66.97855717677479}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-4728


1it [00:03,  3.85s/it]

 - Batch-level eval:
{'exact_match': 45.3125, 'f1': 52.39290223665224}




2it [00:09,  4.30s/it]

 - Batch-level eval:
{'exact_match': 41.40625, 'f1': 50.43319452286843}




3it [00:16,  5.24s/it]

 - Batch-level eval:
{'exact_match': 46.09375, 'f1': 54.23885886591477}




4it [00:24,  5.89s/it]

 - Batch-level eval:
{'exact_match': 52.34375, 'f1': 61.32212539165863}




5it [00:31,  6.35s/it]

 - Batch-level eval:
{'exact_match': 42.1875, 'f1': 49.21719756278579}




25it [03:11,  7.67s/it]


Per-epoch eval results
{'exact_match': 39.86486486486486, 'f1': 49.5154054701122}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-1576


1it [00:07,  7.26s/it]

 - Batch-level eval:
{'exact_match': 9.375, 'f1': 19.23611111111111}




2it [00:14,  7.24s/it]

 - Batch-level eval:
{'exact_match': 7.8125, 'f1': 18.140763954183072}




3it [00:25,  8.32s/it]

 - Batch-level eval:
{'exact_match': 4.6875, 'f1': 13.445274343711846}




4it [00:32,  7.96s/it]

 - Batch-level eval:
{'exact_match': 2.34375, 'f1': 8.421186656480774}




5it [00:41,  8.26s/it]

 - Batch-level eval:
{'exact_match': 6.25, 'f1': 13.256062513875014}




25it [03:13,  7.74s/it]


Per-epoch eval results
{'exact_match': 4.665379665379666, 'f1': 12.833010771270196}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-18912


1it [00:06,  6.20s/it]

 - Batch-level eval:
{'exact_match': 65.625, 'f1': 70.2133348422411}




2it [00:12,  6.26s/it]

 - Batch-level eval:
{'exact_match': 56.25, 'f1': 67.9842509920635}




3it [00:19,  6.45s/it]

 - Batch-level eval:
{'exact_match': 60.9375, 'f1': 68.8569568452381}




4it [00:26,  6.76s/it]

 - Batch-level eval:
{'exact_match': 66.40625, 'f1': 76.35070116530021}




5it [00:35,  7.36s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 68.52483682381141}




25it [03:13,  7.72s/it]


Per-epoch eval results
{'exact_match': 58.17245817245817, 'f1': 68.79156988217447}



[{'model_ckp': '12608',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-12608',
  'exact_match': 54.536679536679536,
  'f1': 65.59756396653704},
 {'model_ckp': '3152',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-3152',
  'exact_match': 26.319176319176318,
  'f1': 33.346876833831345},
 {'model_ckp': '14184',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-14184',
  'exact_match': 55.85585585585586,
  'f1': 66.5010978035878},
 {'model_ckp': '9456',
  'model_dir': '../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-1e-5.max_steps-25000.save_steps-500/checkpoint-9456',
  'exact_match': 52.76705276705277,
  'f1': 63.6

1it [00:07,  7.00s/it]

 - Batch-level eval:
{'exact_match': 63.28125, 'f1': 70.3140435952936}




2it [00:12,  6.68s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 71.21532547313797}




3it [00:21,  7.18s/it]

 - Batch-level eval:
{'exact_match': 66.40625, 'f1': 73.91896081349208}




4it [00:29,  7.40s/it]

 - Batch-level eval:
{'exact_match': 66.40625, 'f1': 76.90356450019279}




5it [00:38,  7.85s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 73.08413329075782}




25it [02:59,  7.19s/it]


Per-epoch eval results
{'exact_match': 58.42985842985843, 'f1': 70.1691519615059}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500/checkpoint-3152


1it [00:06,  6.30s/it]

 - Batch-level eval:
{'exact_match': 56.25, 'f1': 61.937355179542685}




2it [00:12,  6.37s/it]

 - Batch-level eval:
{'exact_match': 50.0, 'f1': 59.739878177378195}




3it [00:19,  6.54s/it]

 - Batch-level eval:
{'exact_match': 57.03125, 'f1': 65.23200757575758}




4it [00:26,  6.69s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 68.97917901134814}




5it [00:32,  6.53s/it]

 - Batch-level eval:
{'exact_match': 53.90625, 'f1': 63.14630143927017}




25it [03:04,  7.38s/it]


Per-epoch eval results
{'exact_match': 53.18532818532819, 'f1': 64.066593265394}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500/checkpoint-14184


1it [00:06,  6.47s/it]

 - Batch-level eval:
{'exact_match': 64.84375, 'f1': 71.85336365023865}




2it [00:12,  6.46s/it]

 - Batch-level eval:
{'exact_match': 63.28125, 'f1': 72.8822783119658}




3it [00:19,  6.57s/it]

 - Batch-level eval:
{'exact_match': 69.53125, 'f1': 75.17456501831502}




4it [00:26,  6.71s/it]

 - Batch-level eval:
{'exact_match': 69.53125, 'f1': 78.13816323479152}




5it [00:34,  7.09s/it]

 - Batch-level eval:
{'exact_match': 62.5, 'f1': 72.1158035321199}




25it [02:57,  7.12s/it]


Per-epoch eval results
{'exact_match': 59.84555984555985, 'f1': 70.40215089404819}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500/checkpoint-9456


1it [00:06,  6.71s/it]

 - Batch-level eval:
{'exact_match': 64.84375, 'f1': 70.80828676531802}




2it [00:12,  6.55s/it]

 - Batch-level eval:
{'exact_match': 63.28125, 'f1': 73.2657490079365}




3it [00:19,  6.63s/it]

 - Batch-level eval:
{'exact_match': 66.40625, 'f1': 72.77529761904763}




4it [00:26,  6.79s/it]

 - Batch-level eval:
{'exact_match': 65.625, 'f1': 73.81973273759263}




5it [00:34,  7.18s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 70.65161206123432}




25it [02:50,  6.83s/it]


Per-epoch eval results
{'exact_match': 58.59073359073359, 'f1': 69.65042148559827}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500/checkpoint-17336


1it [00:06,  6.14s/it]

 - Batch-level eval:
{'exact_match': 63.28125, 'f1': 70.97023636086136}




2it [00:14,  6.73s/it]

 - Batch-level eval:
{'exact_match': 64.0625, 'f1': 74.18030753968254}




3it [00:21,  6.86s/it]

 - Batch-level eval:
{'exact_match': 64.84375, 'f1': 71.3045634920635}




4it [00:29,  7.31s/it]

 - Batch-level eval:
{'exact_match': 66.40625, 'f1': 74.80088419438748}




5it [00:37,  7.48s/it]

 - Batch-level eval:
{'exact_match': 64.84375, 'f1': 74.65651001856217}




25it [03:09,  7.58s/it]


Per-epoch eval results
{'exact_match': 59.33075933075933, 'f1': 70.24665588137069}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500/checkpoint-11032


1it [00:06,  6.74s/it]

 - Batch-level eval:
{'exact_match': 64.84375, 'f1': 70.98407805644646}




2it [00:13,  6.64s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 71.60251640720391}




3it [00:20,  6.76s/it]

 - Batch-level eval:
{'exact_match': 67.96875, 'f1': 74.10964686355312}




4it [00:27,  6.90s/it]

 - Batch-level eval:
{'exact_match': 67.1875, 'f1': 75.53835334395997}




5it [00:36,  7.48s/it]

 - Batch-level eval:
{'exact_match': 59.375, 'f1': 70.63006527990787}




25it [03:19,  7.97s/it]


Per-epoch eval results
{'exact_match': 58.3011583011583, 'f1': 69.59416738318495}



MODEL_DIR: ../checkpoints/t5-base.adafactor.seq2seq.xorqa_hparams.bz-8.grad_acc-1.lr-5e-5.max_steps-20000.save_steps-500/checkpoint-6304


1it [00:06,  6.16s/it]

 - Batch-level eval:
{'exact_match': 64.0625, 'f1': 68.69831557331558}




2it [00:12,  6.23s/it]

 - Batch-level eval:
{'exact_match': 60.15625, 'f1': 70.3804181929182}




3it [00:20,  6.80s/it]

 - Batch-level eval:
{'exact_match': 64.84375, 'f1': 71.31616141381767}




4it [00:27,  6.87s/it]

 - Batch-level eval:
{'exact_match': 61.71875, 'f1': 71.15031561335215}




5it [00:36,  7.51s/it]

 - Batch-level eval:
{'exact_match': 58.59375, 'f1': 69.63546617840477}




13it [01:40,  7.52s/it]

# XQuAD

In [7]:
XQUAD_FINETUNED_MODEL_DIR = '../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-16.grad_acc-1.lr-1e-4.epochs-10'

T5_TOKENIZER_MODEL_DIR = XQUAD_FINETUNED_MODEL_DIR

tokenizer = MT5Tokenizer.from_pretrained(T5_TOKENIZER_MODEL_DIR)


### Your dataset

In [8]:
SQUAD_EN_DATA_DIR = '../data/xquad/en/'

squad_en = { 
     'train': json.load(open(os.path.join(SQUAD_EN_DATA_DIR, 'train-v1.1.json')))['data'],
     'validation': json.load(open(os.path.join(SQUAD_EN_DATA_DIR, 'dev-v1.1.json')))['data']
}


XSQUAD_DATA_DIR = '../data/xquad/datasets_format'

squad_xx = { 
#      'train': json.load(open(os.path.join(SQUAD_EN_DATA_DIR, 'train-v1.1.json')))['data'],
     'test': json.load(open(os.path.join(XSQUAD_DATA_DIR, 'test.json')))['data']
}

In [9]:
squad_xx['test'][0]

{'lang': 'en',
 'qid': '56beb4343aeaaa14008c925b',
 'question': 'How many points did the Panthers defense surrender?',
 'context': "The Panthers defense gave up just 308 points, ranking sixth in the league, while also leading the NFL in interceptions with 24 and boasting four Pro Bowl selections. Pro Bowl defensive tackle Kawann Short led the team in sacks with 11, while also forcing three fumbles and recovering two. Fellow lineman Mario Addison added 6½ sacks. The Panthers line also featured veteran defensive end Jared Allen, a 5-time pro bowler who was the NFL's active career sack leader with 136, along with defensive end Kony Ealy, who had 5 sacks in just 9 starts. Behind them, two of the Panthers three starting linebackers were also selected to play in the Pro Bowl: Thomas Davis and Luke Kuechly. Davis compiled 5½ sacks, four forced fumbles, and four interceptions, while Kuechly led the team in tackles (118) forced two fumbles, and intercepted four passes of his own. Carolina's sec

In [10]:
# process the examples in input and target text format and the eos token at the end 
def add_eos_to_examples(example):
    result = {}
    context = example['context']
    question = example['question']
    if 'answers' in example.keys() and type(example['answers']) == list and type(example['answers'][0]) == dict:
        answer = example['answers'][0]['text']
        result['answers'] = example['answers']
    elif 'answers' in example.keys()  and type(example['answers']) == list and type(example['answers'][0]) == str:
        answer = example['answers'][0]
        result['answers'] = example['answers']
    else:
        answer = example['answer']

    result['input_text'] =  'question: %s context: %s' % (question, context)
    result['answer'] = result
    result['target_text'] = '%s' % answer
    
    return result

def convert_to_features(example):

    encoding = {}
    
    input_encoding = tokenizer.encode_plus(example['input_text'],
#                                            pad_to_max_length=True,
                                           truncation=True,
                                           max_length=512,
                                           add_special_tokens=True)
    target_encoding = tokenizer.encode_plus(example['target_text'],
                                            pad_to_max_length=True,
                                              truncation=True,
                                            max_length=16, add_special_tokens=True)

    encoding['input_ids'] = input_encoding['input_ids'] 
#     encoding['attention_mask'] = input_encoding['attention_mask'] 
    encoding['target_ids'] = target_encoding['input_ids'] 
#     encoding['labels_attention_mask'] =  target_encoding['attention_mask'] 

    # print(f"type(encodings['input_ids']: {type(encodings['input_ids'])}")
    return encoding

In [11]:
# data_files = {"test": "/ist/ist-share/scads/korn/datasets/qa_datasset/MLQA/dev/dev-context-ar-question-vi.json"}
xquad_test_dataset = squad_xx['test']

xquad_test_features = list(map(convert_to_features, map(add_eos_to_examples, xquad_test_dataset)))
len(xquad_test_features)

xquad_references = [ list(map(lambda x: x['text'], item['answers'])) for item in xquad_test_dataset ]
len(xquad_references),  xquad_references[:2]



(14280, [['308'], ['136']])

In [12]:
def get_squad_answer_str(context, qas):
    context_qa_pairs = []
    for qa in qas:
        qid = qa['id']
        question = qa['question']
        answer = qa['answers'][0]['text']
        answers = list(map(lambda x: x['text'], qa['answers']))
        answer_start = qa['answers'][0]['answer_start']
        context_qa_pairs.append((qid, context, question, answer, answer_start, answers))
    return context_qa_pairs

In [13]:
squad_dataset = defaultdict(lambda : dict())
for split_name in ['train', 'validation']:
    for i, item in enumerate(squad_en[split_name]):
        paragraphs = item['paragraphs']
#         print('.' ,end='')
        for j, paragraph in enumerate(paragraphs):

            context = paragraph['context']
            context_qa_pairs = get_squad_answer_str(context=context, qas=paragraph['qas'])

            for context_qa_pair in context_qa_pairs:
                qid, context, question, answer, answer_start, answers = context_qa_pair

                qa_item = {
                    'qid': qid,
                    'question': question,
                    'context': context,
                    'answer': answer,
                    'answers': answers,
                    'answer_start': answer_start,
                }
                squad_dataset[split_name][qid] = qa_item
    
    print(f'Number of {split_name} examples: {len(squad_dataset[split_name]):,}')

Number of train examples: 87,599
Number of validation examples: 10,570


In [14]:
squad_dev = list(squad_dataset['validation'].values())
len(squad_dev)

10570

In [15]:
squad_dev[0]

{'qid': '56be4db0acb8001400a502ec',
 'question': 'Which NFL team represented the AFC at Super Bowl 50?',
 'context': 'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.',
 'answer': 'Denver Broncos',
 'answers': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'],
 'answ

In [16]:
squad_dev_features = list(map(convert_to_features, map(add_eos_to_examples, squad_dev)))

In [17]:
squad_dev_features[0].keys()

dict_keys(['input_ids', 'target_ids'])

In [18]:
references = [item['answers'] for item in squad_dev]
len(references)

10570

In [19]:
references[:5]

[['Denver Broncos', 'Denver Broncos', 'Denver Broncos'],
 ['Carolina Panthers', 'Carolina Panthers', 'Carolina Panthers'],
 ['Santa Clara, California',
  "Levi's Stadium",
  "Levi's Stadium in the San Francisco Bay Area at Santa Clara, California."],
 ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'],
 ['gold', 'gold', 'gold']]

In [20]:
dev_squad = Dataset.from_pandas(pd.DataFrame(data=squad_dev_features))
dev_squad

Dataset({
    features: ['input_ids', 'target_ids'],
    num_rows: 10570
})

### Evaluate for each ckp

In [21]:
# final ckp
# XQUAD_FINETUNED_MODEL_DIR = '../checkpoints/t5-base.adafactor.seq2seq.squad_hparams.bz-16.grad_acc-1.lr-1e-4.epochs-10'
# XQUAD_FINETUNED_MODEL_DIR = '../checkpoints/t5-base.adafactor.seq2seq.squad_hparams.bz-8.grad_acc-1.lr-1e-4.epochs-5'
ALL_SQUAD_FINETUNED_MODEL_DIRS = glob.glob('../checkpoints/squad/*')
xquad_test_features

['../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-16.grad_acc-1.lr-1e-3.max_steps-30000.save_steps-1500',
 '../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-16.grad_acc-1.lr-2e-4.max_steps-35000.save_steps-1000',
 '../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-16.grad_acc-1.lr-1e-4.epochs-10',
 '../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-16.grad_acc-1.lr-4e-4.max_steps-35000.save_steps-1000',
 '../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-4.grad_acc-1.lr-5e-5.epochs-5',
 '../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-4.grad_acc-1.lr-1e-4.epochs-5',
 '../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-16.grad_acc-1.lr-1e-4.max_steps-30000.save_steps-1500',
 '../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-8.grad_acc-1.lr-1e-4.epochs-5']

In [None]:
for ALL_XQUAD_FINETUNED_MODEL_DIR in ALL_SQUAD_FINETUNED_MODEL_DIRS:
    
    
    print('='*50)
    print(f'\n\n\nModel EXP: {ALL_XQUAD_FINETUNED_MODEL_DIR}  ')
    
    model_exp_name = ALL_XQUAD_FINETUNED_MODEL_DIR.split('/')[-1]
    model_exp_name
    XQUAD_FINETUNED_MODEL_DIRS = sorted(glob.glob(os.path.join(ALL_XQUAD_FINETUNED_MODEL_DIR, 'checkpoint-*')), key=lambda x: int(x.split('-')[-1]) )
    XQUAD_FINETUNED_MODEL_DIRS

    BATCH_SIZE = 128
    xquad_en_scores = []
    xquad_scores = []
    for MODEL_DIR in XQUAD_FINETUNED_MODEL_DIRS:
    # MODEL_DIR = '/ist/ist-share/scads/aires/'
        print(f'MODEL_DIR: {MODEL_DIR}')
        model_ckp = MODEL_DIR.split('-')[-1]
        device = 'cuda:3'
        model = MT5ForConditionalGeneration.from_pretrained(MODEL_DIR).to(device)
        model.eval()
        data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer,
                                                padding=True,
                                                max_length=512)
        
        ### FOR SQUAD
        data_loader = torch.utils.data.DataLoader(dev_squad,
                                                  batch_size=BATCH_SIZE,
                                                  collate_fn=data_collator)

        predictions = []
        answers = []
        c = 0
        batched_gt_ans = []
        for i, batch in tqdm(enumerate(data_loader), total=len(data_loader)):
            batch_size = len(batch['input_ids'])

            target_ans = tokenizer.batch_decode(batch['target_ids'], skip_special_tokens=True)
            batched_gt_ans.extend(target_ans)
            # print('batch size', len(batch['input_ids']))
            outs = model.generate(input_ids=batch['input_ids'].to(device), 
                                attention_mask=batch['attention_mask'].to(device),
                                max_length=64,
                                early_stopping=True,
                                num_beams=1,
                                decoder_start_token_id=0)

            answer = [tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
#             if c < 5:
    
#                 batch_level_eval_results = evaluate(references[i*BATCH_SIZE:(i+1)*BATCH_SIZE], answer[:BATCH_SIZE])
#                 print(' - Batch-level eval:')
#                 pprint(batch_level_eval_results, indent=4)
#                 print('\n')
   
#                 c+=1
            answers.extend(answer)
  
        predictions = answers

        ### FOR XSQUAD
#         data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer,
#                                                 padding=True,
#                                                 max_length=512)
#         xquad_data_loader = torch.utils.data.DataLoader(xquad_test_features,
#                                                   batch_size=BATCH_SIZE,
#                                                   collate_fn=data_collator)

#         xquad_predictions = []
#         xquad_answers = []
#         c = 0
#         batched_gt_ans = []
        
#         for i, batch in tqdm(enumerate(xquad_data_loader), total=len(xquad_data_loader)):
#             batch_size = len(batch['input_ids'])

#             target_ans = tokenizer.batch_decode(batch['target_ids'], skip_special_tokens=True)
#             batched_gt_ans.extend(target_ans)
#             # print('batch size', len(batch['input_ids']))
#             outs = model.generate(input_ids=batch['input_ids'].to(device), 
#                                 attention_mask=batch['attention_mask'].to(device),
#                                 max_length=64,
#                                 early_stopping=True,
#                                 num_beams=1,
#                                 decoder_start_token_id=0)

#             xquad_answer = [tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
# #             if c < 5:
# #     #             print(f'\nRefs and answer \n')
# #     #             pprint(list(zip(answer[:BATCH_SIZE], references[:BATCH_SIZE])))
# #     #             print('\n')
# #                 batch_level_eval_results = evaluate(xquad_references[i*BATCH_SIZE:(i+1)*BATCH_SIZE],
# #                                                     xquad_answer[:BATCH_SIZE])
# #                 print(' - Batch-level eval:')
# #                 pprint(batch_level_eval_results, indent=4)
# #                 print('\n')
# #     #             print(f'\n  refs: {xorqa_references[-6:]}')
# #     #             print(f'\nanswer: {answer[-6:]}')
# #                 c+=1
#             xquad_answers.extend(xquad_answer)
     
#         xquad_predictions = xquad_answers
        
        
        eval_results = evaluate(references, predictions)
#         xquad_eval_results = evaluate(xquad_references,xquad_predictions)
        print('Per-epoch eval results')
        print(eval_results)
#         print('Per-epoch eval results (XQUAD)')
#         print(xquad_eval_results)
#         print('\n\n')
        xquad_en_scores.append({
            'model_ckp': model_ckp,
            'model_dir': MODEL_DIR,
            **eval_results,
        })
#         xquad_scores.append({
#             'model_ckp': model_ckp,
#             'model_dir': MODEL_DIR,
#             **xquad_eval_results,
#         })

    with open(f'./squad/squad_en_scores.{model_exp_name}.json', 'w') as f:
        json.dump(xquad_en_scores, f, indent=4)
#     with open(f'./xquad/squad_xx_scores.{model_exp_name}.json', 'w') as f:
#         json.dump(xquad_scores, f, indent=4)    
#     pprint(xquad_en_scores, sort_dicts=False)
    print('\n')
    print('-'*60)
    print('\n\n')




Model EXP: ../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-16.grad_acc-1.lr-1e-3.max_steps-30000.save_steps-1500  
MODEL_DIR: ../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-16.grad_acc-1.lr-1e-3.max_steps-30000.save_steps-1500/checkpoint-1500


100%|██████████| 83/83 [04:40<00:00,  3.38s/it]


Per-epoch eval results
{'exact_match': 41.68401135288553, 'f1': 52.552103738141}
MODEL_DIR: ../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-16.grad_acc-1.lr-1e-3.max_steps-30000.save_steps-1500/checkpoint-3000


100%|██████████| 83/83 [04:06<00:00,  2.97s/it]


Per-epoch eval results
{'exact_match': 52.78145695364238, 'f1': 64.78896116378334}
MODEL_DIR: ../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-16.grad_acc-1.lr-1e-3.max_steps-30000.save_steps-1500/checkpoint-4500


100%|██████████| 83/83 [04:36<00:00,  3.33s/it]


Per-epoch eval results
{'exact_match': 51.59886471144749, 'f1': 63.5315791925234}
MODEL_DIR: ../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-16.grad_acc-1.lr-1e-3.max_steps-30000.save_steps-1500/checkpoint-6000


100%|██████████| 83/83 [04:34<00:00,  3.31s/it]


Per-epoch eval results
{'exact_match': 53.58561967833491, 'f1': 66.34840055523262}
MODEL_DIR: ../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-16.grad_acc-1.lr-1e-3.max_steps-30000.save_steps-1500/checkpoint-7500


100%|██████████| 83/83 [04:00<00:00,  2.90s/it]


Per-epoch eval results
{'exact_match': 55.24124881740776, 'f1': 67.53465170176963}
MODEL_DIR: ../checkpoints/squad/t5-base.adafactor.seq2seq.squad_hparams.bz-16.grad_acc-1.lr-1e-3.max_steps-30000.save_steps-1500/checkpoint-9000


 61%|██████▏   | 51/83 [02:18<01:45,  3.31s/it]

## Error analysis

In [28]:
ref_ans = list(map(lambda x: x[0], references))
ref_tokens = list( map(lambda x: tokenizer.tokenize(x[0]), references))

ref_ntokens = list(map(len, ref_tokens))
len(ref_ntokens)

10570

In [29]:
# ref_tokens[125]

['▁New',
 '▁Orleans',
 "'",
 '▁Mercedes',
 '-',
 'Benz',
 '▁Super',
 'dome',
 ',',
 '▁Miami',
 "'",
 's',
 '▁Sun',
 '▁Life',
 '▁Stadium',
 ',',
 '▁and',
 '▁the',
 '▁San',
 '▁Francisco',
 '▁Bay',
 '▁Area',
 "'",
 's',
 '▁Levi',
 "'",
 's',
 '▁Stadium']

In [30]:
df = pd.DataFrame.from_dict({ 'ntokens': ref_ntokens, 'answer': ref_ans})
df.describe()

Unnamed: 0,ntokens
count,10570.0
mean,4.863103
std,4.684753
min,1.0
25%,2.0
50%,3.0
75%,6.0
max,52.0


In [31]:
df[df['ntokens']> 6]

Unnamed: 0,ntokens,answer
60,11,Newton was limited by Denver's defense
117,8,New Orleans' Mercedes-Benz Superdome
119,10,San Francisco Bay Area's Levi's Stadium
125,28,"New Orleans' Mercedes-Benz Superdome, Miami's ..."
262,8,a plantar fasciitis injury
...,...,...
10480,9,"Newton's Universal Gravitation Constant,"
10524,7,approximately 1015 kelvins
10543,7,rotational equivalent for position
10545,7,Newton's Second Law of Motion


In [32]:
answers[-10:]

['statistical mechanics',
 'nonconservative forces',
 'nonconservative forces',
 'Second law',
 'entropy increases',
 'pound-force',
 'kilopond',
 'kilogram-force',
 'metric slug',
 'metric counterpart']

In [33]:
pred_ans = list(map(lambda x: x, answers))
pred_tokens = list( map(lambda x: tokenizer.tokenize(x), answers))

pred_ntokens = list(map(len, pred_tokens))
len(pred_ntokens)
pred_df = pd.DataFrame.from_dict({ 'ntokens': pred_ntokens, 'answer': pred_ans})
pred_df.describe()

Unnamed: 0,ntokens
count,10570.0
mean,5.012015
std,4.825947
min,1.0
25%,2.0
50%,4.0
75%,6.0
max,29.0


In [34]:
merged_df = df
merged_df['pred_ans'] = pred_df['answer']
merged_df['pred_ntokens'] = pred_df['ntokens']

In [61]:

def get_hash(text):
    import hashlib
    return hashlib.sha256(text.encode()).hexdigest()[:6]
get_hash('sssdadasda'), \
get_hash('1'), \
get_hash('1')

('e1dc34', '6b86b2', '6b86b2')

In [66]:
counter = 0
wrong_counter = 0
ctx_hash_counter = Counter()
em_count_per_ctx = Counter()
for k, item in merged_df.iterrows():
#     squad_dev[k]
  
    squad_item = squad_dev[k]
    ctx = squad_item['context']
    ctx_hash = get_hash(ctx)
    ctx_hash_counter[ctx_hash] += 1
    print(f'\nidx: {k}')
    print(f'ctx_hash: {ctx_hash}')
#     break
    gt, pred = item['answer'], item['pred_ans']
    if gt == pred:
        counter += 1
        em_count_per_ctx[ctx_hash] += 1
        print(f'✅ ({counter:4,}/{len(merged_df):,})')
    else:
        wrong_counter+=1
        print(f'❌ ({wrong_counter:4,}/{len(merged_df):,})')
    print('  gt:', item['answer'])
    print('pref:', item['pred_ans'])
      


idx: 0
ctx_hash: d1f67c
❌ (   1/10,570)
  gt: Denver Broncos
pref: Carolina Panthers

idx: 1
ctx_hash: d1f67c
✅ (   1/10,570)
  gt: Carolina Panthers
pref: Carolina Panthers

idx: 2
ctx_hash: d1f67c
❌ (   2/10,570)
  gt: Santa Clara, California
pref: Levi's Stadium in the San Francisco Bay Area at Santa Clara, California

idx: 3
ctx_hash: d1f67c
❌ (   3/10,570)
  gt: Denver Broncos
pref: Carolina Panthers

idx: 4
ctx_hash: d1f67c
❌ (   4/10,570)
  gt: gold
pref: golden

idx: 5
ctx_hash: d1f67c
❌ (   5/10,570)
  gt: "golden anniversary"
pref: golden anniversary

idx: 6
ctx_hash: d1f67c
✅ (   2/10,570)
  gt: February 7, 2016
pref: February 7, 2016

idx: 7
ctx_hash: d1f67c
✅ (   3/10,570)
  gt: American Football Conference
pref: American Football Conference

idx: 8
ctx_hash: d1f67c
❌ (   6/10,570)
  gt: "golden anniversary"
pref: golden anniversary

idx: 9
ctx_hash: d1f67c
✅ (   4/10,570)
  gt: American Football Conference
pref: American Football Conference

idx: 10
ctx_hash: d1f67c
✅ ( 

In [122]:
em_percent_per_ctx = dict(sorted(dict({
    hash_ctx: {
        'em_percentage': round(raw_count / ctx_hash_counter[hash_ctx] * 100, 2),
        'raw_count': raw_count,
        'num_questions': ctx_hash_counter[hash_ctx],
    } \
    for hash_ctx, raw_count in em_count_per_ctx.items()}).items(),                  
    key=lambda x: x[1]['em_percentage'], reverse=True))

In [144]:
N_contexts =len(em_percent_per_ctx)
N_dev = len(merged_df)

In [146]:
print(f'Total number of contexts: {N_contexts:7,}')
print(f'Total number of questions: {N_dev:,}')
print(f'Average number of questions per context: {N_dev / N_contexts:.3f}')

Total number of contexts:   1,985
Total number of questions: 10,570
Average number of questions per context: 5.325


In [124]:
pprint(em_percent_per_ctx,sort_dicts=False, width=200)

{'4b7cad': {'em_percentage': 100.0, 'raw_count': 14, 'num_questions': 14},
 '6e7b59': {'em_percentage': 100.0, 'raw_count': 5, 'num_questions': 5},
 '1b6b54': {'em_percentage': 100.0, 'raw_count': 5, 'num_questions': 5},
 'de308e': {'em_percentage': 100.0, 'raw_count': 5, 'num_questions': 5},
 'f59086': {'em_percentage': 100.0, 'raw_count': 5, 'num_questions': 5},
 'c4689e': {'em_percentage': 100.0, 'raw_count': 5, 'num_questions': 5},
 '6cc341': {'em_percentage': 100.0, 'raw_count': 2, 'num_questions': 2},
 '277496': {'em_percentage': 100.0, 'raw_count': 1, 'num_questions': 1},
 '752ba5': {'em_percentage': 100.0, 'raw_count': 2, 'num_questions': 2},
 '8d0f13': {'em_percentage': 100.0, 'raw_count': 3, 'num_questions': 3},
 '00490c': {'em_percentage': 100.0, 'raw_count': 3, 'num_questions': 3},
 'd352fa': {'em_percentage': 100.0, 'raw_count': 3, 'num_questions': 3},
 '2eb036': {'em_percentage': 100.0, 'raw_count': 3, 'num_questions': 3},
 'e77cc2': {'em_percentage': 100.0, 'raw_count': 

Measure Correct EM versus number of questions per context



In [172]:
nquestion_vs_em_per_stats = defaultdict(lambda : [])
nquestion_counter = Counter()

for ctx_hash, data in em_percent_per_ctx.items():
    em_percentage = data['em_percentage']
    raw_count = data['raw_count']
    num_questions = data['num_questions']
    nquestion_counter[num_questions] += 1
    nquestion_vs_em_per_stats[num_questions].append(em_percentage)

In [173]:
nquestion_vs_em_per_summary = dict()
for nquestion, em_per_list in nquestion_vs_em_per_stats.items():
    nquestion_vs_em_per_summary[nquestion] = {
        'mean': round(np.mean(em_per_list), 4),
        'sd': round( np.std(em_per_list),4),
        'nquestion_correct': math.floor(np.mean(em_per_list) / 100 * nquestion * nquestion_counter[nquestion]),
        'nquestion_total': nquestion * nquestion_counter[nquestion],
        'nquestion_percentage': round((nquestion * nquestion_counter[nquestion]) / N_dev * 100 , 2),
    }
nquestion_vs_em_per_summary = dict(sorted(nquestion_vs_em_per_summary.items(), key=lambda x: x[0]))

In [174]:
sum(map(lambda x: x[1]['nquestion_percentage'], nquestion_vs_em_per_summary.items()))

97.02000000000001

In [175]:
pprint(nquestion_vs_em_per_summary, sort_dicts=False)

{1: {'mean': 100.0,
     'sd': 0.0,
     'nquestion_correct': 7,
     'nquestion_total': 7,
     'nquestion_percentage': 0.07},
 2: {'mean': 73.3333,
     'sd': 24.9444,
     'nquestion_correct': 22,
     'nquestion_total': 30,
     'nquestion_percentage': 0.28},
 3: {'mean': 65.1994,
     'sd': 25.9955,
     'nquestion_correct': 311,
     'nquestion_total': 477,
     'nquestion_percentage': 4.51},
 4: {'mean': 61.6348,
     'sd': 24.7375,
     'nquestion_correct': 1033,
     'nquestion_total': 1676,
     'nquestion_percentage': 15.86},
 5: {'mean': 60.6511,
     'sd': 23.068,
     'nquestion_correct': 3633,
     'nquestion_total': 5990,
     'nquestion_percentage': 56.67},
 6: {'mean': 44.4417,
     'sd': 18.4259,
     'nquestion_correct': 15,
     'nquestion_total': 36,
     'nquestion_percentage': 0.34},
 7: {'mean': 50.001,
     'sd': 24.1154,
     'nquestion_correct': 35,
     'nquestion_total': 70,
     'nquestion_percentage': 0.66},
 8: {'mean': 52.9412,
     'sd': 19.8937,
    

In [65]:
ctx_hash_counter.most_common(30)

[('d1f67c', 30),
 ('24a811', 26),
 ('52fbc2', 25),
 ('0fc4e5', 25),
 ('415c63', 23),
 ('35981b', 21),
 ('327e5a', 20),
 ('3d7f28', 20),
 ('69092d', 20),
 ('bc7c32', 19),
 ('35a736', 19),
 ('190806', 18),
 ('be9904', 17),
 ('ca7fe1', 17),
 ('093de0', 17),
 ('61f255', 17),
 ('549c34', 16),
 ('92d214', 16),
 ('24b4d3', 16),
 ('e3e756', 15),
 ('9bd616', 15),
 ('c182de', 15),
 ('cae68c', 15),
 ('eae43b', 15),
 ('919476', 15),
 ('e1e68b', 15),
 ('c7c7ec', 15),
 ('24fea2', 15),
 ('bb8a76', 15),
 ('6490f5', 15)]

In [37]:
evaluate(references, answers)

{'exact_match': 66.21570482497634, 'f1': 80.62264427404348}

In [143]:
evaluate(list(map(lambda x: [x], batched_gt_ans)), answers)

{'exact_match': 64.74929044465468, 'f1': 79.34378028240067}

In [145]:
evaluate(references, batched_gt_ans)

{'exact_match': 96.13055818353831, 'f1': 99.17341278932695}

In [None]:
references[0:5]

In [None]:
len(predictions)

In [None]:
# predictions[:10]

In [None]:
# references[:10]

In [None]:
# evaluate(references, predictions)

In [None]:
print('\tPrediction \t|\t Groundtruth')
list(zip(predictions[:10], references[:10]))