In [1]:
import random
import string
import re
import collections

from pythainlp.tokenize import word_tokenize
import pandas as pd
pd.set_option('display.max_colwidth', 300)

In [2]:
def compute_em(result_df):
    correct_num = 0
    num_example = 0
    wrong_qids = []
    for qid, ans, pred in zip(result_df['question_id'], result_df['ans'], result_df['pred']):
        ans = str(ans).replace(' ', '')
        pred = pred.replace(' ', '')
        if ans == pred:
            correct_num += 1
        else:
            wrong_qids.append(qid)

        num_example += 1
    em = correct_num / num_example
    return f'EM: {em:.4f}', wrong_qids

##################################################

def normalize_thai_answer(s):
    """Lower text and remove whitespace."""
    def white_space_fix(text):
        return str(text).replace(' ', '')
    return white_space_fix(s)

def get_thai_tokens(s):
    if not s: return []
    return word_tokenize(normalize_answer(s))

def compute_thai_f1(a_gold, a_pred):
    gold_toks = get_thai_tokens(a_gold)
    pred_toks = get_thai_tokens(a_pred)
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
    # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

##################################################

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
        return re.sub(regex, ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
    if not s: return []
    return normalize_answer(s).split()

def compute_f1(a_gold, a_pred):
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

##################################################

def compute_score(result_df, lang):
    f1 = 0
    correct_num = 0
    num_example = 0
    wrong_qids = []
    for qid, ans, pred in zip(result_df['question_id'], result_df['ans'], result_df['pred']):
        ans = str(ans).replace(' ', '')
        pred = pred.replace(' ', '')
        if ans == pred:
            correct_num += 1
        else:
            wrong_qids.append(qid)

        num_example += 1

    em = correct_num / num_example
    if lang == 'thai':
        for qid, ans, pred in zip(result_df['question_id'], result_df['ans'], result_df['pred']):
            f1 += compute_thai_f1(str(ans), str(pred))
        avg_f1 = f1 / num_example
    elif lang == 'eng':
        for qid, ans, pred in zip(result_df['question_id'], result_df['ans'], result_df['pred']):
            f1 += compute_f1(ans, pred)
        avg_f1 = f1 / num_example
    return f'EM: {em:.4f}', f'F1: {avg_f1:.4f}'

# Test 50

In [3]:
result_test_50_thai_only = pd.read_pickle('/Users/jeew/works/qas/result/result_test_50_thai_only.p')
result_test_50_mixed_7600 = pd.read_pickle('/Users/jeew/works/qas/result/result_test_50_mixed_7600.p')
result_test_50_mixed_23600 = pd.read_pickle('/Users/jeew/works/qas/result/result_test_50_mixed_23600.p')
result_test_50_mixed_all = pd.read_pickle('/Users/jeew/works/qas/result/result_test_50_mixed_all.p')

In [4]:
compute_score(result_test_50_thai_only, 'thai')

('EM: 0.3400', 'F1: 0.4397')

In [5]:
compute_score(result_test_50_mixed_7600, 'thai')

('EM: 0.3200', 'F1: 0.3982')

In [6]:
compute_score(result_test_50_mixed_23600, 'thai')

('EM: 0.3800', 'F1: 0.4657')

In [7]:
compute_score(result_test_50_mixed_all, 'thai')

('EM: 0.5600', 'F1: 0.6180')

# Test 400

In [8]:
result_test_400_thai_only = pd.read_pickle('/Users/jeew/works/qas/result/result_test_400_thai_only.p')
result_test_400_mixed_7600 = pd.read_pickle('/Users/jeew/works/qas/result/result_test_400_mixed_7600.p')
result_test_400_mixed_23600 = pd.read_pickle('/Users/jeew/works/qas/result/result_test_400_mixed_23600.p')
result_test_400_mixed_all = pd.read_pickle('/Users/jeew/works/qas/result/result_test_400_mixed_all.p')

In [9]:
compute_score(result_test_400_thai_only, 'thai')

('EM: 0.5725', 'F1: 0.6002')

In [10]:
compute_score(result_test_400_mixed_7600, 'thai')

('EM: 0.6075', 'F1: 0.6385')

In [11]:
compute_score(result_test_400_mixed_23600, 'thai')

('EM: 0.5950', 'F1: 0.6411')

In [12]:
compute_score(result_test_400_mixed_all, 'thai')

('EM: 0.6125', 'F1: 0.6407')

# Test 1756 (Google translate from SQuAD)

In [13]:
result_test_1756_thai_only = pd.read_pickle('/Users/jeew/works/qas/result/result_test_1756_thai_only.p')
result_test_1756_mixed_7600 = pd.read_pickle('/Users/jeew/works/qas/result/result_test_1756_mixed_7600.p')
result_test_1756_mixed_23600 = pd.read_pickle('/Users/jeew/works/qas/result/result_test_1756_mixed_23600.p')
result_test_1756_mixed_all = pd.read_pickle('/Users/jeew/works/qas/result/result_test_1756_mixed_all.p')

In [14]:
compute_score(result_test_1756_thai_only, 'thai')

('EM: 0.1640', 'F1: 0.2281')

In [15]:
compute_score(result_test_1756_mixed_7600, 'thai')

('EM: 0.2318', 'F1: 0.3046')

In [16]:
compute_score(result_test_1756_mixed_23600, 'thai')

('EM: 0.2654', 'F1: 0.3416')

In [17]:
compute_score(result_test_1756_mixed_all, 'thai')

('EM: 0.3235', 'F1: 0.4137')

# Test 2000 (English, Sample from SQuAD)

In [18]:
result_test_squad_thai_only = pd.read_pickle('/Users/jeew/works/qas/result/result_test_squad_thai_only.p')
result_test_squad_mixed_7600 = pd.read_pickle('/Users/jeew/works/qas/result/result_test_squad_mixed_7600.p')
result_test_squad_mixed_23600 = pd.read_pickle('/Users/jeew/works/qas/result/result_test_squad_mixed_23600.p')
result_test_squad_mixed_all = pd.read_pickle('/Users/jeew/works/qas/result/result_test_squad_mixed_all.p')

In [19]:
compute_score(result_test_squad_thai_only, 'eng')

('EM: 0.1100', 'F1: 0.2011')

In [20]:
compute_score(result_test_squad_mixed_7600, 'eng')

('EM: 0.4085', 'F1: 0.7101')

In [21]:
compute_score(result_test_squad_mixed_23600, 'eng')

('EM: 0.4795', 'F1: 0.7870')

In [22]:
compute_score(result_test_squad_mixed_all, 'eng')

('EM: 0.5255', 'F1: 0.8321')