In [1]:
import argparse
import json
import multiprocessing
import os
import pickle
import platform
import threading
import time
from collections import Counter, defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime
from typing import Any, Dict, List, Tuple
from warnings import warn

from multiprocessing import Value

from concurrent.futures import ThreadPoolExecutor

from tqdm import tqdm

import numpy as np

from evalplus.data.utils import CACHE_DIR

from evalplus.data import (
    get_human_eval_plus,
    get_human_eval_plus_hash,
    get_mbpp_plus,
    get_mbpp_plus_hash,
    load_solutions,
)

from evalplus.eval.utils import (
    create_tempdir,
    reliability_guard,
    swallow_io,
    time_limit,
)

from evalplus.eval._special_oracle import (
    MBPP_OUTPUT_NOT_NONE_TASKS,
    MBPP_OUTPUT_SET_EQ_TASKS,
    _poly,
)

from evalplus.gen.util import trusted_exec

def is_floats(x) -> bool:
    # check if it is float; List[float]; Tuple[float]
    if isinstance(x, float):
        return True
    if isinstance(x, (list, tuple)):
        return all(isinstance(i, float) for i in x)
    if isinstance(x, np.ndarray):
        return x.dtype == np.float64 or x.dtype == np.float32
    return False

import resource
import signal

def time_limit_exceeded(signum, frame):
    raise TimeoutError("Time limit exceeded")

def set_time_limit(seconds):
    signal.signal(signal.SIGALRM, time_limit_exceeded)
    signal.alarm(seconds)

def set_memory_limit(maximum_memory_bytes):
    import resource

    resource.setrlimit(
        resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)
    )
    resource.setrlimit(
        resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)
    )
    if not platform.uname().system == "Darwin":
        resource.setrlimit(
            resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)
        )

def wrapped_ut_exact_match(
    hyp_ut, 
    ref_ut, 
    entry_point, 
    dataset, 
    inp=None, 
    atol=0, 
    time_bound = False,
    time_limit=1, 
    memory_bound = True,
    memory_limit=40*1024*1024, 
    ):
    final_result = Value('d', 0)
    with create_tempdir():
        # These system calls are needed when cleaning up tempdir.
        import os
        import shutil

        rmtree = shutil.rmtree
        rmdir = os.rmdir
        chdir = os.chdir
        if time_bound:
            set_time_limit(time_limit)
        try:
            if memory_bound:
                reliability_guard(maximum_memory_bytes=memory_limit)
            if time_bound:
                set_time_limit(time_limit)
            final_result = ut_exact_match(hyp_ut, ref_ut, entry_point, dataset, inp, atol, time_limit, memory_limit, memory_bound)
        except MemoryError:
            final_result = 0.0
        except TimeoutError:
            final_result = 0.0
    if time_bound:
        signal.alarm(0)
    shutil.rmtree = rmtree
    os.rmdir = rmdir
    os.chdir = chdir
    return final_result.value

def ut_exact_match(
    hyp_ut, 
    ref_ut, 
    entry_point, 
    dataset, 
    inp=None, 
    atol=0, # need to change this later
    time_limit=1, # seconds
    memory_limit=4*1024*1024*1024, # 4GB
    ):
    
    #try:
        #set_time_limit(time_limit)
        #set_memory_limit(memory_limit)

    exact_match = hyp_ut == ref_ut

    # ================================================ #
    # ============== special oracles ================= #
    if dataset == "mbpp":
        if "are_equivalent" == entry_point:  # Mbpp/164 special oracle
            exact_match = exact_match or True
        elif "sum_div" == entry_point:  # Mbpp/295 special oracle
            exact_match = exact_match or hyp_ut == 0 or ref_ut == 0
        elif entry_point in MBPP_OUTPUT_SET_EQ_TASKS:
            exact_match = set(hyp_ut) == set(ref_ut)
        elif entry_point in MBPP_OUTPUT_NOT_NONE_TASKS:
            # exp is True  if not None
            #        False if None
            if isinstance(hyp_ut, bool):
                hyp_ut = hyp_ut is not None
            if isinstance(ref_ut, bool):
                ref_ut = ref_ut is not None
            exact_match = hyp_ut == ref_ut

    if dataset == "humaneval":
        if "find_zero" == entry_point:
            hyp_ut = _poly(*inp, hyp_ut) <= atol
            ref_ut = _poly(*inp, ref_ut) <= atol
            exact_match = hyp_ut == ref_ut
    # ============== special oracles ================= #
    # ================================================ #

    if atol == 0 and (is_floats(ref_ut) or is_floats(hyp_ut)):
        atol = 1e-6  # enforce atol for float comparison
    if not exact_match and atol != 0:
        # explicitly set rtol=1e-07
        # to match `np.testing.assert_allclose`'s default values
        exact_match =  np.allclose(hyp_ut, ref_ut, rtol=1e-07, atol=atol)
    
    return int(exact_match)

def get_groundtruth(problems, hashcode, tasks_only_output_not_none):
    cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl")
    if os.path.exists(cache_file):
        #print(f"Load from ground-truth from {cache_file}")
        with open(cache_file, "rb") as f:
            return pickle.load(f)

    os.makedirs(CACHE_DIR, exist_ok=True)
    #print("Computing expected output...")
    tbegin = time.time()
    expected_output = {}
    for task_id, problem in problems.items():
        oracle = {}
        oracle["base"], oracle["base_time"] = trusted_exec(
            problem["prompt"] + problem["canonical_solution"],
            problem["base_input"],
            problem["entry_point"],
            record_time=True,
            output_not_none=problem["entry_point"] in tasks_only_output_not_none,
        )

        oracle["plus"], oracle["plus_time"] = trusted_exec(
            problem["prompt"] + problem["canonical_solution"],
            problem["plus_input"],
            problem["entry_point"],
            record_time=True,
            output_not_none=problem["entry_point"] in tasks_only_output_not_none,
        )
        expected_output[task_id] = oracle
    #print(f"Expected outputs computed in {time.time() - tbegin:.2f}s")

    with open(cache_file, "wb") as f:
        pickle.dump(expected_output, f)

    return expected_output

def mbr_exec(hyp_uts, ref_uts, entry_point, dataset, n_uts, inps=None, granular=False):
    n_matches = 0
    for i in range(n_uts):
        # skip if either hyp_ut or ref_ut is not in the list
        if i not in hyp_uts or i not in ref_uts:
            continue
        # if there's an error, we return 0
        if type(hyp_uts[i]) == str and hyp_uts[i].startswith("failed:"):
            return 0 
        if type(ref_uts[i]) == str and ref_uts[i].startswith("failed:"):
            return 0
        # we start counting the number of matches
        try:
            n_matches += ut_exact_match(
                hyp_uts[i], 
                ref_uts[i], 
                entry_point, 
                dataset, 
                inp=inps[i] if inps else None
                )
        except:
            n_matches += 0
        
    if granular:
        try:
            return n_matches/ n_uts
        except:
            return 0
    else:
        return int(n_matches == n_uts)

In [46]:
work_dir = "/mnt/scratch-artemis/haausing/code_reranking/evalplus_outputs"
dataset = "mbpp"
gen_dir = "code-llama-13b-instruct_temp_1.6"
#gen_dir = "deepseek-coder-6.7b-instruct_temp_1.2"
#debug_gen_dir = gen_dir + "_debug1_not_change_positive"
#debug_gen_dir = gen_dir + "_debug1_sd-ut"
#debug_3times_gen_dir = gen_dir + "_debug1_sd-ut"
# load exec_outputs

# load problems
if dataset == "mbpp":
    problems = get_mbpp_plus()
    dataset_hash = get_mbpp_plus_hash()
    expected_output = get_groundtruth(
        problems,
        dataset_hash,
        MBPP_OUTPUT_NOT_NONE_TASKS,
    )
elif dataset == "humaneval":
    problems = get_human_eval_plus()
    dataset_hash = get_human_eval_plus_hash()
    expected_output = get_groundtruth(
        problems,
        dataset_hash,
        []
    )
else:
    raise ValueError("Invalid dataset")

with open(f"{work_dir}/{dataset}/{gen_dir}/exec_outputs_v2.pkl", "rb") as f:
    exec_outputs = pickle.load(f)
print("exec_outputs loaded")

# load eval_results
with open(f"{work_dir}/{dataset}/{gen_dir}/eval_results.json", "r") as f:
    eval_results = json.load(f)
for task_id in eval_results["eval"]:
    eval_results["eval"][task_id] = sorted(eval_results["eval"][task_id], key=lambda x: int(x["solution_id"]))
    
#pop out ["Mbpp/6", "Mbpp/7", "Mbpp/8", "Mbpp/9"]
if dataset == "mbpp":
    for task_id in ["Mbpp/6", "Mbpp/7", "Mbpp/8", "Mbpp/9"]:
        eval_results["eval"].pop(task_id)
print("eval_results loaded")

exec_outputs loaded
eval_results loaded


In [47]:
if dataset == "humaneval":
    with open(f"{work_dir}/{dataset}/{gen_dir}/errors.pkl", "rb") as f:
        errors = pickle.load(f)
with open(f"{work_dir}/{dataset}/{gen_dir}/logprobs.pkl", "rb") as f:
    logprobs = pickle.load(f)
print("logprobs loaded")
with open(f"{work_dir}/{dataset}/{gen_dir}/reviewer_logprobs.pkl", "rb") as f:
    reviewer_logprobs = pickle.load(f)
print("reviewer_logprobs loaded")

logprobs loaded
reviewer_logprobs loaded


In [54]:
from concurrent.futures import ThreadPoolExecutor

def process_task(task_id, eval_results_dict, exec_outputs_dict, max_hyps = 200, start_id = 0, num_plus_test_cases=3, granular=False, filter = True):
    p_name = task_id.replace("/", "_")
    test_all_plus_cases = True
    assert start_id >= 0 and start_id + max_hyps <= 200
    n_expected_outputs_base = len(expected_output[task_id]["base"])
    n_expected_outputs_plus = len(expected_output[task_id]["plus"])
    if num_plus_test_cases < n_expected_outputs_plus:
        test_all_plus_cases = False
    task_utility_base = []
    task_utility_plus = []
    
    for hyp_id, hyp in enumerate(eval_results_dict["eval"][task_id]):
        if hyp_id >= max_hyps + start_id or hyp_id < start_id:
            continue
        hyp_base_outputs = exec_outputs_dict[task_id][hyp_id]["base"]
        hyp_plus_outputs = exec_outputs_dict[task_id][hyp_id]["plus"]
        if not test_all_plus_cases:
            hyp_plus_outputs = {i: hyp_plus_outputs[i] for i in range(num_plus_test_cases) if i < min(num_plus_test_cases, len(hyp_plus_outputs))}
        hyp_utility_base = []
        hyp_utility_plus = []
        
        for ref_id, ref in enumerate(eval_results_dict["eval"][task_id]):
            if ref_id >= max_hyps + start_id or ref_id < start_id:
                continue
            if ref["base_status"] == ref["plus_status"] == hyp["base_status"] == hyp["plus_status"] == "pass":
                hyp_utility_base.append(1)
                hyp_utility_plus.append(1)
                continue
            if filter:
                ### add the filtering baseline
                if dataset in "humaneval":
                    if errors[task_id][hyp_id]["base"]["status"] != "pass":
                        hyp_utility_base.append(0)
                        hyp_utility_plus.append(0)
                        continue
                elif dataset == "mbpp":
                    if len(hyp["base_details"]) == 0 or hyp["base_details"][0] == 0:
                        hyp_utility_base.append(0)
                        hyp_utility_plus.append(0)
                        continue
                ### end of filtering baseline
            ref_base_outputs = exec_outputs_dict[task_id][ref_id]["base"]
            ref_plus_outputs = exec_outputs_dict[task_id][ref_id]["plus"]
            if not test_all_plus_cases:
                ref_plus_outputs = {i: ref_plus_outputs[i] for i in range(num_plus_test_cases) if i < min(num_plus_test_cases, len(ref_plus_outputs))}
            
            util_score_base = mbr_exec(hyp_base_outputs, ref_base_outputs, problems[task_id]["entry_point"], "mbpp", n_expected_outputs_base, granular=granular)
            util_score_plus = mbr_exec(hyp_plus_outputs, ref_plus_outputs, problems[task_id]["entry_point"], "mbpp", n_expected_outputs_plus, granular=granular)
            hyp_utility_base.append(util_score_base)
            if granular:
                hyp_utility_plus.append((util_score_plus*n_expected_outputs_plus+util_score_base*n_expected_outputs_base)/(n_expected_outputs_plus+n_expected_outputs_base))
            else:
                hyp_utility_plus.append(int(util_score_plus==util_score_base==1))
            #hyp_utility_plus.append(mbr_exec(hyp_plus_outputs, ref_plus_outputs, problems[task_id]["entry_point"], "mbpp", n_expected_outputs_plus))
            #hyp_utility_plus.extend(hyp_utility_base)
        task_utility_base.append(np.mean(hyp_utility_base) 
                                 * np.exp(
                                     np.mean(reviewer_logprobs[p_name][hyp_id]) +
                                     np.mean(logprobs[p_name][hyp_id])
                                     )
                                 )
        task_utility_plus.append(np.mean(hyp_utility_plus) 
                                 * np.exp(
                                     np.mean(reviewer_logprobs[p_name][hyp_id]) +
                                     np.mean(logprobs[p_name][hyp_id])
                                     )
                                 )
    
    # get argmax
    argmax_base = np.argmax(task_utility_base) + start_id
    argmax_plus = np.argmax(task_utility_plus) + start_id
    assert argmax_base == int(eval_results_dict["eval"][task_id][argmax_base]["solution_id"])
    assert argmax_plus == int(eval_results_dict["eval"][task_id][argmax_plus]["solution_id"])
    base_status = eval_results_dict["eval"][task_id][argmax_base]["base_status"]
    plus_status = eval_results_dict["eval"][task_id][argmax_plus]["plus_status"]
    if not test_all_plus_cases:
        argmax_plus_solution = eval_results_dict["eval"][task_id][argmax_plus]
        if len(argmax_plus_solution["plus_details"]) < min(num_plus_test_cases, len(expected_output[task_id]["plus"])):
            plus_status = "fail"
        else:
            if all(argmax_plus_solution["plus_details"][:num_plus_test_cases]):
                plus_status = "pass"
            else:
                plus_status = "fail"
    return (int(base_status == "pass"), int(base_status == plus_status == "pass"), argmax_base, argmax_plus)

In [55]:
def get_results(eval_results, exec_outputs, max_hyps=200, start_id=0, num_plus_test_cases=300, granular=False, filter=True, workers=20):
    for task_id in eval_results["eval"]:
        eval_results["eval"][task_id] = sorted(eval_results["eval"][task_id], key=lambda x: int(x["solution_id"]))
        
    base_results = {}
    plus_results = {}
    argmax_bases = {}
    argmax_pluss = {}
    
    def process_single_task(task_id):
        base_result, plus_result, argmax_base, argmax_plus = process_task(task_id, 
                                                                          eval_results, 
                                                                          exec_outputs, 
                                                                          max_hyps=max_hyps, 
                                                                          start_id=start_id, 
                                                                          num_plus_test_cases=num_plus_test_cases, 
                                                                          granular=granular, 
                                                                          filter=filter)
        base_results[task_id] = base_result
        plus_results[task_id] = plus_result
        argmax_bases[task_id] = argmax_base
        argmax_pluss[task_id] = argmax_plus

    with ThreadPoolExecutor(max_workers=workers) as executor:
        list(tqdm(executor.map(process_single_task, eval_results["eval"]), total=len(eval_results["eval"])))
    
    return base_results, plus_results

In [58]:
num_plus_test_cases = 300
max_hyps = 50
granular = False
filter = True
workers = 1
base_score = []
plus_score = []
for start_id in range(0, 200, max_hyps):
    #the thing that dont pass is between 120 and 125s
    #if start_id == 120:
    #    continue
    base_results, plus_results = get_results(eval_results, 
                                             exec_outputs, 
                                             max_hyps=max_hyps, 
                                             start_id=start_id, 
                                             num_plus_test_cases=num_plus_test_cases,
                                             granular=granular,
                                             filter=filter, 
                                             workers=workers)

    base_score.append(sum(base_results.values())/len(base_results))
    plus_score.append(sum(plus_results.values())/len(plus_results))

  3%|▎         | 13/395 [00:00<00:03, 98.63it/s]

100%|██████████| 395/395 [00:05<00:00, 69.29it/s] 
100%|██████████| 395/395 [00:05<00:00, 71.02it/s] 
100%|██████████| 395/395 [00:05<00:00, 74.34it/s]
100%|██████████| 395/395 [00:05<00:00, 67.74it/s] 


In [59]:
round_digits = 10
print("filter: {}".format(filter))
print("task: {}".format(dataset))
print("model: {}".format(gen_dir))
print("number of hypotheses: {}".format(max_hyps))
print("MBR base     ", np.round(sum(base_score)/len(base_score) * 100, round_digits), "%")
print("-"*100)
print("MBR plus     ", np.round(sum(plus_score)/len(plus_score) * 100, round_digits), "%")

filter: True
task: mbpp
model: code-llama-13b-instruct_temp_1.6
number of hypotheses: 50
MBR base      84.1772151899 %
----------------------------------------------------------------------------------------------------
MBR plus      74.8101265823 %
