# Process XORQA dataset for Seq2Seq QA training

In [3]:
import os, sys, re
import json
import glob
import hashlib
from collections import defaultdict, Counter
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit, GroupShuffleSplit
from pprint import pprint

In [4]:
len('5733be284776f41900661182')

24

In [5]:
def get_hash(text):
    import hashlib
    return hashlib.sha256(text.encode()).hexdigest()[:24]
get_hash('a')

def get_qa_hash(question, context):
    return get_hash('question:' + question + ' ' + 'context:' + context)

In [6]:
sys.path.append('../../exploratory/online_prompt/notebooks/')
from exploring_sentence_level import (
    load_model,
    mine_prompt_gt,  
    segment_sentence,
    run_online_prompt_mining
)

In [7]:
XORQA_BASE_DIR = '../data/xorqa/en/tydi_xor_gp/'
xorqa_xx = {
    'train': json.load(open(os.path.join(XORQA_BASE_DIR, 'gp_squad_train_data.json'), 'r'))['data'],
      'val': json.load(open(os.path.join(XORQA_BASE_DIR, 'gp_squad_dev_data.json'), 'r'))['data'],   
}

In [8]:
xorqa_xx['train'][-1]

{'title': 'title:Vamsy_parentSection:Introduction_sectionName:Career._sectionIndex:2',
 'paragraphs': [{'context': 'He has published a short stories compilation called "Maa Pasalapudi Kathalu". Besides that compilation, Vamsy has written a wide variety of short stories since 1974 when he was 18 years old. His major works include "Mahallo kokila", "Manchupallaki", "Aa Naati Vaana Chinukulu", "Venditera Kathalu" (original scripts of "Sankarabharanam" and "Anveshana"), "Vennela Bomma", "Gokulam lo Radha", "Ravvala konda", "Sree seetarama lanchi service Rajahmundry", "Manyam rani", "Rangularatnam". He has penned around 150 short stories published in swathi weekly under title "Maa Diguwa Godavari Kathalu" For his contributions to the art of story telling with a native approach through his books he was bestowed with "Sripada Puraskhaaram" at Rajamundry on 17 April 2011.',
   'qas': [{'question': 'మా పసలపూడి కథలు పుస్తకమును ఎవరు రచించారు?',
     'answers': [{'text': 'Vamsy', 'answer_start': 1

In [9]:
title = 'title:Vamsy_parentSection:Introduction_sectionName:Career._sectionIndex:2'
document_title = re.search(r'(title:)(.+)(Section:)', title).group(2)


In [10]:
def get_xorqa_answer_str(context, qas):
    context_qa_pairs = []
    for qa in qas:
        question = qa['question']
        lang = qa['lang'].strip()
        answer = qa['answers'][0]['text']
        answer_start = qa['answers'][0]['answer_start']
        answers = qa['answers']
        context_qa_pairs.append((context, question, answer, answer_start, lang, answers))
    return context_qa_pairs

In [11]:
xorqa_xx_dataset = defaultdict(lambda: { 'train': [], 'val': [] })

for split_name in ['train', 'val']:
    for i, item in enumerate(xorqa_xx[split_name]):
        paragraphs = item['paragraphs']
#         print('.' ,end='')
        full_title = item['title']
        document_title = re.search(r'(title:)(.+)(Section:)', full_title).group(2)
        for j, paragraph in enumerate(paragraphs):

            context = paragraph['context']
            context_qa_pairs = get_xorqa_answer_str(context=context, qas=paragraph['qas'])

            for context_qa_pair in context_qa_pairs:
                context, question, answer, answer_start, lang, answers = context_qa_pair
                lang = lang
                gt_sentence = mine_prompt_gt((context, question, answer, answer_start))
                qa_item = {
                     'question': question,
                     'lang': lang,
                     'context': context,
                     'segmented_context': segment_sentence(context),
                     'answer': answer,
                     'answer_start': answer_start,
                     'gt_sentence': gt_sentence,
                     'document_title': document_title,
                     'full_title': full_title,
                     'answers': answers
                }
                xorqa_xx_dataset[lang][split_name].append(qa_item)

In [12]:
XORQA_LANGS = list(xorqa_xx_dataset.keys())
XORQA_LANGS

['bn', 'ja', 'ko', 'ru', 'fi', 'ar', 'te']

In [13]:
for lang in XORQA_LANGS:
    print(f'lang= {lang}')
    for split_name in ['train', 'val']:

        n_questions = len(xorqa_xx_dataset[lang][split_name])
        print(f'split_name={split_name:6}, n_questions: {n_questions}')
    print('\n')

lang= bn
split_name=train , n_questions: 2474
split_name=val   , n_questions: 523


lang= ja
split_name=train , n_questions: 1927
split_name=val   , n_questions: 371


lang= ko
split_name=train , n_questions: 2395
split_name=val   , n_questions: 460


lang= ru
split_name=train , n_questions: 1744
split_name=val   , n_questions: 384


lang= fi
split_name=train , n_questions: 1855
split_name=val   , n_questions: 509


lang= ar
split_name=train , n_questions: 2303
split_name=val   , n_questions: 485


lang= te
split_name=train , n_questions: 1305
split_name=val   , n_questions: 376




### Data splitting strategy



Test set => training set and validation set

Val set = test set

------

A Disjoint set of documents for training, validation





In [14]:
xorqa_train_val_set = []
xorqa_train_val_groups = []
per_lang_q_counter = {}
for lang in XORQA_LANGS:

    for split_name in ['train']:
        qas = xorqa_xx_dataset[lang][split_name]
        xorqa_train_val_set.extend(qas)
        xorqa_train_val_groups.extend(list(map(lambda x: x['document_title'], qas)))
        print(f'lang: {lang}, {len(qas):,}')
        per_lang_q_counter[lang] = len(qas)
print(f'\nTotal number of QA example in train/val set: {len(xorqa_train_val_set):,}')
len(xorqa_train_val_groups)

lang: bn, 2,474
lang: ja, 1,927
lang: ko, 2,395
lang: ru, 1,744
lang: fi, 1,855
lang: ar, 2,303
lang: te, 1,305

Total number of QA example in train/val set: 14,003


14003

In [15]:
xorqa_test_set = []
xorqa_test_groups = []
for lang in XORQA_LANGS:

    for split_name in ['val']:
        qas = xorqa_xx_dataset[lang][split_name]
        xorqa_test_set.extend(qas)
        xorqa_test_groups.extend(list(map(lambda x: x['document_title'], qas)))
        
        print(f'lang: {lang}, {len(qas):,}')
print(f'\nTotal number of QA example in test set: {len(xorqa_test_set):,}')
len(xorqa_test_groups)

lang: bn, 523
lang: ja, 371
lang: ko, 460
lang: ru, 384
lang: fi, 509
lang: ar, 485
lang: te, 376

Total number of QA example in test set: 3,108


3108

In [16]:
len(set(xorqa_train_val_groups).intersection(set(xorqa_test_groups))), \
len(list((set(xorqa_train_val_groups).union(set(xorqa_test_groups)))))

(663, 12232)

In [17]:
xorqa_train_val_set[-1]

{'question': 'మా పసలపూడి కథలు పుస్తకమును ఎవరు రచించారు?',
 'lang': 'te',
 'context': 'He has published a short stories compilation called "Maa Pasalapudi Kathalu". Besides that compilation, Vamsy has written a wide variety of short stories since 1974 when he was 18 years old. His major works include "Mahallo kokila", "Manchupallaki", "Aa Naati Vaana Chinukulu", "Venditera Kathalu" (original scripts of "Sankarabharanam" and "Anveshana"), "Vennela Bomma", "Gokulam lo Radha", "Ravvala konda", "Sree seetarama lanchi service Rajahmundry", "Manyam rani", "Rangularatnam". He has penned around 150 short stories published in swathi weekly under title "Maa Diguwa Godavari Kathalu" For his contributions to the art of story telling with a native approach through his books he was bestowed with "Sripada Puraskhaaram" at Rajamundry on 17 April 2011.',
 'segmented_context': ['He has published a short stories compilation called "Maa Pasalapudi Kathalu".',
  'Besides that compilation, Vamsy has written 

In [18]:
len(xorqa_train_val_groups), xorqa_train_val_groups[:40]

(14003,
 ['WikiLeaks_parent',
  'World War II_parent',
  '1948 Arab–Israeli War_parent',
  'History of capitalism_parent',
  'World War I_parent',
  'Objections to evolution_parent',
  'Evolution_parent',
  'Republican Party (United States)_parent',
  '1948 Arab–Israeli War_parent',
  'Jimmy Wales_parent',
  'History of Islam_parent',
  'Great Salt Lake_parent',
  'Wikipedia_parent',
  'Political career of Arnold Schwarzenegger_parent',
  'On the Origin of Species_parent',
  'Athens_parent',
  'Motivation_parent',
  'Feminist sex wars_parent',
  'Raja Harishchandra_parent',
  'History of evolutionary thought_parent',
  'Ted Hughes_parent',
  'History of Apple Inc._parent',
  'Television network_parent',
  'John Maynard Keynes_parent',
  'History of the United States Republican Party_parent',
  'Al-Qaeda_parent',
  'Indo-Pakistani Naval War of 1971_parent',
  'Moscow_parent',
  'Byzantine Empire_parent',
  'History of India_parent',
  'Origen_parent',
  'Linux_parent',
  '1st Dalai Lama

In [30]:
SEED = 68

In [31]:

from math import isclose
from tqdm import tqdm
c =0
for random_seed in tqdm([SEED]):
    X = xorqa_train_val_set
    groups = xorqa_train_val_groups
    gss = GroupShuffleSplit(n_splits=1, train_size=0.90, random_state=random_seed)
    gss.get_n_splits(X)

    xorqa_group_split = {}
    for train_index, test_index in gss.split(X, groups=groups):
#         print("TRAIN:", len(train_index), "TEST:", len(test_index))

#         print(train_index[0:20], train_index[-20:], )

        xorqa_group_split['train'], xorqa_group_split['validation'] = [ X[idx] for idx in train_index ], [ X[idx] for idx in test_index ]


    per_split_doc_title_counter = defaultdict(lambda: set())
    per_split_paragraph_in_doc_counter = defaultdict(lambda: Counter())


    for split_name in ['train', 'validation']:

        for qa in xorqa_group_split[split_name]:
            document_title = qa['document_title']
            per_split_doc_title_counter[split_name].add(document_title)
            per_split_paragraph_in_doc_counter[split_name][document_title] += 1

    N_doc_train = len(per_split_doc_title_counter['train'])
    N_doc_validation = len(per_split_doc_title_counter['validation'])

    N_paragraphs_train = sum(per_split_paragraph_in_doc_counter['train'].values())
    N_paragraphs_validation = sum(per_split_paragraph_in_doc_counter['validation'].values())

    total_docs, total_paragraphs = N_doc_train + N_doc_validation, N_paragraphs_train + N_paragraphs_validation


    
    n_paragraph_train_ratio = (N_paragraphs_train/total_paragraphs)*100
    n_paragraph_train_validation = (N_paragraphs_validation/total_paragraphs)*100
       

    if isclose(n_paragraph_train_ratio, 90.0 , abs_tol=0.0024) and n_paragraph_train_ratio >=90.0 and c <=5:

        print(f'\nrandom_seed: {random_seed}')

        print(f'N_doc_train : {N_doc_train:,} ({(N_doc_train/total_docs)*100:.4f})')
        print(f'N_doc_validation : {N_doc_validation:,} ({(N_doc_validation/total_docs)*100:.4f})')
        print(f'\nN_paragraphs_train : {N_paragraphs_train:,} ({n_paragraph_train_ratio:.4f})')
        print(f'N_paragraphs_validation : {N_paragraphs_validation:,} ({n_paragraph_train_validation:.4f})')

        print(f'\ntotal_paragraphs: {total_paragraphs:,}')
        print(f'total_docs: {total_docs:,}')
        print('\n')
        print('-'*40)
#         break
        c+=1

100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 23.51it/s]


random_seed: 68
N_doc_train : 9,811 (89.9927)
N_doc_validation : 1,091 (10.0073)

N_paragraphs_train : 12,603 (90.0021)
N_paragraphs_validation : 1,400 (9.9979)

total_paragraphs: 14,003
total_docs: 10,902


----------------------------------------





### Convert into Json format (for Dataset.loader()

Example
```
5733be284776f41900661182

University_of_Notre_Dame

Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.

To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?

{ "text": [ "Saint Bernadette Soubirous" ], "answer_start": [ 515 ] }


```

In [32]:
def convert_qa_item_to_datasets_format(qa_item):
    title = qa_item['document_title']
    question, context, answers = qa_item['question'], qa_item['context'], qa_item['answers']
    lang =  qa_item['lang']
    return {
        'id': get_qa_hash(question, context),
        'lang': lang,
        'title': title,
        'context': context,
        'question': question,
        'answers': { 'text': [_answer['text'] for _answer in answers],
                     'answer_start': [_answer['answer_start'] for _answer in answers]
                   }
    }

In [33]:
xorqa_group_split['test'] = xorqa_test_set

xorqa_processed_dataset = {
   split_name: list(map(convert_qa_item_to_datasets_format, xorqa_group_split[split_name])) \
    for split_name in ['train','validation', 'test']
}

In [36]:
for split_name in ['train','validation', 'test']:
    os.makedirs('../data/xorqa/datasets_format/', exist_ok=True)
    with open(f'../data/xorqa/datasets_format/{split_name}.json', 'w', encoding='utf-8') as f:
        json.dump({ 'meta': {
            'train_split_ratio': 0.9,
            'val_split_ratio': 0.1,
            'split_type': 'GroupShuffleSplit',
            'split_seed': 68,
            'date': '04/08/2022',
        },
            'data': xorqa_processed_dataset[split_name] }, f, ensure_ascii=False, indent=4)


### Compute distribution of questions for all langs for train, validation and test set.

In [37]:
lang_counter = defaultdict()
lang_counter_percentage = defaultdict(lambda: defaultdict())                    

for split_name in ['train','validation', 'test']:
    
    lang_counter[split_name] = Counter(list(map(lambda x: x['lang'], xorqa_processed_dataset[split_name])))
    
    for lang in XORQA_LANGS:
        if split_name in ['train','validation']:
            lang_counter_percentage[split_name][lang] = {
                'lang': lang,
                'lang_counter': lang_counter[split_name][lang],
                'per_lang_q_counter': per_lang_q_counter[lang],
                'percentage': lang_counter[split_name][lang] / per_lang_q_counter[lang]
            }

In [38]:
pprint(lang_counter_percentage, indent=2)

defaultdict(<function <lambda> at 0x7fbcc9b739e0>,
            { 'train': defaultdict(None,
                                   { 'ar': { 'lang': 'ar',
                                             'lang_counter': 2077,
                                             'per_lang_q_counter': 2303,
                                             'percentage': 0.9018671298306556},
                                     'bn': { 'lang': 'bn',
                                             'lang_counter': 2203,
                                             'per_lang_q_counter': 2474,
                                             'percentage': 0.8904607922392886},
                                     'fi': { 'lang': 'fi',
                                             'lang_counter': 1676,
                                             'per_lang_q_counter': 1855,
                                             'percentage': 0.9035040431266846},
                                     'ja': { 'lang': 'ja',
            