In [11]:
%matplotlib inline
import pandas as pd
from scipy.stats import pearsonr 
import matplotlib.pyplot as plt
plt.style.use('seaborn-whitegrid')
import numpy as np

pretrained = 'bert-base-multilingual-uncased'
#pretrained = 'bert-base-multilingual-cased' 
#pretrained = 'xlm-roberta-base'
#pretrained = 'xlm-roberta-large'

In [13]:
def read_array(arr):
    arr = arr.replace('[','').replace(']','')
    num = arr.split(', ')
    num = [float(a) for a in num]
    return num

def read_array2(arr):
    arr = arr.replace('[','').replace(']','')
    num = arr.split(', ')
    num = [a.replace('\'','') for a in num]
    return num

def match(keys, values):
    res = {}
    for idx, key in enumerate(keys):
        res[key] = values[idx]
    return res
        
def correlation(human_metric, id2score):
    filtered_human = []
    filtered_metric = []
    
    for key in id2score.keys():
        if len(human_metric[human_metric['id']==int(key)]) == 0:
            continue
        human_score = human_metric[human_metric['id'] == int(key)]['score'].tolist()[0]
        filtered_human.append(human_score)
        filtered_metric.append(id2score[key])
    return pearsonr(filtered_human, filtered_metric)[0], len(filtered_human)

def get_max(arr):
    max_id = np.argmax(arr)
    print(max_id+1, '--', arr[max_id])
    min_id = np.argmin(arr)
    print(min_id+1, '--', arr[min_id])

def get_best_rank(corr_arrs):
    score = 0
    rank = -1
    
    for idx in range(len(corr_arrs[0])):
        tmp = 0
        for corr_arr in corr_arrs:
            tmp += corr_arr[idx]
        tmp /= len(corr_arrs)
        if tmp > score:
            rank = idx
            score = tmp        
    print('Best layer:', rank+1)
    return rank

In [14]:
def read_human_annotation(lang, model, types='focus'):
    if types=='focus':
        path_human = f'mturk/annotation_result/{lang}/human_focus_final.csv'
    else:
        assert types=='coverage'
        path_human = f'mturk/annotation_result/{lang}/human_coverage_final.csv'
    human = pd.read_csv(path_human)
    return human[human['model']==model]

def read(lang, model, human_score, prec_or_rec):
    path = f'bert_score/{lang}--{model}--{pretrained}.csv'
    df = pd.read_csv(path)
    xlabel = []; cors = []
    for idx, row in df.iterrows():
        cur_layer = row['layer']
        score = read_array(row[prec_or_rec])
        doc_id = read_array2(row['fnames'])
        id2score = match(doc_id, score)
        
        xlabel.append(cur_layer)
        cors.append(correlation(human_score, id2score)[0])
    return xlabel, cors

In [17]:
# Focus
cors = {}
for lang in LANGS:
    for model in ['PG', 'BERT']:
        human_score = read_human_annotation(lang, model, 'focus')
        _, cor = read(lang, model, human_score, 'precision')
        cors[(lang, model)] = cor

print('Focus of', pretrained)
layer = get_best_rank(list(cors.values()))
print('Individual model performance:')
for key in cors:
    print(key, cors[key][layer])

Focus of bert-base-multilingual-uncased
Best layer: 12
Individual model performance:
('EN', 'PG') 0.6366980318400038
('EN', 'BERT') 0.6100619544153179
('ID', 'PG') 0.6826102703970842
('ID', 'BERT') 0.7183780273572422
('FR', 'PG') 0.6720144746324448
('FR', 'BERT') 0.7330420943115018
('TR', 'PG') 0.860145104227336
('TR', 'BERT') 0.8092996194988722
('ZH', 'PG') 0.7944995696980006
('ZH', 'BERT') 0.7744011505746847
('RU', 'PG') 0.4806242731518413
('RU', 'BERT') 0.6933440880722743
('DE', 'PG') 0.888640529382837
('DE', 'BERT') 0.8994458047094434
('ES', 'PG') 0.7015689507007323
('ES', 'BERT') 0.570212706066844


In [18]:
# Coverage
cors = {}
for lang in LANGS:
    for model in ['PG', 'BERT']:
        human_score = read_human_annotation(lang, model, 'coverage')
        _, cor = read(lang, model, human_score, 'recall')
        cors[(lang, model)] = cor

print('Coverage of', pretrained)
layer = get_best_rank(list(cors.values()))
print('Individual model performance:')
for key in cors:
    print(key, cors[key][layer])

Coverage of bert-base-multilingual-uncased
Best layer: 6
Individual model performance:
('EN', 'PG') 0.6337745838190866
('EN', 'BERT') 0.6432568225528881
('ID', 'PG') 0.7539321018441252
('ID', 'BERT') 0.7291109531081518
('FR', 'PG') 0.6966728453074331
('FR', 'BERT') 0.7584867035912555
('TR', 'PG') 0.8497323827865577
('TR', 'BERT') 0.8906809021022299
('ZH', 'PG') 0.810508927683462
('ZH', 'BERT') 0.7653215496265204
('RU', 'PG') 0.6744441649350216
('RU', 'BERT') 0.6811374244786897
('DE', 'PG') 0.9068873906259574
('DE', 'BERT') 0.9033131846111686
('ES', 'PG') 0.7123259131508018
('ES', 'BERT') 0.7331705718492405
