# Process xquad dataset for Seq2Seq QA training

In [15]:
import os, sys
import json
import glob
from collections import defaultdict, Counter

In [27]:
SQUAD_EN_DATA_DIR = '../../exploratory/online_prompt/data/xquad/en/'
XQUAD_DATA_DIR = '../../exploratory/online_prompt/data/xquad/xx/'

In [25]:

!ls $SQUAD_EN_DATA_DIR
!ls $XQUAD_DATA_DIR

train-v1.1.json
xquad.ar.json xquad.en.json xquad.ro.json xquad.tr.json
xquad.de.json xquad.es.json xquad.ru.json xquad.vi.json
xquad.el.json xquad.hi.json xquad.th.json xquad.zh.json


In [65]:
squad_en = { 
     'train': json.load(open(os.path.join(SQUAD_EN_DATA_DIR, 'train-v1.1.json')))['data'],
     'validation': json.load(open(os.path.join(SQUAD_EN_DATA_DIR, 'dev-v1.1.json')))['data']
}

In [66]:
len(squad_en['train']), len(squad_en['validation'])

(442, 48)

In [64]:
# squad_dataset['train']

In [68]:
squad_dataset = defaultdict(lambda : dict())
for split_name in ['train', 'validation']:
    for i, item in enumerate(squad_en[split_name]):
        paragraphs = item['paragraphs']
#         print('.' ,end='')
        for j, paragraph in enumerate(paragraphs):

            context = paragraph['context']
            context_qa_pairs = get_squad_answer_str(context=context, qas=paragraph['qas'])

            for context_qa_pair in context_qa_pairs:
                qid, context, question, answer, answer_start = context_qa_pair

                qa_item = {
                    'qid': qid,
                    'question': question,
                    'context': context,
                    'answer': answer,
                    'answer_start': answer_start,
                }
                squad_dataset[split_name][qid] = qa_item
    
    print(f'Number of {split_name} examples: {len(squad_dataset[split_name]):,}')

Number of train examples: 87,599
Number of validation examples: 10,570


In [23]:
# train_squad[0]
# qid, lang, quesiotn, context

In [90]:
xquad_xx = {}
XQUAD_LANGS = ['en', 'ar', 'de','el', 'es', 'hi', 'vi','ro', 'ru', 'th','tr',  'zh']
for lang in XQUAD_LANGS:
    xquad_xx[f'{lang}_all'] = json.load(open(os.path.join(XQUAD_DATA_DIR, f'xquad.{lang}.json'), 'r'))['data']
xquad_xx.keys()

dict_keys(['en_all', 'ar_all', 'de_all', 'el_all', 'es_all', 'hi_all', 'vi_all', 'ro_all', 'ru_all', 'th_all', 'tr_all', 'zh_all'])

In [91]:
# xquad_xx['en_all'][0]

In [92]:
def get_squad_answer_str(context, qas):
    context_qa_pairs = []
    for qa in qas:
        qid = qa['id']
        question = qa['question']
        answer = qa['answers'][0]['text']
        answer_start = qa['answers'][0]['answer_start']
        context_qa_pairs.append((qid, context, question, answer, answer_start))
    return context_qa_pairs

In [93]:
xquad_question_counter = Counter()
# n_paragraph = len(xquad_en['data'])
xquad_dataset= defaultdict(lambda: [])

for lang in XQUAD_LANGS:
    for i, item in enumerate(xquad_xx[f'{lang}_all']):
        paragraphs = item['paragraphs']
#         print('.' ,end='')
        for j, paragraph in enumerate(paragraphs):
            xquad_question_counter[f'd-{i}_p-{j}'] = len(paragraph['qas'])

            context = paragraph['context']
            context_qa_pairs = get_squad_answer_str(context=context, qas=paragraph['qas'])

            for context_qa_pair in context_qa_pairs:
                qid, context, question, answer, answer_start = context_qa_pair

                qa_item = {
                    'lang': lang,
                    'qid': qid,
                    'question': question,
#                     'context': context,
                    'answer': answer,
                    'answer_start': answer_start,
                }
                xquad_dataset[lang].append(qa_item)

In [94]:
len(xquad_dataset)

12

In [95]:
xquad_dataset[0]

[]

In [96]:
# squad_dataset['train'].keys()

### Separate train and validation set

In [97]:
xquad_dataset_split = defaultdict(lambda : list())

for split_name in ['train', 'validation']:
    for lang in XQUAD_LANGS:
        xquad_dataset_split[f'{lang}_{split_name}'] = []
        for item in xquad_dataset[f'{lang}']:
            if item['qid'] in squad_dataset[split_name].keys():
                _item = item
                _item['context'] = squad_dataset[split_name][item['qid']]['context']
                xquad_dataset_split[f'{lang}_{split_name}'].append(_item)

In [98]:
for key, vals in xquad_dataset_split.items():
    print(f'{key}: {len(vals)}')

en_train: 0
ar_train: 0
de_train: 0
el_train: 0
es_train: 0
hi_train: 0
vi_train: 0
ro_train: 0
ru_train: 0
th_train: 0
tr_train: 0
zh_train: 0
en_validation: 1190
ar_validation: 1190
de_validation: 1190
el_validation: 1190
es_validation: 1190
hi_validation: 1190
vi_validation: 1190
ro_validation: 1190
ru_validation: 1190
th_validation: 1190
tr_validation: 1190
zh_validation: 1190
