In [None]:
from evalplus.data import get_human_eval_plus, write_jsonl

In [None]:
problems = []
for task_id, problem in get_human_eval_plus().items():
    # print(f"Task {task_id}:")
    # print(f"{problem['prompt']}")
    # print(f"{problem}")
    # print(f"{problem['base_input'][0]}")
    problems.append(problem)


In [98]:
import ast
import inspect
import json
from typing import Any, Callable, Dict, List, Union
from hypothesis import given, strategies as st, assume
from hypothesis import Verbosity


def extract_function(code: str, func_name: str) -> Callable:
    """Extract a function from a code string."""
    module = ast.parse(code)
    function_def = next(node for node in module.body if isinstance(node, ast.FunctionDef) and node.name == func_name)
    
    locals_dict = {}
    exec(compile(ast.Module(body=[function_def], type_ignores=[]), filename="<ast>", mode="exec"), globals(), locals_dict)
    return locals_dict[func_name]

def infer_type_strategy(value: Any) -> st.SearchStrategy:
    """Infer a Hypothesis strategy based on the type of the given value."""
    if isinstance(value, bool):
        return st.booleans()
    elif isinstance(value, int):
        return st.integers(min_value=value-100, max_value=value+100)
    elif isinstance(value, float):
        return st.floats(min_value=value-100, max_value=value+100, allow_nan=False, allow_infinity=False)
    elif isinstance(value, str):
        return st.text(min_size=len(value), max_size=len(value)+10)
    elif isinstance(value, list):
        if not value:
            return st.lists(st.integers(), min_size=0, max_size=10)
        element_strategy = infer_type_strategy(value[0])
        return st.lists(element_strategy, min_size=len(value), max_size=len(value)+5)
    else:
        raise ValueError(f"Unsupported type: {type(value)}")

def generate_hypothesis_test(task: Dict[str, Any]) -> Callable:
    """Generate a Hypothesis test function based on the task's base_input."""
    arg_strategies = [infer_type_strategy(arg) for arg in task["base_input"][0]]
    
    @settings(max_examples=10000, deadline=None, derandomize=True)
    @given(st.tuples(*arg_strategies))
    def test_equivalence(canonical_func: Callable, candidate_func: Callable, args: tuple):
        try:
            canonical_result = canonical_func(*args)
            candidate_result = candidate_func(*args)
            
            if isinstance(canonical_result, float) and isinstance(candidate_result, float):
                assert abs(canonical_result - candidate_result) <= task["atol"]
            else:
                assert canonical_result == candidate_result
        except Exception as e:
            print(f"Error occurred with inputs: {args}")
            raise e
    
    return test_equivalence

def test_implementation(task: Dict, candidate_implementation: str) -> bool:
    """Test if the candidate implementation is equivalent to the canonical solution using Hypothesis."""
    
    canonical_func = extract_function(task["prompt"] + task["canonical_solution"], task["entry_point"])
    candidate_func = extract_function(task["prompt"] + candidate_implementation, task["entry_point"])
    
    try:
        hypothesis_test = generate_hypothesis_test(task)
        hypothesis_test(canonical_func, candidate_func)
        return True
    except AssertionError:
        return False
    except Exception as e:
        return False
    

    

# Example usage
task_json = '''
{
    "task_id": "HumanEval/0",
    "prompt": "from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n    \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n    given threshold.\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n    False\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n    True\n    \"\"\"",
    "canonical_solution": "\n    sorted_numbers = sorted(numbers)\n    for i in range(len(sorted_numbers) - 1):\n        if sorted_numbers[i + 1] - sorted_numbers[i] < threshold:\n            return True\n    return False\n",
    "entry_point": "has_close_elements",
    "atol": 0,
    "base_input": [
        [[1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3],
        [[1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05],
        [[1.0, 2.0, 5.9, 4.0, 5.0], 0.95],
        [[1.0, 2.0, 5.9, 4.0, 5.0], 0.8],
        [[1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1]
    ]
}
'''

candidate_implementation = """
    for i in range(len(numbers)):
        for j in range(i + 1, len(numbers)):
            if abs(numbers[i] - numbers[j]) < threshold:
                return True
    return False
"""

is_equivalent = test_implementation(problems[0], candidate_implementation)
print(f"The candidate implementation is {'equivalent' if is_equivalent else 'not equivalent'} to the canonical solution.")

The candidate implementation is equivalent to the canonical solution.


In [None]:
problem