diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 48ecf396a..7a7641a57 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections import Counter, defaultdict +from functools import lru_cache from typing import TYPE_CHECKING import libcst as cst @@ -13,6 +14,7 @@ if TYPE_CHECKING: from collections.abc import Iterator + import enum import re import sys @@ -23,11 +25,13 @@ from typing import Annotated, Optional, cast from jedi.api.classes import Name -from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError +from pydantic import (AfterValidator, BaseModel, ConfigDict, PrivateAttr, + ValidationError) from pydantic.dataclasses import dataclass from codeflash.cli_cmds.console import console, logger -from codeflash.code_utils.code_utils import module_name_from_file_path, validate_python_code +from codeflash.code_utils.code_utils import (module_name_from_file_path, + validate_python_code) from codeflash.code_utils.env_utils import is_end_to_end from codeflash.verification.comparator import comparator @@ -513,23 +517,22 @@ def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Option return None def get_src_code(self, test_path: Path) -> Optional[str]: - if not test_path.exists(): - return None - test_src = test_path.read_text(encoding="utf-8") - module_node = cst.parse_module(test_src) - - if self.test_class_name: - for stmt in module_node.body: - if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name: - func_node = self.find_func_in_class(stmt, self.test_function_name) - if func_node: - return module_node.code_for_node(func_node).strip() - # class not found + module_node = self._parse_module_by_path(str(test_path)) + if module_node is None: return None + test_func_name = self.test_function_name + test_class_name = self.test_class_name + found_func = None + # Otherwise, look for a top level function for stmt in module_node.body: - if isinstance(stmt, cst.FunctionDef) and stmt.name.value == self.test_function_name: + if test_class_name is not None and isinstance(stmt, cst.ClassDef) and stmt.name.value == test_class_name: + found_func = self.find_func_in_class(stmt, test_func_name) + if found_func: + return module_node.code_for_node(found_func).strip() + return None # Class found but function not found + if test_class_name is None and isinstance(stmt, cst.FunctionDef) and stmt.name.value == test_func_name: return module_node.code_for_node(stmt).strip() return None @@ -552,6 +555,17 @@ def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId iteration_id=iteration_id if iteration_id else components[3], ) + # All attribute definitions are preserved + + @staticmethod + @lru_cache(maxsize=32) + def _parse_module_by_path(test_path_str: str) -> Optional[cst.Module]: + path = Path(test_path_str) + if not path.exists(): + return None + test_src = path.read_text(encoding="utf-8") + return cst.parse_module(test_src) + @dataclass(frozen=True) class FunctionTestInvocation: @@ -631,7 +645,8 @@ def get_all_ids(self) -> set[InvocationId]: return {test_result.id for test_result in self.test_results} def get_all_unique_invocation_loop_ids(self) -> set[str]: - return {test_result.unique_invocation_loop_id for test_result in self.test_results} + # generator expression for memory efficiency + return set(tr.unique_invocation_loop_id for tr in self.test_results) def number_of_loops(self) -> int: if not self.test_results: diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index b752a0af7..febbfb5b9 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -13,7 +13,8 @@ import sentry_sdk from codeflash.cli_cmds.console import logger -from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError +from codeflash.picklepatch.pickle_placeholder import \ + PicklePlaceholderAccessError HAS_NUMPY = find_spec("numpy") is not None HAS_SQLALCHEMY = find_spec("sqlalchemy") is not None @@ -34,11 +35,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 # distinct type objects are created at runtime, even if the class code is exactly the same, so we can only compare the names if type_obj.__name__ != new_type_obj.__name__ or type_obj.__qualname__ != new_type_obj.__qualname__: return False - if isinstance(orig, (list, tuple, deque, ChainMap)): - if len(orig) != len(new): - return False - return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new)) + # Cheap, common types first if isinstance( orig, ( @@ -65,6 +63,14 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 if math.isnan(orig) and math.isnan(new): return True return math.isclose(orig, new) + if isinstance(orig, (list, tuple, deque, ChainMap)): + if len(orig) != len(new): + return False + for elem1, elem2 in zip(orig, new): + if not comparator(elem1, elem2, superset_obj): + return False + return True + if isinstance(orig, BaseException): if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError): # If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object. @@ -78,15 +84,16 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")} return comparator(orig_dict, new_dict, superset_obj) + # JAX, XARRAY, NUMPY, PANDAS, TORCH modules imported once per function call if needed + np = None + pandas = None if HAS_JAX: import jax # type: ignore # noqa: PGH003 import jax.numpy as jnp # type: ignore # noqa: PGH003 # Handle JAX arrays first to avoid boolean context errors in other conditions if isinstance(orig, jax.Array): - if orig.dtype != new.dtype: - return False - if orig.shape != new.shape: + if orig.dtype != new.dtype or orig.shape != new.shape: return False return bool(jnp.allclose(orig, new, equal_nan=True)) @@ -101,11 +108,11 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 import sqlalchemy # type: ignore # noqa: PGH003 try: - insp = sqlalchemy.inspection.inspect(orig) - insp = sqlalchemy.inspection.inspect(new) # noqa: F841 + sqlalchemy.inspection.inspect(orig) + sqlalchemy.inspection.inspect(new) orig_keys = orig.__dict__ new_keys = new.__dict__ - for key in list(orig_keys.keys()): + for key in orig_keys: if key.startswith("_"): continue if key not in new_keys or not comparator(orig_keys[key], new_keys[key], superset_obj): @@ -117,16 +124,20 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 if HAS_SCIPY: import scipy # type: ignore # noqa: PGH003 - # scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it - if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)): + + # Dict support/Sparse + is_sparse = HAS_SCIPY and "scipy" in globals() and isinstance(orig, scipy.sparse.spmatrix) + if isinstance(orig, dict) and not is_sparse: if superset_obj: - return all(k in new and comparator(v, new[k], superset_obj) for k, v in orig.items()) + for k, v in orig.items(): + if k not in new or not comparator(v, new[k], superset_obj): + return False + return True + # Strict equality check if len(orig) != len(new): return False - for key in orig: - if key not in new: - return False - if not comparator(orig[key], new[key], superset_obj): + for k, v in orig.items(): + if k not in new or not comparator(v, new[k], superset_obj): return False return True @@ -134,15 +145,15 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 import numpy as np # type: ignore # noqa: PGH003 if isinstance(orig, np.ndarray): - if orig.dtype != new.dtype: - return False - if orig.shape != new.shape: + if orig.dtype != new.dtype or orig.shape != new.shape: return False try: return np.allclose(orig, new, equal_nan=True) except Exception: - # fails at "ufunc 'isfinite' not supported for the input types" - return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)]) + for x, y in zip(orig, new): + if not comparator(x, y, superset_obj): + return False + return True if isinstance(orig, (np.floating, np.complex64, np.complex128)): return np.isclose(orig, new) @@ -153,12 +164,24 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 if isinstance(orig, np.void): if orig.dtype != new.dtype: return False - return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields) + for field in orig.dtype.fields: + if not comparator(orig[field], new[field], superset_obj): + return False + return True + # nan/inf for numpy base types + try: + if np.isnan(orig): + return np.isnan(new) + except Exception: + pass + try: + if np.isinf(orig): + return np.isinf(new) + except Exception: + pass - if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix): - if orig.dtype != new.dtype: - return False - if orig.get_shape() != new.get_shape(): + if is_sparse: + if orig.dtype != new.dtype or orig.get_shape() != new.get_shape(): return False return (orig != new).nnz == 0 @@ -176,35 +199,23 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return True if isinstance(orig, array.array): - if orig.typecode != new.typecode: - return False - if len(orig) != len(new): + if orig.typecode != new.typecode or len(orig) != len(new): return False - return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new)) - - # This should be at the end of all numpy checking - try: - if HAS_NUMPY and np.isnan(orig): - return np.isnan(new) - except Exception: # noqa: S110 - pass - try: - if HAS_NUMPY and np.isinf(orig): - return np.isinf(new) - except Exception: # noqa: S110 - pass + for elem1, elem2 in zip(orig, new): + if not comparator(elem1, elem2, superset_obj): + return False + return True if HAS_TORCH: import torch # type: ignore # noqa: PGH003 if isinstance(orig, torch.Tensor): - if orig.dtype != new.dtype: - return False - if orig.shape != new.shape: - return False - if orig.requires_grad != new.requires_grad: - return False - if orig.device != new.device: + if ( + orig.dtype != new.dtype + or orig.shape != new.shape + or orig.requires_grad != new.requires_grad + or orig.device != new.device + ): return False return torch.allclose(orig, new, equal_nan=True) @@ -242,12 +253,12 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 if attr.eq: attr_name = attr.name new_attrs_dict[attr_name] = getattr(new, attr_name, None) - return all( - k in new_attrs_dict and comparator(v, new_attrs_dict[k], superset_obj) for k, v in orig_dict.items() - ) + for k, v in orig_dict.items(): + if k not in new_attrs_dict or not comparator(v, new_attrs_dict[k], superset_obj): + return False + return True return comparator(orig_dict, new_dict, superset_obj) - # re.Pattern can be made better by DFA Minimization and then comparing if isinstance( orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone, re.Pattern) ): @@ -275,8 +286,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 new_keys = {k: v for k, v in new_keys.items() if not k.startswith("__")} if superset_obj: - # allow new object to be a superset of the original object - return all(k in new_keys and comparator(v, new_keys[k], superset_obj) for k, v in orig_keys.items()) + for k, v in orig_keys.items(): + if k not in new_keys or not comparator(v, new_keys[k], superset_obj): + return False + return True if isinstance(orig, ast.AST): orig_keys = {k: v for k, v in orig.__dict__.items() if k != "parent"} diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 77798d88f..2f641fb86 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -38,19 +38,25 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR original_recursion_limit = sys.getrecursionlimit() if original_recursion_limit < INCREASED_RECURSION_LIMIT: sys.setrecursionlimit(INCREASED_RECURSION_LIMIT) # Increase recursion limit to avoid RecursionError - test_ids_superset = original_results.get_all_unique_invocation_loop_ids().union( - set(candidate_results.get_all_unique_invocation_loop_ids()) - ) + test_ids_superset = original_results.get_all_unique_invocation_loop_ids() + test_ids_superset = test_ids_superset.union(candidate_results.get_all_unique_invocation_loop_ids()) + test_diffs: list[TestDiff] = [] did_all_timeout: bool = True + # Cache candidate failures dict lookup outside loop + candidate_test_failures = candidate_results.test_failures + # Loop with cached function calls + get_cdd_result = candidate_results.get_by_unique_invocation_loop_id + get_orig_result = original_results.get_by_unique_invocation_loop_id + for test_id in test_ids_superset: - original_test_result = original_results.get_by_unique_invocation_loop_id(test_id) - cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id) - candidate_test_failures = candidate_results.test_failures + original_test_result = get_orig_result(test_id) + cdd_test_result = get_cdd_result(test_id) + # This is just caching the pytest error extraction branch to single lookup # original_test_failures = original_results.test_failures cdd_pytest_error = ( candidate_test_failures.get(original_test_result.id.test_function_name, "") - if candidate_test_failures + if candidate_test_failures and original_test_result is not None else "" ) # original_pytest_error = ( @@ -59,9 +65,9 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR if cdd_test_result is not None and original_test_result is None: continue - # If helper function instance_state verification is not present, that's ok. continue if ( - original_test_result.verification_type + original_test_result + and original_test_result.verification_type and original_test_result.verification_type == VerificationType.INIT_STATE_HELPER and cdd_test_result is None ): @@ -71,12 +77,13 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR did_all_timeout = did_all_timeout and original_test_result.timed_out if original_test_result.timed_out: continue - superset_obj = False - if original_test_result.verification_type and ( + superset_obj = ( original_test_result.verification_type in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO} - ): - superset_obj = True + if original_test_result.verification_type + else False + ) + test_src_code = original_test_result.id.get_src_code(original_test_result.file_name) if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj): test_diffs.append( @@ -101,8 +108,12 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR except Exception as e: logger.error(e) break - if (original_test_result.stdout and cdd_test_result.stdout) and not comparator( - original_test_result.stdout, cdd_test_result.stdout + + # Fast fail: check stdout + if ( + original_test_result.stdout + and cdd_test_result.stdout + and not comparator(original_test_result.stdout, cdd_test_result.stdout) ): test_diffs.append( TestDiff( @@ -115,12 +126,17 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR ) break - if original_test_result.test_type in { - TestType.EXISTING_UNIT_TEST, - TestType.CONCOLIC_COVERAGE_TEST, - TestType.GENERATED_REGRESSION, - TestType.REPLAY_TEST, - } and (cdd_test_result.did_pass != original_test_result.did_pass): + # TestType mismatch + if ( + original_test_result.test_type + in { + TestType.EXISTING_UNIT_TEST, + TestType.CONCOLIC_COVERAGE_TEST, + TestType.GENERATED_REGRESSION, + TestType.REPLAY_TEST, + } + and cdd_test_result.did_pass != original_test_result.did_pass + ): test_diffs.append( TestDiff( scope=TestDiffScope.DID_PASS,