In [17]:
import pandas as pd
from tqdm import tqdm
import json
from scipy import stats
import itertools
import numpy as np
import sacrebleu

In [18]:
def read_json(path):
    f = open (path, "r")
    data = json.loads(f.read())
    k = list(data.keys())[0]

    src = []
    mt = []
    ref = []
    COMET_score = []

    for i in data[k]:
        src.append(i['src'])
        mt.append(i['mt'])
        ref.append(i['ref'])
        COMET_score.append(float(i['COMET']))

    f.close()
    
    df = pd.DataFrame(data=np.array([src, mt, ref, COMET_score]).T, 
                      columns=['src', 'mt', 'ref', 'comet'])
    
    return df

## en-de

In [19]:
# specify your home path
home_path = '/home/glushkovato/robustness'

split1 = pd.read_csv(home_path + '/robust_MT_evaluation/data/test/mqm2021/split1-updated.csv')
split2 = pd.read_csv(home_path + '/robust_MT_evaluation/data/test/mqm2021/split2-updated.csv')

In [20]:
mqm21 = pd.concat([split1, split2], axis=0)
mqm21_ende = mqm21[mqm21.lp == 'en-de']
mqm21_enru = mqm21[mqm21.lp == 'en-ru']
mqm21_zhen = mqm21[mqm21.lp == 'zh-en']

In [21]:
mqm21_zhen.iloc[:3]

Unnamed: 0.2,Unnamed: 0,System,src,mt,ref,raw_score,score,z_score,lp,Unnamed: 0.1
3120,0,metricsystem5,张继科林小宅组队力挺有肌少年-新华网,Zhang Jike Lin Xiaozhai team up to support mus...,Zhang Jike and Lin Xiaozhai Form a Team to Sup...,-5.0,0.016176,0.016176,zh-en,
3121,1,SMU,张继科林小宅组队力挺有肌少年-新华网,Zhang Jilin small house formed a team to suppo...,Zhang Jike and Lin Xiaozhai Form a Team to Sup...,-5.0,0.016176,0.016176,zh-en,
3122,2,DIDI-NLP,张继科林小宅组队力挺有肌少年-新华网,Zhang Jike and Lin Xiaozhai team up to support...,Zhang Jike and Lin Xiaozhai Form a Team to Sup...,-1.1,0.68438,0.68438,zh-en,


In [12]:
lps = ['en-de', 'en-ru', 'zh-en']
dfs = [mqm21_ende, mqm21_enru, mqm21_zhen]

for i, df in enumerate(dfs):
    lp = lps[i]
    
    with open(home_path + "/robust_MT_evaluation/data/test/mqm2021/" + lp + "_mqm21_src.txt", "w") as f:
        for i in df.src.tolist():
            print(i, file=f)

    with open(home_path + "/robust_MT_evaluation/data/test/mqm2021/" + lp + "_mqm21_mt.txt", "w") as f:
        for i in df.mt.tolist():
            print(i, file=f)

    with open(home_path + "/robust_MT_evaluation/data/test/mqm2021/" + lp + "_mqm21_ref.txt", "w") as f:
        for i in df.ref.tolist():
            print(i, file=f)

    with open(home_path + "/robust_MT_evaluation/data/test/mqm2021/" + lp + "_mqm21_score.txt", "w") as f:
        for i in df.score.tolist():
            print(i, file=f)

# compute features
### EN-DE

In [13]:
mqm21_ende.iloc[:3]

Unnamed: 0.2,Unnamed: 0,System,src,mt,ref,raw_score,score,z_score,lp,Unnamed: 0.1
0,0,metricsystem3,Couple MACED at California dog park for not we...,"Paar im kalifornischen Hundepark GEMAUT, weil ...",Paar in Hundepark in Kalifornien mit Pfeffersp...,-5.0,-0.981905,-0.981905,en-de,
1,1,VolcTrans-AT,Couple MACED at California dog park for not we...,"Paar zerfleischt im kalifornischen Hundepark, ...",Paar in Hundepark in Kalifornien mit Pfeffersp...,-5.0,-0.981905,-0.981905,en-de,
2,2,metricsystem1,Couple MACED at California dog park for not we...,"Paar MACED im kalifornischen Hundepark, weil e...",Paar in Hundepark in Kalifornien mit Pfeffersp...,-15.0,-4.007126,-4.007126,en-de,


In [14]:
# compute bleu

mqm21_ende_scores_bleu = []
refs_mqm21_ende = mqm21_ende.ref.tolist()
mts_mqm21_ende = mqm21_ende.mt.tolist()

for i in tqdm(range(len(mts_mqm21_ende))):
    mqm21_ende_scores_bleu.append(sacrebleu.sentence_bleu(mts_mqm21_ende[i], [refs_mqm21_ende[i]]))
    
mqm21_ende_scores_bleu = np.array([i.score for i in mqm21_ende_scores_bleu])

with open(home_path + "/robust_MT_evaluation/data/test/mqm2021/en-de/mqm21_ende_scores_bleu.txt", "w") as f:
    for i in mqm21_ende_scores_bleu:
        print(i, file=f)


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 6851/6851 [00:01<00:00, 3836.59it/s]


In [9]:
# compute chrf

mqm21_ende_scores_chrf = []
refs_mqm21_ende = mqm21_ende.ref.tolist()
mts_mqm21_ende = mqm21_ende.mt.tolist()

for i in tqdm(range(len(mts_mqm21_ende))):
    mqm21_ende_scores_chrf.append(sacrebleu.sentence_chrf(mts_mqm21_ende[i], [refs_mqm21_ende[i]]))
    
mqm21_ende_scores_chrf = np.array([i.score for i in mqm21_ende_scores_chrf])

with open(home_path + "/robust_MT_evaluation/data/test/mqm2021/en-de/mqm21_ende_scores_chrf.txt", "w") as f:
    for i in mqm21_ende_scores_chrf:
        print(i, file=f)



100%|█████████████████████████████████████████████████████████████████████████████████████████████| 6851/6851 [00:02<00:00, 2714.25it/s]


In [53]:
mqm21_ende_feats = pd.DataFrame(data=np.array([mqm21_ende_scores_bleu, mqm21_ende_scores_chrf]).T, columns=['f1', 'f2'])
mqm21_ende_feats.to_csv(home_path + '/robust_MT_evaluation/data/test/mqm2021/en-de/mqm21_ende_features.csv', index=None)


In [54]:
mqm21_ende_feats.head()

Unnamed: 0,f1,f2
0,47.06366,60.460395
1,51.889248,64.418421
2,37.285764,52.067971
3,29.060361,49.19406
4,37.522511,60.708039


### EN-RU

In [10]:
mqm21_enru.iloc[:3]

Unnamed: 0.2,Unnamed: 0,System,src,mt,ref,raw_score,score,z_score,lp,Unnamed: 0.1
7865,0,metricsystem1,Dominic Raab: Government can't make apologies ...,Доминик Рааб: Правительство не может извинятьс...,Доминик Рааб: Правительство не может извинитьс...,-0.0,100.0,0.557256,en-ru,
7866,1,metricsystem1,Dominic Raab has defended the Government's dec...,Доминик Рааб защитил решение правительства о п...,Доминик Рааб выступил в защиту решения правите...,-0.0,100.0,0.557256,en-ru,
7867,2,metricsystem1,Ministers announced on Saturday that holidayma...,"В субботу министры объявили, что отдыхающие, н...","В субботу министры объявили, что отдыхающие, н...",-0.0,100.0,0.557256,en-ru,


In [10]:
# compute bleu

mqm21_enru_scores_bleu = []
refs_mqm21_enru = mqm21_enru.ref.tolist()
mts_mqm21_enru = mqm21_enru.mt.tolist()

for i in tqdm(range(len(mts_mqm21_enru))):
    mqm21_enru_scores_bleu.append(sacrebleu.sentence_bleu(mts_mqm21_enru[i], [refs_mqm21_enru[i]]))
    
mqm21_enru_scores_bleu = np.array([i.score for i in mqm21_enru_scores_bleu])

with open(home_path + "/robust_MT_evaluation/data/test/mqm2021/en-ru/mqm21_enru_scores_bleu.txt", "w") as f:
    for i in mqm21_enru_scores_bleu:
        print(i, file=f)

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 7378/7378 [00:01<00:00, 4050.48it/s]


In [11]:
# compute chrf

mqm21_enru_scores_chrf = []
refs_mqm21_enru = mqm21_enru.ref.tolist()
mts_mqm21_enru = mqm21_enru.mt.tolist()

for i in tqdm(range(len(mts_mqm21_enru))):
    mqm21_enru_scores_chrf.append(sacrebleu.sentence_chrf(mts_mqm21_enru[i], [refs_mqm21_enru[i]]))
    
mqm21_enru_scores_chrf = np.array([i.score for i in mqm21_enru_scores_chrf])

with open(home_path + "/robust_MT_evaluation/data/test/mqm2021/en-ru/mqm21_enru_scores_chrf.txt", "w") as f:
    for i in mqm21_enru_scores_chrf:
        print(i, file=f)

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 7378/7378 [00:03<00:00, 2399.86it/s]


In [41]:
mqm21_enru_feats = pd.DataFrame(data=np.array([mqm21_enru_scores_bleu, mqm21_enru_scores_chrf]).T, columns=['f1', 'f2'])
mqm21_enru_feats.to_csv(home_path + '/robust_MT_evaluation/data/test/mqm2021/en-ru/mqm21_enru_features.csv', index=None)


In [42]:
mqm21_enru_feats.head()

Unnamed: 0,f1,f2
0,41.128253,75.731222
1,6.437165,47.586519
2,60.039834,80.267139
3,6.250382,59.357699
4,12.684775,48.387963


### ZH-EN

In [12]:
mqm21_zhen.mt = mqm21_zhen.mt.fillna(' ')

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  mqm21_zhen.mt = mqm21_zhen.mt.fillna(' ')


In [13]:
mqm21_zhen.iloc[:3]

Unnamed: 0.2,Unnamed: 0,System,src,mt,ref,raw_score,score,z_score,lp,Unnamed: 0.1
3120,0,metricsystem5,张继科林小宅组队力挺有肌少年-新华网,Zhang Jike Lin Xiaozhai team up to support mus...,Zhang Jike and Lin Xiaozhai Form a Team to Sup...,-5.0,0.016176,0.016176,zh-en,
3121,1,SMU,张继科林小宅组队力挺有肌少年-新华网,Zhang Jilin small house formed a team to suppo...,Zhang Jike and Lin Xiaozhai Form a Team to Sup...,-5.0,0.016176,0.016176,zh-en,
3122,2,DIDI-NLP,张继科林小宅组队力挺有肌少年-新华网,Zhang Jike and Lin Xiaozhai team up to support...,Zhang Jike and Lin Xiaozhai Form a Team to Sup...,-1.1,0.68438,0.68438,zh-en,


In [14]:
# compute bleu

mqm21_zhen_scores_bleu = []
refs_mqm21_zhen = mqm21_zhen.ref.tolist()
mts_mqm21_zhen = mqm21_zhen.mt.tolist()

for i in tqdm(range(len(mts_mqm21_zhen))):
    mqm21_zhen_scores_bleu.append(sacrebleu.sentence_bleu(mts_mqm21_zhen[i], [refs_mqm21_zhen[i]]))
    
mqm21_zhen_scores_bleu = np.array([i.score for i in mqm21_zhen_scores_bleu])

with open(home_path + "/robust_MT_evaluation/data/test/mqm2021/zh-en/mqm21_zhen_scores_bleu.txt", "w") as f:
    for i in mqm21_zhen_scores_bleu:
        print(i, file=f)


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 8450/8450 [00:02<00:00, 3002.21it/s]


In [15]:
# compute chrf

mqm21_zhen_scores_chrf = []
refs_mqm21_zhen = mqm21_zhen.ref.tolist()
mts_mqm21_zhen = mqm21_zhen.mt.tolist()

for i in tqdm(range(len(mts_mqm21_zhen))):
    mqm21_zhen_scores_chrf.append(sacrebleu.sentence_chrf(mts_mqm21_zhen[i], [refs_mqm21_zhen[i]]))
    
mqm21_zhen_scores_chrf = np.array([i.score for i in mqm21_zhen_scores_chrf])

with open(home_path + "/robust_MT_evaluation/data/test/mqm2021/zh-en/mqm21_zhen_scores_chrf.txt", "w") as f:
    for i in mqm21_zhen_scores_chrf:
        print(i, file=f)



100%|█████████████████████████████████████████████████████████████████████████████████████████████| 8450/8450 [00:04<00:00, 1865.16it/s]


In [49]:
mqm21_zhen_feats = pd.DataFrame(data=np.array([mqm21_zhen_scores_bleu, mqm21_zhen_scores_chrf]).T, columns=['f1', 'f2'])
mqm21_zhen_feats.to_csv(home_path + '/robust_MT_evaluation/data/test/mqm2021/zh-en/mqm21_zhen_features.csv', index=None)


In [50]:
mqm21_zhen_feats.head()

Unnamed: 0,f1,f2
0,6.496184,42.857062
1,2.858684,32.090638
2,21.282441,50.27676
3,2.579588,41.500319
4,3.016552,40.083697


In [28]:
def optimize_weights(true_scores, bleu_scores, chrf_scores, comet_scores, scale_a, scale_b, scale_c, 
                     num_bins=200, metric='pearson'):
    scale_a = np.linspace(0, 1, num_bins)
    scale_b = np.linspace(0, 1, num_bins)
    scale_c = np.linspace(0, 1, num_bins)
    best_value = -np.inf
    best_weights = np.nan
    combinations = [i for i in tqdm(itertools.product(scale_a, scale_b, scale_c)) if np.sum(i) == 1]
#     combinations = [i for i in tqdm(itertools.product(scale_a, scale_b, scale_c))]
    print(metric)
    for (a, b, c) in tqdm(combinations):
        weighted_ensemble = a*np.array(bleu_scores) + b*np.array(chrf_scores) + c*np.array(comet_scores)
        if metric == 'pearson':
            value = stats.pearsonr(true_scores, weighted_ensemble)
        elif metric == 'spearman':
            value = stats.spearmanr(true_scores, weighted_ensemble)
        else:
            value = stats.kendalltau(true_scores, weighted_ensemble)
            
        if value[0] > best_value:
            best_value = value[0]
            best_weights = [a, b, c]
    return best_value, best_weights

In [23]:
def compute_norm(x):
    mean = np.mean(x)
    std = np.std(x)
    return [mean, std]

def apply_norm(mean, std, x):
    xn = (x - mean)/std
    return np.array(xn)

In [18]:
v = '24'
path_ende = home_path + '/robust_MT_evaluation/data/test/mqm2021/en-de/mqm21_newstest_output_v' + v + 'e1.json'
mqm21_ende_output = read_json(path_ende)
comet_ende = mqm21_ende_output.comet.astype(float).tolist()

path_enru = home_path + '/robust_MT_evaluation/data/test/mqm2021/en-ru/mqm21_newstest_output_v' + v + 'e1.json'
mqm21_enru_output = read_json(path_enru)
comet_enru = mqm21_enru_output.comet.astype(float).tolist()

path_zhen = home_path + '/robust_MT_evaluation/data/test/mqm2021/zh-en/mqm21_newstest_output_v' + v + 'e1.json'
mqm21_zhen_output = read_json(path_zhen)
comet_zhen = mqm21_zhen_output.comet.astype(float).tolist()


all_bleu = list(mqm21_ende_scores_bleu) + list(mqm21_enru_scores_bleu) + list(mqm21_zhen_scores_bleu)
all_chrf = list(mqm21_ende_scores_chrf) + list(mqm21_enru_scores_chrf) + list(mqm21_zhen_scores_chrf)
all_mqm = mqm21_ende.z_score.tolist() + mqm21_enru.z_score.tolist() + mqm21_zhen.z_score.tolist() 
all_comet = comet_ende + comet_enru + comet_zhen

In [25]:
true_scores = all_mqm

bleu_mean, bleu_std = compute_norm(all_bleu)
scores_bleu = apply_norm(bleu_mean, bleu_std, all_bleu)

chrf_mean, chrf_std = compute_norm(all_chrf)
scores_chrf = apply_norm(chrf_mean, chrf_std, all_chrf)

comet_mean, comet_std = compute_norm(all_comet)
scores_comet = apply_norm(comet_mean, comet_std, all_comet)

In [None]:
bleu_mean = 28.759837809513634
bleu_std = 18.47107097319373
chrf_mean = 58.992697061544284
chrf_std = 14.286372518233168
comet_mean = 0.46782439675103793
comet_std = 0.37521584265953595

true_scores = all_mqm
scores_bleu = apply_norm(bleu_mean, bleu_std, all_bleu)
scores_chrf = apply_norm(chrf_mean, chrf_std, all_chrf)
scores_comet = apply_norm(comet_mean, comet_std, all_comet)

In [36]:
bleu_mean, bleu_std

(28.759837809513634, 18.47107097319373)

In [37]:
chrf_mean, chrf_std

(58.992697061544284, 14.286372518233168)

In [38]:
comet_mean, comet_std

(0.46782439675103793, 0.37521584265953595)

In [20]:
# df = pd.DataFrame(data=np.array([true_scores, scores_bleu, scores_chrf, scores_comet]).T, 
#                   columns=['mqm', 'bleu', 'chrf', 'comet'])

In [28]:
# df.to_csv(home_path + '/robust_MT_evaluation/data/test/mqm2021/mqm2021_scores_for_w_opt.csv', index=None)

In [27]:
num_bins = 200
scale_a = np.linspace(0, 1, num_bins)
scale_b = np.linspace(0, 1, num_bins)
scale_c = np.linspace(0, 1, num_bins)

In [34]:
best_kendall, best_weights = optimize_weights(true_scores, scores_bleu, 
                                              scores_chrf, scores_comet,
                                              scale_a, scale_b, scale_c, metric='kendall')


0it [00:00, ?it/s][A
17348it [00:00, 173464.86it/s][A
34695it [00:00, 169655.91it/s][A
52999it [00:00, 175689.26it/s][A
71815it [00:00, 180571.33it/s][A
90528it [00:00, 182923.26it/s][A
108828it [00:00, 182563.84it/s][A
127247it [00:00, 183091.09it/s][A
145713it [00:00, 183587.04it/s][A
164326it [00:00, 184375.93it/s][A
182775it [00:01, 184406.69it/s][A
201576it [00:01, 185505.78it/s][A
225743it [00:01, 202579.76it/s][A
250615it [00:01, 216549.61it/s][A
274695it [00:01, 223869.24it/s][A
297742it [00:01, 225856.92it/s][A
322699it [00:01, 232991.34it/s][A
346986it [00:01, 235959.21it/s][A
371362it [00:01, 238300.05it/s][A
395833it [00:01, 240222.58it/s][A
420484it [00:02, 242107.85it/s][A
444781it [00:02, 242363.92it/s][A
469018it [00:02, 225239.31it/s][A
491785it [00:02, 211629.63it/s][A
513242it [00:02, 203195.32it/s][A
533788it [00:02, 197505.61it/s][A
553690it [00:02, 194107.51it/s][A
573194it [00:02, 186616.39it/s][A
591938it [00:02, 185898.43it/s][A
61

kendall



  0%|                                                                                                         | 0/19478 [00:00<?, ?it/s][A
  0%|                                                                                               | 16/19478 [00:00<02:02, 158.46it/s][A
  0%|▏                                                                                              | 32/19478 [00:00<02:02, 158.22it/s][A
  0%|▏                                                                                              | 48/19478 [00:00<02:02, 158.05it/s][A
  0%|▎                                                                                              | 64/19478 [00:00<02:03, 157.79it/s][A
  0%|▍                                                                                              | 80/19478 [00:00<02:03, 157.68it/s][A
  0%|▍                                                                                              | 96/19478 [00:00<02:02, 158.07it/s][A
  1%|▌             

 12%|███████████▏                                                                                 | 2340/19478 [00:11<01:21, 211.24it/s][A
 12%|███████████▎                                                                                 | 2362/19478 [00:12<01:21, 210.61it/s][A
 12%|███████████▍                                                                                 | 2384/19478 [00:12<01:21, 210.38it/s][A
 12%|███████████▍                                                                                 | 2406/19478 [00:12<01:21, 210.10it/s][A
 12%|███████████▌                                                                                 | 2428/19478 [00:12<01:21, 210.08it/s][A
 13%|███████████▋                                                                                 | 2450/19478 [00:12<01:20, 212.15it/s][A
 13%|███████████▊                                                                                 | 2472/19478 [00:12<01:19, 213.79it/s][A
 13%|███████████▉   

 24%|██████████████████████▍                                                                      | 4710/19478 [00:24<01:32, 159.57it/s][A
 24%|██████████████████████▌                                                                      | 4727/19478 [00:24<01:32, 159.80it/s][A
 24%|██████████████████████▋                                                                      | 4743/19478 [00:24<01:32, 159.28it/s][A
 24%|██████████████████████▋                                                                      | 4759/19478 [00:24<01:32, 159.10it/s][A
 25%|██████████████████████▊                                                                      | 4776/19478 [00:24<01:32, 159.64it/s][A
 25%|██████████████████████▉                                                                      | 4792/19478 [00:24<01:32, 159.45it/s][A
 25%|██████████████████████▉                                                                      | 4808/19478 [00:24<01:32, 159.12it/s][A
 25%|███████████████

 37%|██████████████████████████████████                                                           | 7130/19478 [00:36<00:58, 210.61it/s][A
 37%|██████████████████████████████████▏                                                          | 7152/19478 [00:36<00:58, 210.20it/s][A
 37%|██████████████████████████████████▎                                                          | 7174/19478 [00:36<00:58, 210.10it/s][A
 37%|██████████████████████████████████▎                                                          | 7196/19478 [00:36<00:58, 209.78it/s][A
 37%|██████████████████████████████████▍                                                          | 7217/19478 [00:36<00:58, 209.77it/s][A
 37%|██████████████████████████████████▌                                                          | 7238/19478 [00:36<00:58, 209.20it/s][A
 37%|██████████████████████████████████▋                                                          | 7259/19478 [00:36<00:58, 208.96it/s][A
 37%|███████████████

 49%|█████████████████████████████████████████████▉                                               | 9633/19478 [00:48<00:46, 210.42it/s][A
 50%|██████████████████████████████████████████████                                               | 9655/19478 [00:48<00:46, 212.51it/s][A
 50%|██████████████████████████████████████████████▏                                              | 9677/19478 [00:48<00:45, 213.99it/s][A
 50%|██████████████████████████████████████████████▎                                              | 9699/19478 [00:48<00:45, 212.60it/s][A
 50%|██████████████████████████████████████████████▍                                              | 9721/19478 [00:48<00:46, 211.07it/s][A
 50%|██████████████████████████████████████████████▌                                              | 9743/19478 [00:48<00:46, 210.48it/s][A
 50%|██████████████████████████████████████████████▌                                              | 9765/19478 [00:48<00:46, 210.15it/s][A
 50%|███████████████

 61%|████████████████████████████████████████████████████████▍                                   | 11947/19478 [01:00<00:35, 210.78it/s][A
 61%|████████████████████████████████████████████████████████▌                                   | 11969/19478 [01:00<00:36, 207.82it/s][A
 62%|████████████████████████████████████████████████████████▋                                   | 11990/19478 [01:00<00:35, 208.22it/s][A
 62%|████████████████████████████████████████████████████████▋                                   | 12011/19478 [01:00<00:35, 208.49it/s][A
 62%|████████████████████████████████████████████████████████▊                                   | 12032/19478 [01:00<00:35, 208.59it/s][A
 62%|████████████████████████████████████████████████████████▉                                   | 12053/19478 [01:00<00:35, 208.44it/s][A
 62%|█████████████████████████████████████████████████████████                                   | 12074/19478 [01:00<00:35, 208.30it/s][A
 62%|███████████████

 74%|████████████████████████████████████████████████████████████████████                        | 14402/19478 [01:12<00:24, 209.33it/s][A
 74%|████████████████████████████████████████████████████████████████████                        | 14423/19478 [01:12<00:24, 208.86it/s][A
 74%|████████████████████████████████████████████████████████████████████▏                       | 14444/19478 [01:12<00:24, 208.46it/s][A
 74%|████████████████████████████████████████████████████████████████████▎                       | 14466/19478 [01:12<00:23, 210.81it/s][A
 74%|████████████████████████████████████████████████████████████████████▍                       | 14488/19478 [01:12<00:23, 210.02it/s][A
 74%|████████████████████████████████████████████████████████████████████▌                       | 14510/19478 [01:12<00:23, 209.62it/s][A
 75%|████████████████████████████████████████████████████████████████████▋                       | 14531/19478 [01:12<00:26, 185.35it/s][A
 75%|███████████████

 86%|███████████████████████████████████████████████████████████████████████████████▌            | 16843/19478 [01:24<00:15, 173.08it/s][A
 87%|███████████████████████████████████████████████████████████████████████████████▋            | 16861/19478 [01:24<00:15, 168.33it/s][A
 87%|███████████████████████████████████████████████████████████████████████████████▋            | 16878/19478 [01:24<00:15, 165.70it/s][A
 87%|███████████████████████████████████████████████████████████████████████████████▊            | 16895/19478 [01:24<00:15, 164.00it/s][A
 87%|███████████████████████████████████████████████████████████████████████████████▉            | 16912/19478 [01:24<00:15, 162.70it/s][A
 87%|███████████████████████████████████████████████████████████████████████████████▉            | 16929/19478 [01:24<00:15, 161.64it/s][A
 87%|████████████████████████████████████████████████████████████████████████████████            | 16946/19478 [01:24<00:15, 160.77it/s][A
 87%|███████████████

 98%|██████████████████████████████████████████████████████████████████████████████████████████▍ | 19145/19478 [01:36<00:01, 208.54it/s][A
 98%|██████████████████████████████████████████████████████████████████████████████████████████▌ | 19166/19478 [01:36<00:01, 207.99it/s][A
 99%|██████████████████████████████████████████████████████████████████████████████████████████▋ | 19187/19478 [01:36<00:01, 207.90it/s][A
 99%|██████████████████████████████████████████████████████████████████████████████████████████▋ | 19208/19478 [01:36<00:01, 207.56it/s][A
 99%|██████████████████████████████████████████████████████████████████████████████████████████▊ | 19229/19478 [01:36<00:01, 207.70it/s][A
 99%|██████████████████████████████████████████████████████████████████████████████████████████▉ | 19251/19478 [01:36<00:01, 210.22it/s][A
 99%|███████████████████████████████████████████████████████████████████████████████████████████ | 19273/19478 [01:36<00:00, 209.49it/s][A
 99%|███████████████

In [35]:
best_kendall, best_weights

(0.22566684283236826,
 [0.02512562814070352, 0.04522613065326633, 0.9296482412060302])

In [62]:
versions = ['24', '25', '29', '83']

In [56]:
def compute_correlations(df):
    pearson = np.round(stats.pearsonr(df.comet, df.mqm), 5)
    spearman = np.round(stats.spearmanr(df.comet, df.mqm), 5)
    kendall = np.round(stats.kendalltau(df.comet, df.mqm), 5)
    return pearson[0], spearman[0], kendall[0]

## en-de

In [59]:
pearsons = []
spearmans = []
kendalls = []

for v in versions:
    path = '/home/glushkovato/robustness/COMET/data/dev/predictions/en-de/mqm21_newstest_output_v' + v + 'e1.json'
    mqm21_ende_output = read_json(path)
    mqm21_ende_output['mqm'] = mqm21_ende.score.tolist()
    mqm21_ende_output.comet = mqm21_ende_output.comet.astype(float)
    mqm21_ende_output.mqm = mqm21_ende_output.mqm.astype(float)

    p, s, k = compute_correlations(mqm21_ende_output)
    pearsons.append(p)
    spearmans.append(s)
    kendalls.append(k)
    print(p, s, k)

0.26747 0.30603 0.23256
0.24256 0.26363 0.19987
0.27094 0.29343 0.22262
0.25374 0.28464 0.21602
0.25169 0.28599 0.21737
0.26651 0.29569 0.2245


## en-ru

In [60]:
pearsons = []
spearmans = []
kendalls = []

for v in versions:
    path = '/home/glushkovato/robustness/COMET/data/dev/predictions/en-ru/mqm21_newstest_output_v' + v + 'e1.json'
    mqm21_enru_output = read_json(path)
    mqm21_enru_output['mqm'] = mqm21_enru.score.tolist()
    mqm21_enru_output.comet = mqm21_enru_output.comet.astype(float)
    mqm21_enru_output.mqm = mqm21_enru_output.mqm.astype(float)

    p, s, k = compute_correlations(mqm21_enru_output)
    pearsons.append(p)
    spearmans.append(s)
    kendalls.append(k)
    print(p, s, k)

0.33841 0.35588 0.27266
0.3228 0.34875 0.26705
0.40141 0.39594 0.30524
0.36862 0.36259 0.27905
0.35703 0.35825 0.27526
0.36945 0.36587 0.28065


## zh-en

In [61]:
pearsons = []
spearmans = []
kendalls = []

for v in versions:
    path = '/home/glushkovato/robustness/COMET/data/dev/predictions/zh-en/mqm21_newstest_output_v' + v + 'e1.json'
    mqm21_zhen_output = read_json(path)
    mqm21_zhen_output['mqm'] = mqm21_zhen.score.tolist()
    mqm21_zhen_output.comet = mqm21_zhen_output.comet.astype(float)
    mqm21_zhen_output.mqm = mqm21_zhen_output.mqm.astype(float)

    p, s, k = compute_correlations(mqm21_zhen_output)
    pearsons.append(p)
    spearmans.append(s)
    kendalls.append(k)
    print(p, s, k)

0.35353 0.43859 0.31605
0.38834 0.44822 0.32399
0.38276 0.44158 0.31854
0.3885 0.4562 0.33012
0.37932 0.45002 0.32466
0.35644 0.4376 0.31536
