Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,4 @@ jobs:
run: uvx poetry install --with dev

- name: Unit tests
run: uvx poetry run pytest tests/ --cov --cov-report=xml --benchmark-skip -m "not ci_skip"

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
if: matrix.python-version == '3.12.1'
with:
token: ${{ secrets.CODECOV_TOKEN }}
run: uvx poetry run pytest tests/ --benchmark-skip -m "not ci_skip"
9 changes: 8 additions & 1 deletion codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,14 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list


def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
return any(isinstance(node, ast.Return) for node in ast.walk(function_node))
# Custom DFS, return True as soon as a Return node is found
stack = [function_node]
while stack:
node = stack.pop()
if isinstance(node, ast.Return):
return True
stack.extend(ast.iter_child_nodes(node))
return False


def function_is_a_property(function_node: FunctionDef | AsyncFunctionDef) -> bool:
Expand Down
214 changes: 137 additions & 77 deletions codeflash/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#
from __future__ import annotations

import contextlib
import importlib.machinery
import io
import json
Expand Down Expand Up @@ -99,6 +100,7 @@ def __init__(
)
disable = True
self.disable = disable
self._db_lock: threading.Lock | None = None
if self.disable:
return
if sys.getprofile() is not None or sys.gettrace() is not None:
Expand All @@ -108,6 +110,9 @@ def __init__(
)
self.disable = True
return

self._db_lock = threading.Lock()

self.con = None
self.output_file = Path(output).resolve()
self.functions = functions
Expand All @@ -130,6 +135,7 @@ def __init__(
self.timeout = timeout
self.next_insert = 1000
self.trace_count = 0
self.path_cache = {} # Cache for resolved file paths

# Profiler variables
self.bias = 0 # calibration constant
Expand Down Expand Up @@ -178,34 +184,55 @@ def __enter__(self) -> None:
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
if self.disable:
if self.disable or self._db_lock is None:
return
sys.setprofile(None)
self.con.commit()
console.rule("Codeflash: Traced Program Output End", style="bold blue")
self.create_stats()
threading.setprofile(None)

cur = self.con.cursor()
cur.execute(
"CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
"call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
"cumulative_time_ns INTEGER, callers BLOB)"
)
for func, (cc, nc, tt, ct, callers) in self.stats.items():
remapped_callers = [{"key": k, "value": v} for k, v in callers.items()]
with self._db_lock:
if self.con is None:
return

self.con.commit() # Commit any pending from tracer_logic
console.rule("Codeflash: Traced Program Output End", style="bold blue")
self.create_stats() # This calls snapshot_stats which uses self.timings

cur = self.con.cursor()
cur.execute(
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(str(Path(func[0]).resolve()), func[1], func[2], func[3], cc, nc, tt, ct, json.dumps(remapped_callers)),
"CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
"call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
"cumulative_time_ns INTEGER, callers BLOB)"
)
self.con.commit()
# self.stats is populated by snapshot_stats() called within create_stats()
# Ensure self.stats is accessed after create_stats() and within the lock if it involves DB data
# For now, assuming self.stats is primarily in-memory after create_stats()
for func, (cc, nc, tt, ct, callers) in self.stats.items():
remapped_callers = [{"key": k, "value": v} for k, v in callers.items()]
cur.execute(
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
str(Path(func[0]).resolve()),
func[1],
func[2],
func[3],
cc,
nc,
tt,
ct,
json.dumps(remapped_callers),
),
)
self.con.commit()

self.make_pstats_compatible()
self.print_stats("tottime")
cur = self.con.cursor()
cur.execute("CREATE TABLE total_time (time_ns INTEGER)")
cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,))
self.con.commit()
self.con.close()
self.make_pstats_compatible() # Modifies self.stats and self.timings in-memory
self.print_stats("tottime") # Uses self.stats, prints to console

cur = self.con.cursor() # New cursor
cur.execute("CREATE TABLE total_time (time_ns INTEGER)")
cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,))
self.con.commit()
self.con.close()
self.con = None # Mark connection as closed

# filter any functions where we did not capture the return
self.function_modules = [
Expand Down Expand Up @@ -245,18 +272,29 @@ def __exit__(
def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
if event != "call":
return
if self.timeout is not None and (time.time() - self.start_time) > self.timeout:
if None is not self.timeout and (time.time() - self.start_time) > self.timeout:
sys.setprofile(None)
threading.setprofile(None)
console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.")
return
code = frame.f_code
if self.disable or self._db_lock is None or self.con is None:
return

file_name = Path(code.co_filename).resolve()
# TODO : It currently doesn't log the last return call from the first function
code = frame.f_code

# Check function name first before resolving path
if code.co_name in self.ignored_functions:
return

# Now resolve file path only if we need it
co_filename = code.co_filename
if co_filename in self.path_cache:
file_name = self.path_cache[co_filename]
else:
file_name = Path(co_filename).resolve()
self.path_cache[co_filename] = file_name
# TODO : It currently doesn't log the last return call from the first function

if not file_name.is_relative_to(self.project_root):
return
if not file_name.exists():
Expand All @@ -266,18 +304,29 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
class_name = None
arguments = frame.f_locals
try:
if (
"self" in arguments
and hasattr(arguments["self"], "__class__")
and hasattr(arguments["self"].__class__, "__name__")
):
class_name = arguments["self"].__class__.__name__
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
class_name = arguments["cls"].__name__
self_arg = arguments.get("self")
if self_arg is not None:
try:
class_name = self_arg.__class__.__name__
except AttributeError:
cls_arg = arguments.get("cls")
if cls_arg is not None:
with contextlib.suppress(AttributeError):
class_name = cls_arg.__name__
else:
cls_arg = arguments.get("cls")
if cls_arg is not None:
with contextlib.suppress(AttributeError):
class_name = cls_arg.__name__
except: # noqa: E722
# someone can override the getattr method and raise an exception. I'm looking at you wrapt
return
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"

try:
function_qualified_name = f"{file_name}:{code.co_qualname}"
except AttributeError:
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"

if function_qualified_name in self.ignored_qualified_functions:
return
if function_qualified_name not in self.function_count:
Expand Down Expand Up @@ -310,49 +359,59 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911

# TODO: Also check if this function arguments are unique from the values logged earlier

cur = self.con.cursor()
with self._db_lock:
# Check connection again inside lock, in case __exit__ closed it.
if self.con is None:
return

t_ns = time.perf_counter_ns()
original_recursion_limit = sys.getrecursionlimit()
try:
# pickling can be a recursive operator, so we need to increase the recursion limit
sys.setrecursionlimit(10000)
# We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
# directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
# leaks, bad references or side effects when unpickling.
arguments = dict(arguments.items())
if class_name and code.co_name == "__init__":
del arguments["self"]
local_vars = pickle.dumps(arguments, protocol=pickle.HIGHEST_PROTOCOL)
sys.setrecursionlimit(original_recursion_limit)
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
# we retry with dill if pickle fails. It's slower but more comprehensive
cur = self.con.cursor()

t_ns = time.perf_counter_ns()
original_recursion_limit = sys.getrecursionlimit()
try:
local_vars = dill.dumps(arguments, protocol=dill.HIGHEST_PROTOCOL)
# pickling can be a recursive operator, so we need to increase the recursion limit
sys.setrecursionlimit(10000)
# We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
# directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
# leaks, bad references or side effects when unpickling.
arguments_copy = dict(arguments.items()) # Use the local 'arguments' from frame.f_locals
if class_name and code.co_name == "__init__" and "self" in arguments_copy:
del arguments_copy["self"]
local_vars = pickle.dumps(arguments_copy, protocol=pickle.HIGHEST_PROTOCOL)
sys.setrecursionlimit(original_recursion_limit)
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
# we retry with dill if pickle fails. It's slower but more comprehensive
try:
sys.setrecursionlimit(10000) # Ensure limit is high for dill too
# arguments_copy should be used here as well if defined above
local_vars = dill.dumps(
arguments_copy if "arguments_copy" in locals() else dict(arguments.items()),
protocol=dill.HIGHEST_PROTOCOL,
)
sys.setrecursionlimit(original_recursion_limit)

except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError):
self.function_count[function_qualified_name] -= 1
return

except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError):
# give up
self.function_count[function_qualified_name] -= 1
return
cur.execute(
"INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
(
event,
code.co_name,
class_name,
str(file_name),
frame.f_lineno,
frame.f_back.__hash__(),
t_ns,
local_vars,
),
)
self.trace_count += 1
self.next_insert -= 1
if self.next_insert == 0:
self.next_insert = 1000
self.con.commit()
cur.execute(
"INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
(
event,
code.co_name,
class_name,
str(file_name),
frame.f_lineno,
frame.f_back.__hash__(),
t_ns,
local_vars,
),
)
self.trace_count += 1
self.next_insert -= 1
if self.next_insert == 0:
self.next_insert = 1000
self.con.commit()

def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None:
# profiler section
Expand Down Expand Up @@ -475,8 +534,9 @@ def trace_dispatch_return(self, frame: FrameType, t: int) -> int:
cc = cc + 1

if pfn in callers:
callers[pfn] = callers[pfn] + 1 # TODO: gather more
# stats such as the amount of time added to ct courtesy
# Increment call count between these functions
callers[pfn] = callers[pfn] + 1
# Note: This tracks stats such as the amount of time added to ct
# of this specific call, and the contribution to cc
# courtesy of this call.
else:
Expand Down Expand Up @@ -703,7 +763,7 @@ def create_stats(self) -> None:

def snapshot_stats(self) -> None:
self.stats = {}
for func, (cc, _ns, tt, ct, caller_dict) in self.timings.items():
for func, (cc, _ns, tt, ct, caller_dict) in list(self.timings.items()):
callers = caller_dict.copy()
nc = 0
for callcnt in callers.values():
Expand Down
21 changes: 1 addition & 20 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ types-cffi = ">=1.16.0.20240331"
types-openpyxl = ">=3.1.5.20241020"
types-regex = ">=2024.9.11.20240912"
types-python-dateutil = ">=2.9.0.20241003"
pytest-cov = "^6.0.0"
pytest-benchmark = ">=5.1.0"
types-gevent = "^24.11.0.20241230"
types-greenlet = "^3.1.0.20241221"
Expand Down
Loading
Loading