## MongoDB HumanEval Prediction Inconsistency Database Builder

This notebook is used to initialize a new prediction inconsistency database onto MongoDB for HumanEval benchmark.

!!! Note: You **MUST** build the **Code Generation HumanEval**database before running this notebook as building the prediction inconsistency database relies on pre-processed data from there. Notebook for code generation HumanEval database is under `code_generation/test_notebooks/code_generation_test_notebook_humaneval.ipynb` from the project root directory.

In [None]:
import os
import sys
import ast
import copy
from typing import List, Any
import inspect
from tqdm import tqdm
from dotenv import load_dotenv

In [None]:
curr_dir = os.getcwd()
par_dir = os.path.dirname(curr_dir)
proj_dir = os.path.dirname(par_dir)
sys.path.append(proj_dir)
load_dotenv()

In [None]:
from database import MongoDBHelper
from prediction_inconsistency.prediction_inconsistency_tester import PredictionInconsistencyHumanEvalHelper
from code_mutation.mutation_functions import CodeMutator

In [None]:
db = MongoDBHelper()
if db.check_database_connectivity():
    print("MongoDB connected")

In [None]:
base_qns_db = db.client[os.getenv("MONGODB_BENCHMARK_DATABASE")]
question_database = base_qns_db[os.getenv("MONGODB_HUMANEVAL_CG_COLLECTION")]

In [None]:
tf_question_database = base_qns_db[os.getenv('MONGODB_HUMANEVAL_IO_COLLECTION')]

In [None]:
from typing import Tuple, Dict
class AST_Helper:

    COMPARE_OP_MAP = {
        ast.Eq: "==",
        ast.NotEq: "!=",
        ast.Lt: "<",
        ast.LtE: "<=",
        ast.Gt: ">",
        ast.GtE: ">=",
        ast.Is: "is",
        ast.IsNot: "is not",
        ast.In: "in",
        ast.NotIn: "not in",
    }

    def route_ast_type(code: ast.AST):
        if isinstance(code, ast.Call):
            return AST_Helper.extract_ast_call_args(code)

        elif isinstance(code, ast.Constant):
            return AST_Helper.extract_ast_constant(code)
        
        elif isinstance(code, ast.UnaryOp):
            return AST_Helper.extract_ast_unaryop(code)
        
        elif isinstance(code, ast.List):
            return AST_Helper.extract_ast_list(code)
        
        elif isinstance(code, list):
            res = [AST_Helper.route_ast_type(c) for c in code]
            return res[0] if len(res) == 1 else res
        
        elif isinstance(code, ast.Tuple):
            return AST_Helper.extract_ast_tuple(code)
        
        elif isinstance(code, ast.Dict):
            return AST_Helper.extract_ast_dict(code)
            
        elif isinstance(code, ast.BinOp):
            return AST_Helper.extract_ast_bin_op(code)
        
        elif isinstance(code, ast.Name):
            return AST_Helper.extract_ast_name(code)
        
        elif isinstance(code, ast.operator):
            if isinstance(code, ast.Mult):
                return "*"
            elif isinstance(code, ast.Sub):
                return "-"
            elif isinstance(code, ast.Div):
                return "/"
            elif isinstance(code, ast.Add):
                return "+"
            elif isinstance(code, ast.Pow):
                return "**"
            else:
                raise ValueError(f"A method to process {type(code)} operator type has not been developed")
        
        elif not isinstance(code, ast.AST):
            return code
        
        else:
            raise ValueError(f"A method to process {type(code)} has not been developed")

    def extract_ast_dict(code: ast.Dict) -> Dict[Any, Any]:
        if isinstance(code, ast.Dict):
            dict_values = code.values
            dict_keys = code.keys
            key_val_pairs = zip(dict_keys, dict_values)
            res = {}
            for pair in key_val_pairs:
                res[AST_Helper.route_ast_type(pair[0])] = AST_Helper.route_ast_type(pair[1])
            return res

        else:
            raise ValueError("Incorrect extraction method used, code snippet is not ast.Dict type.")

    def extract_ast_call_args(code: ast.Call) -> Tuple[str, list[str]]:
        if isinstance(code, ast.Call):
            test_args = [AST_Helper.route_ast_type(arg) for arg in code.args]
            if code.func.id not in dir(__builtins__):
                args_meta_data = [type(arg).__name__ for arg in test_args]
                return test_args[0] if len(test_args) == 1 else test_args, args_meta_data[0] if len(args_meta_data) == 1 else args_meta_data
            else:
                return test_args[0] if len(test_args) == 1 else test_args
        else:
            raise ValueError("Incorrect extraction method used, code snippet is not ast.Call type.")
        
    def extract_ast_constant(code: ast.Constant) -> str:
        if isinstance(code, ast.Constant):
            return code.value
        else:
            raise ValueError("Incorrect extraction method used, code snippet is not ast.Constant type.")
    
    def extract_ast_compare(code: ast.Compare) -> Tuple[Any, str | List[str], Any]:
        if isinstance(code, ast.Compare):
            left_side = code.left
            if isinstance(left_side, ast.AST):
                left = AST_Helper.route_ast_type(left_side)

            comparators = AST_Helper.route_ast_type(code.comparators)

            ops = [AST_Helper.COMPARE_OP_MAP[type(op)] for op in code.ops]                    # a list is used here as there could be more than 1 ops
            
            return (
                left, 
                ops[0] if len(ops) == 1 else ops, 
                comparators
            )
        
        else:
            raise ValueError("Incorrect extraction method used, code snippet is not ast.Constant type.")
        
    def extract_ast_tuple(code: ast.Tuple) -> Tuple[Any]:
        if isinstance(code, ast.Tuple):
            elts = code.elts
            return tuple(AST_Helper.route_ast_type(elt) for elt in elts)
        else:
            raise ValueError("Incorrect extraction method used, code snippet is not ast.Tuple type.")
    
    def extract_ast_bin_op(code: ast.BinOp) -> int:
        if isinstance(code, ast.BinOp):
            left = AST_Helper.route_ast_type(code.left)
            right = AST_Helper.route_ast_type(code.right)
            oper = AST_Helper.route_ast_type(code.op)

            def format_val(val):
                return repr(val) if isinstance(val, str) else val
            
            return (eval(f"{format_val(left)} {oper} {format_val(right)}"))
        else:
            raise ValueError("Incorrect extraction method used, code snippet is not ast.Tuple type.")

    def extract_ast_name(code: ast.Name) -> str:
        if isinstance(code, ast.Name):
            return code.id
        else:
            raise ValueError("Incorrect extraction method used, code snippet is not ast.Name type.")

    def extract_ast_unaryop(code: ast.UnaryOp) -> int | float:
        unary_map = {
            ast.USub: "-",
            ast.UAdd: "+",
            ast.Not: "not ",
            ast.Invert: "~"
        }
        if isinstance(code, ast.UnaryOp):
            symbol = unary_map[type(code.op)]
            value = str(code.operand.value)
            try:
                return int(symbol + value)
            except ValueError:
                return float(symbol + value)
            except Exception as e:
                raise ValueError(f"Unable to extract the unaryop due to the following error: {e}")

    def extract_ast_list(code: ast.List) -> List[Any]:
        if isinstance(code, ast.List):
            list_elements = code.elts
            return [AST_Helper.route_ast_type(elt) for elt in list_elements]
        else:
            print(type(code))
            raise ValueError("Incorrect extraction method used, code snippet is not ast.List type.")

In [None]:
def extract_assert_cases(code: str) -> Tuple[int, List, int]:
    test_cases = []             # list storing the test parameters and test outputs for this check function
    num_cases = 0               # integer storing the number of test cases in this check function
    failed_cases = set()        # assert statements that failed to extract, if any
    rejected_cases = 0           # rejected test cases as the assert statements do not check for "=="

    tree = ast.parse(code)
    for node in tree.body:
        if not isinstance(node, ast.FunctionDef):
            continue

        for subnode in node.body:       #iterating through eachline node within the check function
            if not isinstance(subnode, ast.Assert):
                continue 

            test_expr = subnode.test
            num_cases+= 1
            test_outputs = None
            test_params = None

            if isinstance(test_expr, ast.Compare):
                ### Extracts test cases such as "assert candidate([1,2,3]) == 3"
                ops_type = test_expr.ops[0]                 # assuming only one operator in the assert case
                if type(ops_type) != ast.Eq:                # test case rejected as it is not check for equivalence
                    rejected_cases += 1
                    num_cases -= 1                          # not collecting comparisons with "<", ">", etc as test cases
                    continue
                else:
                    test_params, test_operators, test_outputs = AST_Helper.extract_ast_compare(test_expr)
            elif isinstance(test_expr, ast.Call):
                ### Extracts test cases such as assert candidate([1,2,3]) 
                test_params = AST_Helper.extract_ast_call_args(test_expr)
                test_outputs = True

            elif isinstance(test_expr, ast.UnaryOp):
                ### Extracts test cases such as assert not candidate([1,2,3])
                op = test_expr.op
                operand = test_expr.operand
                test_params = AST_Helper.route_ast_type(operand)
                if isinstance(op, ast.Not):
                    test_outputs = False 
                else: 
                    failed_cases.add(num_cases-1)
                
            elif isinstance(test_expr, ast.Constant):
                ### Ignores test cases such as assert True
                num_cases -= 1
                continue
            else:
                print("Can't decide: ", type(test_expr))
                continue
                
            test_cases.append((test_params, test_outputs))
                
        
    return num_cases, test_cases, rejected_cases


In [None]:
c = set()
tot = 0
repurposed_qn = {}
rej_tot = 0
for i in tqdm(range(
    question_database.count_documents({})
    )):
    task_id = f"HumanEvalo{i}"
    sample_qn = question_database.find_one({"_id" : task_id})
    qn = sample_qn['qn']
    qn_desc = sample_qn['qn_desc']
    examples = sample_qn['examples']
    canon_sol = sample_qn['canon_solution']
    check = sample_qn['check']
    original_id = sample_qn['original_id']
    func_name = sample_qn['func_name']

    complete_sol = qn + '\n' + canon_sol

    try:
        num_cases, test_cases, rejected_cases = extract_assert_cases(check)
        if num_cases != len(test_cases):
            print(f"{task_id}: num cases = {num_cases}, test cases extracted = {test_cases}")
        tot += num_cases
        rej_tot += rejected_cases
        if len(test_cases) < 1:
            c.add(task_id)
        else:
            repurposed_qn[task_id] = test_cases

    except Exception as e:
        print(task_id)
        print(f"Failed due to following error: {e}")

    for idx, test_case_details in enumerate(test_cases):
        test_case_id = f"HumanEvalTF{tot - len(test_cases) + idx}"

        test_case, expected_output = test_case_details
        test_input, args_meta_data = test_case

        namespace = {}

        exec(complete_sol, namespace)

        sig = inspect.signature(namespace[func_name])
        input_copy = copy.deepcopy(test_input) if isinstance(test_input, (list, dict, set, tuple)) else test_input

        try:
            if len(sig.parameters) > 1 and isinstance(test_input, list):
                assert namespace[func_name](*input_copy) == expected_output
            else:
                assert namespace[func_name](input_copy) == expected_output

        except:
            print(f"Did not pass test case. Double check task_id {task_id}, test_case {test_input}")

        input_metadata = type(test_input).__name__

        func_in_input = None
        
        if isinstance(test_input, dict):
            test_input = str(test_input)
        elif isinstance(test_input, (tuple, list)):
            test_input = str(test_input) if any(i for i in test_input if isinstance(i, dict)) else test_input

        ### Storing / updating entry in the database
        db_entry = {
            "_id" : test_case_id,
            "full_sol" : complete_sol,
            "qn_desc": qn_desc,
            "input" : {
                "args": test_input,
                "metadata": args_meta_data,
                },
            "output": {
                "args" : str(expected_output),
                "metadata": type(expected_output).__name__,
                },
            "examples": examples,
            "original_id": original_id,
            "func_name": func_name
        }

        try:
            tf_question_database.update_one(
                filter={"_id": test_case_id},
                update={"$set": db_entry},
                upsert=True
            )
        except Exception as e:
            print(f"Could not enter test case {test_case_id} into TF database due to the following error: {e}")
            print(f"Testcase: {test_case}")


        ## Secondary check where the question is pulled from the database and tested against the check function
        ## This step is necessary as MongoDB does not store these details in standard Python data formats and a secondary step is needed for sanity check. 
        ## For example, tuples are stored as "arrays" in MongoDB, which are converted to Lists in Python.
        
        qn = tf_question_database.find_one({"_id" : test_case_id})
        if qn is None:
            continue
        full_sol = qn['full_sol']                           # full canonical solution for the task
        examples = qn['examples']                           # examples for other prompt techniques like one shot, few shot

        test_inputs = qn['input']                           # unpacking input args and metadata from qn
        input_args = test_inputs['args']                    # test input args
        input_metadata = test_inputs['metadata']            # test input metadata

        test_outputs = qn['output']                         # unpacking outputs args and metadata from qn
        output_args = test_outputs['args']                  # test output args
        output_metadata = test_outputs['metadata']          # test output metadata
        func_name = qn['func_name']

        try:
            input_args = eval(input_args) if isinstance(input_args, str) and input_metadata != str.__name__  else input_args
        except Exception as e:
            print(input_metadata)
            print(original_id)
            print(e)
            continue
        
        ## Processing of output args and metadata
        output_args = ast.literal_eval(output_args) if output_metadata != str.__name__ else output_args

        check_stored_soln_validity = PredictionInconsistencyHumanEvalHelper.check_input_output(
            full_sol= full_sol,
            test_input= copy.deepcopy(input_args),
            expected_output= output_args,
            func_name=func_name,
            input_metadata = input_metadata
        )

        if check_stored_soln_validity is not True:
            tf_question_database.find_one_and_delete({"_id" : task_id})
            print(f'{test_case_id} from {task_id} failed the secondary checks and was not added to the database.')
            tot-=1



In [None]:
print(f"{tf_question_database.count_documents({})} total test cases in the database")
print(f'{tot} valid test cases successfully stored in MongoDB')
print(f'{rej_tot} test cases were rejected')