In [None]:
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

font_size = 18
ticks_size = 15

In [None]:
model_name = 'code-davinci-002'
dataset_name = 'strategy_qa' #'svamp' 
dt = list(map(json.loads, open(f"{model_name}/{dataset_name}/{dataset_name}_seed1.jsonl")))
dt_df = pd.DataFrame(dt)
n_samples = len(dt_df.iloc[0].scores)
print(n_samples)

In [None]:
for i in range(len(dt)):
    if len(dt_df.iloc[0].scores) != n_samples:
        print(i)

In [None]:
# Expand the dataframe
expanded_rows = []
for idx, row in dt_df.iterrows():
    for i in range(n_samples):
        expanded_row = {
            'input': row['input'],
            'generation': row['generation'][i],
            'scores': row['scores'][i],
            'answers': row['answers'][i],
            'target': row['target'],
            'probs': row['probs'][i],
            'sample_idx': i,
            'src_idx': idx
        }
        expanded_rows.append(expanded_row)

expanded_df = pd.DataFrame(expanded_rows)
expanded_df['critical'] = expanded_df['scores'].apply(lambda x: 0 if x == 1 else 1)

n_examples = len(dt_df)
aux_1 = np.arange(1, n_samples+1)
aux_2= np.tile(aux_1, n_examples)
data_filtered = expanded_df.copy()
data_filtered['sample_idx'] = aux_2

In [None]:
data_filtered

In [None]:
def rerankers(data, n_samples, n_examples, metric, set):
    
    aux_1 = np.arange(1, n_samples+1)
    aux_2= np.tile(aux_1, n_examples)

    critical_n_lbd = np.full((n_samples, n_examples), -1)
    for K in range(1, n_samples+1):
        selected_df = data[data['sample_idx']<K+1]

        if metric == 'oracle':
            chosen_df = selected_df.loc[selected_df.groupby(selected_df.index // n_samples )['critical'].idxmin()]
        
        if metric == 'random':
            chosen_df = selected_df[selected_df['sample_idx'] == K]
        if metric == 'majority_voting':
            indexes_to_keep = []
            starting_at = 0
            if set=='test':
                if len(dt_df) % 2 != 0:
                    starting_at=n_examples-1
                else:
                    starting_at=n_examples
            for src_idx in range (starting_at, starting_at + n_examples):
                mode_output = selected_df[selected_df['src_idx']==src_idx]['answers'].value_counts().idxmax()
                selected_df_src = selected_df[selected_df['src_idx']==src_idx]
                indexes = selected_df_src['answers'].index
                for i in range(0,len(indexes)):
                    if selected_df_src['answers'].iloc[i] == mode_output:
                        break
                indexes_to_keep.append(indexes[i])
            chosen_df = selected_df.loc[indexes_to_keep]
        critical_n_lbd[K-1] = np.array(chosen_df['critical'])
    critical_absolute = critical_n_lbd.sum(axis=1)
    critical_rate = critical_n_lbd.sum(axis=1)/n_examples
    
    return critical_absolute, critical_rate

In [None]:
half_length = int(n_samples * int(n_examples/2))
data_filtered_dev = data_filtered[:half_length]
data_filtered_test = data_filtered[half_length:]

metric = 'oracle'
critical_absolute_oracle_dev, critical_rate_oracle_dev = rerankers(data_filtered_dev,  n_samples, int(len(data_filtered_dev)/n_samples), metric, 'dev')
critical_absolute_oracle_test, critical_rate_oracle_test = rerankers(data_filtered_test,  n_samples, int(len(data_filtered_test)/n_samples), metric, 'test')

metric = 'majority_voting'
critical_absolute_maj_dev, critical_rate_maj_dev = rerankers(data_filtered_dev,  n_samples, int(len(data_filtered_dev)/n_samples), metric, 'dev')
critical_absolute_maj_test, critical_rate_maj_test = rerankers(data_filtered_test,  n_samples, int(len(data_filtered_test)/n_samples), metric, 'test')



In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 4))  # 1 row, 2 columns

# Left: critical errors
axes[0].plot(critical_rate_oracle_dev, label="oracle dev", color='blue')
axes[0].plot(critical_rate_maj_dev, label="maj dev", color='red')
axes[0].plot(critical_rate_oracle_test, label="oracle test", color='blue', linestyle='dashed')
axes[0].plot(critical_rate_maj_test, label="maj test", color='red', linestyle='dashed')
axes[0].set_xlabel("N", fontsize=font_size)
axes[0].set_ylabel("% critical errors", fontsize=font_size)
axes[0].set_title("% critical errors in the set", fontsize=font_size)
axes[0].tick_params(axis='both', labelsize=ticks_size)

# Right: log critical errors
axes[1].plot(np.log(critical_rate_oracle_dev)-np.log(critical_rate_oracle_dev)[0], label="oracle dev", color='blue')
axes[1].plot(np.log(critical_rate_maj_dev)-np.log(critical_rate_maj_dev)[0], label="maj dev", color='red')
axes[1].plot(np.log(critical_rate_oracle_test)-np.log(critical_rate_oracle_test)[0], label="oracle test", color='blue', linestyle='dashed')
axes[1].plot(np.log(critical_rate_maj_test)-np.log(critical_rate_maj_test)[0], label="maj test", color='red', linestyle='dashed')
axes[1].set_xlabel("N", fontsize=font_size)
axes[1].set_ylabel("log critical errors", fontsize=font_size)
axes[1].set_title("log critical errors in the set", fontsize=font_size)
axes[1].tick_params(axis='both', labelsize=ticks_size)

axes[0].legend(loc='lower center', bbox_to_anchor=(1.1, -0.4), fontsize=font_size, frameon=False, ncol=4)
plt.show()

# Fit reranking laws

In [None]:
import math
from entmax import sparsemax, entmax15, entmax_bisect
from scipy.optimize import least_squares
import scipy.integrate as integrate
from scipy.special import gammaln, comb, logsumexp
import torch

In [None]:
def log_beta_distribution(alpha, beta):
    beta_lognorm = -gammaln(alpha+beta) + gammaln(alpha) + gammaln(beta)
    return lambda tau: (alpha-1)*np.log(tau) + (beta-1)*np.log(1-tau) - beta_lognorm

def integrate_expression(tau_start, tau_end):
    result, _ = integrate.quad(lambda tau: expression(tau, N), tau_start, tau_end)
    return result

def fit_alpha_beta(p, n_samples_fit, y):
    # p[0]= alpha, p[1]=beta
    # y: ground truth
    q=0
    N = np.arange(1, n_samples_fit+1)
    comet_log_failure_rate = np.zeros_like(N, dtype=float)

    for i, n in enumerate(N):
        somation = 0
        for K in range(n+1):
            
            i_values_1 = np.arange(1, K + 1)
            i_values_2 = np.arange(1, n - K + 1)
            i_values_3 = np.arange(1, n+1)

            t_values = np.arange(n-K+1, n+1)
            j_values = np.arange(1, n+1)

            somation_log =  np.log(comb(n, K)) + np.sum(np.log(p[0] + i_values_1 - 1)) + np.sum(np.log(p[1] + i_values_2 - 1)) - np.sum(np.log(p[0] + p[1] + i_values_3 - 1))
            somation += np.exp(somation_log)*np.sum(q**(t_values-1)) / np.sum(q**(j_values-1))
        
        result = somation
        comet_log_failure_rate[i] = np.log(result)

    eps = p[0]/(p[0]+p[1]) # eps = alpha/(alpha+beta)
    expression_comet = comet_log_failure_rate - np.log(eps)
    return expression_comet - y

def fit_q_entmaxalpha(p, n_samples_fit, fitted_alpha, fitted_beta, y):

    N = np.arange(1, n_samples_fit+1)
    log_failure_rate = np.zeros_like(N, dtype=float)

    for i, n in enumerate(N):

        somation = 0
        for K in range(n+1):
            i_values_1 = np.arange(1, K + 1)
            i_values_2 = np.arange(1, n - K + 1)
            i_values_3 = np.arange(1, n+1)

            t_values = np.arange(n-K+1, n+1)
            j_values = np.arange(1, n+1)            

            x = torch.tensor((N[:n]) * np.log(p[0]))
            results_entmax = []
            expression_values_entmax = entmax_bisect(x, alpha=p[1]).tolist()
            results_entmax.append(expression_values_entmax)
            results_entmax = np.array(results_entmax)

            somation_log =  np.log(comb(n, K)) + np.sum(np.log(fitted_alpha + i_values_1 - 1)) + np.sum(np.log(fitted_beta + i_values_2 - 1)) - np.sum(np.log(fitted_alpha + fitted_beta + i_values_3 - 1))
            somation += np.exp(somation_log) * results_entmax[0][n-K:n].sum()
        
        result = somation
        log_failure_rate[i] = np.log(result)
    
    eps = fitted_alpha/(fitted_alpha+fitted_beta) # eps = alpha/(alpha+beta)
    expression_cometkiwi = log_failure_rate - np.log(eps)
    
    return expression_cometkiwi - y

In [None]:
n_samples_fit = int(5/5 * n_samples)

In [None]:
y_comet = (np.log(critical_rate_oracle_dev[:n_samples_fit]) - np.log(critical_rate_oracle_dev[0]))[:n_samples_fit]
y_mbr = (np.log(critical_rate_maj_dev[:n_samples_fit]) - np.log(critical_rate_maj_dev[0]))[:n_samples_fit]

p0 = np.zeros((2,))
p0[0] = 1.0
p0[1] = 1.0
first_fit = least_squares(fit_alpha_beta, p0, loss='soft_l1', f_scale=1.0, args=(n_samples_fit, y_comet), max_nfev=1000, bounds=([0.1,0.1],[1000,1000]))

p1 = np.zeros((2,))
p1[0] = 0.1
p1[1] = 0.75
second_fit_mbr = least_squares(fit_q_entmaxalpha, p1, loss='soft_l1', f_scale=1.0, args=(n_samples_fit, first_fit.x[0], first_fit.x[1], y_mbr), max_nfev=100, bounds=([0.001, 0.001],[1, 1]))

# Plots

In [None]:
plt.figure(figsize=(8, 5))
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

alpha = first_fit.x[0]
beta = first_fit.x[1]
eps = alpha / (alpha + beta)
N = np.arange(1, n_samples+1)

fitted_q_mbr = second_fit_mbr.x[0]
entmax_alpha_mbr = second_fit_mbr.x[1]
fitted_qs = [fitted_q_mbr]
entmax_alphas = [entmax_alpha_mbr]
labels = ['self-consistency', 'oracle']


for fitted_q, entmax_alpha, color_reranker in zip(fitted_qs, entmax_alphas, ['red']):
    q_values = [fitted_q, 0]
    n_points = n_samples
    for q, color, label in zip(q_values, [colors[3], colors[4]], labels):
        log_failure_rate = np.zeros_like(N, dtype=float)
        for i, n in enumerate(N):
            somation = 0
            for K in range(n+1):
                i_values_1 = np.arange(1, K + 1)
                i_values_2 = np.arange(1, n - K + 1)
                i_values_3 = np.arange(1, n+1)
                t_values = np.arange(n-K+1, n+1)
                j_values = np.arange(1, n+1)
                if q!= 0:
                    x = torch.tensor((N[:n]) * np.log(q))
                    results_entmax = []
                    expression_values_entmax = entmax_bisect(x, alpha=entmax_alpha).tolist()
                    results_entmax.append(expression_values_entmax)
                    results_entmax = np.array(results_entmax)
                    somation_log =  np.log(comb(n, K)) + np.sum(np.log(alpha + i_values_1 - 1)) + np.sum(np.log(beta + i_values_2 - 1)) - np.sum(np.log(alpha + beta + i_values_3 - 1)) #+ np.log(np.sum(q**(t_values-1))) - np.log(np.sum(q**(j_values-1)))
                    somation += np.exp(somation_log) * results_entmax[0][n-K:n].sum()
                else:
                    somation_log =  np.log(comb(n, K)) + np.sum(np.log(alpha + i_values_1 - 1)) + np.sum(np.log(beta + i_values_2 - 1)) - np.sum(np.log(alpha + beta + i_values_3 - 1))
                    somation += np.exp(somation_log)*np.sum(q**(t_values-1)) / np.sum(q**(j_values-1))
            result = somation
            log_failure_rate[i] = np.log(result)
        if q!= 0:
            plt.plot(N, log_failure_rate - np.log(eps), color=color, linewidth=4.0, linestyle='solid', label=label)
        else:
            plt.plot(N, log_failure_rate - np.log(eps), color=color, linewidth=4.0, linestyle='dashed', label=label)

plt.scatter(N, np.log(critical_rate_maj_test) - np.log(critical_rate_maj_test[0]), color=colors[3], marker='o', s=.5 * (plt.rcParams['lines.markersize'] ** 2))
plt.scatter(N, np.log(critical_rate_oracle_test[:n_points]) - np.log(critical_rate_oracle_test[0]), color=colors[4], marker='o', s=.5 * (plt.rcParams['lines.markersize'] ** 2))
plt.xlabel('$N$', fontsize=font_size)
plt.ylabel('Log failure rate (wrt baseline)', fontsize=font_size)
plt.legend(loc='lower left', fontsize=font_size, ncol=1)
plt.yticks([-6, -5, -4, -3, -2, -1, 0], fontsize=font_size)
plt.xticks([0, 16, 32, 48, 64], fontsize=font_size)
plt.tick_params(axis='both', labelsize=ticks_size)

plt.savefig(f"{model_name}_{dataset_name}_apred%.4f_bpred%.4f_qpred%.4f_entpred%.4f_.pdf" % (alpha,beta,fitted_q,entmax_alpha), bbox_inches='tight')
plt.show()


In [None]:
print(fitted_q_mbr, entmax_alpha_mbr)
print(alpha, beta)

In [None]:
plt.figure(figsize=(8, 5))
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

eps = alpha / (alpha + beta)
N = np.arange(1, n_samples+1)
fitted_qs = [fitted_q_mbr]
entmax_alphas = [entmax_alpha_mbr]
labels = ['self-consistency', 'oracle']


for fitted_q, entmax_alpha, color_reranker in zip(fitted_qs, entmax_alphas, ['red']):
    q_values = [fitted_q, 0]
    n_points = n_samples
    for q, color, label in zip(q_values, [colors[3], colors[4]], labels):
        log_failure_rate = np.zeros_like(N, dtype=float)
        for i, n in enumerate(N):
            somation = 0
            for K in range(n+1):
                i_values_1 = np.arange(1, K + 1)
                i_values_2 = np.arange(1, n - K + 1)
                i_values_3 = np.arange(1, n+1)
                t_values = np.arange(n-K+1, n+1)
                j_values = np.arange(1, n+1)
                if q!= 0:
                    x = torch.tensor((N[:n]) * np.log(q))
                    results_entmax = []
                    expression_values_entmax = entmax_bisect(x, alpha=entmax_alpha).tolist()
                    results_entmax.append(expression_values_entmax)
                    results_entmax = np.array(results_entmax)
                    somation_log =  np.log(comb(n, K)) + np.sum(np.log(alpha + i_values_1 - 1)) + np.sum(np.log(beta + i_values_2 - 1)) - np.sum(np.log(alpha + beta + i_values_3 - 1)) #+ np.log(np.sum(q**(t_values-1))) - np.log(np.sum(q**(j_values-1)))
                    somation += np.exp(somation_log) * results_entmax[0][n-K:n].sum()
                else:
                    somation_log =  np.log(comb(n, K)) + np.sum(np.log(alpha + i_values_1 - 1)) + np.sum(np.log(beta + i_values_2 - 1)) - np.sum(np.log(alpha + beta + i_values_3 - 1))
                    somation += np.exp(somation_log)*np.sum(q**(t_values-1)) / np.sum(q**(j_values-1))
            result = somation
            log_failure_rate[i] = np.log(result)
        if q!= 0:
            plt.plot(N, log_failure_rate - np.log(eps), color=color, linewidth=4.0, linestyle='solid', label=label)
        else:
            plt.plot(N, log_failure_rate - np.log(eps), color=color, linewidth=4.0, linestyle='dashed', label=label)

plt.scatter(N, np.log(critical_rate_maj_dev) - np.log(critical_rate_maj_dev[0]), color=colors[3], marker='o', s=.5 * (plt.rcParams['lines.markersize'] ** 2))
plt.scatter(N, np.log(critical_rate_oracle_dev[:n_points]) - np.log(critical_rate_oracle_dev[0]), color=colors[4], marker='o', s=.5 * (plt.rcParams['lines.markersize'] ** 2))
plt.xlabel('$N$', fontsize=font_size)
plt.ylabel('Log failure rate (wrt baseline)', fontsize=font_size)
plt.legend(loc='lower left', fontsize=font_size, ncol=1)
plt.yticks([-6, -5, -4, -3, -2, -1, 0], fontsize=font_size)
plt.xticks([0, 16, 32, 48, 64], fontsize=font_size)
plt.tick_params(axis='both', labelsize=ticks_size)

plt.savefig(f"{model_name}_{dataset_name}_dev_apred%.4f_bpred%.4f_qpred%.4f_entpred%.4f_.pdf" % (alpha,beta,fitted_q,entmax_alpha), bbox_inches='tight')
plt.show()
