Skip to content

Commit

Permalink
Integrate Mewsli-X into XTREME-R evaluation, and hook up some missing…
Browse files Browse the repository at this point in the history
… tests for the XTREME-R task dictionary.

PiperOrigin-RevId: 456011711
  • Loading branch information
Jan Botha authored and sebastianruder committed Jun 20, 2022
1 parent 66367fa commit da430dd
Show file tree
Hide file tree
Showing 24 changed files with 94 additions and 11 deletions.
31 changes: 22 additions & 9 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from seqeval.metrics import recall_score
from third_party.evaluate_mlqa import evaluate as mlqa_eval
from third_party.evaluate_squad import evaluate as squad_eval
from third_party.utils_mewslix import evaluate as mewslix_eval


def read_tag(file):
Expand Down Expand Up @@ -108,6 +109,12 @@ def mlqa_em_f1(labels, predictions, language):
return mlqa_eval(labels, predictions, language)


def mewslix_map20(labels, predictions, language=None):
del language
mrr = mewslix_eval(labels, predictions, k=20)
return {'map@20': mrr * 100}


XTREME_GROUP2TASK = {
'classification': ['pawsx', 'xnli'],
'tagging': ['udpos', 'panx'],
Expand All @@ -120,7 +127,7 @@ def mlqa_em_f1(labels, predictions, language):
'classification': ['xnli', 'xcopa'],
'tagging': ['udpos', 'panx'],
'qa': ['xquad', 'mlqa', 'tydiqa'],
'retrieval': ['tatoeba'],
'retrieval': ['tatoeba', 'mewslix'],
'multi_choice': ['xcopa'],
}

Expand Down Expand Up @@ -162,7 +169,7 @@ def mlqa_em_f1(labels, predictions, language):
'ro'.split(','),
'xcopa': 'et,ht,id,it,qu,sw,ta,th,tr,vi,zh'.split(','),
'lareqa': [],
'mewslix': [],
'mewslix': 'ar,de,en,es,fa,ja,pl,ro,ta,tr,uk'.split(','),
'xquad': 'en,es,de,el,ru,tr,ar,vi,th,zh,hi,ro'.split(','),
'mlqa': 'en,es,de,ar,hi,vi,zh'.split(','),
'tydiqa': 'en,ar,bn,fi,id,ko,ru,sw,te'.split(','),
Expand All @@ -183,6 +190,7 @@ def mlqa_em_f1(labels, predictions, language):
'bucc2018': read_label,
'tatoeba': read_label,
'xquad': read_squad,
'mewslix': read_squad,
'mlqa': read_squad,
'tydiqa': read_squad,
'xcopa': read_xcopa,
Expand All @@ -197,6 +205,7 @@ def mlqa_em_f1(labels, predictions, language):
'bucc2018': bucc_f1,
'tatoeba': accuracy,
'xquad': squad_em_f1,
'mewslix': mewslix_map20,
'mlqa': mlqa_em_f1,
'tydiqa': squad_em_f1,
'xcopa': accuracy,
Expand All @@ -219,14 +228,23 @@ def evaluate_one_task(prediction_file, label_file, task, language=None):
"""
predictions = READER_FUNCTION[task](prediction_file)
labels = READER_FUNCTION[task](label_file)
if task not in ['bucc2018', 'mlqa', 'tydiqa', 'xquad']:
if task not in ['bucc2018', 'mewslix', 'mlqa', 'tydiqa', 'xquad']:
assert len(predictions) == len(labels), (
'Number of examples in {} and {} not matched in {} task'.format(
prediction_file, label_file, task))
result = METRIC_FUNCTION[task](labels, predictions, language)
return result


def get_suffix(task, group2task):
if task in group2task['qa'] or task in ('mewslix',):
return 'json'
elif 'multi_choice' in group2task and task in group2task['multi_choice']:
return 'jsonl'
else:
return 'tsv'


def evaluate(prediction_folder, label_folder, xtreme_version, verbose=False):
"""Evaluate on all tasks if available.
Expand All @@ -250,12 +268,7 @@ def evaluate(prediction_folder, label_folder, xtreme_version, verbose=False):
detailed_scores = {}
for task, langs in task2langs.items():
if task in prediction_tasks and task in label_tasks:
if task in group2task['qa']:
suffix = 'json'
elif 'multi_choice' in group2task and task in group2task['multi_choice']:
suffix = 'jsonl'
else:
suffix = 'tsv'
suffix = get_suffix(task, group2task)
# collect scores over all languages
score = collections.defaultdict(dict)
for lg in langs:
Expand Down
52 changes: 50 additions & 2 deletions evaluate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
from absl.testing import absltest
from absl.testing import parameterized
from xtreme.evaluate import evaluate_one_task
from xtreme.evaluate import get_suffix
from xtreme.evaluate import XTREME_GROUP2TASK
from xtreme.evaluate import XTREME_R_GROUP2TASK
from xtreme.evaluate import XTREME_R_TASK2LANGS
from xtreme.evaluate import XTREME_TASK2LANGS

DATA_DIR = './/mock_test_data'

# Mock submission scores for testing
# Mock submission scores for testing XTREME.
TASK2AVG_SCORES = {
'pawsx': {'avg_accuracy': 51.42857142857143},
'xnli': {'avg_accuracy': 30.666666666666668},
Expand All @@ -40,6 +43,24 @@
'tydiqa': {'avg_exact_match': 88.88888888888889, 'avg_f1': 97.22222222222223}
}

# Mock submission scores for testing XTREME-R.
# TODO(ruder): Update data/numbers for tasks with added languages (UD-POS,
# PANX, Tatoeba, and XQuAD) and for new tasks (XCOPA, LAReQA).
XTREME_R_TASK2AVG_SCORES = {
'xnli': {'avg_accuracy': 30.666666666666668},
# 'panx': {'avg_f1': 57.50793650793652, 'avg_precision': 54.729166666666664,
# 'avg_recall': 62.750000000000014},
# 'udpos': {'avg_f1': 70.21746048354693, 'avg_precision': 71.02232625883823,
# 'avg_recall': 69.54982073976082},
# 'tatoeba': {'avg_accuracy': 53.611111111111114},
# 'xcopa'
# 'lareqa'
'mewslix': {'avg_map@20': 14.39025156130419},
# 'xquad': {'avg_exact_match': 77.27272727272727, 'avg_f1': 79.9586776859504},
'mlqa': {'avg_exact_match': 57.142857142857146, 'avg_f1': 81.76870748299321},
'tydiqa': {'avg_exact_match': 88.88888888888889, 'avg_f1': 97.22222222222223}
}


class EvaluateTest(parameterized.TestCase):
"""Test cases for evaluate.py."""
Expand All @@ -54,7 +75,7 @@ class EvaluateTest(parameterized.TestCase):
('XQuAD', 'xquad'),
('MLQA', 'mlqa'),
('TyDiQA', 'tydiqa'))
def testTask(self, task):
def testXtremeTask(self, task):
data_dir = os.path.join(absltest.get_default_test_srcdir(), DATA_DIR)
suffix = 'json' if task in XTREME_GROUP2TASK['qa'] else 'tsv'
score = collections.defaultdict(dict)
Expand All @@ -71,5 +92,32 @@ def testTask(self, task):
self.assertEqual(avg_score, TASK2AVG_SCORES[task])


@parameterized.named_parameters(
('XNLI', 'xnli'),
# ('PANX', 'panx'),
# ('UDPOS', 'udpos'),
# ('Tatoeba', 'tatoeba'),
# ('XCOPA', 'xcopa'),
# ('LAReQA', 'lareqa'),
('Mewsli-X', 'mewslix'),
# ('XQuAD', 'xquad'),
('MLQA', 'mlqa'),
('TyDiQA', 'tydiqa'))
def testXtremeRTask(self, task):
data_dir = os.path.join(absltest.get_default_test_srcdir(), DATA_DIR)
suffix = get_suffix(task, XTREME_R_GROUP2TASK)
score = collections.defaultdict(dict)
for lg in XTREME_R_TASK2LANGS[task]:
pred_file = os.path.join(data_dir, 'predictions', task,
f'test-{lg}.{suffix}')
label_file = os.path.join(data_dir, 'labels', task, f'test-{lg}.{suffix}')
score_lg = evaluate_one_task(pred_file, label_file, task, language=lg)
for metric in score_lg:
score[metric][lg] = score_lg[metric]
avg_score = {}
for m in score:
avg_score[f'avg_{m}'] = sum(score[m].values()) / len(score[m])
self.assertEqual(avg_score, XTREME_R_TASK2AVG_SCORES[task])

if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions mock_test_data/labels/mewslix/test-ar.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"8e18e51eced73e6495df0043192edbfe": ["Q46930"]}
1 change: 1 addition & 0 deletions mock_test_data/labels/mewslix/test-de.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"4be5a1742223cc3a8c01e6bf9c6e3f27": ["Q156913"]}
1 change: 1 addition & 0 deletions mock_test_data/labels/mewslix/test-en.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"64ca9e2f229acf8e39c2a3d2e45f81e7": ["Q720285"]}
1 change: 1 addition & 0 deletions mock_test_data/labels/mewslix/test-es.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"4a2d7fd3e4791f09bc3c804a15d647ef": ["Q786"]}
1 change: 1 addition & 0 deletions mock_test_data/labels/mewslix/test-fa.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"d35cc57a7869168ddeb8143c1b2260f3": ["Q76"]}
1 change: 1 addition & 0 deletions mock_test_data/labels/mewslix/test-ja.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"d0e7a9dd0359610c53bba176d702dfce": ["Q174691"]}
1 change: 1 addition & 0 deletions mock_test_data/labels/mewslix/test-pl.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"64232b8a3c3ee67f76f96ccd963b78f7": ["Q1362561"]}
1 change: 1 addition & 0 deletions mock_test_data/labels/mewslix/test-ro.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"ebd92132adbb679fdd090503cd925f81": ["Q185007"]}
1 change: 1 addition & 0 deletions mock_test_data/labels/mewslix/test-ta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"12760cb39680a822c3cd0c8495cf1b4b": ["Q11468"]}
1 change: 1 addition & 0 deletions mock_test_data/labels/mewslix/test-tr.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"9f39acb0fef259aaf24224fe41954f6c": ["Q258"]}
1 change: 1 addition & 0 deletions mock_test_data/labels/mewslix/test-uk.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"9f4dba86a6d21cfd246353403da46abd": ["Q1899"]}
1 change: 1 addition & 0 deletions mock_test_data/predictions/mewslix/test-ar.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"8e18e51eced73e6495df0043192edbfe": ["Q4963862", "Q42309905", "Q45789", "Q13403337", "Q5564588", "Q4009605", "Q1635932", "Q4980057", "Q5958027", "Q233750", "Q2922959", "Q203023", "Q2425422", "Q2340576", "Q4639323", "Q46930", "Q66891", "Q5423986", "Q15556629", "Q1347825"]}
1 change: 1 addition & 0 deletions mock_test_data/predictions/mewslix/test-de.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"4be5a1742223cc3a8c01e6bf9c6e3f27": ["Q11490423", "Q156913", "Q490356", "Q16222746", "Q4873731", "Q2102531", "Q209944", "Q4630241", "Q9033638", "Q18249334", "Q65216438", "Q333185", "Q2530561", "Q20013418", "Q10826362", "Q2575270", "Q2914850", "Q55697199", "Q853167", "Q111730"]}
1 change: 1 addition & 0 deletions mock_test_data/predictions/mewslix/test-en.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"64ca9e2f229acf8e39c2a3d2e45f81e7": ["Q82674", "Q13551861", "Q2418898", "Q198748", "Q1146387", "Q6730240", "Q6769706", "Q2315496", "Q3375182", "Q711611", "Q55732114", "Q720285", "Q4760035", "Q28670149", "Q375278", "Q260559", "Q82840", "Q878942", "Q269810", "Q427535"]}
1 change: 1 addition & 0 deletions mock_test_data/predictions/mewslix/test-es.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"4a2d7fd3e4791f09bc3c804a15d647ef": ["Q6151759", "Q19904197", "Q1138905", "Q440165", "Q787524", "Q13050046", "Q15748660", "Q6604140", "Q11400285", "Q20071151", "Q2912875", "Q786", "Q1999706", "Q11398056", "Q4486275", "Q3744158", "Q63524702", "Q38745473", "Q37996883", "Q29260670"]}
1 change: 1 addition & 0 deletions mock_test_data/predictions/mewslix/test-fa.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"d35cc57a7869168ddeb8143c1b2260f3": ["Q333972", "Q48270", "Q5254564", "Q76", "Q5947394", "Q3151708", "Q1756916", "Q63091766", "Q13104276", "Q5839704", "Q6598064", "Q1008989", "Q48762758", "Q55842144", "Q461358", "Q447087", "Q13640998", "Q535894", "Q223278", "Q3504372"]}
1 change: 1 addition & 0 deletions mock_test_data/predictions/mewslix/test-ja.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"d0e7a9dd0359610c53bba176d702dfce": ["Q1210312", "Q3662301", "Q2877167", "Q13548902", "Q3458109", "Q65159649", "Q49892", "Q204547", "Q12699816", "Q372592", "Q1776619", "Q16633277", "Q1658454", "Q174691", "Q1053638", "Q23653996", "Q798074", "Q24939391", "Q8037644", "Q65967892"]}
1 change: 1 addition & 0 deletions mock_test_data/predictions/mewslix/test-pl.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"64232b8a3c3ee67f76f96ccd963b78f7": ["Q1033066", "Q565472", "Q11598441", "Q29522", "Q16027287", "Q1174348", "Q1052293", "Q16903684", "Q12860947", "Q48769622", "Q2606279", "Q7315521", "Q268776", "Q13621486", "Q1400430", "Q7124665", "Q11280748", "Q710911", "Q1362561", "Q34754"]}
1 change: 1 addition & 0 deletions mock_test_data/predictions/mewslix/test-ro.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"ebd92132adbb679fdd090503cd925f81": ["Q1144739", "Q5836568", "Q20582855", "Q1311", "Q711832", "Q185007", "Q311559", "Q50391138", "Q55418237", "Q5037965", "Q601712", "Q6654524", "Q615949", "Q980941", "Q5188638", "Q15060144", "Q6737309", "Q21670139", "Q1040955", "Q928053"]}
1 change: 1 addition & 0 deletions mock_test_data/predictions/mewslix/test-ta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"12760cb39680a822c3cd0c8495cf1b4b": ["Q22959171", "Q13385006", "Q608803", "Q3046191", "Q1750336", "Q15353797", "Q1695555", "Q124473", "Q836937", "Q3297349", "Q430687", "Q2181287", "Q11468", "Q20393369", "Q888226", "Q56477015", "Q22692651", "Q13829184", "Q2479497", "Q3207103"]}
1 change: 1 addition & 0 deletions mock_test_data/predictions/mewslix/test-tr.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"9f39acb0fef259aaf24224fe41954f6c": ["Q11350542", "Q188447", "Q15905812", "Q15868", "Q6630136", "Q6734763", "Q105927", "Q258", "Q9181720", "Q313196", "Q4099359", "Q15567185", "Q587455", "Q190436", "Q5284896", "Q18709782", "Q16233625", "Q5246694", "Q11620425", "Q12568992"]}
1 change: 1 addition & 0 deletions mock_test_data/predictions/mewslix/test-uk.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"9f4dba86a6d21cfd246353403da46abd": ["Q524624", "Q3830755", "Q3800390", "Q508679", "Q20383186", "Q930701", "Q18682623", "Q16969424", "Q1899", "Q2320371", "Q266613", "Q2469647", "Q749794", "Q6241038", "Q5754881", "Q2879448", "Q1630799", "Q447", "Q628319", "Q25515301"]}

0 comments on commit da430dd

Please sign in to comment.