## MongoDB CruxEval Prediction Inconsistency Database Builder

This notebook is used to initialize a new prediction inconsistency database onto MongoDB for CruxEval benchmark.

In [None]:
import os
import sys
import copy
import inspect
from tqdm import tqdm
import pandas as pd
import ast
from dotenv import load_dotenv

In [None]:
curr_dir = os.getcwd()
par_dir = os.path.dirname(curr_dir)
proj_dir = os.path.dirname(par_dir)
sys.path.append(proj_dir)
load_dotenv()

In [None]:
from database import MongoDBHelper
from prediction_inconsistency.utility.cruxeval_helper import PredictionInconsistencyCruxEvalHelper

In [None]:
db = MongoDBHelper()
if db.check_database_connectivity():
    print("MongoDB connected")

In [None]:
base_qns_db = db.client[os.getenv('MONGODB_BENCHMARK_DATABASE')]
cruxeval_database = pd.read_csv(
    os.path.join(proj_dir, "datasets/open_ended_format/cruxeval_test.csv"),
    encoding="utf-8",
    header=0,
    quoting=1,           
    )

In [None]:
tf_question_database = base_qns_db[os.getenv('MONGODB_CRUXEVAL_COLLECTION')]

In [None]:
c = set()
tot = 0
repurposed_qn = {}
rej_tot = 0
for i in tqdm(range(
    len(cruxeval_database)
    )):
    task_id = f"CruxEvalTF{i}"
    sample_qn = cruxeval_database.iloc[i]
    full_sol = sample_qn['code']
    test_input = sample_qn['input']
    expected_output = sample_qn['output']
    original_id = sample_qn['id']

    namespace = {}

    func_name  = PredictionInconsistencyCruxEvalHelper.extract_func_name(full_sol)

    exec(full_sol, namespace)


    try:
        test_input = eval(test_input, namespace) if not isinstance(test_input, float) else None
        match expected_output:
            case "FALSE":
                expected_output = False
            case "TRUE":
                expected_output = True
            case _:
                expected_output = eval(expected_output)
    except Exception as e:
        print(test_input)
        print(task_id, f"failed due to following error: {e}")
        continue

    sig = inspect.signature(namespace[func_name])
    input_copy = copy.deepcopy(test_input) if isinstance(test_input, (list, dict, tuple, set)) else test_input
    try:
        if len(sig.parameters) == 0:
            assert namespace[func_name]() == expected_output
        elif len(sig.parameters) > 1:
            assert namespace[func_name](*input_copy) == expected_output
        else:
            assert namespace[func_name](input_copy) == expected_output

    except:
        print(f"Did not pass test case. Double check task_id {task_id}, test_case {test_input}")
    
    input_metadata = type(test_input).__name__

    func_in_input = None

    if input_metadata not in ('str', 'NoneType'):
        input_args_tree = ast.parse(sample_qn['input'])
        func_in_input = any(isinstance(node, (ast.Call, ast.Lambda)) for node in ast.walk(input_args_tree))

    if isinstance(test_input, dict):
        test_input = str(test_input)
    elif isinstance(test_input, (tuple, list)):
        test_input = str(test_input) if any(i for i in test_input if isinstance(i, dict)) else test_input
    
    ### Storing / updating entry in the database
    db_entry = {
        "_id" : task_id,
        "full_sol" : full_sol,
        "input" : {
            "args": test_input if not func_in_input else sample_qn['input'],
            "metadata": input_metadata,
            },
        "output": {
            "args" : str(expected_output),
            "metadata": type(expected_output).__name__,
            },
        "original_id": original_id
    }

    try:
        tf_question_database.update_one(
            filter={"_id": task_id},
            update={"$set": db_entry},
            upsert=True
        )
    except Exception as e:
        print(f"Could not enter test case {original_id} into TF database due to the following error: {e}")


    ## Secondary check where the question is pulled from the database and tested against the check function
    ## This step is necessary as MongoDB does not store these details in standard Python data formats and a secondary step is needed for sanity check. 
    ## For example, tuples are stored as "arrays" in MongoDB, which are converted to Lists in Python.
    
    qn = tf_question_database.find_one({"_id" : task_id})
    if qn is None:
        continue
    full_sol = qn['full_sol']                           # full canonical solution for the task

    test_inputs = qn['input']                           # unpacking input args and metadata from qn
    input_args = test_inputs['args']                    # test input args
    input_metadata = test_inputs['metadata']            # test input metadata

    test_outputs = qn['output']                         # unpacking outputs args and metadata from qn
    output_args = test_outputs['args']                  # test output args
    output_metadata = test_outputs['metadata']          # test output metadata

    try:
        input_args = eval(input_args) if isinstance(input_args, str) and input_metadata != str.__name__  else input_args
    except Exception as e:
        print(original_id)
        print(e)
        continue
    output_args = ast.literal_eval(output_args) if output_metadata != str.__name__ else output_args

    check_stored_soln_validity = PredictionInconsistencyCruxEvalHelper.check_input_output(
        full_sol=full_sol,
        test_input=input_args,
        input_metadata=input_metadata,
        expected_output= output_args,
        func_name= func_name,
    )

    if check_stored_soln_validity is not True:
        tf_question_database.find_one_and_delete({"_id" : task_id})
        print(f'{original_id} from {task_id} failed the secondary checks and was not added to the database.')
        tot-=1

sample_452: modified such that the "," behind the initial test input is removed and the single quotation marks replaced with double

In [None]:
print(f"{tf_question_database.count_documents({})} total test cases in the database")
print(f'{tot} valid test cases')
print(f'{rej_tot} test cases were rejected')