From 0e6751bc1378f36052cd15f963185c57cb3431a9 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 1 Apr 2025 17:07:02 -0500 Subject: [PATCH] remove workspace config & cleanup --- .gitignore | 3 + codeflash.code-workspace | 80 ------------------- codeflash/cli_cmds/cli_common.py | 2 +- codeflash/cli_cmds/cmd_init.py | 24 +++--- codeflash/code_utils/code_extractor.py | 6 +- codeflash/code_utils/config_parser.py | 2 +- codeflash/context/code_context_extractor.py | 68 +++++++++------- codeflash/discovery/functions_to_optimize.py | 2 +- .../discovery/pytest_new_process_discovery.py | 2 +- codeflash/models/models.py | 4 +- codeflash/optimization/function_context.py | 15 ++-- codeflash/optimization/function_optimizer.py | 2 +- codeflash/tracing/replay_test.py | 12 ++- codeflash/verification/codeflash_capture.py | 7 +- codeflash/verification/comparator.py | 6 +- codeflash/verification/concolic_testing.py | 10 ++- .../instrument_codeflash_capture.py | 7 +- 17 files changed, 100 insertions(+), 152 deletions(-) delete mode 100644 codeflash.code-workspace diff --git a/.gitignore b/.gitignore index 535acfb3e..e2f8c3f50 100644 --- a/.gitignore +++ b/.gitignore @@ -254,3 +254,6 @@ fabric.properties # Mac .DS_Store + +# VSC Workspace settings +codeflash.code-workspace \ No newline at end of file diff --git a/codeflash.code-workspace b/codeflash.code-workspace deleted file mode 100644 index 7f62bcc77..000000000 --- a/codeflash.code-workspace +++ /dev/null @@ -1,80 +0,0 @@ -{ - "folders": [ - { - "path": ".", - "name": "codeflash", - "extensions": [ - "charliermarsh.ruff", - "ms-python.python", - ] - } - ], - "settings": { - "python.defaultInterpreterPath": "~/miniforge3/envs/codeflash312/bin/python", - "python.terminal.activateEnvironment": true, - "python.testing.pytestEnabled": true, - "python.testing.pytestArgs": ["tests/", "-vv"], - }, - "launch": { - "version": "0.2.0", - "configurations": [ - { - "name": "bubble_sort", - "type": "debugpy", - "request": "launch", - "program": "${workspaceFolder:codeflash}/codeflash/main.py", - "args": [ - "--file", - "code_to_optimize/bubble_sort.py", - "--module-root", - "${workspaceFolder:codeflash}", - "--function", - "sorter", - "--test-framework", - "pytest", - "--tests-root", - "code_to_optimize/tests/pytest" - ], - "cwd": "${workspaceFolder:codeflash}", - "console": "integratedTerminal", - "env": { - "PYTHONUNBUFFERED": "1" - }, - }, - { - "name": "bubble_sort -all", - "type": "debugpy", - "request": "launch", - "program": "${workspaceFolder:codeflash}/codeflash/main.py", - "args": [ - "--all", - "--test-framework", - "pytest", - "--tests-root", - "code_to_optimize/tests/pytest", - "--module-root", - "code_to_optimize" - ], - "cwd": "${workspaceFolder:codeflash}", - "console": "integratedTerminal", - "env": { - "PYTHONUNBUFFERED": "1" - }, - }, - { - "name": "bubble_sort --file bubble_sort.py (MBR)", - "type": "debugpy", - "request": "launch", - "program": "${workspaceFolder:codeflash}/codeflash/main.py", - "args": [ - "--all", - ], - "cwd": "/Users/krrt7/Desktop/work/my-best-repo", - "console": "integratedTerminal", - "env": { - "PYTHONUNBUFFERED": "1" - }, - } - ] - } -} \ No newline at end of file diff --git a/codeflash/cli_cmds/cli_common.py b/codeflash/cli_cmds/cli_common.py index 942b8f634..b8f04ec6e 100644 --- a/codeflash/cli_cmds/cli_common.py +++ b/codeflash/cli_cmds/cli_common.py @@ -74,7 +74,7 @@ def inquirer_wrapper_path(*args: str, **kwargs: str) -> dict[str, str] | None: new_kwargs["message"] = last_message new_args.append(args[0]) - return cast(dict[str, str], inquirer.prompt([inquirer.Path(*new_args, **new_kwargs)])) + return cast("dict[str, str]", inquirer.prompt([inquirer.Path(*new_args, **new_kwargs)])) def split_string_to_fit_width(string: str, width: int) -> list[str]: diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 5df70828c..a399e35c2 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -68,7 +68,6 @@ def init_codeflash() -> None: did_add_new_key = prompt_api_key() if should_modify_pyproject_toml(): - setup_info: SetupInfo = collect_setup_info() configure_pyproject_toml(setup_info) @@ -81,7 +80,6 @@ def init_codeflash() -> None: if "setup_info" in locals(): module_string = f" you selected ({setup_info.module_root})" - click.echo( f"{LF}" f"⚡️ Codeflash is now set up! You can now run:{LF}" @@ -123,18 +121,19 @@ def ask_run_end_to_end_test(args: Namespace) -> None: bubble_sort_path, bubble_sort_test_path = create_bubble_sort_file_and_test(args) run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path) + def should_modify_pyproject_toml() -> bool: - """ - Check if the current directory contains a valid pyproject.toml file with codeflash config + """Check if the current directory contains a valid pyproject.toml file with codeflash config If it does, ask the user if they want to re-configure it. """ from rich.prompt import Confirm + pyproject_toml_path = Path.cwd() / "pyproject.toml" if not pyproject_toml_path.exists(): return True try: config, config_file_path = parse_config_file(pyproject_toml_path) - except Exception as e: + except Exception: return True if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir(): @@ -142,10 +141,11 @@ def should_modify_pyproject_toml() -> bool: if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir(): return True - create_toml = Confirm.ask( - f"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?", default=False, show_default=True + return Confirm.ask( + "✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?", + default=False, + show_default=True, ) - return create_toml def collect_setup_info() -> SetupInfo: @@ -224,7 +224,7 @@ def collect_setup_info() -> SetupInfo: else: apologize_and_exit() else: - tests_root = Path(curdir) / Path(cast(str, tests_root_answer)) + tests_root = Path(curdir) / Path(cast("str", tests_root_answer)) tests_root = tests_root.relative_to(curdir) ph("cli-tests-root-provided") @@ -278,9 +278,9 @@ def collect_setup_info() -> SetupInfo: return SetupInfo( module_root=str(module_root), tests_root=str(tests_root), - test_framework=cast(str, test_framework), + test_framework=cast("str", test_framework), ignore_paths=ignore_paths, - formatter=cast(str, formatter), + formatter=cast("str", formatter), git_remote=str(git_remote), ) @@ -390,7 +390,7 @@ def check_for_toml_or_setup_file() -> str | None: click.echo("⏩️ Skipping pyproject.toml creation.") apologize_and_exit() click.echo() - return cast(str, project_name) + return cast("str", project_name) def install_github_actions(override_formatter_check: bool = False) -> None: diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index a06a45c92..eeefa70fe 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -1,7 +1,6 @@ from __future__ import annotations import ast -from pathlib import Path from typing import TYPE_CHECKING import libcst as cst @@ -11,12 +10,15 @@ from libcst.helpers import calculate_module_and_package from codeflash.cli_cmds.console import logger -from codeflash.models.models import FunctionParent, FunctionSource +from codeflash.models.models import FunctionParent if TYPE_CHECKING: + from pathlib import Path + from libcst.helpers import ModuleNameAndPackage from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.models.models import FunctionSource class FutureAliasedImportTransformer(cst.CSTTransformer): diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index d814f12d0..73eafd1dd 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -87,7 +87,7 @@ def parse_config_file( "In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest." ) if len(config["formatter-cmds"]) > 0: - #see if this is happening during GitHub actions setup + # see if this is happening during GitHub actions setup if not override_formatter_check: assert config["formatter-cmds"][0] != "your-formatter $file", ( "The formatter command is not set correctly in pyproject.toml. Please set the " diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 3827239ed..6a534af38 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -3,18 +3,15 @@ import os from collections import defaultdict from itertools import chain -from pathlib import Path +from typing import TYPE_CHECKING, Optional import jedi import libcst as cst import tiktoken -from jedi.api.classes import Name -from libcst import CSTNode from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages -from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import ( CodeContextType, CodeOptimizationContext, @@ -24,6 +21,15 @@ ) from codeflash.optimization.function_context import belongs_to_function_qualified +if TYPE_CHECKING: + from pathlib import Path + + from jedi.api.classes import Name + from libcst import CSTNode + + from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from typing import Callable + def get_code_optimization_context( function_to_optimize: FunctionToOptimize, @@ -75,7 +81,8 @@ def get_code_optimization_context( tokenizer = tiktoken.encoding_for_model("gpt-4o") final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code)) if final_read_writable_tokens > optim_token_limit: - raise ValueError("Read-writable code has exceeded token limit, cannot proceed") + msg = "Read-writable code has exceeded token limit, cannot proceed" + raise ValueError(msg) # Setup preexisting objects for code replacer preexisting_objects = set( @@ -122,7 +129,8 @@ def get_code_optimization_context( testgen_context_code = testgen_code_markdown.code testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code)) if testgen_context_code_tokens > testgen_token_limit: - raise ValueError("Testgen code context has exceeded token limit, cannot proceed") + msg = "Testgen code context has exceeded token limit, cannot proceed" + raise ValueError(msg) return CodeOptimizationContext( testgen_context_code=testgen_context_code, @@ -143,7 +151,7 @@ def extract_code_string_context_from_files( """Extract code context from files containing target functions and their helpers. This function processes two sets of files: 1. Files containing the function to optimize (fto) and their first-degree helpers - 2. Files containing only helpers of helpers (with no overlap with the first set) + 2. Files containing only helpers of helpers (with no overlap with the first set). For each file, it extracts relevant code based on the specified context type, adds necessary imports, and combines them. @@ -358,7 +366,7 @@ def get_function_to_optimize_as_function_source( and name.name == function_to_optimize.function_name and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name ): - function_source = FunctionSource( + return FunctionSource( file_path=function_to_optimize.file_path, qualified_name=function_to_optimize.qualified_name, fully_qualified_name=name.full_name, @@ -366,11 +374,9 @@ def get_function_to_optimize_as_function_source( source_code=name.get_line_code(), jedi_definition=name, ) - return function_source - raise ValueError( - f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}" - ) + msg = f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}" + raise ValueError(msg) def get_function_sources_from_jedi( @@ -436,7 +442,7 @@ def get_section_names(node: cst.CSTNode) -> list[str]: def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode: - """Removes the docstring from an indented block if it exists""" + """Removes the docstring from an indented block if it exists.""" if not isinstance(indented_block.body[0], cst.SimpleStatementLine): return indented_block first_stmt = indented_block.body[0].body[0] @@ -464,16 +470,13 @@ def parse_code_and_prune_cst( filtered_node, found_target = prune_cst_for_testgen_code( module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings ) - else: - raise ValueError(f"Unknown code_context_type: {code_context_type}") - if not found_target: - raise ValueError("No target functions found in the provided code") + msg = "No target functions found in the provided code" + raise ValueError(msg) if filtered_node and isinstance(filtered_node, cst.Module): return str(filtered_node.code) return "" - def prune_cst_for_read_writable_code( node: cst.CSTNode, target_functions: set[str], prefix: str = "" ) -> tuple[cst.CSTNode | None, bool]: @@ -500,7 +503,8 @@ def prune_cst_for_read_writable_code( return None, False # Assuming always an IndentedBlock if not isinstance(node.body, cst.IndentedBlock): - raise ValueError("ClassDef body is not an IndentedBlock") + msg = "ClassDef body is not an IndentedBlock" + raise ValueError(msg) class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value new_body = [] found_target = False @@ -593,7 +597,8 @@ def prune_cst_for_read_only_code( return None, False # Assuming always an IndentedBlock if not isinstance(node.body, cst.IndentedBlock): - raise ValueError("ClassDef body is not an IndentedBlock") + msg = "ClassDef body is not an IndentedBlock" + raise ValueError(msg) class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value @@ -612,10 +617,12 @@ def prune_cst_for_read_only_code( return None, False if remove_docstrings: - return node.with_changes( - body=remove_docstring_from_body(node.body.with_changes(body=new_class_body)) - ) if new_class_body else None, True - return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True + return ( + node.with_changes(body=remove_docstring_from_body(node.body.with_changes(body=new_class_body))) + if new_class_body + else None + ), True + return (node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None), True # For other nodes, keep the node and recursively filter children section_names = get_section_names(node) @@ -698,7 +705,8 @@ def prune_cst_for_testgen_code( return None, False # Assuming always an IndentedBlock if not isinstance(node.body, cst.IndentedBlock): - raise ValueError("ClassDef body is not an IndentedBlock") + msg = "ClassDef body is not an IndentedBlock" + raise ValueError(msg) class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value @@ -717,10 +725,12 @@ def prune_cst_for_testgen_code( return None, False if remove_docstrings: - return node.with_changes( - body=remove_docstring_from_body(node.body.with_changes(body=new_class_body)) - ) if new_class_body else None, True - return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True + return ( + node.with_changes(body=remove_docstring_from_body(node.body.with_changes(body=new_class_body))) + if new_class_body + else None + ), True + return (node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None), True # For other nodes, keep the node and recursively filter children section_names = get_section_names(node) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index a234a2827..f91977c0f 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -300,7 +300,7 @@ def get_all_replay_test_functions( if valid_function.qualified_name == function_name ] ) - if len(filtered_list): + if filtered_list: filtered_valid_functions[file_path] = filtered_list return filtered_valid_functions diff --git a/codeflash/discovery/pytest_new_process_discovery.py b/codeflash/discovery/pytest_new_process_discovery.py index 2d8583255..f2a0c59e9 100644 --- a/codeflash/discovery/pytest_new_process_discovery.py +++ b/codeflash/discovery/pytest_new_process_discovery.py @@ -16,7 +16,7 @@ def pytest_collection_finish(self, session) -> None: collected_tests.extend(session.items) pytest_rootdir = session.config.rootdir - def pytest_collection_modifyitems(config, items): + def pytest_collection_modifyitems(self, items) -> None: skip_benchmark = pytest.mark.skip(reason="Skipping benchmark tests") for item in items: if "benchmark" in item.fixturenames: diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 1366fcc0b..5134988c1 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -11,7 +11,7 @@ import enum import re import sys -from collections.abc import Collection, Iterator +from collections.abc import Collection from enum import Enum, IntEnum from pathlib import Path from re import Pattern @@ -527,7 +527,7 @@ def __eq__(self, other: object) -> bool: if len(self) != len(other): return False original_recursion_limit = sys.getrecursionlimit() - cast(TestResults, other) + cast("TestResults", other) for test_result in self: other_test_result = other.get_by_unique_invocation_loop_id(test_result.unique_invocation_loop_id) if other_test_result is None: diff --git a/codeflash/optimization/function_context.py b/codeflash/optimization/function_context.py index 3c28a92db..2207bf355 100644 --- a/codeflash/optimization/function_context.py +++ b/codeflash/optimization/function_context.py @@ -1,9 +1,12 @@ from __future__ import annotations -from jedi.api.classes import Name +from typing import TYPE_CHECKING from codeflash.code_utils.code_utils import get_qualified_name +if TYPE_CHECKING: + from jedi.api.classes import Name + def belongs_to_method(name: Name, class_name: str, method_name: str) -> bool: """Check if the given name belongs to the specified method.""" @@ -14,9 +17,8 @@ def belongs_to_function(name: Name, function_name: str) -> bool: """Check if the given jedi Name is a direct child of the specified function.""" if name.name == function_name: # Handles function definition and recursive function calls return False - if name := name.parent(): - if name.type == "function": - return name.name == function_name + if (name := name.parent()) and name.type == "function": + return name.name == function_name return False @@ -34,9 +36,8 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b if get_qualified_name(name.module_name, name.full_name) == qualified_function_name: # Handles function definition and recursive function calls return False - if name := name.parent(): - if name.type == "function": - return get_qualified_name(name.module_name, name.full_name) == qualified_function_name + if (name := name.parent()) and name.type == "function": + return get_qualified_name(name.module_name, name.full_name) == qualified_function_name return False except ValueError: return False diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 36d6c6f76..17b616967 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -41,7 +41,6 @@ from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast from codeflash.code_utils.time_utils import humanize_runtime from codeflash.context import code_context_extractor -from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import Failure, Success, is_successful from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( @@ -74,6 +73,7 @@ if TYPE_CHECKING: from argparse import Namespace + from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import Result from codeflash.models.models import CoverageData, FunctionSource, OptimizedCandidate from codeflash.verification.verification_utils import TestConfig diff --git a/codeflash/tracing/replay_test.py b/codeflash/tracing/replay_test.py index eca1e50ef..4fe1f60b5 100644 --- a/codeflash/tracing/replay_test.py +++ b/codeflash/tracing/replay_test.py @@ -2,11 +2,15 @@ import sqlite3 import textwrap -from collections.abc import Generator -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional -from codeflash.discovery.functions_to_optimize import FunctionProperties, inspect_top_level_functions_or_methods -from codeflash.tracing.tracing_utils import FunctionModules +from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods + +if TYPE_CHECKING: + from collections.abc import Generator + + from codeflash.discovery.functions_to_optimize import FunctionProperties + from codeflash.tracing.tracing_utils import FunctionModules def get_next_arg_and_return( diff --git a/codeflash/verification/codeflash_capture.py b/codeflash/verification/codeflash_capture.py index 0f2e285aa..e986bae25 100644 --- a/codeflash/verification/codeflash_capture.py +++ b/codeflash/verification/codeflash_capture.py @@ -78,7 +78,7 @@ def codeflash_capture(function_name: str, tmp_dir_path: str, tests_root: str, is def decorator(wrapped): @functools.wraps(wrapped) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs) -> None: # Dynamic information retrieved from stack test_module_name, test_class_name, test_name, line_id = get_test_info_from_stack(tests_root) @@ -103,7 +103,7 @@ def wrapper(*args, **kwargs): # Generate invocation id invocation_id = f"{line_id}_{codeflash_test_index}" - print( + print( # noqa: T201 f"!######{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}######!" ) # Connect to sqlite @@ -129,7 +129,8 @@ def wrapper(*args, **kwargs): 0 ].__dict__ # self is always the first argument, this is ensured during instrumentation else: - raise ValueError("Instance state could not be captured.") + msg = "Instance state could not be captured." + raise ValueError(msg) codeflash_cur.execute( "CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)" ) diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 60372fcb4..caaf123f9 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -31,7 +31,7 @@ HAS_SCIPY = False try: - import pandas + import pandas # noqa: ICN001 HAS_PANDAS = True except ImportError: @@ -100,7 +100,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: if HAS_SQLALCHEMY: try: insp = sqlalchemy.inspection.inspect(orig) - insp = sqlalchemy.inspection.inspect(new) + insp = sqlalchemy.inspection.inspect(new) # noqa: F841 orig_keys = orig.__dict__ new_keys = new.__dict__ for key in list(orig_keys.keys()): @@ -115,7 +115,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)): if superset_obj: - return all(k in new.keys() and comparator(v, new[k], superset_obj) for k, v in orig.items()) + return all(k in new and comparator(v, new[k], superset_obj) for k, v in orig.items()) if len(orig) != len(new): return False for key in orig: diff --git a/codeflash/verification/concolic_testing.py b/codeflash/verification/concolic_testing.py index 14acdbec6..fc8935c85 100644 --- a/codeflash/verification/concolic_testing.py +++ b/codeflash/verification/concolic_testing.py @@ -3,19 +3,23 @@ import ast import subprocess import tempfile -from argparse import Namespace from pathlib import Path +from typing import TYPE_CHECKING from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE from codeflash.code_utils.concolic_utils import clean_concolic_tests from codeflash.code_utils.static_analysis import has_typed_parameters from codeflash.discovery.discover_unit_tests import discover_unit_tests -from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionCalledInTest from codeflash.telemetry.posthog_cf import ph from codeflash.verification.verification_utils import TestConfig +if TYPE_CHECKING: + from argparse import Namespace + + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.models.models import FunctionCalledInTest + def generate_concolic_tests( test_cfg: TestConfig, args: Namespace, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.AST diff --git a/codeflash/verification/instrument_codeflash_capture.py b/codeflash/verification/instrument_codeflash_capture.py index c54f3e5d7..903a2e470 100644 --- a/codeflash/verification/instrument_codeflash_capture.py +++ b/codeflash/verification/instrument_codeflash_capture.py @@ -2,11 +2,14 @@ import ast from pathlib import Path +from typing import TYPE_CHECKING import isort from codeflash.code_utils.code_utils import get_run_tmp_file -from codeflash.discovery.functions_to_optimize import FunctionToOptimize + +if TYPE_CHECKING: + from codeflash.discovery.functions_to_optimize import FunctionToOptimize def instrument_codeflash_capture( @@ -109,7 +112,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: func=ast.Name(id="codeflash_capture", ctx=ast.Load()), args=[], keywords=[ - ast.keyword(arg="function_name", value=ast.Constant(value=".".join([node.name, "__init__"]))), + ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")), ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)), ast.keyword(arg="tests_root", value=ast.Constant(value=str(self.tests_root))), ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)),