From 51c936f42167f24beb3c0d2c43d9b44ad3c7138b Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Mon, 9 Jun 2025 00:02:59 -0700 Subject: [PATCH 1/2] normalize code before hashing --- codeflash/context/code_context_extractor.py | 6 +- tests/test_code_context_extractor.py | 78 ++++++--------------- 2 files changed, 26 insertions(+), 58 deletions(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 2971b4e7f..b520f12b7 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ast import hashlib import os from collections import defaultdict @@ -510,7 +511,10 @@ def parse_code_and_prune_cst( if not found_target: raise ValueError("No target functions found in the provided code") if filtered_node and isinstance(filtered_node, cst.Module): - return str(filtered_node.code) + code = str(filtered_node.code) + if code_context_type == CodeContextType.HASHING: + code = ast.unparse(ast.parse(code)) # Makes it standard + return code return "" diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 2d4dd56cb..c5f008b43 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -114,11 +114,10 @@ class HelperClass: def helper_method(self): return self.name - class MainClass: def main_method(self): - self.name = HelperClass.NestedClass("test").nested_method() + self.name = HelperClass.NestedClass('test').nested_method() return HelperClass(self.name).helper_method() ``` """ @@ -181,22 +180,17 @@ class Graph: def topologicalSortUtil(self, v, visited, stack): visited[v] = True - for i in self.graph[v]: if visited[i] == False: self.topologicalSortUtil(i, visited, stack) - stack.insert(0, v) def topologicalSort(self): visited = [False] * self.V stack = [] - for i in range(self.V): if visited[i] == False: self.topologicalSortUtil(i, visited, stack) - - # Print contents of stack return stack ``` """ @@ -614,58 +608,37 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]): ```python:{file_path.relative_to(opt.args.project_root)} class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): - def get_cache_or_call( - self, - *, - func: Callable[_P, Any], - args: tuple[Any, ...], - kwargs: dict[str, Any], - lifespan: datetime.timedelta, - ) -> Any: # noqa: ANN401 - if os.environ.get("NO_CACHE"): + def get_cache_or_call(self, *, func: Callable[_P, Any], args: tuple[Any, ...], kwargs: dict[str, Any], lifespan: datetime.timedelta) -> Any: + if os.environ.get('NO_CACHE'): return func(*args, **kwargs) - try: key = self.hash_key(func=func, args=args, kwargs=kwargs) - except: # noqa: E722 - # If we can't create a cache key, we should just call the function. - logging.warning("Failed to hash cache key for function: %s", func) + except: + logging.warning('Failed to hash cache key for function: %s', func) return func(*args, **kwargs) result_pair = self.get(key=key) - if result_pair is not None: cached_time, result = result_pair - if not os.environ.get("RE_CACHE") and ( - datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005 - ): + if not os.environ.get('RE_CACHE') and datetime.datetime.now() < cached_time + lifespan: try: return self.decode(data=result) except CacheBackendDecodeError as e: - logging.warning("Failed to decode cache data: %s", e) - # If decoding fails we will treat this as a cache miss. - # This might happens if underlying class definition of the data changes. + logging.warning('Failed to decode cache data: %s', e) self.delete(key=key) result = func(*args, **kwargs) try: self.put(key=key, data=self.encode(data=result)) except CacheBackendEncodeError as e: - logging.warning("Failed to encode cache data: %s", e) - # If encoding fails, we should still return the result. + logging.warning('Failed to encode cache data: %s', e) return result - class _PersistentCache(Generic[_P, _R, _CacheBackendT]): def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: - if "NO_CACHE" in os.environ: + if 'NO_CACHE' in os.environ: return self.__wrapped__(*args, **kwargs) os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True) - return self.__backend__.get_cache_or_call( - func=self.__wrapped__, - args=args, - kwargs=kwargs, - lifespan=self.__duration__, - ) + return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__) ``` """ assert read_write_context.strip() == expected_read_write_context.strip() @@ -749,10 +722,12 @@ def __repr__(self): expected_hashing_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} class MyClass: + def target_method(self): y = HelperClass().helper_method() class HelperClass: + def helper_method(self): return self.x ``` @@ -843,10 +818,12 @@ def __repr__(self): expected_hashing_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} class MyClass: + def target_method(self): y = HelperClass().helper_method() class HelperClass: + def helper_method(self): return self.x ``` @@ -927,10 +904,12 @@ def helper_method(self): expected_hashing_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} class MyClass: + def target_method(self): y = HelperClass().helper_method() class HelperClass: + def helper_method(self): return self.x ``` @@ -1116,22 +1095,17 @@ class DataProcessor: def process_data(self, raw_data: str) -> str: return raw_data.upper() - def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str: + def add_prefix(self, data: str, prefix: str='PREFIX_') -> str: return prefix + data ``` ```python:{path_to_file.relative_to(project_root)} def fetch_and_process_data(): - # Use the global variable for the request response = requests.get(API_URL) response.raise_for_status() - raw_data = response.text - - # Use code from another file (utils.py) processor = DataProcessor() processed = processor.process_data(raw_data) processed = processor.add_prefix(processed) - return processed ``` """ @@ -1225,16 +1199,11 @@ def transform_data(self, data: str) -> str: ``` ```python:{path_to_file.relative_to(project_root)} def fetch_and_transform_data(): - # Use the global variable for the request response = requests.get(API_URL) - raw_data = response.text - - # Use code from another file (utils.py) processor = DataProcessor() processed = processor.process_data(raw_data) transformed = processor.transform_data(processed) - return transformed ``` """ @@ -1450,9 +1419,8 @@ def transform_data_all_same_file(self, data): new_data = update_data(data) return self.transform_using_own_method(new_data) - def update_data(data): - return data + " updated" + return data + ' updated' ``` """ @@ -1591,6 +1559,7 @@ def outside_method(): expected_hashing_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} class MyClass: + def target_method(self): return self.x + self.y ``` @@ -1640,16 +1609,11 @@ def transform_data(self, data: str) -> str: expected_hashing_context = """ ```python:main.py def fetch_and_transform_data(): - # Use the global variable for the request response = requests.get(API_URL) - raw_data = response.text - - # Use code from another file (utils.py) processor = DataProcessor() processed = processor.process_data(raw_data) transformed = processor.transform_data(processed) - return transformed ``` ```python:import_test.py @@ -1915,9 +1879,9 @@ def subtract(self, a, b): return a - b def calculate(self, operation, x, y): - if operation == "add": + if operation == 'add': return self.add(x, y) - elif operation == "subtract": + elif operation == 'subtract': return self.subtract(x, y) else: return None From 7167d2b75da666781f3b174ff7760999e550c994 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Mon, 9 Jun 2025 00:14:45 -0700 Subject: [PATCH 2/2] edge case for python 39 --- tests/test_code_context_extractor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index c5f008b43..010d3bc65 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys import tempfile from argparse import Namespace from collections import defaultdict @@ -618,7 +619,7 @@ def get_cache_or_call(self, *, func: Callable[_P, Any], args: tuple[Any, ...], k return func(*args, **kwargs) result_pair = self.get(key=key) if result_pair is not None: - cached_time, result = result_pair + {"cached_time, result = result_pair" if sys.version_info >= (3, 11) else "(cached_time, result) = result_pair"} if not os.environ.get('RE_CACHE') and datetime.datetime.now() < cached_time + lifespan: try: return self.decode(data=result)