Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import ast
import hashlib
import os
from collections import defaultdict
Expand Down Expand Up @@ -510,7 +511,10 @@ def parse_code_and_prune_cst(
if not found_target:
raise ValueError("No target functions found in the provided code")
if filtered_node and isinstance(filtered_node, cst.Module):
return str(filtered_node.code)
code = str(filtered_node.code)
if code_context_type == CodeContextType.HASHING:
code = ast.unparse(ast.parse(code)) # Makes it standard
return code
Comment on lines +514 to +517
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

codeflash/code_utils/code_replacer.py has normalize_code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes but it did more things as well which i did not want to happen. I have already removed those sections earlier in the libcst processing code

return ""


Expand Down
81 changes: 23 additions & 58 deletions tests/test_code_context_extractor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
import tempfile
from argparse import Namespace
from collections import defaultdict
Expand Down Expand Up @@ -114,11 +115,10 @@ class HelperClass:
def helper_method(self):
return self.name


class MainClass:

def main_method(self):
self.name = HelperClass.NestedClass("test").nested_method()
self.name = HelperClass.NestedClass('test').nested_method()
return HelperClass(self.name).helper_method()
```
"""
Expand Down Expand Up @@ -181,22 +181,17 @@ class Graph:

def topologicalSortUtil(self, v, visited, stack):
visited[v] = True

for i in self.graph[v]:
if visited[i] == False:
self.topologicalSortUtil(i, visited, stack)

stack.insert(0, v)

def topologicalSort(self):
visited = [False] * self.V
stack = []

for i in range(self.V):
if visited[i] == False:
self.topologicalSortUtil(i, visited, stack)

# Print contents of stack
return stack
```
"""
Expand Down Expand Up @@ -614,58 +609,37 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
```python:{file_path.relative_to(opt.args.project_root)}
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):

def get_cache_or_call(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
lifespan: datetime.timedelta,
) -> Any: # noqa: ANN401
if os.environ.get("NO_CACHE"):
def get_cache_or_call(self, *, func: Callable[_P, Any], args: tuple[Any, ...], kwargs: dict[str, Any], lifespan: datetime.timedelta) -> Any:
if os.environ.get('NO_CACHE'):
return func(*args, **kwargs)

try:
key = self.hash_key(func=func, args=args, kwargs=kwargs)
except: # noqa: E722
# If we can't create a cache key, we should just call the function.
logging.warning("Failed to hash cache key for function: %s", func)
except:
logging.warning('Failed to hash cache key for function: %s', func)
return func(*args, **kwargs)
result_pair = self.get(key=key)

if result_pair is not None:
cached_time, result = result_pair
if not os.environ.get("RE_CACHE") and (
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
):
{"cached_time, result = result_pair" if sys.version_info >= (3, 11) else "(cached_time, result) = result_pair"}
if not os.environ.get('RE_CACHE') and datetime.datetime.now() < cached_time + lifespan:
try:
return self.decode(data=result)
except CacheBackendDecodeError as e:
logging.warning("Failed to decode cache data: %s", e)
# If decoding fails we will treat this as a cache miss.
# This might happens if underlying class definition of the data changes.
logging.warning('Failed to decode cache data: %s', e)
self.delete(key=key)
result = func(*args, **kwargs)
try:
self.put(key=key, data=self.encode(data=result))
except CacheBackendEncodeError as e:
logging.warning("Failed to encode cache data: %s", e)
# If encoding fails, we should still return the result.
logging.warning('Failed to encode cache data: %s', e)
return result


class _PersistentCache(Generic[_P, _R, _CacheBackendT]):

def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
if "NO_CACHE" in os.environ:
if 'NO_CACHE' in os.environ:
return self.__wrapped__(*args, **kwargs)
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
return self.__backend__.get_cache_or_call(
func=self.__wrapped__,
args=args,
kwargs=kwargs,
lifespan=self.__duration__,
)
return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__)
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
Expand Down Expand Up @@ -749,10 +723,12 @@ def __repr__(self):
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:

def target_method(self):
y = HelperClass().helper_method()

class HelperClass:

def helper_method(self):
return self.x
```
Expand Down Expand Up @@ -843,10 +819,12 @@ def __repr__(self):
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:

def target_method(self):
y = HelperClass().helper_method()

class HelperClass:

def helper_method(self):
return self.x
```
Expand Down Expand Up @@ -927,10 +905,12 @@ def helper_method(self):
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:

def target_method(self):
y = HelperClass().helper_method()

class HelperClass:

def helper_method(self):
return self.x
```
Expand Down Expand Up @@ -1116,22 +1096,17 @@ class DataProcessor:
def process_data(self, raw_data: str) -> str:
return raw_data.upper()

def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
def add_prefix(self, data: str, prefix: str='PREFIX_') -> str:
return prefix + data
```
```python:{path_to_file.relative_to(project_root)}
def fetch_and_process_data():
# Use the global variable for the request
response = requests.get(API_URL)
response.raise_for_status()

raw_data = response.text

# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
processed = processor.add_prefix(processed)

return processed
```
"""
Expand Down Expand Up @@ -1225,16 +1200,11 @@ def transform_data(self, data: str) -> str:
```
```python:{path_to_file.relative_to(project_root)}
def fetch_and_transform_data():
# Use the global variable for the request
response = requests.get(API_URL)

raw_data = response.text

# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
transformed = processor.transform_data(processed)

return transformed
```
"""
Expand Down Expand Up @@ -1450,9 +1420,8 @@ def transform_data_all_same_file(self, data):
new_data = update_data(data)
return self.transform_using_own_method(new_data)


def update_data(data):
return data + " updated"
return data + ' updated'
```
"""

Expand Down Expand Up @@ -1591,6 +1560,7 @@ def outside_method():
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:

def target_method(self):
return self.x + self.y
```
Expand Down Expand Up @@ -1640,16 +1610,11 @@ def transform_data(self, data: str) -> str:
expected_hashing_context = """
```python:main.py
def fetch_and_transform_data():
# Use the global variable for the request
response = requests.get(API_URL)

raw_data = response.text

# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
transformed = processor.transform_data(processed)

return transformed
```
```python:import_test.py
Expand Down Expand Up @@ -1915,9 +1880,9 @@ def subtract(self, a, b):
return a - b

def calculate(self, operation, x, y):
if operation == "add":
if operation == 'add':
return self.add(x, y)
elif operation == "subtract":
elif operation == 'subtract':
return self.subtract(x, y)
else:
return None
Expand Down
Loading