Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions codeflash/benchmarking/instrument_codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import isort
import libcst as cst

from codeflash.discovery.functions_to_optimize import FunctionToOptimize

if TYPE_CHECKING:
from pathlib import Path

Expand Down Expand Up @@ -85,16 +87,18 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
The modified source code as a string

"""
target_functions = set()
for function_to_optimize in functions_to_optimize:
class_name = ""
if len(function_to_optimize.parents) == 1 and function_to_optimize.parents[0].type == "ClassDef":
class_name = function_to_optimize.parents[0].name
target_functions.add((class_name, function_to_optimize.function_name))
# Use a generator expression for faster creation and avoid multiple attribute lookups
target_functions = {
(fto.parents[0].name if len(fto.parents) == 1 and fto.parents[0].type == "ClassDef" else "", fto.function_name)
for fto in functions_to_optimize
}

transformer = AddDecoratorTransformer(target_functions=target_functions)

# If code is already a CSTModule, skip reparsing it, but here we assume it is always a str
module = cst.parse_module(code)
# Short-circuit if no target functions (fast return, avoids unnecessary CST tree walk)
if not target_functions:
return code
modified_module = module.visit(transformer)
return modified_module.code

Expand Down
Loading