In [None]:
import numpy as np
import matplotlib.pyplot as plt
import argparse
import torch
import pandas as pd
import seaborn as sns
sns.set()

font_size = 18
ticks_size = 15

In [None]:
def get_calibration_data(srcind, tgtind, repo_path, dataset='tico19', model='towerinstruct13b', temperature=1.0, top_p=1.0, n_samples=50, split='dev', mbr=True):
    if dataset=='tico19':
        source_file=f"{repo_path}results-{dataset}/{split}/{srcind}-{tgtind}/src.txt"
        target_file=f"{repo_path}results-{dataset}/{split}/{srcind}-{tgtind}/ref.txt"
        output_file_dir=f"{repo_path}results-{dataset}/{split}/{srcind}-{tgtind}"
        output_file=f"{output_file_dir}/{model}-temp{temperature}-topp{top_p}-n{n_samples}"
        metrics_dir=f"{repo_path}results-{dataset}/{split}/{srcind}-{tgtind}/metrics"
        metrics_segment=f"{metrics_dir}/segment/{model}-temp{temperature}-topp{top_p}-n{n_samples}"
        if mbr:
            mbr_matrices_path=f"{output_file}-mbr-mbrmatrices"

    # get sources and targets
    sources = [s.strip() for s in open(source_file, "r").readlines()]
    targets = [s.strip() for s in open(target_file, "r").readlines()]

    # get translations df
    with open(output_file, encoding="utf-8") as hyp_f:
            hyps = [line.strip() for line in hyp_f.readlines()]
    translations_df = pd.DataFrame(hyps)

    # read metrics df
    metrics_df = pd.read_csv(metrics_segment, sep=" ")

    # choose metric
    metric='comet'

    result = [num for num in range(len(sources)) for _ in range(n_samples)]
    metrics_df['src_idx'] = result
    
    if mbr:
        mbr_matrices = torch.load(mbr_matrices_path)
        return metrics_df, sources, targets, translations_df, mbr_matrices

    return metrics_df, sources, targets, translations_df

In [None]:
def rerankers(metrics_df, sources, n_samples, n_examples, metric='comet', critical_threshold=0.8):
    # add n_sample for each example
    aux_1 = np.arange(1, n_samples+1)
    aux_2= np.tile(aux_1, n_examples)
    metrics_df['n_sample'] = aux_2

    # add critical errors
    # the definition of critical error is always a threshold on COMET
    metrics_df['critical'] = np.where(metrics_df['comet'] < critical_threshold, 1, 0)

    critical_n_lbd = np.full((n_samples, n_examples), -1)
    for K in range(1, n_samples+1):
        # let's keep all translations up until the Kth one
        selected_df = metrics_df[metrics_df['n_sample']<K+1]
        
        # let's have two modes: (1) reranking with comet, (2) reranking with cometkiwi, and (3) random reranker (keep always the K-th samples)
        # these modes are dependent on 'metric'
        # let's keep only the chosen translation for each example
        if metric != 'random':
            chosen_df = selected_df.loc[selected_df.groupby(selected_df.index // n_samples )[metric].idxmax()]
        else:
            chosen_df = metrics_df[metrics_df['n_sample'] == K]
        critical_n_lbd[K-1] = np.array(chosen_df['critical'])

    critical_absolute = critical_n_lbd.sum(axis=1)
    critical_rate = critical_n_lbd.sum(axis=1)/len(sources)

    return critical_rate, critical_n_lbd

In [None]:
merge_lps = True
tgtinds = ['pt-BR', 'es-LA', 'ru']
threshold_dev = 0.15
threshold_devtest = threshold_dev
srcind="en"
repo_path=""
temperature=1.0
top_p=1.0
n_samples=50
dataset="tico19"

if merge_lps:
    metrics_df_merged = pd.DataFrame()
    test_metrics_df_merged = pd.DataFrame()
    sources_merged = []
    test_sources_merged  = []
    mbr_matrices_merged = [[] for _ in range(n_samples)]
    test_mbr_matrices_merged = [[] for _ in range(n_samples)]


    for tgtind in tgtinds:
        
        split="dev"
        metrics_df, sources, _, _, mbr_matrices = get_calibration_data(srcind, tgtind, repo_path, dataset=dataset, model='towerinstruct13b', temperature=temperature, top_p=top_p, n_samples=n_samples, split=split, mbr=True)
        metrics_df_merged = pd.concat([metrics_df_merged, metrics_df], ignore_index=True)
        sources_merged.extend(sources)
        for i, new_matrix in enumerate(mbr_matrices):
            mbr_matrices_merged[i].extend(new_matrix)


        split="test"
        test_metrics_df, test_sources, _, _, test_mbr_matrices = get_calibration_data(srcind, tgtind, repo_path, dataset=dataset, model='towerinstruct13b', temperature=temperature, top_p=top_p, n_samples=n_samples, split=split, mbr=True)
        test_metrics_df_merged = pd.concat([test_metrics_df_merged, test_metrics_df], ignore_index=True)
        test_sources_merged.extend(test_sources)
        for i, new_matrix in enumerate(test_mbr_matrices):
            test_mbr_matrices_merged[i].extend(new_matrix)
        
    result = [num for num in range(len(sources)*len(tgtinds)) for _ in range(n_samples)]
    metrics_df_merged['src_idx'] = result
    result = [num for num in range(len(test_sources)*len(tgtinds)) for _ in range(n_samples)]
    test_metrics_df_merged['src_idx'] = result

    metrics_df = metrics_df_merged
    sources = sources_merged
    mbr_matrices = mbr_matrices_merged
    test_metrics_df = test_metrics_df_merged
    test_sources = test_sources_merged
    test_mbr_matrices = test_mbr_matrices_merged

    n_examples = len(sources)
    metric = 'comet'

    critical_rate_comet, critical_n_lbd_comet = rerankers(metrics_df, sources, n_samples, n_examples, metric='comet', critical_threshold=1-threshold_dev)
    critical_rate_cometkiwi, critical_n_lbd_cometkiwi = rerankers(metrics_df, sources, n_samples, n_examples, metric='cometkiwi', critical_threshold=1-threshold_dev)
    critical_rate_random, critical_n_lbd_random = rerankers(metrics_df, sources, n_samples, n_examples, metric='random', critical_threshold=1-threshold_dev)

    test_n_examples = len(test_sources)
    test_critical_rate_comet, devtest_critical_n_lbd_comet = rerankers(test_metrics_df, test_sources, n_samples, test_n_examples, metric='comet', critical_threshold=1-threshold_devtest)
    test_critical_rate_cometkiwi, devtest_critical_n_lbd_cometkiwi = rerankers(test_metrics_df, test_sources, n_samples, test_n_examples, metric='cometkiwi', critical_threshold=1-threshold_devtest)
    test_critical_rate_random, devtest_critical_n_lbd_random = rerankers(test_metrics_df, test_sources, n_samples, test_n_examples, metric='random', critical_threshold=1-threshold_devtest)

In [None]:
def get_mbr_scores(metrics_df, mbr_matrix, n_examples, n_samples):
    for idx in range(n_examples):
        chosen = torch.argmax(mbr_matrix[idx][:])
        metrics_df[metrics_df['src_idx']==int(idx)].iloc[int(chosen)]
        if idx == 0:
            mbr_df = pd.DataFrame(metrics_df[metrics_df['src_idx']==int(idx)].iloc[int(chosen)]).transpose()
        else:
            mbr_df = pd.concat([mbr_df, pd.DataFrame(metrics_df[metrics_df['src_idx']==int(idx)].iloc[int(chosen)]).transpose()])
    
    mbr_df['n_samples'] = n_samples + 1
    return mbr_df.mean()["comet"], mbr_df

n_examples = len(sources)
mbr_scores_range = []
mbr_df_range = []
for i in range(n_samples):
    mbr_score, mbr_df = get_mbr_scores(metrics_df, mbr_matrices[i], n_examples, i)
    mbr_scores_range.append(mbr_score)
    mbr_df_range.append(mbr_df)

metrics_df_mbr = pd.concat(mbr_df_range, axis=0)
metrics_df_mbr = metrics_df_mbr.sort_values(by=['src_idx', 'n_samples'])
metrics_df_mbr

test_n_examples = len(test_sources)
test_mbr_scores_range = []
test_mbr_df_range = []
for i in range(n_samples):
    test_mbr_score, test_mbr_df = get_mbr_scores(test_metrics_df, test_mbr_matrices[i], test_n_examples, i)
    test_mbr_scores_range.append(test_mbr_score)
    test_mbr_df_range.append(test_mbr_df)

test_metrics_df_mbr = pd.concat(test_mbr_df_range, axis=0)
test_metrics_df_mbr = test_metrics_df_mbr.sort_values(by=['src_idx', 'n_samples'])

In [None]:
def mbr_reranker(metrics_df_mbr, n_samples, n_examples, metric='comet', critical_threshold=0.7):

    metrics_df_mbr['critical'] = np.where(metrics_df_mbr[metric] < critical_threshold, 1, 0)

    critical_rate = []
    critical_n_lbd = np.zeros((n_samples, n_examples))

    for K in range(1, n_samples+1):
        selected_df = metrics_df_mbr[metrics_df_mbr['n_samples']==K]
        critical_rate.append(selected_df.mean()['critical'])

        critical_n_lbd[K-1] = selected_df['critical']
    
    return critical_rate, critical_n_lbd

critical_rate_mbr, critical_n_lbd_mbr = mbr_reranker(metrics_df_mbr, n_samples, n_examples, metric='comet', critical_threshold=1-threshold_dev)
test_critical_rate_mbr, test_critical_n_lbd_mbr = mbr_reranker(test_metrics_df_mbr, n_samples, test_n_examples, metric='comet', critical_threshold=1-threshold_devtest)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 4))  # 1 row, 2 columns

# Left: critical errors
axes[0].plot(critical_rate_random, label="random", color='g')
axes[0].plot(critical_rate_cometkiwi, label="cometkiwi", color='orange')
axes[0].plot(critical_rate_mbr, label="MBR", color='red')
axes[0].plot(critical_rate_comet, label="comet", color='blue')
axes[0].plot(test_critical_rate_random, label="random (test)", color='g', linestyle='dashed')
axes[0].plot(test_critical_rate_cometkiwi, label="cometkiwi (test)", color='orange', linestyle='dashed')
axes[0].plot(test_critical_rate_mbr, label="MBR (test)", color='red', linestyle='dashed')
axes[0].plot(test_critical_rate_comet, label="comet (test)", color='blue', linestyle='dashed')
axes[0].set_xlabel("N", fontsize=font_size)
axes[0].set_ylabel("% critical errors", fontsize=font_size)
axes[0].set_title("% critical errors in the set", fontsize=font_size)
axes[0].tick_params(axis='both', labelsize=ticks_size)


# Right: log critical errors
axes[1].plot(np.log(critical_rate_random)-np.log(critical_rate_random)[0], label="Random", color='g')
axes[1].plot(np.log(critical_rate_cometkiwi)-np.log(critical_rate_cometkiwi)[0], label="Cometkiwi",  color='orange')
axes[1].plot(np.log(critical_rate_mbr)-np.log(critical_rate_mbr)[0], label="MBR", color='red')
axes[1].plot(np.log(critical_rate_comet)-np.log(critical_rate_comet)[0], label="Comet", color='blue')
axes[1].plot(np.log(test_critical_rate_random)-np.log(test_critical_rate_random)[0], label="random (test)", color='g', linestyle='dashed')
axes[1].plot(np.log(test_critical_rate_cometkiwi)-np.log(test_critical_rate_cometkiwi)[0], label="cometkiwi (test)", color='orange', linestyle='dashed')
axes[1].plot(np.log(test_critical_rate_mbr)-np.log(test_critical_rate_mbr)[0], label="MBR (test)", color='red', linestyle='dashed')
axes[1].plot(np.log(test_critical_rate_comet)-np.log(test_critical_rate_comet)[0], label="comet (test)", color='blue', linestyle='dashed')
axes[1].set_xlabel("N", fontsize=font_size)
axes[1].set_ylabel("log critical errors", fontsize=font_size)
axes[1].set_title("log critical errors in the set for different rerankers", fontsize=font_size)
axes[1].tick_params(axis='both', labelsize=ticks_size)
axes[0].legend(loc='lower center', bbox_to_anchor=(1.1, -.5), fontsize=font_size, frameon=False, ncol=4)

if merge_lps:
    plt.suptitle(f'{srcind}-{tgtinds}, {temperature}, {top_p}, THR={threshold_dev}', y=1.05, fontsize=font_size)
else:
    plt.suptitle(f'{srcind}-{tgtind}, {temperature}, {top_p}, THR={threshold_dev}', y=1.05, fontsize=font_size)

plt.show()

# Fit reranking laws

In [None]:
import torch
import math
from entmax import sparsemax, entmax15, entmax_bisect
from scipy.optimize import least_squares
import scipy.integrate as integrate
from scipy.special import gammaln, comb, logsumexp

In [None]:
def fit_alpha_beta(p, n_samples_fit, y):
    # p[0]= alpha, p[1]=beta
    # y: ground truth
    q=0
    N = np.arange(1, n_samples_fit+1)
    comet_log_failure_rate = np.zeros_like(N, dtype=float)

    for i, n in enumerate(N):
        somation = 0
        for K in range(n+1):
            
            i_values_1 = np.arange(1, K + 1)
            i_values_2 = np.arange(1, n - K + 1)
            i_values_3 = np.arange(1, n+1)

            t_values = np.arange(n-K+1, n+1)
            j_values = np.arange(1, n+1)

            somation_log =  np.log(comb(n, K)) + np.sum(np.log(p[0] + i_values_1 - 1)) + np.sum(np.log(p[1] + i_values_2 - 1)) - np.sum(np.log(p[0] + p[1] + i_values_3 - 1))
            somation += np.exp(somation_log)*np.sum(q**(t_values-1)) / np.sum(q**(j_values-1))
        
        result = somation
        comet_log_failure_rate[i] = np.log(result)

    eps = p[0]/(p[0]+p[1]) # eps = alpha/(alpha+beta)
    expression_comet = comet_log_failure_rate - np.log(eps)
    return expression_comet - y



def fit_q_entmaxalpha(p, n_samples_fit, fitted_alpha, fitted_beta, y):

    N = np.arange(1, n_samples_fit+1)
    log_failure_rate = np.zeros_like(N, dtype=float)

    for i, n in enumerate(N):

        somation = 0
        for K in range(n+1):
            i_values_1 = np.arange(1, K + 1)
            i_values_2 = np.arange(1, n - K + 1)
            i_values_3 = np.arange(1, n+1)

            t_values = np.arange(n-K+1, n+1)
            j_values = np.arange(1, n+1)            

            x = torch.tensor((N[:n]) * np.log(p[0]))
            results_entmax = []
            expression_values_entmax = entmax_bisect(x, alpha=p[1]).tolist()
            results_entmax.append(expression_values_entmax)
            results_entmax = np.array(results_entmax)

            somation_log =  np.log(comb(n, K)) + np.sum(np.log(fitted_alpha + i_values_1 - 1)) + np.sum(np.log(fitted_beta + i_values_2 - 1)) - np.sum(np.log(fitted_alpha + fitted_beta + i_values_3 - 1)) #+ np.log(np.sum(q**(t_values-1))) - np.log(np.sum(q**(j_values-1)))
            somation += np.exp(somation_log) * results_entmax[0][n-K:n].sum()
        
        result = somation
        log_failure_rate[i] = np.log(result)
    
    eps = fitted_alpha/(fitted_alpha+fitted_beta) # eps = alpha/(alpha+beta)
    expression_cometkiwi = log_failure_rate - np.log(eps)
    
    return expression_cometkiwi - y


In [None]:
n_samples_fit = int(5/5 * n_samples)

In [None]:
y_comet = (np.log(critical_rate_comet[:n_samples_fit]) - np.log(critical_rate_comet[0]))[:n_samples_fit]
y_cometkiwi = (np.log(critical_rate_cometkiwi[:n_samples_fit]) - np.log(critical_rate_cometkiwi[0]))[:n_samples_fit]
y_mbr = (np.log(critical_rate_mbr[:n_samples_fit]) - np.log(critical_rate_mbr[0]))[:n_samples_fit]

p0 = np.zeros((2,))
p0[0] = 1.0
p0[1] = 1.0
first_fit = least_squares(fit_alpha_beta, p0, loss='soft_l1', f_scale=1.0, args=(n_samples_fit, y_comet), max_nfev=1000, bounds=([0.1,0.1],[1000,1000]))

# cometkiwi
p1 = np.zeros((2,))
p1[0] = 0.1
p1[1] = 0.75
second_fit_cometkiwi = least_squares(fit_q_entmaxalpha, p1, loss='soft_l1', f_scale=1.0, args=(n_samples_fit, first_fit.x[0], first_fit.x[1], y_cometkiwi), max_nfev=100, bounds=([0.001, 0.001],[1, 1]))

# mbr
p1 = np.zeros((2,))
p1[0] = 0.1
p1[1] = 0.75
second_fit_mbr = least_squares(fit_q_entmaxalpha, p1, loss='soft_l1', f_scale=1.0, args=(n_samples_fit, first_fit.x[0], first_fit.x[1], y_mbr), max_nfev=100, bounds=([0.001, 0.001],[1, 1]))

# Plots

In [None]:
plt.figure(figsize=(8, 5))
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

alpha = first_fit.x[0]
beta = first_fit.x[1]
eps = alpha / (alpha + beta)
N = np.arange(1, n_samples+1)

labels = ['qe', 'mbr', 'oracle']
fitted_q_cometkiwi = second_fit_cometkiwi.x[0]
entmax_alpha_cometkiwi = second_fit_cometkiwi.x[1]
fitted_q_mbr = second_fit_mbr.x[0]
entmax_alpha_mbr = second_fit_mbr.x[1]
fitted_qs = [fitted_q_cometkiwi, fitted_q_mbr]
entmax_alphas = [entmax_alpha_cometkiwi, entmax_alpha_mbr]

q_values = [fitted_q_cometkiwi, fitted_q_mbr, 0]
n_points = n_samples

for q, color, entmax_alpha, label in zip(q_values, [colors[2], colors[3], colors[4]], [entmax_alpha_cometkiwi, entmax_alpha_mbr, 0], labels):
    
    log_failure_rate = np.zeros_like(N, dtype=float)

    for i, n in enumerate(N):
        somation = 0
        for K in range(n+1):
            i_values_1 = np.arange(1, K + 1)
            i_values_2 = np.arange(1, n - K + 1)
            i_values_3 = np.arange(1, n+1)

            t_values = np.arange(n-K+1, n+1)
            j_values = np.arange(1, n+1)

            if q!= 0:

                x = torch.tensor((N[:n]) * np.log(q))
                results_entmax = []
                expression_values_entmax = entmax_bisect(x, alpha=entmax_alpha).tolist()
                results_entmax.append(expression_values_entmax)
                results_entmax = np.array(results_entmax)

                somation_log =  np.log(comb(n, K)) + np.sum(np.log(alpha + i_values_1 - 1)) + np.sum(np.log(beta + i_values_2 - 1)) - np.sum(np.log(alpha + beta + i_values_3 - 1)) #+ np.log(np.sum(q**(t_values-1))) - np.log(np.sum(q**(j_values-1)))
                somation += np.exp(somation_log) * results_entmax[0][n-K:n].sum()
            else:
                somation_log =  np.log(comb(n, K)) + np.sum(np.log(alpha + i_values_1 - 1)) + np.sum(np.log(beta + i_values_2 - 1)) - np.sum(np.log(alpha + beta + i_values_3 - 1))
                somation += np.exp(somation_log)*np.sum(q**(t_values-1)) / np.sum(q**(j_values-1))

        result = somation
        log_failure_rate[i] = np.log(result)

    if q!= 0:
        plt.plot(N, log_failure_rate - np.log(eps), color=color, linewidth=4.0, linestyle='solid', label=label)
    else:
        plt.plot(N, log_failure_rate - np.log(eps), color=color, linewidth=4.0, linestyle='dashed', label=label)

plt.scatter(N, np.log(test_critical_rate_cometkiwi) - np.log(test_critical_rate_cometkiwi[0]), color=colors[2], marker='o', s=.5 * (plt.rcParams['lines.markersize'] ** 2))
plt.scatter(N, np.log(test_critical_rate_mbr) - np.log(test_critical_rate_mbr[0]), color=colors[3], marker='o', s=.5 * (plt.rcParams['lines.markersize'] ** 2))
plt.scatter(N, np.log(test_critical_rate_comet[:n_points]) - np.log(test_critical_rate_comet[0]), color=colors[4], marker='o', s=.5 * (plt.rcParams['lines.markersize'] ** 2))
plt.xlabel('$N$', fontsize=font_size)
plt.ylabel('Log failure rate (wrt baseline)', fontsize=font_size)
plt.legend(loc='upper right', fontsize=font_size, ncol=1)
plt.yticks([-1.5, -1.0, -0.5, 0], fontsize=font_size)
plt.xticks([0, 10, 20, 30, 40 ,50], fontsize=font_size)
plt.tick_params(axis='both', labelsize=ticks_size)


if tgtinds == ['pt-BR']:
    plt.savefig('mt_pt-BR_apred%.4f_bpred%.4f_cometkiwi_qpred%.4f_entpred%.4f_mbr_qpred%.4f_entpred%.4f.pdf' % (alpha,beta,fitted_q_cometkiwi,entmax_alpha_cometkiwi, fitted_q_mbr, entmax_alpha_mbr), bbox_inches='tight')
elif tgtinds == ['es-LA']:
    plt.savefig('mt_es-LA_apred%.4f_bpred%.4f_cometkiwi_qpred%.4f_entpred%.4f_mbr_qpred%.4f_entpred%.4f.pdf' % (alpha,beta,fitted_q_cometkiwi,entmax_alpha_cometkiwi, fitted_q_mbr, entmax_alpha_mbr), bbox_inches='tight')
elif tgtinds == ['ru']:
    plt.savefig('mt_ru_apred%.4f_bpred%.4f_cometkiwi_qpred%.4f_entpred%.4f_mbr_qpred%.4f_entpred%.4f.pdf' % (alpha,beta,fitted_q_cometkiwi,entmax_alpha_cometkiwi, fitted_q_mbr, entmax_alpha_mbr), bbox_inches='tight')
else:
    plt.savefig('mt_all_apred%.4f_bpred%.4f_cometkiwi_qpred%.4f_entpred%.4f_mbr_qpred%.4f_entpred%.4f.pdf' % (alpha,beta,fitted_q_cometkiwi,entmax_alpha_cometkiwi, fitted_q_mbr, entmax_alpha_mbr), bbox_inches='tight')

plt.show()


In [None]:
plt.figure(figsize=(8, 5))
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

alpha = first_fit.x[0]
beta = first_fit.x[1]
eps = alpha / (alpha + beta)
N = np.arange(1, n_samples+1)

labels = ['qe', 'mbr', 'oracle']
fitted_q_cometkiwi = second_fit_cometkiwi.x[0]
entmax_alpha_cometkiwi = second_fit_cometkiwi.x[1]
fitted_q_mbr = second_fit_mbr.x[0]
entmax_alpha_mbr = second_fit_mbr.x[1]
fitted_qs = [fitted_q_cometkiwi, fitted_q_mbr]
entmax_alphas = [entmax_alpha_cometkiwi, entmax_alpha_mbr]

q_values = [fitted_q_cometkiwi, fitted_q_mbr, 0]
n_points = n_samples

for q, color, entmax_alpha, label in zip(q_values, [colors[2], colors[3], colors[4]], [entmax_alpha_cometkiwi, entmax_alpha_mbr, 0], labels):
    
    log_failure_rate = np.zeros_like(N, dtype=float)

    for i, n in enumerate(N):
        somation = 0
        for K in range(n+1):
            i_values_1 = np.arange(1, K + 1)
            i_values_2 = np.arange(1, n - K + 1)
            i_values_3 = np.arange(1, n+1)

            t_values = np.arange(n-K+1, n+1)
            j_values = np.arange(1, n+1)

            if q!= 0:

                x = torch.tensor((N[:n]) * np.log(q))
                results_entmax = []
                expression_values_entmax = entmax_bisect(x, alpha=entmax_alpha).tolist()
                results_entmax.append(expression_values_entmax)
                results_entmax = np.array(results_entmax)

                somation_log =  np.log(comb(n, K)) + np.sum(np.log(alpha + i_values_1 - 1)) + np.sum(np.log(beta + i_values_2 - 1)) - np.sum(np.log(alpha + beta + i_values_3 - 1)) #+ np.log(np.sum(q**(t_values-1))) - np.log(np.sum(q**(j_values-1)))
                somation += np.exp(somation_log) * results_entmax[0][n-K:n].sum()
            else:
                somation_log =  np.log(comb(n, K)) + np.sum(np.log(alpha + i_values_1 - 1)) + np.sum(np.log(beta + i_values_2 - 1)) - np.sum(np.log(alpha + beta + i_values_3 - 1))
                somation += np.exp(somation_log)*np.sum(q**(t_values-1)) / np.sum(q**(j_values-1))

        result = somation
        log_failure_rate[i] = np.log(result)

    if q!= 0:
        plt.plot(N, log_failure_rate - np.log(eps), color=color, linewidth=4.0, linestyle='solid', label=label)
    else:
        plt.plot(N, log_failure_rate - np.log(eps), color=color, linewidth=4.0, linestyle='dashed', label=label)

plt.scatter(N, np.log(critical_rate_cometkiwi) - np.log(critical_rate_cometkiwi[0]), color=colors[2], marker='o', s=.5 * (plt.rcParams['lines.markersize'] ** 2))
plt.scatter(N, np.log(critical_rate_mbr) - np.log(critical_rate_mbr[0]), color=colors[3], marker='o', s=.5 * (plt.rcParams['lines.markersize'] ** 2))
plt.scatter(N, np.log(critical_rate_comet[:n_points]) - np.log(critical_rate_comet[0]), color=colors[4], marker='o', s=.5 * (plt.rcParams['lines.markersize'] ** 2))
plt.xlabel('$N$', fontsize=font_size)
plt.ylabel('Log failure rate (wrt baseline)', fontsize=font_size)
plt.legend(loc='upper right', fontsize=font_size, ncol=1)
plt.yticks([-1.5, -1.0, -0.5, 0], fontsize=font_size)
plt.xticks([0, 10, 20, 30, 40 ,50], fontsize=font_size)
plt.tick_params(axis='both', labelsize=ticks_size)


if tgtinds == ['pt-BR']:
    plt.savefig('mt_dev_pt-BR_apred%.4f_bpred%.4f_cometkiwi_qpred%.4f_entpred%.4f_mbr_qpred%.4f_entpred%.4f.pdf' % (alpha,beta,fitted_q_cometkiwi,entmax_alpha_cometkiwi, fitted_q_mbr, entmax_alpha_mbr), bbox_inches='tight')
elif tgtinds == ['es-LA']:
    plt.savefig('mt_dev_es-LA_apred%.4f_bpred%.4f_cometkiwi_qpred%.4f_entpred%.4f_mbr_qpred%.4f_entpred%.4f.pdf' % (alpha,beta,fitted_q_cometkiwi,entmax_alpha_cometkiwi, fitted_q_mbr, entmax_alpha_mbr), bbox_inches='tight')
elif tgtinds == ['ru']:
    plt.savefig('mt_dev_ru_apred%.4f_bpred%.4f_cometkiwi_qpred%.4f_entpred%.4f_mbr_qpred%.4f_entpred%.4f.pdf' % (alpha,beta,fitted_q_cometkiwi,entmax_alpha_cometkiwi, fitted_q_mbr, entmax_alpha_mbr), bbox_inches='tight')
else:
    plt.savefig('mt_dev_all_apred%.4f_bpred%.4f_cometkiwi_qpred%.4f_entpred%.4f_mbr_qpred%.4f_entpred%.4f.pdf' % (alpha,beta,fitted_q_cometkiwi,entmax_alpha_cometkiwi, fitted_q_mbr, entmax_alpha_mbr), bbox_inches='tight')


plt.show()

In [None]:
# Values

# pt
alpha=0.1
beta=0.4743
fitted_q_cometkiwi = 0.0105
entmax_alpha_cometkiwi = 0.001
fitted_q_mbr = 0.001
entmax_alpha_mbr = 0.1676

# es
alpha=0.1
beta=0.4177
fitted_q_cometkiwi = 0.0056
entmax_alpha_cometkiwi = 0.001
fitted_q_mbr = 0.001
entmax_alpha_mbr = 0.1780

# ru
alpha=0.1
beta=0.5011
fitted_q_cometkiwi = 0.0014
entmax_alpha_cometkiwi = 0.001
fitted_q_mbr = 0.001
entmax_alpha_mbr = 0.1971