diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index e7f195154..08a81e1a1 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -16,6 +16,7 @@ import tomlkit from git import InvalidGitRepositoryError, Repo from pydantic.dataclasses import dataclass +from rich.prompt import Confirm from codeflash.api.cfapi import is_github_app_installed_on_repo from codeflash.cli_cmds.cli_common import apologize_and_exit, inquirer_wrapper, inquirer_wrapper_path @@ -45,6 +46,7 @@ f"{LF}" ) + @dataclass(frozen=True) class SetupInfo: module_root: str @@ -70,7 +72,6 @@ def init_codeflash() -> None: did_add_new_key = prompt_api_key() if should_modify_pyproject_toml(): - setup_info: SetupInfo = collect_setup_info() configure_pyproject_toml(setup_info) @@ -83,7 +84,6 @@ def init_codeflash() -> None: if "setup_info" in locals(): module_string = f" you selected ({setup_info.module_root})" - click.echo( f"{LF}" f"⚡️ Codeflash is now set up! You can now run:{LF}" @@ -125,11 +125,13 @@ def ask_run_end_to_end_test(args: Namespace) -> None: bubble_sort_path, bubble_sort_test_path = create_bubble_sort_file_and_test(args) run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path) + def should_modify_pyproject_toml() -> bool: """Check if the current directory contains a valid pyproject.toml file with codeflash config If it does, ask the user if they want to re-configure it. """ from rich.prompt import Confirm + pyproject_toml_path = Path.cwd() / "pyproject.toml" if not pyproject_toml_path.exists(): return True @@ -144,7 +146,9 @@ def should_modify_pyproject_toml() -> bool: return True create_toml = Confirm.ask( - "✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?", default=False, show_default=True + "✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?", + default=False, + show_default=True, ) return create_toml @@ -160,7 +164,18 @@ def collect_setup_info() -> SetupInfo: # Check for the existence of pyproject.toml or setup.py project_name = check_for_toml_or_setup_file() - ignore_subdirs = ["venv", "node_modules", "dist", "build", "build_temp", "build_scripts", "env", "logs", "tmp", "__pycache__"] + ignore_subdirs = [ + "venv", + "node_modules", + "dist", + "build", + "build_temp", + "build_scripts", + "env", + "logs", + "tmp", + "__pycache__", + ] valid_subdirs = [ d for d in next(os.walk("."))[1] if not d.startswith(".") and not d.startswith("__") and d not in ignore_subdirs ] @@ -225,7 +240,7 @@ def collect_setup_info() -> SetupInfo: else: apologize_and_exit() else: - tests_root = Path(curdir) / Path(cast(str, tests_root_answer)) + tests_root = Path(curdir) / Path(cast("str", tests_root_answer)) tests_root = tests_root.relative_to(curdir) ph("cli-tests-root-provided") @@ -262,13 +277,13 @@ def collect_setup_info() -> SetupInfo: benchmarks_options.append(create_benchmarks_option) benchmarks_options.append(custom_dir_option) - benchmarks_answer = inquirer_wrapper( inquirer.list_input, message="Where are your performance benchmarks located? (benchmarks must be a sub directory of your tests root directory)", choices=benchmarks_options, default=( - default_benchmarks_subdir if default_benchmarks_subdir in benchmarks_options else benchmarks_options[0]), + default_benchmarks_subdir if default_benchmarks_subdir in benchmarks_options else benchmarks_options[0] + ), ) if benchmarks_answer == create_benchmarks_option: @@ -288,7 +303,7 @@ def collect_setup_info() -> SetupInfo: elif benchmarks_answer == no_benchmarks_option: benchmarks_root = None else: - benchmarks_root = tests_root / Path(cast(str, benchmarks_answer)) + benchmarks_root = tests_root / Path(cast("str", benchmarks_answer)) # TODO: Implement other benchmark framework options # if benchmarks_root: @@ -304,7 +319,6 @@ def collect_setup_info() -> SetupInfo: # carousel=True, # ) - formatter = inquirer_wrapper( inquirer.list_input, message="Which code formatter do you use?", @@ -340,10 +354,10 @@ def collect_setup_info() -> SetupInfo: return SetupInfo( module_root=str(module_root), tests_root=str(tests_root), - benchmarks_root = str(benchmarks_root) if benchmarks_root else None, - test_framework=cast(str, test_framework), + benchmarks_root=str(benchmarks_root) if benchmarks_root else None, + test_framework=cast("str", test_framework), ignore_paths=ignore_paths, - formatter=cast(str, formatter), + formatter=cast("str", formatter), git_remote=str(git_remote), ) @@ -453,7 +467,7 @@ def check_for_toml_or_setup_file() -> str | None: click.echo("⏩️ Skipping pyproject.toml creation.") apologize_and_exit() click.echo() - return cast(str, project_name) + return cast("str", project_name) def install_github_actions(override_formatter_check: bool = False) -> None: @@ -499,19 +513,22 @@ def install_github_actions(override_formatter_check: bool = False) -> None: return workflows_path.mkdir(parents=True, exist_ok=True) from importlib.resources import files + benchmark_mode = False if "benchmarks_root" in config: benchmark_mode = inquirer_wrapper( inquirer.confirm, message="⚡️It looks like you've configured a benchmarks_root in your config. Would you like to run the Github action in benchmark mode? " - " This will show the impact of Codeflash's suggested optimizations on your benchmarks", + " This will show the impact of Codeflash's suggested optimizations on your benchmarks", default=True, ) optimize_yml_content = ( files("codeflash").joinpath("cli_cmds", "workflows", "codeflash-optimize.yaml").read_text(encoding="utf-8") ) - materialized_optimize_yml_content = customize_codeflash_yaml_content(optimize_yml_content, config, git_root, benchmark_mode) + materialized_optimize_yml_content = customize_codeflash_yaml_content( + optimize_yml_content, config, git_root, benchmark_mode + ) with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file: optimize_yml_file.write(materialized_optimize_yml_content) click.echo(f"{LF}✅ Created GitHub action workflow at {optimize_yaml_path}{LF}") @@ -941,12 +958,8 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test def ask_for_telemetry() -> bool: """Prompt the user to enable or disable telemetry.""" - from rich.prompt import Confirm - - enable_telemetry = Confirm.ask( + return Confirm.ask( "⚡️ Would you like to enable telemetry to help us improve the Codeflash experience?", default=True, show_default=True, ) - - return enable_telemetry