In [30]:
# from transformers.data.processors.squad import SquadResult, SquadV1Processor

# processor = SquadV1Processor()
# examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)

In [31]:
import torch
import json
import collections
from squad import SquadResult, SquadV1Processor, SquadV2Processor
from squad_metrics import (
    compute_predictions_log_probs,
    compute_predictions_logits,
    squad_evaluate,
)

cached_features_file='tydiqa/cached_dev_tydiqa-gold_384_'
pred_file='tydiqa/predictions_.json'
features_and_dataset = torch.load(cached_features_file)

examples = features_and_dataset["examples"]
predictions=json.load(open(pred_file))

In [None]:
langs = ['swahili','bengali','arabic','korean','english','indonesian','japanese','russian',
        'telugu','finnish','thai']

for lan in langs:
    name = lan+'--'
    new_examples=[]
    new_preds=collections.OrderedDict()
    for exam in examples:
        if exam.qas_id.startswith(name)==True:
            new_examples.append(exam)
            new_preds[exam.qas_id]=predictions[exam.qas_id]
    if len(new_examples)!=0:
        results = squad_evaluate(new_examples, new_preds)
        print(name)
        print(results)

for lan in langs[:-1]:
    name = lan+'-english--'
    new_examples=[]
    new_preds=collections.OrderedDict()
    for exam in examples:
        if exam.qas_id.startswith(name)==True:
            new_examples.append(exam)
            new_preds[exam.qas_id]=predictions[exam.qas_id]
    if len(new_examples)!=0:
        results = squad_evaluate(new_examples, new_preds)
        print(name)
        print(results)

for lan in langs[:-1]:
    name = 'english-'+lan+'--'
    new_examples=[]
    new_preds=collections.OrderedDict()
    for exam in examples:
        if exam.qas_id.startswith(name)==True:
            new_examples.append(exam)
            new_preds[exam.qas_id]=predictions[exam.qas_id]
    if len(new_examples)!=0:
        results = squad_evaluate(new_examples, new_preds)
        print(name)
        print(results)

In [3]:
import os 
import json
import pickle

def get_qdata(filename):

    with open(filename,'r') as f:
        data=json.loads(f.read())
    mlqa_q={}
    i=0
    for dat in data['data']:
    #     if i>2:
    #         break
    #     print(dat)
        for par in dat['paragraphs']:
            for qos in par['qas']:
                mlqa_q[qos['id']] = qos['question'].strip()
                i+=1
    print(i, len(mlqa_q))
    if i!=len(mlqa_q):
        mlqa_q={}
        i=0
        for dat in data['data']:
        #     if i>2:
        #         break
        #     print(dat)
            for par in dat['paragraphs']:
                for qos in par['qas']:
                    if qos['id'] not in mlqa_q:
                        mlqa_q[qos['id']]=[qos['question'].strip()]
                    else:
                        mlqa_q[qos['id']].append(qos['question'].strip())
                    i+=1
        print(i, len(mlqa_q))        
    return mlqa_q

In [11]:
from wikidata.client import Client
import json
import os

json_name = 'tydiqa_data/all_d.json'
with open(json_name) as inp:
    jsonstr = inp.read()
ent_ids = json.loads(jsonstr)

json_name = 'country_dict.json'
with open(json_name) as inp:
    jsonstr = inp.read()
countries = json.loads(jsonstr)

json_name = 'id2geo.json'
with open(json_name) as inp:
    jsonstr = inp.read()
id2geo = json.loads(jsonstr)

json_name = 'place2country.json'
with open(json_name) as inp:
    jsonstr = inp.read()
place2coun = json.loads(jsonstr)

def get_orig_id_ent(id_ques,all_q_id, all_ent):
    client = Client()
    orig_id_ent={}
    tags=['country','citizen','born','died']
    for ids,q in id_ques.items():
        lang=ids.split('-')[0]
        if q in all_q_id[lang]:
            ent=all_ent[lang][int(all_q_id[lang][q])]
            orig_id_ent[ids]=[]
            if ent in ent_ids:
    #             print(q,ent_ids[ent],ent)
                for x,j in ent_ids[ent].items():
                    if x in tags:
                        if j in countries: 
                            orig_id_ent[ids].append(countries[j])
                        elif j in id2geo:
                            orig_id_ent[ids].append(id2geo[j]['country'])
                        elif j in place2coun:
                            orig_id_ent[ids].append(place2coun[j]['country'])
                        else:
                            place = client.get('Q745956', load=True)
                            for key2 in place.keys():
                                if key2.id == 'P17':
                                    try:
                                        countryid = place[key2].id
                                        orig_id_ent[ids].append(countries[countryid])
#                                         print(x,j,countries[countryid])
                                    except:
                                        print(key2,  place)
                orig_id_ent[ids]=set(orig_id_ent[ids])
    return orig_id_ent

def get_coun2ids(orig_id_ent):
    coun2ids={}
    for ids, coun in orig_id_ent.items():
        if len(coun)!=0:
            for x in coun:
                if x not in coun2ids:
                    coun2ids[x]=[ids]
                else:
                    coun2ids[x].append(ids)
    return coun2ids

#### train file

In [27]:
import pickle

datapath='../data/sentences/tydiqa'


id_ques=get_qdata(os.path.join('tydiqa','tydiqa-goldp-v1.1-train.json'))
ques_id = {j:i for i,j in id_ques.items()}

all_q={}
all_q_id={}
all_q_list=[]
for f in os.listdir(datapath):
    if 'train' in f and 'train-bn' not in f:
        with open(os.path.join(datapath,f),'rb') as f1:
            sents=pickle.load(f1)
            sents=[x.replace('[START] ','').replace(' [END]','') for x in sents]
            name=f.split('.')[0].replace('tydiqa-train-','')
            all_q[name]=sents
            all_q_id[name]={x:i for i,x in enumerate(sents)}
            

import pickle

datapath='../data/entities/tydiqa'

all_ent={}
all_ent_id={}

for f in os.listdir(datapath):
    if 'train' in f and 'train-bn' not in f:
        with open(os.path.join(datapath,f),'rb') as f1:
            sents=pickle.load(f1)
            name=f.split('.')[0].replace('tydiqa-train-','')
            all_ent[name]=[x[0]['id'] if len(x)>=1 else '0' for x in sents]

49881 49881


In [28]:
orig_train = get_orig_id_ent(id_ques,all_q_id, all_ent)
coun2ids_train = get_coun2ids(orig_train)
langs = ['swahili','bengali','arabic','korean','english','indonesian','japanese','russian',
        'telugu','finnish','thai']
count_train={}
for lang in langs:
    count_train[lang]={}
    for countr,ids in coun2ids_train.items():
        for id_one in ids:
            if lang in id_one:
                if countr not in count_train[lang]:
                    count_train[lang][countr]=1
                else:
                    count_train[lang][countr]+=1

#### dev file

In [23]:
import pickle

datapath='../data/sentences/tydiqa'

types='dev'
id_ques=get_qdata(os.path.join('tydiqa','tydiqa-goldp-v1.1-{}.json'.format(types)))
ques_id = {j:i for i,j in id_ques.items()}

all_q={}
all_q_id={}
all_q_list=[]
for f in os.listdir(datapath):
    if types in f and 'train-bn' not in f:
        with open(os.path.join(datapath,f),'rb') as f1:
            sents=pickle.load(f1)
            sents=[x.replace('[START] ','').replace(' [END]','') for x in sents]
            name=f.split('.')[0].replace('tydiqa-{}-'.format(types),'')
            all_q[name]=sents
            all_q_id[name]={x:i for i,x in enumerate(sents)}
            

import pickle

datapath='../data/entities/tydiqa'

all_ent={}
all_ent_id={}

for f in os.listdir(datapath):
    if types in f and 'train-bn' not in f:
        with open(os.path.join(datapath,f),'rb') as f1:
            sents=pickle.load(f1)
            name=f.split('.')[0].replace('tydiqa-{}-'.format(types),'')
            all_ent[name]=[x[0]['id'] if len(x)>=1 else '0' for x in sents]

5077 5077


In [25]:
orig_dev = get_orig_id_ent(id_ques,all_q_id, all_ent)
coun2ids_dev = get_coun2ids(orig_dev)
langs = ['swahili','bengali','arabic','korean','english','indonesian','japanese','russian',
        'telugu','finnish','thai']
count_dev={}
for lang in langs:
    count_dev[lang]={}
    for countr,ids in coun2ids_dev.items():
        for id_one in ids:
            if lang in id_one:
                if countr not in count_dev[lang]:
                    count_dev[lang][countr]=1
                else:
                    count_dev[lang][countr]+=1

In [48]:
langs = ['swahili','bengali','arabic','korean','english','indonesian','japanese','russian',
        'telugu','finnish','thai']

dev_result={}
for lan in langs:
    name = lan+'-'
    dev_result[lan]={}
    for cname,ids in coun2ids_dev.items():
        new_examples=[]
        new_preds=collections.OrderedDict()
        for exam in examples:
            if exam.qas_id.startswith(name)==True and exam.qas_id in ids:
                new_examples.append(exam)
                new_preds[exam.qas_id]=predictions[exam.qas_id]
            
            
        if len(new_examples)!=0:
            results = squad_evaluate(new_examples, new_preds)
#             print(name,cname,len(new_examples))
#             print(results['f1'])
            dev_result[lan][cname]=results['f1']

In [50]:
# count_train,count_dev,dev_result

#### analysis

In [85]:
def get_within_between(orig_split, count_split):
    lang_trainc={}
    langs = ['swahili','bengali','arabic','korean','english','indonesian','japanese','russian',
            'telugu','finnish','thai']
    lang_trainc['total']=0
    for lang in langs:
        for x in orig_split.keys():
            if lang in x:
                if lang not in lang_trainc:
                    lang_trainc[lang]=1
                else:
                    lang_trainc[lang]+=1
        if lang in lang_trainc:
            lang_trainc['total']+=lang_trainc[lang]

    gby_coun={}
    for lang,coun in count_split.items():
        for coun,count in coun.items():
            if coun not in gby_coun:
                gby_coun[coun]={}
            gby_coun[coun][lang]={
                'count':count,
                'within_lang':round(100*(count/lang_trainc[lang]),2),
                'between_lang':round(100*(count/lang_trainc['total']),2)
            }
    return gby_coun

In [89]:
gby_countrain=get_within_between(orig_train, count_train)
gby_coundev=get_within_between(orig_dev, count_dev)

In [93]:
gby_coundev

{'Italy': {'swahili': {'count': 6, 'within_lang': 1.2, 'between_lang': 0.12},
  'arabic': {'count': 18, 'within_lang': 1.95, 'between_lang': 0.35},
  'korean': {'count': 5, 'within_lang': 1.81, 'between_lang': 0.1},
  'english': {'count': 3, 'within_lang': 0.68, 'between_lang': 0.06},
  'indonesian': {'count': 7, 'within_lang': 1.24, 'between_lang': 0.14},
  'russian': {'count': 8, 'within_lang': 0.99, 'between_lang': 0.16},
  'finnish': {'count': 12, 'within_lang': 1.53, 'between_lang': 0.24}},
 'United States of America': {'swahili': {'count': 26,
   'within_lang': 5.21,
   'between_lang': 0.51},
  'arabic': {'count': 30, 'within_lang': 3.26, 'between_lang': 0.59},
  'korean': {'count': 30, 'within_lang': 10.87, 'between_lang': 0.59},
  'english': {'count': 47, 'within_lang': 10.68, 'between_lang': 0.93},
  'indonesian': {'count': 26, 'within_lang': 4.6, 'between_lang': 0.51},
  'russian': {'count': 35, 'within_lang': 4.31, 'between_lang': 0.69},
  'telugu': {'count': 13, 'within_lan