diff --git a/code_to_optimize/code_directories/simple_tracer_e2e/workload.py b/code_to_optimize/code_directories/simple_tracer_e2e/workload.py index 053e25904..1fe6af823 100644 --- a/code_to_optimize/code_directories/simple_tracer_e2e/workload.py +++ b/code_to_optimize/code_directories/simple_tracer_e2e/workload.py @@ -1,3 +1,4 @@ +from concurrent.futures import ThreadPoolExecutor def funcA(number): k = 0 for i in range(number * 100): @@ -8,7 +9,14 @@ def funcA(number): # Use a generator expression directly in join for more efficiency return " ".join(str(i) for i in range(number)) +def test_threadpool() -> None: + pool = ThreadPoolExecutor(max_workers=3) + args = list(range(10, 31, 10)) + result = pool.map(funcA, args) + + for r in result: + print(r) + if __name__ == "__main__": - for i in range(10, 31, 10): - funcA(10) + test_threadpool() \ No newline at end of file diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 96c0202f1..eb4df84d4 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -18,19 +18,21 @@ import os import pathlib import pickle -import re import sqlite3 import sys +import threading import time +from argparse import ArgumentParser from collections import defaultdict -from copy import copy -from io import StringIO from pathlib import Path -from types import FrameType -from typing import Any, ClassVar, List +from typing import TYPE_CHECKING, Any, Callable, ClassVar import dill import isort +from rich.align import Align +from rich.panel import Panel +from rich.table import Table +from rich.text import Text from codeflash.cli_cmds.cli import project_root_from_module_root from codeflash.cli_cmds.console import console @@ -41,11 +43,33 @@ from codeflash.tracing.tracing_utils import FunctionModules from codeflash.verification.verification_utils import get_test_file_path +if TYPE_CHECKING: + from types import FrameType, TracebackType + + +class FakeCode: + def __init__(self, filename: str, line: int, name: str) -> None: + self.co_filename = filename + self.co_line = line + self.co_name = name + self.co_firstlineno = 0 + + def __repr__(self) -> str: + return repr((self.co_filename, self.co_line, self.co_name, None)) + + +class FakeFrame: + def __init__(self, code: FakeCode, prior: FakeFrame | None) -> None: + self.f_code = code + self.f_back = prior + self.f_locals: dict = {} + # Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger. class Tracer: - """Use this class as a 'with' context manager to trace a function call, - input arguments, and profiling info. + """Use this class as a 'with' context manager to trace a function call. + + Traces function calls, input arguments, and profiling info. """ def __init__( @@ -57,7 +81,9 @@ def __init__( max_function_count: int = 256, timeout: int | None = None, # seconds ) -> None: - """:param output: The path to the output trace file + """Use this class to trace function calls. + + :param output: The path to the output trace file :param functions: List of functions to trace. If None, trace all functions :param disable: Disable the tracer if True :param config_file_path: Path to the pyproject.toml file, if None then it will be auto-discovered @@ -68,7 +94,9 @@ def __init__( if functions is None: functions = [] if os.environ.get("CODEFLASH_TRACER_DISABLE", "0") == "1": - console.print("Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE") + console.rule( + "Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE", style="bold red" + ) disable = True self.disable = disable if self.disable: @@ -93,10 +121,10 @@ def __init__( self.max_function_count = max_function_count self.config, found_config_path = parse_config_file(config_file_path) self.project_root = project_root_from_module_root(Path(self.config["module_root"]), found_config_path) - print("project_root", self.project_root) + console.rule(f"Project Root: {self.project_root}", style="bold blue") self.ignored_functions = {"", "", "", "", "", ""} - self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_") + self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_") # noqa: SLF001 assert timeout is None or timeout > 0, "Timeout should be greater than 0" self.timeout = timeout @@ -127,29 +155,34 @@ def __enter__(self) -> None: Tracer.used_once = True if pathlib.Path(self.output_file).exists(): - console.print("Codeflash: Removing existing trace file") + console.rule("Removing existing trace file", style="bold red") + console.rule() pathlib.Path(self.output_file).unlink(missing_ok=True) - self.con = sqlite3.connect(self.output_file) + self.con = sqlite3.connect(self.output_file, check_same_thread=False) cur = self.con.cursor() cur.execute("""PRAGMA synchronous = OFF""") + cur.execute("""PRAGMA journal_mode = WAL""") # TODO: Check out if we need to export the function test name as well cur.execute( "CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, " "line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)" ) - console.print("Codeflash: Tracing started!") - frame = sys._getframe(0) # Get this frame and simulate a call to it + console.rule("Codeflash: Traced Program Output Begin", style="bold blue") + frame = sys._getframe(0) # Get this frame and simulate a call to it # noqa: SLF001 self.dispatch["call"](self, frame, 0) self.start_time = time.time() sys.setprofile(self.trace_callback) + threading.setprofile(self.trace_callback) - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + def __exit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> None: if self.disable: return sys.setprofile(None) self.con.commit() - + console.rule("Codeflash: Traced Program Output End", style="bold blue") self.create_stats() cur = self.con.cursor() @@ -198,7 +231,8 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay" ) replay_test = isort.code(replay_test) - with open(test_file_path, "w", encoding="utf8") as file: + + with Path(test_file_path).open("w", encoding="utf8") as file: file.write(replay_test) console.print( @@ -236,7 +270,7 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: class_name = arguments["self"].__class__.__name__ elif "cls" in arguments and hasattr(arguments["cls"], "__name__"): class_name = arguments["cls"].__name__ - except: + except: # noqa: E722 # someone can override the getattr method and raise an exception. I'm looking at you wrapt return function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}" @@ -316,7 +350,7 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: self.next_insert = 1000 self.con.commit() - def trace_callback(self, frame: FrameType, event: str, arg: Any) -> None: + def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None: # profiler section timer = self.timer t = timer() - self.t - self.bias @@ -332,45 +366,60 @@ def trace_callback(self, frame: FrameType, event: str, arg: Any) -> None: else: self.t = timer() - t # put back unrecorded delta - def trace_dispatch_call(self, frame, t) -> int: - if self.cur and frame.f_back is not self.cur[-2]: - rpt, rit, ret, rfn, rframe, rcur = self.cur - if not isinstance(rframe, Tracer.fake_frame): - assert rframe.f_back is frame.f_back, ("Bad call", rfn, rframe, rframe.f_back, frame, frame.f_back) - self.trace_dispatch_return(rframe, 0) - assert self.cur is None or frame.f_back is self.cur[-2], ("Bad call", self.cur[-3]) - fcode = frame.f_code - arguments = frame.f_locals - class_name = None + def trace_dispatch_call(self, frame: FrameType, t: int) -> int: + """Handle call events in the profiler.""" try: - if ( - "self" in arguments - and hasattr(arguments["self"], "__class__") - and hasattr(arguments["self"].__class__, "__name__") - ): - class_name = arguments["self"].__class__.__name__ - elif "cls" in arguments and hasattr(arguments["cls"], "__name__"): - class_name = arguments["cls"].__name__ - except: - pass - fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name) - self.cur = (t, 0, 0, fn, frame, self.cur) - timings = self.timings - if fn in timings: - cc, ns, tt, ct, callers = timings[fn] - timings[fn] = cc, ns + 1, tt, ct, callers - else: - timings[fn] = 0, 0, 0, 0, {} - return 1 + # In multi-threaded contexts, we need to be more careful about frame comparisons + if self.cur and frame.f_back is not self.cur[-2]: + # This happens when we're in a different thread + rpt, rit, ret, rfn, rframe, rcur = self.cur + + # Only attempt to handle the frame mismatch if we have a valid rframe + if ( + not isinstance(rframe, FakeFrame) + and hasattr(rframe, "f_back") + and hasattr(frame, "f_back") + and rframe.f_back is frame.f_back + ): + self.trace_dispatch_return(rframe, 0) + + # Get function information + fcode = frame.f_code + arguments = frame.f_locals + class_name = None + try: + if ( + "self" in arguments + and hasattr(arguments["self"], "__class__") + and hasattr(arguments["self"].__class__, "__name__") + ): + class_name = arguments["self"].__class__.__name__ + elif "cls" in arguments and hasattr(arguments["cls"], "__name__"): + class_name = arguments["cls"].__name__ + except Exception: # noqa: BLE001, S110 + pass + + fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name) + self.cur = (t, 0, 0, fn, frame, self.cur) + timings = self.timings + if fn in timings: + cc, ns, tt, ct, callers = timings[fn] + timings[fn] = cc, ns + 1, tt, ct, callers + else: + timings[fn] = 0, 0, 0, 0, {} + return 1 # noqa: TRY300 + except Exception: # noqa: BLE001 + # Handle any errors gracefully + return 0 - def trace_dispatch_exception(self, frame, t): + def trace_dispatch_exception(self, frame: FrameType, t: int) -> int: rpt, rit, ret, rfn, rframe, rcur = self.cur if (rframe is not frame) and rcur: return self.trace_dispatch_return(rframe, t) self.cur = rpt, rit + t, ret, rfn, rframe, rcur return 1 - def trace_dispatch_c_call(self, frame, t) -> int: + def trace_dispatch_c_call(self, frame: FrameType, t: int) -> int: fn = ("", 0, self.c_func_name, None) self.cur = (t, 0, 0, fn, frame, self.cur) timings = self.timings @@ -381,15 +430,27 @@ def trace_dispatch_c_call(self, frame, t) -> int: timings[fn] = 0, 0, 0, 0, {} return 1 - def trace_dispatch_return(self, frame, t) -> int: - if frame is not self.cur[-2]: - assert frame is self.cur[-2].f_back, ("Bad return", self.cur[-3]) - self.trace_dispatch_return(self.cur[-2], 0) + def trace_dispatch_return(self, frame: FrameType, t: int) -> int: + if not self.cur or not self.cur[-2]: + return 0 + # In multi-threaded environments, frames can get mismatched + if frame is not self.cur[-2]: + # Don't assert in threaded environments - frames can legitimately differ + if hasattr(frame, "f_back") and hasattr(self.cur[-2], "f_back") and frame is self.cur[-2].f_back: + self.trace_dispatch_return(self.cur[-2], 0) + else: + # We're in a different thread or context, can't continue with this frame + return 0 # Prefix "r" means part of the Returning or exiting frame. # Prefix "p" means part of the Previous or Parent or older frame. rpt, rit, ret, rfn, frame, rcur = self.cur + + # Guard against invalid rcur (w threading) + if not rcur: + return 0 + rit = rit + t frame_total = rit + ret @@ -397,6 +458,9 @@ def trace_dispatch_return(self, frame, t) -> int: self.cur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur timings = self.timings + if rfn not in timings: + # w threading, rfn can be missing + timings[rfn] = 0, 0, 0, 0, {} cc, ns, tt, ct, callers = timings[rfn] if not ns: # This is the only occurrence of the function on the stack. @@ -418,7 +482,7 @@ def trace_dispatch_return(self, frame, t) -> int: return 1 - dispatch: ClassVar[dict[str, callable]] = { + dispatch: ClassVar[dict[str, Callable[[Tracer, FrameType, int], int]]] = { "call": trace_dispatch_call, "exception": trace_dispatch_exception, "return": trace_dispatch_return, @@ -427,26 +491,10 @@ def trace_dispatch_return(self, frame, t) -> int: "c_return": trace_dispatch_return, } - class fake_code: - def __init__(self, filename, line, name) -> None: - self.co_filename = filename - self.co_line = line - self.co_name = name - self.co_firstlineno = 0 - - def __repr__(self) -> str: - return repr((self.co_filename, self.co_line, self.co_name, None)) - - class fake_frame: - def __init__(self, code, prior) -> None: - self.f_code = code - self.f_back = prior - self.f_locals = {} - - def simulate_call(self, name) -> None: - code = self.fake_code("profiler", 0, name) + def simulate_call(self, name: str) -> None: + code = FakeCode("profiler", 0, name) pframe = self.cur[-2] if self.cur else None - frame = self.fake_frame(code, pframe) + frame = FakeFrame(code, pframe) self.dispatch["call"](self, frame, 0) def simulate_cmd_complete(self) -> None: @@ -459,58 +507,172 @@ def simulate_cmd_complete(self) -> None: t = 0 self.t = get_time() - t - def print_stats(self, sort=-1) -> None: - import pstats + def print_stats(self, sort: str | int | tuple = -1) -> None: + if not self.stats: + console.print("Codeflash: No stats available to print") + self.total_tt = 0 + return if not isinstance(sort, tuple): sort = (sort,) - # The following code customizes the default printing behavior to - # print in milliseconds. - s = StringIO() - stats_obj = pstats.Stats(copy(self), stream=s) - stats_obj.strip_dirs().sort_stats(*sort).print_stats(25) - self.total_tt = stats_obj.total_tt - console.print("total_tt", self.total_tt) - raw_stats = s.getvalue() - m = re.search(r"function calls?.*in (\d+)\.\d+ (seconds?)", raw_stats) - total_time = None - if m: - total_time = int(m.group(1)) - if total_time is None: - console.print("Failed to get total time from stats") - total_time_ms = total_time / 1e6 - raw_stats = re.sub( - r"(function calls?.*)in (\d+)\.\d+ (seconds?)", rf"\1 in {total_time_ms:.3f} milliseconds", raw_stats - ) - match_pattern = r"^ *[\d\/]+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +" - m = re.findall(match_pattern, raw_stats, re.MULTILINE) - ms_times = [] - for tottime, percall, cumtime, percall_cum in m: - tottime_ms = int(tottime) / 1e6 - percall_ms = int(percall) / 1e6 - cumtime_ms = int(cumtime) / 1e6 - percall_cum_ms = int(percall_cum) / 1e6 - ms_times.append([tottime_ms, percall_ms, cumtime_ms, percall_cum_ms]) - split_stats = raw_stats.split("\n") - new_stats = [] - - replace_pattern = r"^( *[\d\/]+) +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(.*)" - times_index = 0 - for line in split_stats: - if times_index >= len(ms_times): - replaced = line - else: - replaced, n = re.subn( - replace_pattern, - rf"\g<1>{ms_times[times_index][0]:8.3f} {ms_times[times_index][1]:8.3f} {ms_times[times_index][2]:8.3f} {ms_times[times_index][3]:8.3f} \g<6>", - line, - count=1, + + # First, convert stats to make them pstats-compatible + try: + # Initialize empty collections for pstats + self.files = [] + self.top_level = [] + + # Create entirely new dictionaries instead of modifying existing ones + new_stats = {} + new_timings = {} + + # Convert stats dictionary + stats_items = list(self.stats.items()) + for func, stats_data in stats_items: + try: + # Make sure we have 5 elements in stats_data + if len(stats_data) != 5: + console.print(f"Skipping malformed stats data for {func}: {stats_data}") + continue + + cc, nc, tt, ct, callers = stats_data + + if len(func) == 4: + file_name, line_num, func_name, class_name = func + new_func_name = f"{class_name}.{func_name}" if class_name else func_name + new_func = (file_name, line_num, new_func_name) + else: + new_func = func # Keep as is if already in correct format + + new_callers = {} + callers_items = list(callers.items()) + for caller_func, count in callers_items: + if isinstance(caller_func, tuple): + if len(caller_func) == 4: + caller_file, caller_line, caller_name, caller_class = caller_func + caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name + new_caller_func = (caller_file, caller_line, caller_new_name) + else: + new_caller_func = caller_func + else: + console.print(f"Unexpected caller format: {caller_func}") + new_caller_func = str(caller_func) + + new_callers[new_caller_func] = count + + # Store with new format + new_stats[new_func] = (cc, nc, tt, ct, new_callers) + except Exception as e: # noqa: BLE001 + console.print(f"Error converting stats for {func}: {e}") + continue + + timings_items = list(self.timings.items()) + for func, timing_data in timings_items: + try: + if len(timing_data) != 5: + console.print(f"Skipping malformed timing data for {func}: {timing_data}") + continue + + cc, ns, tt, ct, callers = timing_data + + if len(func) == 4: + file_name, line_num, func_name, class_name = func + new_func_name = f"{class_name}.{func_name}" if class_name else func_name + new_func = (file_name, line_num, new_func_name) + else: + new_func = func + + new_callers = {} + callers_items = list(callers.items()) + for caller_func, count in callers_items: + if isinstance(caller_func, tuple): + if len(caller_func) == 4: + caller_file, caller_line, caller_name, caller_class = caller_func + caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name + new_caller_func = (caller_file, caller_line, caller_new_name) + else: + new_caller_func = caller_func + else: + console.print(f"Unexpected caller format: {caller_func}") + new_caller_func = str(caller_func) + + new_callers[new_caller_func] = count + + new_timings[new_func] = (cc, ns, tt, ct, new_callers) + except Exception as e: # noqa: BLE001 + console.print(f"Error converting timings for {func}: {e}") + continue + + self.stats = new_stats + self.timings = new_timings + + self.total_tt = sum(tt for _, _, tt, _, _ in self.stats.values()) + + total_calls = sum(cc for cc, _, _, _, _ in self.stats.values()) + total_primitive = sum(nc for _, nc, _, _, _ in self.stats.values()) + + summary = Text.assemble( + f"{total_calls:,} function calls ", + ("(" + f"{total_primitive:,} primitive calls" + ")", "dim"), + f" in {self.total_tt / 1e6:.3f}milliseconds", + ) + + console.print(Align.center(Panel(summary, border_style="blue", width=80, padding=(0, 2), expand=False))) + + table = Table( + show_header=True, + header_style="bold magenta", + border_style="blue", + title="[bold]Function Profile[/bold] (ordered by internal time)", + title_style="cyan", + caption=f"Showing top 25 of {len(self.stats)} functions", + ) + + table.add_column("Calls", justify="right", style="green", width=10) + table.add_column("Time (ms)", justify="right", style="cyan", width=10) + table.add_column("Per Call", justify="right", style="cyan", width=10) + table.add_column("Cum (ms)", justify="right", style="yellow", width=10) + table.add_column("Cum/Call", justify="right", style="yellow", width=10) + table.add_column("Function", style="blue") + + sorted_stats = sorted( + ((func, stats) for func, stats in self.stats.items() if isinstance(func, tuple) and len(func) == 3), + key=lambda x: x[1][2], # Sort by tt (internal time) + reverse=True, + )[:25] # Limit to top 25 + + # Format and add each row to the table + for func, (cc, nc, tt, ct, _) in sorted_stats: + filename, lineno, funcname = func + + # Format calls - show recursive format if different + calls_str = f"{cc}/{nc}" if cc != nc else f"{cc:,}" + + # Convert to milliseconds + tt_ms = tt / 1e6 + ct_ms = ct / 1e6 + + # Calculate per-call times + per_call = tt_ms / cc if cc > 0 else 0 + cum_per_call = ct_ms / nc if nc > 0 else 0 + base_filename = Path(filename).name + file_link = f"[link=file://{filename}]{base_filename}[/link]" + + table.add_row( + calls_str, + f"{tt_ms:.3f}", + f"{per_call:.3f}", + f"{ct_ms:.3f}", + f"{cum_per_call:.3f}", + f"{funcname} [dim]({file_link}:{lineno})[/dim]", ) - if n > 0: - times_index += 1 - new_stats.append(replaced) - console.print("\n".join(new_stats)) + console.print(Align.center(table)) + + except Exception as e: # noqa: BLE001 + console.print(f"[bold red]Error in stats processing:[/bold red] {e}") + console.print(f"Traced {self.trace_count:,} function calls") + self.total_tt = 0 def make_pstats_compatible(self) -> None: # delete the extra class_name item from the function tuple @@ -527,9 +689,8 @@ def make_pstats_compatible(self) -> None: self.stats = new_stats self.timings = new_timings - def dump_stats(self, file) -> None: - with open(file, "wb") as f: - self.create_stats() + def dump_stats(self, file: str) -> None: + with Path(file).open("wb") as f: marshal.dump(self.stats, f) def create_stats(self) -> None: @@ -538,25 +699,23 @@ def create_stats(self) -> None: def snapshot_stats(self) -> None: self.stats = {} - for func, (cc, _ns, tt, ct, callers) in self.timings.items(): - callers = callers.copy() + for func, (cc, _ns, tt, ct, caller_dict) in self.timings.items(): + callers = caller_dict.copy() nc = 0 for callcnt in callers.values(): nc += callcnt self.stats[func] = cc, nc, tt, ct, callers - def runctx(self, cmd, globals, locals): + def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, Any]) -> Tracer | None: self.__enter__() try: - exec(cmd, globals, locals) + exec(cmd, global_vars, local_vars) # noqa: S102 finally: self.__exit__(None, None, None) return self -def main(): - from argparse import ArgumentParser - +def main() -> ArgumentParser: parser = ArgumentParser(allow_abbrev=False) parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to ", required=True) parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None) diff --git a/testbench.py b/testbench.py new file mode 100644 index 000000000..e3856033d --- /dev/null +++ b/testbench.py @@ -0,0 +1,54 @@ +from concurrent.futures import ThreadPoolExecutor + + +def add_numbers(a: int, b: int) -> int: + print(f"[ADD_NUMBERS] Starting with parameters: a={a}, b={b}") + result = a + b + print(f"[ADD_NUMBERS] Returning result: {result}") + return result + + +def test_threadpool() -> None: + print("[TEST_THREADPOOL] Starting thread pool execution") + pool = ThreadPoolExecutor(max_workers=3) + numbers = [(10, 20), (30, 40), (50, 60)] + print("[TEST_THREADPOOL] Submitting tasks to thread pool") + result = pool.map(add_numbers, *zip(*numbers)) + + print("[TEST_THREADPOOL] Processing results") + for r in result: + print(f"[TEST_THREADPOOL] Thread result: {r}") + print("[TEST_THREADPOOL] Finished thread pool execution") + + +def multiply_numbers(a: int, b: int) -> int: + print(f"[MULTIPLY_NUMBERS] Starting with parameters: a={a}, b={b}") + result = a * b + print(f"[MULTIPLY_NUMBERS] Returning result: {result}") + return result + + +if __name__ == "__main__": + print("[MAIN] Starting testbench execution") + + print("[MAIN] Calling test_threadpool()") + test_threadpool() + print("[MAIN] Finished test_threadpool()") + + print("[MAIN] Calling add_numbers(5, 10)") + result1 = add_numbers(5, 10) + print(f"[MAIN] add_numbers result: {result1}") + + print("[MAIN] Calling add_numbers(15, 25)") + result2 = add_numbers(15, 25) + print(f"[MAIN] add_numbers result: {result2}") + + print("[MAIN] Calling multiply_numbers(3, 7)") + result3 = multiply_numbers(3, 7) + print(f"[MAIN] multiply_numbers result: {result3}") + + print("[MAIN] Calling multiply_numbers(5, 9)") + result4 = multiply_numbers(5, 9) + print(f"[MAIN] multiply_numbers result: {result4}") + + print("[MAIN] Testbench execution completed") diff --git a/tests/scripts/end_to_end_test_tracer_replay.py b/tests/scripts/end_to_end_test_tracer_replay.py index 03c778be9..58662448e 100644 --- a/tests/scripts/end_to_end_test_tracer_replay.py +++ b/tests/scripts/end_to_end_test_tracer_replay.py @@ -10,7 +10,7 @@ def run_test(expected_improvement_pct: int) -> bool: min_improvement_x=0.1, expected_unit_tests=1, coverage_expectations=[ - CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[2, 3, 4, 6, 9]) + CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[3, 4, 5, 7, 10]), ], ) cwd = ( @@ -18,6 +18,5 @@ def run_test(expected_improvement_pct: int) -> bool: ).resolve() return run_codeflash_command(cwd, config, expected_improvement_pct) - if __name__ == "__main__": exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10)))) diff --git a/tests/scripts/end_to_end_test_utilities.py b/tests/scripts/end_to_end_test_utilities.py index 23a67a84a..c961b6fd1 100644 --- a/tests/scripts/end_to_end_test_utilities.py +++ b/tests/scripts/end_to_end_test_utilities.py @@ -202,8 +202,8 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p return False functions_traced = re.search(r"Traced (\d+) function calls successfully and replay test created at - (.*)$", stdout) - if not functions_traced or int(functions_traced.group(1)) != 3: - logging.error("Expected 3 traced functions") + if not functions_traced or int(functions_traced.group(1)) != 5: + logging.error("Expected 5 traced functions") return False replay_test_path = pathlib.Path(functions_traced.group(2)) @@ -249,4 +249,4 @@ def run_with_retries(test_func, *args, **kwargs) -> bool: logging.error("Test failed after all retries") return 1 - return 1 + return 1 \ No newline at end of file