Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions codeflash/cli_cmds/cmd_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tomlkit
from git import InvalidGitRepositoryError, Repo
from pydantic.dataclasses import dataclass
from rich.prompt import Confirm

from codeflash.api.cfapi import is_github_app_installed_on_repo
from codeflash.cli_cmds.cli_common import apologize_and_exit, inquirer_wrapper, inquirer_wrapper_path
Expand Down Expand Up @@ -45,6 +46,7 @@
f"{LF}"
)


@dataclass(frozen=True)
class SetupInfo:
module_root: str
Expand All @@ -70,7 +72,6 @@ def init_codeflash() -> None:
did_add_new_key = prompt_api_key()

if should_modify_pyproject_toml():

setup_info: SetupInfo = collect_setup_info()

configure_pyproject_toml(setup_info)
Expand All @@ -83,7 +84,6 @@ def init_codeflash() -> None:
if "setup_info" in locals():
module_string = f" you selected ({setup_info.module_root})"


click.echo(
f"{LF}"
f"⚡️ Codeflash is now set up! You can now run:{LF}"
Expand Down Expand Up @@ -125,11 +125,13 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
bubble_sort_path, bubble_sort_test_path = create_bubble_sort_file_and_test(args)
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)


def should_modify_pyproject_toml() -> bool:
"""Check if the current directory contains a valid pyproject.toml file with codeflash config
If it does, ask the user if they want to re-configure it.
"""
from rich.prompt import Confirm

pyproject_toml_path = Path.cwd() / "pyproject.toml"
if not pyproject_toml_path.exists():
return True
Expand All @@ -144,7 +146,9 @@ def should_modify_pyproject_toml() -> bool:
return True

create_toml = Confirm.ask(
"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?", default=False, show_default=True
"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?",
default=False,
show_default=True,
)
return create_toml

Expand All @@ -160,7 +164,18 @@ def collect_setup_info() -> SetupInfo:
# Check for the existence of pyproject.toml or setup.py
project_name = check_for_toml_or_setup_file()

ignore_subdirs = ["venv", "node_modules", "dist", "build", "build_temp", "build_scripts", "env", "logs", "tmp", "__pycache__"]
ignore_subdirs = [
"venv",
"node_modules",
"dist",
"build",
"build_temp",
"build_scripts",
"env",
"logs",
"tmp",
"__pycache__",
]
valid_subdirs = [
d for d in next(os.walk("."))[1] if not d.startswith(".") and not d.startswith("__") and d not in ignore_subdirs
]
Expand Down Expand Up @@ -225,7 +240,7 @@ def collect_setup_info() -> SetupInfo:
else:
apologize_and_exit()
else:
tests_root = Path(curdir) / Path(cast(str, tests_root_answer))
tests_root = Path(curdir) / Path(cast("str", tests_root_answer))
tests_root = tests_root.relative_to(curdir)
ph("cli-tests-root-provided")

Expand Down Expand Up @@ -262,13 +277,13 @@ def collect_setup_info() -> SetupInfo:
benchmarks_options.append(create_benchmarks_option)
benchmarks_options.append(custom_dir_option)


benchmarks_answer = inquirer_wrapper(
inquirer.list_input,
message="Where are your performance benchmarks located? (benchmarks must be a sub directory of your tests root directory)",
choices=benchmarks_options,
default=(
default_benchmarks_subdir if default_benchmarks_subdir in benchmarks_options else benchmarks_options[0]),
default_benchmarks_subdir if default_benchmarks_subdir in benchmarks_options else benchmarks_options[0]
),
)

if benchmarks_answer == create_benchmarks_option:
Expand All @@ -288,7 +303,7 @@ def collect_setup_info() -> SetupInfo:
elif benchmarks_answer == no_benchmarks_option:
benchmarks_root = None
else:
benchmarks_root = tests_root / Path(cast(str, benchmarks_answer))
benchmarks_root = tests_root / Path(cast("str", benchmarks_answer))

# TODO: Implement other benchmark framework options
# if benchmarks_root:
Expand All @@ -304,7 +319,6 @@ def collect_setup_info() -> SetupInfo:
# carousel=True,
# )


formatter = inquirer_wrapper(
inquirer.list_input,
message="Which code formatter do you use?",
Expand Down Expand Up @@ -340,10 +354,10 @@ def collect_setup_info() -> SetupInfo:
return SetupInfo(
module_root=str(module_root),
tests_root=str(tests_root),
benchmarks_root = str(benchmarks_root) if benchmarks_root else None,
test_framework=cast(str, test_framework),
benchmarks_root=str(benchmarks_root) if benchmarks_root else None,
test_framework=cast("str", test_framework),
ignore_paths=ignore_paths,
formatter=cast(str, formatter),
formatter=cast("str", formatter),
git_remote=str(git_remote),
)

Expand Down Expand Up @@ -453,7 +467,7 @@ def check_for_toml_or_setup_file() -> str | None:
click.echo("⏩️ Skipping pyproject.toml creation.")
apologize_and_exit()
click.echo()
return cast(str, project_name)
return cast("str", project_name)


def install_github_actions(override_formatter_check: bool = False) -> None:
Expand Down Expand Up @@ -499,19 +513,22 @@ def install_github_actions(override_formatter_check: bool = False) -> None:
return
workflows_path.mkdir(parents=True, exist_ok=True)
from importlib.resources import files

benchmark_mode = False
if "benchmarks_root" in config:
benchmark_mode = inquirer_wrapper(
inquirer.confirm,
message="⚡️It looks like you've configured a benchmarks_root in your config. Would you like to run the Github action in benchmark mode? "
" This will show the impact of Codeflash's suggested optimizations on your benchmarks",
" This will show the impact of Codeflash's suggested optimizations on your benchmarks",
default=True,
)

optimize_yml_content = (
files("codeflash").joinpath("cli_cmds", "workflows", "codeflash-optimize.yaml").read_text(encoding="utf-8")
)
materialized_optimize_yml_content = customize_codeflash_yaml_content(optimize_yml_content, config, git_root, benchmark_mode)
materialized_optimize_yml_content = customize_codeflash_yaml_content(
optimize_yml_content, config, git_root, benchmark_mode
)
with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file:
optimize_yml_file.write(materialized_optimize_yml_content)
click.echo(f"{LF}✅ Created GitHub action workflow at {optimize_yaml_path}{LF}")
Expand Down Expand Up @@ -941,12 +958,8 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test

def ask_for_telemetry() -> bool:
"""Prompt the user to enable or disable telemetry."""
from rich.prompt import Confirm

enable_telemetry = Confirm.ask(
return Confirm.ask(
"⚡️ Would you like to enable telemetry to help us improve the Codeflash experience?",
default=True,
show_default=True,
)

return enable_telemetry
Loading