diff --git a/codeflash/benchmarking/trace_benchmarks.py b/codeflash/benchmarking/trace_benchmarks.py index 8f68030cb..8d14068e7 100644 --- a/codeflash/benchmarking/trace_benchmarks.py +++ b/codeflash/benchmarking/trace_benchmarks.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os import re import subprocess from pathlib import Path @@ -9,6 +10,11 @@ def trace_benchmarks_pytest(benchmarks_root: Path, tests_root:Path, project_root: Path, trace_file: Path, timeout:int = 300) -> None: + benchmark_env = os.environ.copy() + if "PYTHONPATH" not in benchmark_env: + benchmark_env["PYTHONPATH"] = str(project_root) + else: + benchmark_env["PYTHONPATH"] += os.pathsep + str(project_root) result = subprocess.run( [ SAFE_SYS_EXECUTABLE, @@ -21,7 +27,7 @@ def trace_benchmarks_pytest(benchmarks_root: Path, tests_root:Path, project_root check=False, capture_output=True, text=True, - env={"PYTHONPATH": str(project_root)}, + env=benchmark_env, timeout=timeout, ) if result.returncode != 0: diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index ed0dbd760..29fd5d7f1 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -167,6 +167,7 @@ def process_pyproject_config(args: Namespace) -> Namespace: # in this case, the ".." becomes outside project scope, causing issues with un-importable paths args.project_root = project_root_from_module_root(args.module_root, pyproject_file_path) args.tests_root = Path(args.tests_root).resolve() + args.benchmarks_root = Path(args.benchmarks_root).resolve() args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path) return handle_optimize_all_arg_parsing(args) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 6e4f744d7..8adfb4e00 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -299,7 +299,7 @@ def get_all_replay_test_functions( if valid_function.qualified_name == function_name ] ) - if len(filtered_list): + if filtered_list: filtered_valid_functions[file_path] = filtered_list return filtered_valid_functions diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 7f42b58c4..946e3e822 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -2,9 +2,9 @@ import ast import os -import time import shutil import tempfile +import time from collections import defaultdict from pathlib import Path from typing import TYPE_CHECKING @@ -97,7 +97,7 @@ def run(self) -> None: ) function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {} total_benchmark_timings: dict[BenchmarkKey, int] = {} - if self.args.benchmark: + if self.args.benchmark and num_optimizable_functions > 0: with progress_bar( f"Running benchmarks in {self.args.benchmarks_root}", transient=True,