Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 52 additions & 25 deletions codeflash/benchmarking/codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
from typing import Callable

from codeflash.cli_cmds.cli import logger
from codeflash.picklepatch.pickle_patcher import PicklePatcher


Expand Down Expand Up @@ -42,10 +43,8 @@ def setup(self, trace_path: str) -> None:
)
self._connection.commit()
except Exception as e:
print(f"Database setup error: {e}")
if self._connection:
self._connection.close()
self._connection = None
logger.error(f"Database setup error: {e}")
self.close()
raise

def write_function_timings(self) -> None:
Expand All @@ -63,18 +62,17 @@ def write_function_timings(self) -> None:

try:
cur = self._connection.cursor()
# Insert data into the benchmark_function_timings table
cur.executemany(
"INSERT INTO benchmark_function_timings"
"(function_name, class_name, module_name, file_path, benchmark_function_name, "
"benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
self.function_calls_data
self.function_calls_data,
)
self._connection.commit()
self.function_calls_data = []
except Exception as e:
print(f"Error writing to function timings database: {e}")
logger.error(f"Error writing to function timings database: {e}")
if self._connection:
self._connection.rollback()
raise
Expand All @@ -100,9 +98,10 @@ def __call__(self, func: Callable) -> Callable:
The wrapped function

"""
func_id = (func.__module__,func.__name__)
func_id = (func.__module__, func.__name__)

@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: tuple, **kwargs: dict) -> object:
# Initialize thread-local active functions set if it doesn't exist
if not hasattr(self._thread_local, "active_functions"):
self._thread_local.active_functions = set()
Expand All @@ -123,25 +122,33 @@ def wrapper(*args, **kwargs):
if os.environ.get("CODEFLASH_BENCHMARKING", "False") == "False":
self._thread_local.active_functions.remove(func_id)
return result
# Get benchmark info from environment

benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "")
benchmark_module_path = os.environ.get("CODEFLASH_BENCHMARK_MODULE_PATH", "")
benchmark_line_number = os.environ.get("CODEFLASH_BENCHMARK_LINE_NUMBER", "")
# Get class name
class_name = ""
qualname = func.__qualname__
if "." in qualname:
class_name = qualname.split(".")[0]

# Limit pickle count so memory does not explode
if self.function_call_count > self.pickle_count_limit:
print("Pickle limit reached")
logger.debug("CodeflashTrace: Pickle limit reached")
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, None, None)
(
func.__name__,
class_name,
func.__module__,
func.__code__.co_filename,
benchmark_function_name,
benchmark_module_path,
benchmark_line_number,
execution_time,
overhead_time,
None,
None,
)
)
return result

Expand All @@ -150,30 +157,50 @@ def wrapper(*args, **kwargs):
pickled_args = PicklePatcher.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
pickled_kwargs = PicklePatcher.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
except Exception as e:
print(f"Error pickling arguments for function {func.__name__}: {e}")
logger.debug(f"CodeflashTrace: Error pickling arguments for function {func.__name__}: {e}")
# Add to the list of function calls without pickled args. Used for timing info only
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, None, None)
(
func.__name__,
class_name,
func.__module__,
func.__code__.co_filename,
benchmark_function_name,
benchmark_module_path,
benchmark_line_number,
execution_time,
overhead_time,
None,
None,
)
)
return result
# Flush to database every 100 calls
if len(self.function_calls_data) > 100:
self.write_function_timings()

# Add to the list of function calls with pickled args, to be used for replay tests
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, pickled_args, pickled_kwargs)
(
func.__name__,
class_name,
func.__module__,
func.__code__.co_filename,
benchmark_function_name,
benchmark_module_path,
benchmark_line_number,
execution_time,
overhead_time,
pickled_args,
pickled_kwargs,
)
)
return result

return wrapper

# Create a singleton instance

codeflash_trace = CodeflashTrace()
62 changes: 21 additions & 41 deletions codeflash/benchmarking/instrument_codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,57 +13,46 @@ def __init__(self, target_functions: set[tuple[str, str]]) -> None:
self.added_codeflash_trace = False
self.class_name = ""
self.function_name = ""
self.decorator = cst.Decorator(
decorator=cst.Name(value="codeflash_trace")
)
self.decorator = cst.Decorator(decorator=cst.Name(value="codeflash_trace"))

def leave_ClassDef(self, original_node, updated_node):
def leave_ClassDef(self, original_node, updated_node): # noqa: ANN001, ANN201, N802
if self.class_name == original_node.name.value:
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
return updated_node

def visit_ClassDef(self, node):
if self.class_name: # Don't go into nested class
def visit_ClassDef(self, node): # noqa: ANN001, ANN201, N802
if self.class_name: # Don't go into nested class
return False
self.class_name = node.name.value
self.class_name = node.name.value # noqa: RET503

def visit_FunctionDef(self, node):
if self.function_name: # Don't go into nested function
def visit_FunctionDef(self, node): # noqa: ANN001, ANN201, N802
if self.function_name: # Don't go into nested function
return False
self.function_name = node.name.value
self.function_name = node.name.value # noqa: RET503

def leave_FunctionDef(self, original_node, updated_node):
def leave_FunctionDef(self, original_node, updated_node): # noqa: ANN001, ANN201, N802
if self.function_name == original_node.name.value:
self.function_name = ""
if (self.class_name, original_node.name.value) in self.target_functions:
# Add the new decorator after any existing decorators, so it gets executed first
updated_decorators = list(updated_node.decorators) + [self.decorator]
updated_decorators = [*list(updated_node.decorators), self.decorator]
self.added_codeflash_trace = True
return updated_node.with_changes(
decorators=updated_decorators
)
return updated_node.with_changes(decorators=updated_decorators)

return updated_node

def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002, N802
# Create import statement for codeflash_trace
if not self.added_codeflash_trace:
return updated_node
import_stmt = cst.SimpleStatementLine(
body=[
cst.ImportFrom(
module=cst.Attribute(
value=cst.Attribute(
value=cst.Name(value="codeflash"),
attr=cst.Name(value="benchmarking")
),
attr=cst.Name(value="codeflash_trace")
value=cst.Attribute(value=cst.Name(value="codeflash"), attr=cst.Name(value="benchmarking")),
attr=cst.Name(value="codeflash_trace"),
),
names=[
cst.ImportAlias(
name=cst.Name(value="codeflash_trace")
)
]
names=[cst.ImportAlias(name=cst.Name(value="codeflash_trace"))],
)
]
)
Expand All @@ -73,12 +62,13 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c

return updated_node.with_changes(body=new_body)


def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[FunctionToOptimize]) -> str:
"""Add codeflash_trace to a function.

Args:
code: The source code as a string
function_to_optimize: The FunctionToOptimize instance containing function details
functions_to_optimize: List of FunctionToOptimize instances containing function details

Returns:
The modified source code as a string
Expand All @@ -91,27 +81,17 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
class_name = function_to_optimize.parents[0].name
target_functions.add((class_name, function_to_optimize.function_name))

transformer = AddDecoratorTransformer(
target_functions = target_functions,
)
transformer = AddDecoratorTransformer(target_functions=target_functions)

module = cst.parse_module(code)
modified_module = module.visit(transformer)
return modified_module.code


def instrument_codeflash_trace_decorator(
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
) -> None:
def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]) -> None:
"""Instrument codeflash_trace decorator to functions to optimize."""
for file_path, functions_to_optimize in file_to_funcs_to_optimize.items():
original_code = file_path.read_text(encoding="utf-8")
new_code = add_codeflash_decorator_to_code(
original_code,
functions_to_optimize
)
# Modify the code
new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize)
modified_code = isort.code(code=new_code, float_to_top=True)

# Write the modified code back to the file
file_path.write_text(modified_code, encoding="utf-8")
Loading
Loading