## Sentence-level online prompty mining: XORQA

In [1]:
import copy
import re
import os, sys
import json
import glob
from collections import Counter, defaultdict
import pandas as pd
import jsonlines

from exploring_sentence_level import (
    load_model,
    mine_prompt_gt,  
    segment_sentence,
    run_online_prompt_mining
)

### 0. Download dataset

```bash
cd ../scripts
bash ./download_xorqa.sh
```

### 1. Process dataset

In [2]:
XORQA_BASE_DIR = '../data/xorqa/en/tydi_xor_gp/'
xorqa_xx = {
    'train': json.load(open(os.path.join(XORQA_BASE_DIR, 'gp_squad_train_data.json'), 'r'))['data'],
      'val': json.load(open(os.path.join(XORQA_BASE_DIR, 'gp_squad_dev_data.json'), 'r'))['data'],   
}

In [3]:
xorqa_xx['train'][-1]

{'title': 'title:Vamsy_parentSection:Introduction_sectionName:Career._sectionIndex:2',
 'paragraphs': [{'context': 'He has published a short stories compilation called "Maa Pasalapudi Kathalu". Besides that compilation, Vamsy has written a wide variety of short stories since 1974 when he was 18 years old. His major works include "Mahallo kokila", "Manchupallaki", "Aa Naati Vaana Chinukulu", "Venditera Kathalu" (original scripts of "Sankarabharanam" and "Anveshana"), "Vennela Bomma", "Gokulam lo Radha", "Ravvala konda", "Sree seetarama lanchi service Rajahmundry", "Manyam rani", "Rangularatnam". He has penned around 150 short stories published in swathi weekly under title "Maa Diguwa Godavari Kathalu" For his contributions to the art of story telling with a native approach through his books he was bestowed with "Sripada Puraskhaaram" at Rajamundry on 17 April 2011.',
   'qas': [{'question': 'మా పసలపూడి కథలు పుస్తకమును ఎవరు రచించారు?',
     'answers': [{'text': 'Vamsy', 'answer_start': 1

In [4]:
def get_xorqa_answer_str(context, qas):
    context_qa_pairs = []
    for qa in qas:
        question = qa['question']
        lang = qa['lang']
        answer = qa['answers'][0]['text']
        answer_start = qa['answers'][0]['answer_start']
        context_qa_pairs.append((context, question, answer, answer_start, lang))
    return context_qa_pairs

In [48]:
xorqa_xx_dataset = defaultdict(lambda: { 'train': [], 'val': [] })

xorqa_sentences = []
global_paragraph_id = 0
global_sentence_id = 0

for split_name in ['train', 'val']:
    for i, item in enumerate(xorqa_xx[split_name]):

        title = item['title']
        paragraphs = item['paragraphs']

        for j, paragraph in enumerate(paragraphs):

            context = paragraph['context']
            context_qa_pairs = get_xorqa_answer_str(context=context, qas=paragraph['qas'])
            segmented_context = segment_sentence(context)
            segmented_context_ids = []
        
            for sentence_id in range(len(segmented_context)):
                xorqa_sentences.append((title, global_paragraph_id, global_sentence_id, segmented_context[sentence_id], split_name))
                segmented_context_ids.append(global_sentence_id)
                global_sentence_id += 1

            for context_qa_pair in context_qa_pairs:
                context, question, answer, answer_start, lang = context_qa_pair
                gt_sentence, gt_sentence_idx = mine_prompt_gt((context, question, answer, answer_start))
                gt_sentence_global_idx = segmented_context_ids[gt_sentence_idx]
                
                qa_item = {
                     'question': question,
                     'lang': lang,
                     'context': context,
                     'segmented_context': segment_sentence(context),
                     'segmented_context_ids': segmented_context_ids,
                     'answer': answer,
                     'answer_start': answer_start,
                     'split_name': split_name,
                     'gt_sentence': gt_sentence,
                     'gt_sentence_idx': gt_sentence_global_idx,

                }
                xorqa_xx_dataset[lang][split_name].append(qa_item)
            global_paragraph_id += 1

In [23]:
list(xorqa_xx_dataset.keys())

['bn', 'ja', 'ko', 'ru', 'fi', ' ar', 'te', 'ar']

In [24]:
len(xorqa_xx_dataset['ar']['val'])

485

#### Write (English) segmented sentences into separated csv file:


In [25]:
len(xorqa_sentences), xorqa_sentences[0]

(76113,
 ('title:WikiLeaks_parentSection:Introduction_sectionName:Introduction_sectionIndex:0',
  0,
  0,
  'WikiLeaks () is an international non-profit organisation that publishes secret information, news leaks, and classified media provided by anonymous sources.'))

In [49]:
xorqa_en_sentences_df = pd.DataFrame.from_dict(xorqa_sentences)
xorqa_en_sentences_df.columns=['doc_title', 'paragraph_id', 'sentence_id', 'sentence', 'split_name']

In [56]:
xorqa_en_sentences_df.head(5)

Unnamed: 0,doc_title,paragraph_id,sentence_id,sentence,split_name
0,title:WikiLeaks_parentSection:Introduction_sec...,0,0,WikiLeaks () is an international non-profit or...,train
1,title:WikiLeaks_parentSection:Introduction_sec...,0,1,"Its website, initiated in 2006 in Iceland by t...",train
2,title:WikiLeaks_parentSection:Introduction_sec...,0,2,"Julian Assange, an Australian Internet activis...",train
3,title:WikiLeaks_parentSection:Introduction_sec...,0,3,Kristinn Hrafnsson is its editor-in-chief.,train
4,title:World War II_parentSection:Introduction_...,1,4,The war in Europe concluded with an invasion o...,train


In [57]:
xorqa_en_sentences_df.tail(5)

Unnamed: 0,doc_title,paragraph_id,sentence_id,sentence,split_name
76108,title:Gamergate controversy_parentSection:Hist...,17110,76108,"In mid-October Brianna Wu, another independent...",val
76109,title:Gamergate controversy_parentSection:Hist...,17110,76109,Wu then became the target of rape and death th...,val
76110,title:Gamergate controversy_parentSection:Hist...,17110,76110,"After contacting police, Wu fled her home with...",val
76111,title:Gamergate controversy_parentSection:Hist...,17110,76111,Wu announced an reward for information leading...,val
76112,title:Gamergate controversy_parentSection:Hist...,17110,76112,"As of April 2016, Wu was still receiving threa...",val


In [51]:
xorqa_en_sentences_df.to_csv('./question-sentences-pairs/xorqa/xorqa_sentence-en.csv')

#### Write (All languages) question-sentence pairs into separated csv file:



In [46]:
xorqa_xx_dataset.keys(), xorqa_xx_dataset['bn'].keys()

(dict_keys(['bn', 'ja', 'ko', 'ru', 'fi', ' ar', 'te', 'ar', 'question_lang']),
 dict_keys(['train', 'val']))

In [41]:
xorqa_xx_dataset['bn']['train'][0]

{'question': 'উইকিলিকস কত সালে সর্বপ্রথম ইন্টারনেটে প্রথম তথ্য প্রদর্শন করে ?',
 'lang': 'bn',
 'context': 'WikiLeaks () is an international non-profit organisation that publishes secret information, news leaks, and classified media provided by anonymous sources. Its website, initiated in 2006 in Iceland by the organisation Sunshine Press, claims a database of 10 million documents in 10 years since its launch. Julian Assange, an Australian Internet activist, is generally described as its founder and director. Kristinn Hrafnsson is its editor-in-chief.',
 'segmented_context': ['WikiLeaks () is an international non-profit organisation that publishes secret information, news leaks, and classified media provided by anonymous sources.',
  'Its website, initiated in 2006 in Iceland by the organisation Sunshine Press, claims a database of 10 million documents in 10 years since its launch.',
  'Julian Assange, an Australian Internet activist, is generally described as its founder and director.

In [45]:
for question_lang in list(xorqa_xx_dataset.keys()):
    
    for split_name in list(xorqa_xx_dataset[question_lang].keys()):

        if len(xorqa_xx_dataset[question_lang][split_name]) > 0:
            xorqa_question_sentence_pairs_df = pd.DataFrame.from_dict(list(map(lambda x: (x['question'], x['gt_sentence_idx']), xorqa_xx_dataset[question_lang][split_name])))
            xorqa_question_sentence_pairs_df.columns = ['question' , 'gt_sentence_idx']

            xorqa_question_sentence_pairs_df.to_csv(f'./question-sentences-pairs/xorqa/xorqa-{split_name}_question-{question_lang}_sentence-en.csv')

### 2. Compute question-sentence similarity


#### 2.1 Load models

##### a) Load mUSE_small (v3) model (as a baseline)

In [8]:
muse_small_v3_model = load_model('https://tfhub.dev/google/universal-sentence-encoder-multilingual/3')


##### b) Load teacher models

In [9]:
XQUAD_TEACHER_DIR = '../../../../CL-ReLKT_store/models/XQUAD/teacher_model/'
MLQA_TEACHER_DIR = '../../../../CL-ReLKT_store/models/MLQA/teacher_model/'

In [10]:
xquad_teacher_model = load_model(XQUAD_TEACHER_DIR)
mlqa_teacher_model = load_model(MLQA_TEACHER_DIR)

##### c) Load student models

In [11]:
XQUAD_STUDENT_SUPPORTED_LANGS_DIR = '../../../../CL-ReLKT_store/models/XQUAD/student_best_supported_languages/'
XQUAD_STUDENT_UNSUPPORTED_LANGS_DIR = '../../../../CL-ReLKT_store/models/XQUAD/student_best_unsupported_languages/'

XORQA_STUDENT_SUPPORTED_LANGS_DIR = '../../../../CL-ReLKT_store/models/XORQA/student_best_supported_languages/'
XORQA_STUDENT_UNSUPPORTED_LANGS_DIR = '../../../../CL-ReLKT_store/models/XORQA/student_best_unsupported_languages/'

MLQA_STUDENT_SUPPORTED_LANGS_DIR = '../../../../CL-ReLKT_store/models/MLQA/student_best_supported_languages/'
MLQA_STUDENT_UNSUPPORTED_LANGS_DIR = '../../../../CL-ReLKT_store/models/MLQA/student_best_unsupported_languages/'

In [12]:
xquad_student_supported_langs_model = load_model(XQUAD_STUDENT_SUPPORTED_LANGS_DIR)
xorqa_student_supported_langs_model = load_model(XORQA_STUDENT_SUPPORTED_LANGS_DIR)
mlqa_student_supported_langs_model = load_model(MLQA_STUDENT_SUPPORTED_LANGS_DIR)

xquad_student_unsupported_langs_model = load_model(XQUAD_STUDENT_UNSUPPORTED_LANGS_DIR)
xorqa_student_unsupported_langs_model = load_model(XORQA_STUDENT_UNSUPPORTED_LANGS_DIR)
mlqa_student_unsupported_langs_model = load_model(MLQA_STUDENT_UNSUPPORTED_LANGS_DIR)

In [13]:
MODEL_MAPPING = {
  # mUSE_small
  'model-muse_small_v3': muse_small_v3_model,
  # teacher    
  'model-xquad_teacher': xquad_teacher_model,
  'model-mlqa_teacher': mlqa_teacher_model,
  # student
  'model-xquad_student_supported_langs': xquad_student_supported_langs_model,
  'model-xorqa_student_supported_langs': xorqa_student_supported_langs_model,
  'model-mlqa_student_supported_langs': mlqa_student_supported_langs_model,
  'model-xquad_student_unsupported_langs': xquad_student_unsupported_langs_model,
  'model-xorqa_student_unsupported_langs': xorqa_student_unsupported_langs_model,
  'model-mlqa_student_unsupported_langs': mlqa_student_unsupported_langs_model,
}



In [14]:
DATASET_MAPPING = {}

for lang in list(xorqa_xx_dataset.keys()):
    if len(xorqa_xx_dataset[lang]['train']) != 0:
        DATASET_MAPPING[f'dataset-xorqa_{lang.strip()}_train'] = xorqa_xx_dataset[lang]['train']
    if len(xorqa_xx_dataset[lang]['val']) != 0:
        DATASET_MAPPING[f'dataset-xorqa_{lang.strip()}_val'] = xorqa_xx_dataset[lang]['val']
print(DATASET_MAPPING.keys())

dict_keys(['dataset-xorqa_bn_train', 'dataset-xorqa_bn_val', 'dataset-xorqa_ja_train', 'dataset-xorqa_ja_val', 'dataset-xorqa_ko_train', 'dataset-xorqa_ko_val', 'dataset-xorqa_ru_train', 'dataset-xorqa_ru_val', 'dataset-xorqa_fi_train', 'dataset-xorqa_fi_val', 'dataset-xorqa_ar_train', 'dataset-xorqa_te_train', 'dataset-xorqa_te_val', 'dataset-xorqa_ar_val'])


#### 2.2 Run inference and evaluate

The following function `run_online_prompt_mining` iterates over question-answer-passage triplets $(q_i, a_i, p_i)$ and compute 
the cosine similarity scores between question $q_i$ and segmented setences $s^i_j \textrm{ where } p_i = ( s^i_0, \ldots , s^i_{|p_i| - 1} )$ , and rank each quesiton-sentence pair by similairy score. Then, it evaluate the sentence-level precision@k.  Note: There is only 1 groundtruth sentence (i.e. the sentence where the answer span is a part of). 


In [None]:
results = defaultdict(lambda : defaultdict())

for dataset_prefix, dataset in DATASET_MAPPING.items():
    print(f'\n\ndataset_prefix: {dataset_prefix}')
    for model_prefix, model in MODEL_MAPPING.items():
        
        print(f'\n - model_prefix: {model_prefix}')
        prefix = f'{dataset_prefix}+{model_prefix}'
        _result = run_online_prompt_mining(dataset,
                             prefix=f'{dataset_prefix}_{model_prefix}',
                             model=model)


        results[dataset_prefix][model_prefix] = _result
        print('--'*50)
    print('\n')    
    print('=='*50)
    print('\n')    
    



dataset_prefix: dataset-xorqa_bn_train

 - model_prefix: model-muse_small_v3


100%|██████████| 2474/2474 [05:04<00:00,  8.11it/s]



	Evaluation result:
	 - Accuracy: 0.4062
	 - precision_at_k:
{1: 0.40622473726758285,
 2: 0.6471301535974131,
 3: 0.7934518997574778,
 4: 0.8811641067097817,
 5: 0.9329021827000809,
 6: 0.9595796281325788,
 7: 0.97696038803557,
 8: 0.9854486661277284,
 9: 0.9894907033144705,
 10: 0.9911075181891673}
----------------------------------------------------------------------------------------------------

 - model_prefix: model-xquad_teacher


100%|██████████| 2474/2474 [05:00<00:00,  8.24it/s]



	Evaluation result:
	 - Accuracy: 0.3868
	 - precision_at_k:
{1: 0.3868229587712207,
 2: 0.6329830234438156,
 3: 0.778496362166532,
 4: 0.8779304769603881,
 5: 0.9240097008892482,
 6: 0.9583670169765561,
 7: 0.9737267582861763,
 8: 0.9834276475343573,
 9: 0.9878738884397736,
 10: 0.9907033144704931}
----------------------------------------------------------------------------------------------------

 - model_prefix: model-mlqa_teacher


100%|██████████| 2474/2474 [05:05<00:00,  8.09it/s]



	Evaluation result:
	 - Accuracy: 0.4228
	 - precision_at_k:
{1: 0.42279708973322555,
 2: 0.650767987065481,
 3: 0.7922392886014551,
 4: 0.881568310428456,
 5: 0.9284559417946645,
 6: 0.9595796281325788,
 7: 0.9753435731608731,
 8: 0.9810024252223121,
 9: 0.9886822958771221,
 10: 0.99232012934519}
----------------------------------------------------------------------------------------------------

 - model_prefix: model-xquad_student_supported_langs


100%|██████████| 2474/2474 [04:47<00:00,  8.60it/s]



	Evaluation result:
	 - Accuracy: 0.3108
	 - precision_at_k:
{1: 0.31083265966046886,
 2: 0.5675020210185934,
 3: 0.7485852869846402,
 4: 0.8536782538399353,
 5: 0.9058205335489087,
 6: 0.9470493128536782,
 7: 0.9652384801940178,
 8: 0.97696038803557,
 9: 0.9826192400970089,
 10: 0.9886822958771221}
----------------------------------------------------------------------------------------------------

 - model_prefix: model-xorqa_student_supported_langs


100%|██████████| 2474/2474 [04:58<00:00,  8.28it/s]



	Evaluation result:
	 - Accuracy: 0.3015
	 - precision_at_k:
{1: 0.301535974130962,
 2: 0.5642683912691997,
 3: 0.732821341956346,
 4: 0.8480194017784963,
 5: 0.9042037186742118,
 6: 0.9405820533548909,
 7: 0.9632174616006467,
 8: 0.9757477768795473,
 9: 0.9822150363783346,
 10: 0.986661277283751}
----------------------------------------------------------------------------------------------------

 - model_prefix: model-mlqa_student_supported_langs


100%|██████████| 2474/2474 [04:55<00:00,  8.37it/s]



	Evaluation result:
	 - Accuracy: 0.4268
	 - precision_at_k:
{1: 0.42683912691996767,
 2: 0.66410670978173,
 3: 0.8055780113177041,
 4: 0.883589329021827,
 5: 0.9288601455133387,
 6: 0.9555375909458367,
 7: 0.973322554567502,
 8: 0.9814066289409863,
 9: 0.9858528698464026,
 10: 0.9894907033144705}
----------------------------------------------------------------------------------------------------

 - model_prefix: model-xquad_student_unsupported_langs


100%|██████████| 2474/2474 [04:58<00:00,  8.28it/s]



	Evaluation result:
	 - Accuracy: 0.3395
	 - precision_at_k:
{1: 0.3395311236863379,
 2: 0.5905416329830234,
 3: 0.7623282134195635,
 4: 0.862570735650768,
 5: 0.9211802748585287,
 6: 0.952303961196443,
 7: 0.9717057396928052,
 8: 0.9814066289409863,
 9: 0.9894907033144705,
 10: 0.9927243330638642}
----------------------------------------------------------------------------------------------------

 - model_prefix: model-xorqa_student_unsupported_langs


100%|██████████| 2474/2474 [04:39<00:00,  8.84it/s]



	Evaluation result:
	 - Accuracy: 0.2676
	 - precision_at_k:
{1: 0.2675828617623282,
 2: 0.5161681487469685,
 3: 0.6956345998383185,
 4: 0.8096200485044462,
 5: 0.8848019401778496,
 6: 0.9268391269199676,
 7: 0.9567502021018593,
 8: 0.9725141471301536,
 9: 0.9785772029102667,
 10: 0.9830234438156831}
----------------------------------------------------------------------------------------------------

 - model_prefix: model-mlqa_student_unsupported_langs


 61%|██████▏   | 1521/2474 [03:01<02:01,  7.86it/s]

In [None]:
class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)


#### 2.3 Write result as JSON file

In [None]:
json.dump(results, open('./eval_results.dataset_name-xorqa.json', 'w'), ensure_ascii=False, indent=2)

### 3. Convert evaluation results to a pandas.DataFrame

In [18]:
results = json.load(open('./eval_results.dataset_name-xorqa.json', 'r'))

In [19]:
list(results.keys()), len(list(results.keys()))

(['dataset-xorqa_bn_train',
  'dataset-xorqa_bn_val',
  'dataset-xorqa_ja_train',
  'dataset-xorqa_ja_val',
  'dataset-xorqa_ko_train',
  'dataset-xorqa_ko_val',
  'dataset-xorqa_ru_train',
  'dataset-xorqa_ru_val',
  'dataset-xorqa_fi_train',
  'dataset-xorqa_fi_val',
  'dataset-xorqa_ar_train',
  'dataset-xorqa_te_train',
  'dataset-xorqa_te_val',
  'dataset-xorqa_ar_val'],
 14)

In [22]:
result_objs = []
for dataset_name, result_model_group in results.items():
    for model_name, (metric, raw_result) in result_model_group.items():
        top1, precision_at_k = metric
        
        result_objs.append({
            'dataset_name': dataset_name,
            'model_name': model_name,
            'precision_at_1': top1,
            'precision_at_2': precision_at_k['2'],
            'precision_at_3': precision_at_k['6'],
            'precision_at_4': precision_at_k['4'],
            'precision_at_5': precision_at_k['5'],
            'precision_at_10': precision_at_k['10'],
        })
    
df = pd.DataFrame.from_dict(result_objs)
df.to_csv('./eval_results.dataset_name-xorqa.csv')

In [23]:
df

Unnamed: 0,dataset_name,model_name,precision_at_1,precision_at_2,precision_at_3,precision_at_4,precision_at_5,precision_at_10
0,dataset-xorqa_bn_train,model-muse_small_v3,0.406225,0.647130,0.959580,0.881164,0.932902,0.991108
1,dataset-xorqa_bn_train,model-xquad_teacher,0.386823,0.632983,0.958367,0.877930,0.924010,0.990703
2,dataset-xorqa_bn_train,model-mlqa_teacher,0.422797,0.650768,0.959580,0.881568,0.928456,0.992320
3,dataset-xorqa_bn_train,model-xquad_student_supported_langs,0.310833,0.567502,0.947049,0.853678,0.905821,0.988682
4,dataset-xorqa_bn_train,model-xorqa_student_supported_langs,0.301536,0.564268,0.940582,0.848019,0.904204,0.986661
...,...,...,...,...,...,...,...,...
121,dataset-xorqa_ar_val,model-xorqa_student_supported_langs,0.494845,0.721649,0.975258,0.917526,0.958763,0.995876
122,dataset-xorqa_ar_val,model-mlqa_student_supported_langs,0.540206,0.748454,0.975258,0.931959,0.967010,0.989691
123,dataset-xorqa_ar_val,model-xquad_student_unsupported_langs,0.404124,0.628866,0.956701,0.880412,0.925773,1.000000
124,dataset-xorqa_ar_val,model-xorqa_student_unsupported_langs,0.453608,0.717526,0.975258,0.921649,0.962887,0.985567
