Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 83 additions & 79 deletions codeflash/benchmarking/replay_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import sqlite3
import textwrap
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -68,94 +67,99 @@ def create_trace_replay_test_code(
"""
assert test_framework in ["pytest", "unittest"]

# Create Imports
imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
{"import unittest" if test_framework == "unittest" else ""}
from codeflash.benchmarking.replay_test import get_next_arg_and_return
"""
# Precompute all needed values up-front for efficiency
unittest_import = "import unittest" if test_framework == "unittest" else ""
imports = (
"from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle\n"
f"{unittest_import}\n"
"from codeflash.benchmarking.replay_test import get_next_arg_and_return\n"
)

function_imports = []
functions_to_optimize = set()

# Collect imports and test function names in one pass:
for func in functions_data:
module_name = func.get("module_name")
function_name = func.get("function_name")
class_name = func.get("class_name", "")
module_name = func["module_name"]
function_name = func["function_name"]
class_name = func.get("class_name")
if class_name:
function_imports.append(
f"from {module_name} import {class_name} as {get_function_alias(module_name, class_name)}"
)
alias = get_function_alias(module_name, class_name)
function_imports.append(f"from {module_name} import {class_name} as {alias}")
else:
function_imports.append(
f"from {module_name} import {function_name} as {get_function_alias(module_name, function_name)}"
)

alias = get_function_alias(module_name, function_name)
function_imports.append(f"from {module_name} import {function_name} as {alias}")
if function_name != "__init__":
functions_to_optimize.add(function_name)
imports += "\n".join(function_imports)

functions_to_optimize = sorted(
{func.get("function_name") for func in functions_data if func.get("function_name") != "__init__"}
)
metadata = f"""functions = {functions_to_optimize}
trace_file_path = r"{trace_file}"
"""
# Templates for different types of tests
test_function_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl)
ret = {function_name}(*args, **kwargs)
"""
)
metadata = f'functions = {sorted(functions_to_optimize)}\ntrace_file_path = r"{trace_file}"\n'

test_method_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl){filter_variables}
function_name = "{orig_function_name}"
if not args:
raise ValueError("No arguments provided for the method.")
if function_name == "__init__":
ret = {class_name_alias}(*args[1:], **kwargs)
else:
ret = {class_name_alias}{method_name}(*args, **kwargs)
"""
# Templates, dedented once for speed
test_function_body = (
"for args_pkl, kwargs_pkl in get_next_arg_and_return("
'trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", '
'function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}):\n'
" args = pickle.loads(args_pkl)\n"
" kwargs = pickle.loads(kwargs_pkl)\n"
" ret = {function_name}(*args, **kwargs)\n"
)

test_class_method_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl){filter_variables}
if not args:
raise ValueError("No arguments provided for the method.")
ret = {class_name_alias}{method_name}(*args[1:], **kwargs)
"""
test_method_body = (
"for args_pkl, kwargs_pkl in get_next_arg_and_return("
'trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", '
'function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):\n'
" args = pickle.loads(args_pkl)\n"
" kwargs = pickle.loads(kwargs_pkl){filter_variables}\n"
' function_name = "{orig_function_name}"\n'
" if not args:\n"
' raise ValueError("No arguments provided for the method.")\n'
' if function_name == "__init__":\n'
" ret = {class_name_alias}(*args[1:], **kwargs)\n"
" else:\n"
" ret = {class_name_alias}{method_name}(*args, **kwargs)\n"
)
test_static_method_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl){filter_variables}
ret = {class_name_alias}{method_name}(*args, **kwargs)
"""
test_class_method_body = (
"for args_pkl, kwargs_pkl in get_next_arg_and_return("
'trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", '
'function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):\n'
" args = pickle.loads(args_pkl)\n"
" kwargs = pickle.loads(kwargs_pkl){filter_variables}\n"
" if not args:\n"
' raise ValueError("No arguments provided for the method.")\n'
" ret = {class_name_alias}{method_name}(*args[1:], **kwargs)\n"
)
test_static_method_body = (
"for args_pkl, kwargs_pkl in get_next_arg_and_return("
'trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", '
'function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):\n'
" args = pickle.loads(args_pkl)\n"
" kwargs = pickle.loads(kwargs_pkl){filter_variables}\n"
" ret = {class_name_alias}{method_name}(*args, **kwargs)\n"
)

# Create main body

if test_framework == "unittest":
self = "self"
test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n"
self_arg = "self"
test_header = "\nclass TestTracedFunctions(unittest.TestCase):\n"
def_indent = " "
body_indent = " "
else:
test_template = ""
self = ""
self_arg = ""
test_header = ""
def_indent = ""
body_indent = " "

# String builder technique for fast test template construction
test_template_lines = [test_header]
append = test_template_lines.append # local variable for speed

for func in functions_data:
module_name = func.get("module_name")
function_name = func.get("function_name")
module_name = func["module_name"]
function_name = func["function_name"]
class_name = func.get("class_name")
file_path = func.get("file_path")
benchmark_function_name = func.get("benchmark_function_name")
function_properties = func.get("function_properties")
file_path = func["file_path"]
benchmark_function_name = func["benchmark_function_name"]
function_properties = func["function_properties"]

if not class_name:
alias = get_function_alias(module_name, function_name)
test_body = test_function_body.format(
Expand All @@ -168,9 +172,7 @@ def create_trace_replay_test_code(
else:
class_name_alias = get_function_alias(module_name, class_name)
alias = get_function_alias(module_name, class_name + "_" + function_name)

filter_variables = ""
# filter_variables = '\n args.pop("cls", None)'
method_name = "." + function_name if function_name != "__init__" else ""
if function_properties.is_classmethod:
test_body = test_class_method_body.format(
Expand Down Expand Up @@ -206,12 +208,14 @@ def create_trace_replay_test_code(
filter_variables=filter_variables,
)

formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")

test_template += " " if test_framework == "unittest" else ""
test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n"
# Manually indent for speed (no textwrap.indent)
test_body_indented = "".join(
body_indent + ln if ln else body_indent for ln in test_body.splitlines(keepends=True)
)
append(f"{def_indent}def test_{alias}({self_arg}):\n{test_body_indented}\n")

return imports + "\n" + metadata + "\n" + test_template
# Final string concatenation
return f"{imports}\n{metadata}\n{''.join(test_template_lines)}"


def generate_replay_test(
Expand Down
Loading