Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,25 @@ def get_code_optimization_context(
) -> CodeOptimizationContext:
# Get FunctionSource representation of helpers of FTO
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi({function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path)

# Add function to optimize into helpers of FTO dict, as they'll be processed together
fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path)
helpers_of_fto_dict[function_to_optimize.file_path].add(fto_as_function_source)

# Format data to search for helpers of helpers using get_function_sources_from_jedi
helpers_of_fto_qualified_names_dict = {
file_path: {source.qualified_name for source in sources}
for file_path, sources in helpers_of_fto_dict.items()
}

# __init__ functions are automatically considered as helpers of FTO, so we add them to the dict (regardless of whether they exist)
# This helps us to search for helpers of __init__ functions of classes that contain helpers of FTO
for qualified_names in helpers_of_fto_qualified_names_dict.values():
qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if '.' in qn})

# Get FunctionSource representation of helpers of helpers of FTO
helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi(helpers_of_fto_qualified_names_dict, project_root_path)

# Add function to optimize into helpers of FTO dict, as they'll be processed together
fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path)
helpers_of_fto_dict[function_to_optimize.file_path].add(fto_as_function_source)

# Extract code context for optimization
final_read_writable_code = extract_code_string_context_from_files(helpers_of_fto_dict,{}, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.READ_WRITABLE).code
read_only_code_markdown = extract_code_markdown_context_from_files(
Expand Down
55 changes: 54 additions & 1 deletion tests/test_code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def main_method(self):
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()


def test_class_method_dependencies() -> None:
file_path = Path(__file__).resolve()

Expand Down Expand Up @@ -1260,3 +1259,57 @@ def __repr__(self) -> str:

assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()

def test_indirect_init_helper() -> None:
code = """
class MyClass:
def __init__(self):
self.x = 1
self.y = outside_method()
def target_method(self):
return self.x + self.y

def outside_method():
return 1
"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)

code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
class MyClass:
def __init__(self):
self.x = 1
self.y = outside_method()
def target_method(self):
return self.x + self.y
"""
expected_read_only_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
def outside_method():
return 1
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
Loading