diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 6134451e4..2b041ffff 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -6,8 +6,9 @@ import subprocess import sys from enum import Enum, auto +from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast import click import git @@ -32,6 +33,7 @@ from codeflash.code_utils.github_utils import get_github_secrets_page_url from codeflash.code_utils.shell_utils import get_shell_rc_path, save_api_key_to_rc from codeflash.either import is_successful +from codeflash.lsp.helpers import is_LSP_enabled from codeflash.telemetry.posthog_cf import ph from codeflash.version import __version__ as version @@ -52,14 +54,23 @@ @dataclass(frozen=True) -class SetupInfo: +class CLISetupInfo: module_root: str tests_root: str benchmarks_root: Union[str, None] test_framework: str ignore_paths: list[str] - formatter: str + formatter: Union[str, list[str]] git_remote: str + enable_telemetry: bool + + +@dataclass(frozen=True) +class VsCodeSetupInfo: + module_root: str + tests_root: str + test_framework: str + formatter: Union[str, list[str]] class DependencyManager(Enum): @@ -91,9 +102,11 @@ def init_codeflash() -> None: git_remote = config.get("git_remote", "origin") if config else "origin" if should_modify: - setup_info: SetupInfo = collect_setup_info() + setup_info: CLISetupInfo = collect_setup_info() git_remote = setup_info.git_remote - configure_pyproject_toml(setup_info) + configured = configure_pyproject_toml(setup_info) + if not configured: + apologize_and_exit() install_github_app(git_remote) @@ -158,30 +171,30 @@ def ask_run_end_to_end_test(args: Namespace) -> None: run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path) -def is_valid_pyproject_toml(pyproject_toml_path: Path) -> tuple[dict[str, Any] | None, str]: # noqa: PLR0911 +def is_valid_pyproject_toml(pyproject_toml_path: Path) -> tuple[bool, dict[str, Any] | None, str]: # noqa: PLR0911 if not pyproject_toml_path.exists(): - return None, f"Configuration file not found: {pyproject_toml_path}" + return False, None, f"Configuration file not found: {pyproject_toml_path}" try: config, _ = parse_config_file(pyproject_toml_path) except Exception as e: - return None, f"Failed to parse configuration: {e}" + return False, None, f"Failed to parse configuration: {e}" module_root = config.get("module_root") if not module_root: - return None, "Missing required field: 'module_root'" + return False, config, "Missing required field: 'module_root'" if not Path(module_root).is_dir(): - return None, f"Invalid 'module_root': directory does not exist at {module_root}" + return False, config, f"Invalid 'module_root': directory does not exist at {module_root}" tests_root = config.get("tests_root") if not tests_root: - return None, "Missing required field: 'tests_root'" + return False, config, "Missing required field: 'tests_root'" if not Path(tests_root).is_dir(): - return None, f"Invalid 'tests_root': directory does not exist at {tests_root}" + return False, config, f"Invalid 'tests_root': directory does not exist at {tests_root}" - return config, "" + return True, config, "" def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]: @@ -193,8 +206,9 @@ def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]: pyproject_toml_path = Path.cwd() / "pyproject.toml" - config, _message = is_valid_pyproject_toml(pyproject_toml_path) - if config is None: + valid, config, _message = is_valid_pyproject_toml(pyproject_toml_path) + if not valid: + # needs to be re-configured return True, None return Confirm.ask( @@ -217,17 +231,19 @@ def __init__(self) -> None: self.Checkbox.unselected_icon = "⬜" -def collect_setup_info() -> SetupInfo: - curdir = Path.cwd() - # Check if the cwd is writable - if not os.access(curdir, os.W_OK): - click.echo(f"❌ The current directory isn't writable, please check your folder permissions and try again.{LF}") - click.echo("It's likely you don't have write permissions for this folder.") - sys.exit(1) +# common sections between normal mode and lsp mode +class CommonSections(Enum): + module_root = "module_root" + tests_root = "tests_root" + test_framework = "test_framework" + formatter_cmds = "formatter_cmds" + + def get_toml_key(self) -> str: + return self.value.replace("_", "-") - # Check for the existence of pyproject.toml or setup.py - project_name = check_for_toml_or_setup_file() +@lru_cache(maxsize=1) +def get_valid_subdirs(current_dir: Optional[Path] = None) -> list[str]: ignore_subdirs = [ "venv", "node_modules", @@ -240,11 +256,41 @@ def collect_setup_info() -> SetupInfo: "tmp", "__pycache__", ] - valid_subdirs = [ - d for d in next(os.walk("."))[1] if not d.startswith(".") and not d.startswith("__") and d not in ignore_subdirs + path_str = str(current_dir) if current_dir else "." + return [ + d + for d in next(os.walk(path_str))[1] + if not d.startswith(".") and not d.startswith("__") and d not in ignore_subdirs ] - valid_module_subdirs = [d for d in valid_subdirs if d != "tests"] + +def get_suggestions(section: str) -> tuple(list[str], Optional[str]): + valid_subdirs = get_valid_subdirs() + if section == CommonSections.module_root: + return [d for d in valid_subdirs if d != "tests"], None + if section == CommonSections.tests_root: + default = "tests" if "tests" in valid_subdirs else None + return valid_subdirs, default + if section == CommonSections.test_framework: + auto_detected = detect_test_framework_from_config_files(Path.cwd()) + return ["pytest", "unittest"], auto_detected + if section == CommonSections.formatter_cmds: + return ["disabled", "ruff", "black"], "disabled" + msg = f"Unknown section: {section}" + raise ValueError(msg) + + +def collect_setup_info() -> CLISetupInfo: + curdir = Path.cwd() + # Check if the cwd is writable + if not os.access(curdir, os.W_OK): + click.echo(f"❌ The current directory isn't writable, please check your folder permissions and try again.{LF}") + click.echo("It's likely you don't have write permissions for this folder.") + sys.exit(1) + + # Check for the existence of pyproject.toml or setup.py + project_name = check_for_toml_or_setup_file() + valid_module_subdirs, _ = get_suggestions(CommonSections.module_root) curdir_option = f"current directory ({curdir})" custom_dir_option = "enter a custom directory…" @@ -308,10 +354,10 @@ def collect_setup_info() -> SetupInfo: ph("cli-project-root-provided") # Discover test directory - default_tests_subdir = "tests" create_for_me_option = f"🆕 Create a new tests{os.pathsep} directory for me!" - test_subdir_options = [sub_dir for sub_dir in valid_subdirs if sub_dir != module_root] - if "tests" not in valid_subdirs: + tests_suggestions, default_tests_subdir = get_suggestions(CommonSections.tests_root) + test_subdir_options = [sub_dir for sub_dir in tests_suggestions if sub_dir != module_root] + if "tests" not in tests_suggestions: test_subdir_options.append(create_for_me_option) custom_dir_option = "📁 Enter a custom directory…" test_subdir_options.append(custom_dir_option) @@ -334,7 +380,7 @@ def collect_setup_info() -> SetupInfo: "tests_root", message="Where are your tests located?", choices=test_subdir_options, - default=(default_tests_subdir if default_tests_subdir in test_subdir_options else test_subdir_options[0]), + default=(default_tests_subdir or test_subdir_options[0]), carousel=True, ) ] @@ -385,7 +431,8 @@ def collect_setup_info() -> SetupInfo: ph("cli-tests-root-provided") - autodetected_test_framework = detect_test_framework(curdir, tests_root) + test_framework_choices, detected_framework = get_suggestions(CommonSections.test_framework) + autodetected_test_framework = detected_framework or detect_test_framework_from_test_files(tests_root) framework_message = "⚗️ Let's configure your test framework.\n\n" if autodetected_test_framework: @@ -396,11 +443,19 @@ def collect_setup_info() -> SetupInfo: console.print(framework_panel) console.print() + framework_choices = [] + # add icons based on the detected framework + for choice in test_framework_choices: + if choice == "pytest": + framework_choices.append(("🧪 pytest", "pytest")) + elif choice == "unittest": + framework_choices.append(("🐍 unittest", "unittest")) + framework_questions = [ inquirer.List( "test_framework", message="Which test framework do you use?", - choices=[("🧪 pytest", "pytest"), ("🐍 unittest", "unittest")], + choices=framework_choices, default=autodetected_test_framework or "pytest", carousel=True, ) @@ -502,8 +557,10 @@ def collect_setup_info() -> SetupInfo: except InvalidGitRepositoryError: git_remote = "" + enable_telemetry = ask_for_telemetry() + ignore_paths: list[str] = [] - return SetupInfo( + return CLISetupInfo( module_root=str(module_root), tests_root=str(tests_root), benchmarks_root=str(benchmarks_root) if benchmarks_root else None, @@ -511,10 +568,11 @@ def collect_setup_info() -> SetupInfo: ignore_paths=ignore_paths, formatter=cast("str", formatter), git_remote=str(git_remote), + enable_telemetry=enable_telemetry, ) -def detect_test_framework(curdir: Path, tests_root: Path) -> str | None: +def detect_test_framework_from_config_files(curdir: Path) -> Optional[str]: test_framework = None pytest_files = ["pytest.ini", "pyproject.toml", "tox.ini", "setup.cfg"] pytest_config_patterns = { @@ -532,27 +590,31 @@ def detect_test_framework(curdir: Path, tests_root: Path) -> str | None: test_framework = "pytest" break test_framework = "pytest" - else: - # Check if any python files contain a class that inherits from unittest.TestCase - for filename in tests_root.iterdir(): - if filename.suffix == ".py": - with filename.open(encoding="utf8") as file: - contents = file.read() - try: - node = ast.parse(contents) - except SyntaxError: - continue - if any( - isinstance(item, ast.ClassDef) - and any( - (isinstance(base, ast.Attribute) and base.attr == "TestCase") - or (isinstance(base, ast.Name) and base.id == "TestCase") - for base in item.bases - ) - for item in node.body - ): - test_framework = "unittest" - break + return test_framework + + +def detect_test_framework_from_test_files(tests_root: Path) -> Optional[str]: + test_framework = None + # Check if any python files contain a class that inherits from unittest.TestCase + for filename in tests_root.iterdir(): + if filename.suffix == ".py": + with filename.open(encoding="utf8") as file: + contents = file.read() + try: + node = ast.parse(contents) + except SyntaxError: + continue + if any( + isinstance(item, ast.ClassDef) + and any( + (isinstance(base, ast.Attribute) and base.attr == "TestCase") + or (isinstance(base, ast.Name) and base.id == "TestCase") + for base in item.bases + ) + for item in node.body + ): + test_framework = "unittest" + break return test_framework @@ -607,41 +669,41 @@ def check_for_toml_or_setup_file() -> str | None: apologize_and_exit() create_toml = toml_answers["create_toml"] if create_toml: - ph("cli-create-pyproject-toml") - # Define a minimal pyproject.toml content - new_pyproject_toml = tomlkit.document() - new_pyproject_toml["tool"] = {"codeflash": {}} - try: - pyproject_toml_path.write_text(tomlkit.dumps(new_pyproject_toml), encoding="utf8") - - # Check if the pyproject.toml file was created - if pyproject_toml_path.exists(): - success_panel = Panel( - Text( - f"✅ Created a pyproject.toml file at {pyproject_toml_path}\n\n" - "Your project is now ready for Codeflash configuration!", - style="green", - justify="center", - ), - title="🎉 Success!", - border_style="bright_green", - ) - console.print(success_panel) - console.print("\n📍 Press any key to continue...") - console.input() - ph("cli-created-pyproject-toml") - except OSError: - click.echo( - "❌ Failed to create pyproject.toml. Please check your disk permissions and available space." - ) - apologize_and_exit() - else: - click.echo("⏩️ Skipping pyproject.toml creation.") - apologize_and_exit() + create_empty_pyproject_toml(pyproject_toml_path) click.echo() return cast("str", project_name) +def create_empty_pyproject_toml(pyproject_toml_path: Path) -> None: + ph("cli-create-pyproject-toml") + lsp_mode = is_LSP_enabled() + # Define a minimal pyproject.toml content + new_pyproject_toml = tomlkit.document() + new_pyproject_toml["tool"] = {"codeflash": {}} + try: + pyproject_toml_path.write_text(tomlkit.dumps(new_pyproject_toml), encoding="utf8") + + # Check if the pyproject.toml file was created + if pyproject_toml_path.exists() and not lsp_mode: + success_panel = Panel( + Text( + f"✅ Created a pyproject.toml file at {pyproject_toml_path}\n\n" + "Your project is now ready for Codeflash configuration!", + style="green", + justify="center", + ), + title="🎉 Success!", + border_style="bright_green", + ) + console.print(success_panel) + console.print("\n📍 Press any key to continue...") + console.input() + ph("cli-created-pyproject-toml") + except OSError: + click.echo("❌ Failed to create pyproject.toml. Please check your disk permissions and available space.") + apologize_and_exit() + + def install_github_actions(override_formatter_check: bool = False) -> None: # noqa: FBT001, FBT002 try: config, _config_file_path = parse_config_file(override_formatter_check=override_formatter_check) @@ -931,9 +993,27 @@ def customize_codeflash_yaml_content( return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd) +def get_formatter_cmds(formatter: str) -> list[str]: + if formatter == "black": + return ["black $file"] + if formatter == "ruff": + return ["ruff check --exit-zero --fix $file", "ruff format $file"] + if formatter == "other": + click.echo( + "🔧 In pyproject.toml, please replace 'your-formatter' with the command you use to format your code." + ) + return ["your-formatter $file"] + if formatter in {"don't use a formatter", "disabled"}: + return ["disabled"] + return [formatter] + + # Create or update the pyproject.toml file with the Codeflash dependency & configuration -def configure_pyproject_toml(setup_info: SetupInfo) -> None: - toml_path = Path.cwd() / "pyproject.toml" +def configure_pyproject_toml( + setup_info: Union[VsCodeSetupInfo, CLISetupInfo], config_file: Optional[Path] = None +) -> bool: + for_vscode = isinstance(setup_info, VsCodeSetupInfo) + toml_path = config_file or Path.cwd() / "pyproject.toml" try: with toml_path.open(encoding="utf8") as pyproject_file: pyproject_data = tomlkit.parse(pyproject_file.read()) @@ -942,44 +1022,51 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None: f"I couldn't find a pyproject.toml in the current directory.{LF}" f"Please create a new empty pyproject.toml file here, OR if you use poetry then run `poetry init`, OR run `codeflash init` again from a directory with an existing pyproject.toml file." ) - apologize_and_exit() - - enable_telemetry = ask_for_telemetry() + return False codeflash_section = tomlkit.table() codeflash_section.add(tomlkit.comment("All paths are relative to this pyproject.toml's directory.")) - codeflash_section["module-root"] = setup_info.module_root - codeflash_section["tests-root"] = setup_info.tests_root - codeflash_section["test-framework"] = setup_info.test_framework - codeflash_section["ignore-paths"] = setup_info.ignore_paths - if not enable_telemetry: - codeflash_section["disable-telemetry"] = not enable_telemetry - if setup_info.git_remote not in ["", "origin"]: - codeflash_section["git-remote"] = setup_info.git_remote + + if for_vscode: + for section in CommonSections: + if hasattr(setup_info, section.value): + codeflash_section[section.get_toml_key()] = getattr(setup_info, section.value) + else: + codeflash_section["module-root"] = setup_info.module_root + codeflash_section["tests-root"] = setup_info.tests_root + codeflash_section["test-framework"] = setup_info.test_framework + codeflash_section["ignore-paths"] = setup_info.ignore_paths + if not setup_info.enable_telemetry: + codeflash_section["disable-telemetry"] = not setup_info.enable_telemetry + if setup_info.git_remote not in ["", "origin"]: + codeflash_section["git-remote"] = setup_info.git_remote + formatter = setup_info.formatter - formatter_cmds = [] - if formatter == "black": - formatter_cmds.append("black $file") - elif formatter == "ruff": - formatter_cmds.extend(["ruff check --exit-zero --fix $file", "ruff format $file"]) - elif formatter == "other": - formatter_cmds.append("your-formatter $file") - click.echo( - "🔧 In pyproject.toml, please replace 'your-formatter' with the command you use to format your code." - ) - elif formatter == "don't use a formatter": - formatter_cmds.append("disabled") + + formatter_cmds = formatter if isinstance(formatter, list) else get_formatter_cmds(formatter) + check_formatter_installed(formatter_cmds, exit_on_failure=False) codeflash_section["formatter-cmds"] = formatter_cmds # Add the 'codeflash' section, ensuring 'tool' section exists tool_section = pyproject_data.get("tool", tomlkit.table()) - tool_section["codeflash"] = codeflash_section + + if for_vscode: + # merge the existing codeflash section, instead of overwriting it + existing_codeflash = tool_section.get("codeflash", tomlkit.table()) + + for key, value in codeflash_section.items(): + existing_codeflash[key] = value + tool_section["codeflash"] = existing_codeflash + else: + tool_section["codeflash"] = codeflash_section + pyproject_data["tool"] = tool_section with toml_path.open("w", encoding="utf8") as pyproject_file: pyproject_file.write(tomlkit.dumps(pyproject_data)) click.echo(f"✅ Added Codeflash configuration to {toml_path}") click.echo() + return True def install_github_app(git_remote: str) -> None: @@ -1001,43 +1088,47 @@ def install_github_app(git_remote: str) -> None: ) else: - click.prompt( - f"Finally, you'll need to install the Codeflash GitHub app by choosing the repository you want to install Codeflash on.{LF}" - f"I will attempt to open the github app page - https://github.com/apps/codeflash-ai/installations/select_target {LF}" - f"Press Enter to open the page to let you install the app…{LF}", - default="", - type=click.STRING, - prompt_suffix="", - show_default=False, - ) - click.launch("https://github.com/apps/codeflash-ai/installations/select_target") - click.prompt( - f"Press Enter once you've finished installing the github app from https://github.com/apps/codeflash-ai/installations/select_target{LF}", - default="", - type=click.STRING, - prompt_suffix="", - show_default=False, - ) - - count = 2 - while not is_github_app_installed_on_repo(owner, repo, suppress_errors=True): - if count == 0: - click.echo( - f"❌ It looks like the Codeflash GitHub App is not installed on the repository {owner}/{repo}.{LF}" - f"You won't be able to create PRs with Codeflash until you install the app.{LF}" - f"In the meantime you can make local only optimizations by using the '--no-pr' flag with codeflash.{LF}" - ) - break + try: + click.prompt( + f"Finally, you'll need to install the Codeflash GitHub app by choosing the repository you want to install Codeflash on.{LF}" + f"I will attempt to open the github app page - https://github.com/apps/codeflash-ai/installations/select_target {LF}" + f"Press Enter to open the page to let you install the app…{LF}", + default="", + type=click.STRING, + prompt_suffix="", + show_default=False, + ) + click.launch("https://github.com/apps/codeflash-ai/installations/select_target") click.prompt( - f"❌ It looks like the Codeflash GitHub App is not installed on the repository {owner}/{repo}.{LF}" - f"Please install it from https://github.com/apps/codeflash-ai/installations/select_target {LF}" - f"Press Enter to continue once you've finished installing the github app…{LF}", + f"Press Enter once you've finished installing the github app from https://github.com/apps/codeflash-ai/installations/select_target{LF}", default="", type=click.STRING, prompt_suffix="", show_default=False, ) - count -= 1 + + count = 2 + while not is_github_app_installed_on_repo(owner, repo, suppress_errors=True): + if count == 0: + click.echo( + f"❌ It looks like the Codeflash GitHub App is not installed on the repository {owner}/{repo}.{LF}" + f"You won't be able to create PRs with Codeflash until you install the app.{LF}" + f"In the meantime you can make local only optimizations by using the '--no-pr' flag with codeflash.{LF}" + ) + break + click.prompt( + f"❌ It looks like the Codeflash GitHub App is not installed on the repository {owner}/{repo}.{LF}" + f"Please install it from https://github.com/apps/codeflash-ai/installations/select_target {LF}" + f"Press Enter to continue once you've finished installing the github app…{LF}", + default="", + type=click.STRING, + prompt_suffix="", + show_default=False, + ) + count -= 1 + except (KeyboardInterrupt, EOFError, click.exceptions.Abort): + # leave empty line for the next prompt to be properly rendered + click.echo() class CFAPIKeyType(click.ParamType): diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 0a515a080..5335bad56 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -963,19 +963,28 @@ def _is_target_function_call(self, node: ast.Call) -> bool: return False - def _get_call_name(self, func_node) -> Optional[str]: # noqa : ANN001 + def _get_call_name(self, func_node) -> Optional[str]: # noqa: ANN001 """Extract the name being called from a function node.""" + # Fast path short-circuit for ast.Name nodes if isinstance(func_node, ast.Name): return func_node.id + + # Fast attribute chain extraction (speed: append, loop, join, NO reversed) if isinstance(func_node, ast.Attribute): parts = [] current = func_node - while isinstance(current, ast.Attribute): + # Unwind attribute chain as tight as possible (checked at each loop iteration) + while True: parts.append(current.attr) - current = current.value - if isinstance(current, ast.Name): - parts.append(current.id) - return ".".join(reversed(parts)) + val = current.value + if isinstance(val, ast.Attribute): + current = val + continue + if isinstance(val, ast.Name): + parts.append(val.id) + # Join in-place backwards via slice instead of reversed for slight speedup + return ".".join(parts[::-1]) + break return None def _extract_source_code(self, node: ast.FunctionDef) -> str: diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index a6c9eb2d6..c499199a1 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -5,6 +5,8 @@ import tomlkit +from codeflash.lsp.helpers import is_LSP_enabled + PYPROJECT_TOML_CACHE = {} ALL_CONFIG_FILES = {} # map path to closest config file @@ -93,15 +95,23 @@ def parse_config_file( msg = f"Error while parsing the config file {config_file_path}. Please recheck the file for syntax errors. Error: {e}" raise ValueError(msg) from e + lsp_mode = is_LSP_enabled() + try: tool = data["tool"] assert isinstance(tool, dict) - config = tool["codeflash"] + config = tool.get("codeflash", {}) except tomlkit.exceptions.NonExistentKey as e: + if lsp_mode: + # don't fail in lsp mode if codeflash config is not found. + return {}, config_file_path msg = f"Could not find the 'codeflash' block in the config file {config_file_path}. Please run 'codeflash init' to create the config file." raise ValueError(msg) from e assert isinstance(config, dict) + if config == {} and lsp_mode: + return {}, config_file_path + # default values: path_keys = ["module-root", "tests-root", "benchmarks-root"] path_list_keys = ["ignore-paths"] @@ -139,12 +149,13 @@ def parse_config_file( else: config[key] = [] - assert config["test-framework"] in {"pytest", "unittest"}, ( - "In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest." - ) + if config.get("test-framework"): + assert config["test-framework"] in {"pytest", "unittest"}, ( + "In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest." + ) # see if this is happening during GitHub actions setup - if len(config["formatter-cmds"]) > 0 and not override_formatter_check: - assert config["formatter-cmds"][0] != "your-formatter $file", ( + if config.get("formatter-cmds") and len(config.get("formatter-cmds")) > 0 and not override_formatter_check: + assert config.get("formatter-cmds")[0] != "your-formatter $file", ( "The formatter command is not set correctly in pyproject.toml. Please set the " "formatter command in the 'formatter-cmds' key. More info - https://docs.codeflash.ai/configuration" ) diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 592d8ba58..67b4bacd2 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import os from dataclasses import dataclass from pathlib import Path @@ -8,6 +9,16 @@ from codeflash.api.cfapi import get_codeflash_api_key, get_user_id from codeflash.cli_cmds.cli import process_pyproject_config +from codeflash.cli_cmds.cmd_init import ( + CommonSections, + VsCodeSetupInfo, + configure_pyproject_toml, + create_empty_pyproject_toml, + get_formatter_cmds, + get_suggestions, + get_valid_subdirs, + is_valid_pyproject_toml, +) from codeflash.code_utils.git_utils import git_root_dir from codeflash.code_utils.shell_utils import save_api_key_to_rc from codeflash.discovery.functions_to_optimize import ( @@ -67,6 +78,12 @@ class OptimizableFunctionsInCommitParams: commit_hash: str +@dataclass +class WriteConfigParams: + config_file: str + config: any + + server = CodeflashLanguageServer("codeflash-language-server", "v1.0") @@ -154,11 +171,48 @@ def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]: return top_level_pyproject, False +@server.feature("writeConfig") +def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) -> dict[str, any]: + cfg = params.config + cfg_file = Path(params.config_file) if params.config_file else None + + if cfg_file and not cfg_file.exists(): + # the client provided a config path but it doesn't exist + create_empty_pyproject_toml(cfg_file) + + setup_info = VsCodeSetupInfo( + module_root=getattr(cfg, "module_root", ""), + tests_root=getattr(cfg, "tests_root", ""), + test_framework=getattr(cfg, "test_framework", "pytest"), + formatter=get_formatter_cmds(getattr(cfg, "formatter_cmds", "disabled")), + ) + + devnull_writer = open(os.devnull, "w") # noqa + with contextlib.redirect_stdout(devnull_writer): + configured = configure_pyproject_toml(setup_info, cfg_file) + if configured: + return {"status": "success"} + return {"status": "error", "message": "Failed to configure pyproject.toml"} + + +@server.feature("getConfigSuggestions") +def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> dict[str, any]: + module_root_suggestions, default_module_root = get_suggestions(CommonSections.module_root) + tests_root_suggestions, default_tests_root = get_suggestions(CommonSections.tests_root) + test_framework_suggestions, default_test_framework = get_suggestions(CommonSections.test_framework) + formatter_suggestions, default_formatter = get_suggestions(CommonSections.formatter_cmds) + get_valid_subdirs.cache_clear() + return { + "module_root": {"choices": module_root_suggestions, "default": default_module_root}, + "tests_root": {"choices": tests_root_suggestions, "default": default_tests_root}, + "test_framework": {"choices": test_framework_suggestions, "default": default_test_framework}, + "formatter_cmds": {"choices": formatter_suggestions, "default": default_formatter}, + } + + # should be called the first thing to initialize and validate the project @server.feature("initProject") def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]: - from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml - # Always process args in the init project, the extension can call server.args_processed_before = False @@ -192,11 +246,14 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) "root": root, } - server.show_message_log("Validating project...", "Info") - config, reason = is_valid_pyproject_toml(pyproject_toml_path) - if config is None: - server.show_message_log("pyproject.toml is not valid", "Error") - return {"status": "error", "message": f"reason: {reason}", "pyprojectPath": pyproject_toml_path} + valid, config, reason = is_valid_pyproject_toml(pyproject_toml_path) + if not valid: + return { + "status": "error", + "message": f"reason: {reason}", + "pyprojectPath": pyproject_toml_path, + "existingConfig": config, + } args = process_args(server) diff --git a/tests/test_cmd_init.py b/tests/test_cmd_init.py new file mode 100644 index 000000000..fd976d659 --- /dev/null +++ b/tests/test_cmd_init.py @@ -0,0 +1,189 @@ +import pytest +import tempfile +from pathlib import Path +from codeflash.cli_cmds.cmd_init import ( + is_valid_pyproject_toml, + configure_pyproject_toml, + CLISetupInfo, + get_formatter_cmds, + VsCodeSetupInfo, + get_valid_subdirs, +) +import os + + +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tmpdirname: + yield Path(tmpdirname).resolve() + + +def test_is_valid_pyproject_toml_with_empty_config(temp_dir: Path) -> None: + with (temp_dir / "pyproject.toml").open(mode="w") as f: + f.write( + """[tool.codeflash] +""" + ) + f.flush() + valid, _, _message = is_valid_pyproject_toml(temp_dir / "pyproject.toml") + assert not valid + assert _message == "Missing required field: 'module_root'" + +def test_is_valid_pyproject_toml_with_incorrect_module_root(temp_dir: Path) -> None: + with (temp_dir / "pyproject.toml").open(mode="w") as f: + wrong_module_root = temp_dir / "invalid_directory" + f.write( + f"""[tool.codeflash] +module-root = "invalid_directory" +""" + ) + f.flush() + valid, config, _message = is_valid_pyproject_toml(temp_dir / "pyproject.toml") + assert not valid + assert _message == f"Invalid 'module_root': directory does not exist at {wrong_module_root}" + + +def test_is_valid_pyproject_toml_with_incorrect_tests_root(temp_dir: Path) -> None: + with (temp_dir / "pyproject.toml").open(mode="w") as f: + wrong_tests_root = temp_dir / "incorrect_tests_root" + f.write( + f"""[tool.codeflash] +module-root = "." +tests-root = "incorrect_tests_root" +""" + ) + f.flush() + valid, config, _message = is_valid_pyproject_toml(temp_dir / "pyproject.toml") + assert not valid + assert _message == f"Invalid 'tests_root': directory does not exist at {wrong_tests_root}" + + +def test_is_valid_pyproject_toml_with_valid_config(temp_dir: Path) -> None: + with (temp_dir / "pyproject.toml").open(mode="w") as f: + os.makedirs(temp_dir / "tests") + f.write( + """[tool.codeflash] +module-root = "." +tests-root = "tests" +test-framework = "pytest" +""" + ) + f.flush() + valid, config, _message = is_valid_pyproject_toml(temp_dir / "pyproject.toml") + assert valid + +def test_get_formatter_cmd(temp_dir: Path) -> None: + assert get_formatter_cmds("black") == ["black $file"] + assert get_formatter_cmds("ruff") == ["ruff check --exit-zero --fix $file", "ruff format $file"] + assert get_formatter_cmds("disabled") == ["disabled"] + assert get_formatter_cmds("don't use a formatter") == ["disabled"] + +def test_configure_pyproject_toml_for_cli(temp_dir: Path) -> None: + + pyproject_path = temp_dir / "pyproject.toml" + + with (pyproject_path).open(mode="w") as f: + f.write("") + f.flush() + os.mkdir(temp_dir / "tests") + config = CLISetupInfo( + module_root=".", + tests_root="tests", + benchmarks_root=None, + test_framework="pytest", + ignore_paths=[], + formatter="black", + git_remote="origin", + enable_telemetry=False, + ) + + success = configure_pyproject_toml(config, pyproject_path) + assert success + + config_content = pyproject_path.read_text() + assert """[tool.codeflash] +# All paths are relative to this pyproject.toml's directory. +module-root = "." +tests-root = "tests" +test-framework = "pytest" +ignore-paths = [] +disable-telemetry = true +formatter-cmds = ["black $file"] +""" == config_content + valid, _, _ = is_valid_pyproject_toml(pyproject_path) + assert valid + +def test_configure_pyproject_toml_for_vscode_with_empty_config(temp_dir: Path) -> None: + + pyproject_path = temp_dir / "pyproject.toml" + + with (pyproject_path).open(mode="w") as f: + f.write("") + f.flush() + os.mkdir(temp_dir / "tests") + config = VsCodeSetupInfo( + module_root=".", + tests_root="tests", + test_framework="pytest", + formatter="black", + ) + + success = configure_pyproject_toml(config, pyproject_path) + assert success + + config_content = pyproject_path.read_text() + assert """[tool.codeflash] +module-root = "." +tests-root = "tests" +test-framework = "pytest" +formatter-cmds = ["black $file"] +""" == config_content + valid, _, _ = is_valid_pyproject_toml(pyproject_path) + assert valid + +def test_configure_pyproject_toml_for_vscode_with_existing_config(temp_dir: Path) -> None: + pyproject_path = temp_dir / "pyproject.toml" + + with (pyproject_path).open(mode="w") as f: + f.write("""[tool.codeflash] +module-root = "codeflash" +tests-root = "tests" +benchmarks-root = "tests/benchmarks" +test-framework = "pytest" +formatter-cmds = ["disabled"] +""") + f.flush() + os.mkdir(temp_dir / "tests") + config = VsCodeSetupInfo( + module_root=".", + tests_root="tests", + test_framework="pytest", + formatter="disabled", + ) + + success = configure_pyproject_toml(config, pyproject_path) + assert success + + config_content = pyproject_path.read_text() + # the benchmarks-root shouldn't get overwritten + assert """[tool.codeflash] +module-root = "." +tests-root = "tests" +benchmarks-root = "tests/benchmarks" +test-framework = "pytest" +formatter-cmds = ["disabled"] +""" == config_content + valid, _, _ = is_valid_pyproject_toml(pyproject_path) + assert valid + +def test_get_valid_subdirs(temp_dir: Path) -> None: + os.mkdir(temp_dir / "dir1") + os.mkdir(temp_dir / "dir2") + os.mkdir(temp_dir / "__pycache__") + os.mkdir(temp_dir / ".git") + os.mkdir(temp_dir / "tests") + + dirs = get_valid_subdirs(temp_dir) + assert "tests" in dirs + assert "dir1" in dirs + assert "dir2" in dirs