# Process xquad dataset for Seq2Seq QA training

In [1]:
import os, sys
import json
import glob
import hashlib
from collections import defaultdict, Counter
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit, GroupShuffleSplit

In [2]:
len('5733be284776f41900661182')

24

In [3]:
def get_hash(text):
    import hashlib
    return hashlib.sha256(text.encode()).hexdigest()[:24]
get_hash('a')

def get_qa_hash(question, context):
    return get_hash('question:' + question + ' ' + 'context:' + context)

In [4]:
sys.path.append('../../exploratory/online_prompt/notebooks/')
from exploring_sentence_level import (
    load_model,
    mine_prompt_gt,  
    segment_sentence,
    run_online_prompt_mining
)

In [5]:
MLQA_BASE_DIR = '../data/mlqa/MLQA_V1/'

mlqa_xx = {}
MLQA_LANGS = ['en', 'ar', 'de', 'es', 'hi', 'vi', 'zh']
for lang in MLQA_LANGS:
    mlqa_xx[f'{lang}_val'] = json.load(open(os.path.join(MLQA_BASE_DIR, 'dev', f'dev-context-en-question-{lang}.json'), 'r'))['data'],
    mlqa_xx[f'{lang}_test'] = json.load(open(os.path.join(MLQA_BASE_DIR, 'test', f'test-context-en-question-{lang}.json'), 'r'))['data'],


In [6]:
len(mlqa_xx['ar_test'][0])

2389

In [7]:
def get_squad_answer_str(context, qas):
    context_qa_pairs = []
    for qa in qas:
        question = qa['question']
        answer = qa['answers'][0]['text']
        answer_start = qa['answers'][0]['answer_start']
        answers = qa['answers']
        context_qa_pairs.append((context, question, answer, answer_start, answers))
    return context_qa_pairs

In [8]:
mlqa_xx_dataset = defaultdict(lambda: {'val':[], 'test': []})

doc_title_counter = defaultdict(lambda: list())
paragraph_in_doc_counter = defaultdict(lambda: Counter())

for lang in MLQA_LANGS:

    for split_name in ['val', 'test']:
        for i, item in enumerate(mlqa_xx[f'{lang}_{split_name}'][0]):
            document_title = item['title']
            paragraphs = item['paragraphs']
            doc_title_counter[lang].append(document_title)
            for j, paragraph in enumerate(paragraphs):

                context = paragraph['context']
                context_qa_pairs = get_squad_answer_str(context=context, qas=paragraph['qas'])

                for context_qa_pair in context_qa_pairs:
                    context, question, answer, answer_start, answers = context_qa_pair
                    _context_qa_pair = (context, question, answer, answer_start)
                    gt_sentence = mine_prompt_gt(_context_qa_pair)
                    paragraph_in_doc_counter[lang][document_title] += 1
                    qa_item = {
                         'question': question,
                            'context': context,
                            'document_title': document_title,
                            'segmented_context': segment_sentence(context),
                            'answer': answer,
                            'answer_start': answer_start,
                            'answers': answers,
                            'gt_sentence': gt_sentence,
                    }
                    mlqa_xx_dataset[lang][split_name].append(qa_item)


In [9]:
merged_doc_titles = set()
for lang in MLQA_LANGS:
    doc_titles = doc_title_counter[lang]
    print(f'lang: {lang}, n_docs = {len(doc_titles)}')
    merged_doc_titles.update(doc_titles)
print(f'\nTotal unique doc {len(merged_doc_titles):,}')

lang: en, n_docs = 5530
lang: ar, n_docs = 2627
lang: de, n_docs = 2806
lang: es, n_docs = 2762
lang: hi, n_docs = 2255
lang: vi, n_docs = 2682
lang: zh, n_docs = 2673

Total unique doc 5,530


In [10]:
for lang in MLQA_LANGS:
    n_paragraphs = sum(paragraph_in_doc_counter[lang].values())
    print(f'lang: {lang}, n_docs = {n_paragraphs}')
    
    

lang: en, n_docs = 12738
lang: ar, n_docs = 5852
lang: de, n_docs = 5029
lang: es, n_docs = 5753
lang: hi, n_docs = 5425
lang: vi, n_docs = 6006
lang: zh, n_docs = 5641


In [11]:
for lang in MLQA_LANGS:

    for split_name in ['val', 'test']:

        n_questions = len(mlqa_xx_dataset[lang][split_name])
        print(f'lang= {lang}, split_name={split_name}')
        print(f'n_questions: {n_questions}')

lang= en, split_name=val
n_questions: 1148
lang= en, split_name=test
n_questions: 11590
lang= ar, split_name=val
n_questions: 517
lang= ar, split_name=test
n_questions: 5335
lang= de, split_name=val
n_questions: 512
lang= de, split_name=test
n_questions: 4517
lang= es, split_name=val
n_questions: 500
lang= es, split_name=test
n_questions: 5253
lang= hi, split_name=val
n_questions: 507
lang= hi, split_name=test
n_questions: 4918
lang= vi, split_name=val
n_questions: 511
lang= vi, split_name=test
n_questions: 5495
lang= zh, split_name=val
n_questions: 504
lang= zh, split_name=test
n_questions: 5137


### Data splitting strategy



Test set => training set and validation set

Val set = test set

------

A Disjoint set of documents for training, validation





In [12]:
mlqa_train_val_set = []
mlqa_train_val_groups = []
for lang in MLQA_LANGS:

    for split_name in ['test']:
        qas = mlqa_xx_dataset[lang][split_name]
        mlqa_train_val_set.extend(qas)
        mlqa_train_val_groups.extend(list(map(lambda x: x['document_title'], qas)))
        print(f'lang: {lang}, {len(qas):,}')
print(f'\nTotal number of QA example in train/val set: {len(mlqa_train_val_set):,}')
len(mlqa_train_val_groups)

lang: en, 11,590
lang: ar, 5,335
lang: de, 4,517
lang: es, 5,253
lang: hi, 4,918
lang: vi, 5,495
lang: zh, 5,137

Total number of QA example in train/val set: 42,245


42245

In [13]:
mlqa_test_set = []
mlqa_test_groups = []
for lang in MLQA_LANGS:

    for split_name in ['val']:
        qas = mlqa_xx_dataset[lang][split_name]
        mlqa_test_set.extend(qas)
        mlqa_test_groups.extend(list(map(lambda x: x['document_title'], qas)))
        print(f'lang: {lang}, {len(qas):,}')
print(f'\nTotal number of QA example in test set: {len(mlqa_test_set):,}')
len(mlqa_test_groups)

lang: en, 1,148
lang: ar, 517
lang: de, 512
lang: es, 500
lang: hi, 507
lang: vi, 511
lang: zh, 504

Total number of QA example in test set: 4,199


4199

In [14]:
set(mlqa_train_val_groups).intersection(set(mlqa_test_groups)), \
len(list((set(mlqa_train_val_groups).union(set(mlqa_test_groups))))), \
len(doc_title_counter)

(set(), 5530, 7)

In [15]:
mlqa_train_val_set[-1]

{'question': '2009年未判决案件有多少个？',
 'context': 'In 1999, 8,400 applications were allocated to be heard. In 2003, 27,200 cases were filed and the number of pending applications rose to approximately 65,000. In 2005, the Court opened 45,500 case files. In 2009, 57,200 applications were allocated, with the number of pending applications rose to 119,300. At the time, more than 90% of them were declared to be inadmissible, and the majority of cases decided—around 60% of the decisions by the Court—related to what is termed repetitive cases: where the Court has already delivered judgment finding a violation of the European Convention on Human Rights or where well established case law exists on a similar case.',
 'document_title': 'European Court of Human Rights',
 'segmented_context': ['In 1999, 8,400 applications were allocated to be heard.',
  'In 2003, 27,200 cases were filed and the number of pending applications rose to approximately 65,000.',
  'In 2005, the Court opened 45,500 case files.

In [16]:
len(mlqa_train_val_groups), mlqa_train_val_groups[:40]

(42245,
 ['Area 51',
  'Area 51',
  'Area 51',
  'Area 51',
  'Area 51',
  'Area 51',
  'Area 51',
  'Area 51',
  'Area 51',
  'Area 51',
  'Area 51',
  'Area 51',
  'Area 51',
  'Area 51',
  'Quality of service',
  'Quality of service',
  'Quality of service',
  'Quality of service',
  'Quality of service',
  'Quality of service',
  'Quality of service',
  'Quality of service',
  'Quality of service',
  'Quality of service',
  'Quality of service',
  'Cell culture',
  'Cell culture',
  'Cell culture',
  'Cell culture',
  'Mughal architecture',
  'TRIPS Agreement',
  'TRIPS Agreement',
  'TRIPS Agreement',
  'TRIPS Agreement',
  'Triple H',
  'Triple H',
  'Triple H',
  'Triple H',
  'Anne Boleyn',
  'Anne Boleyn'])

In [17]:
SEED = 93

In [18]:

from math import isclose
from tqdm import tqdm
c =0
for random_seed in tqdm([SEED]):
    X = mlqa_train_val_set
    groups = mlqa_train_val_groups
    gss = GroupShuffleSplit(n_splits=1, train_size=0.90, random_state=random_seed)
    gss.get_n_splits(X)

    mlqa_group_split = {}
    for train_index, test_index in gss.split(X, groups=groups):
#         print("TRAIN:", len(train_index), "TEST:", len(test_index))

#         print(train_index[0:20], train_index[-20:], )

        mlqa_group_split['train'], mlqa_group_split['validation'] = [ X[idx] for idx in train_index ], [ X[idx] for idx in test_index ]


    per_split_doc_title_counter = defaultdict(lambda: set())
    per_split_paragraph_in_doc_counter = defaultdict(lambda: Counter())


    for split_name in ['train', 'validation']:

        for qa in mlqa_group_split[split_name]:
            document_title = qa['document_title']
            per_split_doc_title_counter[split_name].add(document_title)
            per_split_paragraph_in_doc_counter[split_name][document_title] += 1

    N_doc_train = len(per_split_doc_title_counter['train'])
    N_doc_validation = len(per_split_doc_title_counter['validation'])

    N_paragraphs_train = sum(per_split_paragraph_in_doc_counter['train'].values())
    N_paragraphs_validation = sum(per_split_paragraph_in_doc_counter['validation'].values())

    total_docs, total_paragraphs = N_doc_train + N_doc_validation, N_paragraphs_train + N_paragraphs_validation


    
    n_paragraph_train_ratio = (N_paragraphs_train/total_paragraphs)*100
    n_paragraph_train_validation = (N_paragraphs_validation/total_paragraphs)*100
       

    if isclose(n_paragraph_train_ratio, 90.0 , abs_tol=0.0015) and n_paragraph_train_ratio >=90.0 and c <=10:

        print(f'\nrandom_seed: {random_seed}')

        print(f'N_doc_train : {N_doc_train:,} ({(N_doc_train/total_docs)*100:.4f})')
        print(f'N_doc_validation : {N_doc_validation:,} ({(N_doc_validation/total_docs)*100:.4f})')
        print(f'\nN_paragraphs_train : {N_paragraphs_train:,} ({n_paragraph_train_ratio:.4f})')
        print(f'N_paragraphs_validation : {N_paragraphs_validation:,} ({n_paragraph_train_validation:.4f})')

        print(f'\ntotal_paragraphs: {total_paragraphs:,}')
        print(f'total_docs: {total_docs:,}')
#         break
        c+=1

100%|████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.12it/s]


random_seed: 93
N_doc_train : 4,509 (90.0000)
N_doc_validation : 501 (10.0000)

N_paragraphs_train : 38,021 (90.0012)
N_paragraphs_validation : 4,224 (9.9988)

total_paragraphs: 42,245
total_docs: 5,010





### Convert into Json format (for Dataset.loader()

Example
```
5733be284776f41900661182

University_of_Notre_Dame

Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.

To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?

{ "text": [ "Saint Bernadette Soubirous" ], "answer_start": [ 515 ] }


```

In [20]:
def convert_qa_item_to_datasets_format(qa_item):
    title = qa_item['document_title']
    question, context, answers = qa_item['question'], qa_item['context'], qa_item['answers']
    return {
        'id': get_qa_hash(question, context),
        'title': title,
        'context': context,
        'question': question,
        'answers': { 'text': [_answer['text'] for _answer in answers],
                     'answer_start': [_answer['answer_start'] for _answer in answers]
                   }
    }

In [21]:
mlqa_group_split['test'] = mlqa_test_set

mlqa_processed_dataset = {
   split_name: list(map(convert_qa_item_to_datasets_format, mlqa_group_split[split_name])) \
    for split_name in ['train','validation', 'test']
}

In [22]:
for split_name in ['train','validation', 'test']:
    os.makedirs('../data/mlqa/datasets_format/', exist_ok=True)
    with open(f'../data/mlqa/datasets_format/{split_name}.json', 'w', encoding='utf-8') as f:
        json.dump({ 'meta': {
            'train_split_ratio': 0.9,
            'val_split_ratio': 0.1,
            'split_type': 'GroupShuffleSplit',
            'split_seed': 93,
            'date': '02/08/2022',
        },
            'data': mlqa_processed_dataset[split_name] }, f, ensure_ascii=False, indent=4)
