Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions codeflash/cli_cmds/cmd_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,23 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
run_end_to_end_test(args, file_path)


def is_valid_pyproject_toml(pyproject_toml_path: Path) -> tuple[bool, dict[str, Any] | None, str]: # noqa: PLR0911
def config_found(pyproject_toml_path: Union[str, Path]) -> tuple[bool, str]:
pyproject_toml_path = Path(pyproject_toml_path)

if not pyproject_toml_path.exists():
return False, None, f"Configuration file not found: {pyproject_toml_path}"
return False, f"Configuration file not found: {pyproject_toml_path}"

if not pyproject_toml_path.is_file():
return False, f"Configuration file is not a file: {pyproject_toml_path}"

if pyproject_toml_path.suffix != ".toml":
return False, f"Configuration file is not a .toml file: {pyproject_toml_path}"

return True, ""


def is_valid_pyproject_toml(pyproject_toml_path: Union[str, Path]) -> tuple[bool, dict[str, Any] | None, str]:
pyproject_toml_path = Path(pyproject_toml_path)
try:
config, _ = parse_config_file(pyproject_toml_path)
except Exception as e:
Expand Down Expand Up @@ -206,6 +219,10 @@ def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:

pyproject_toml_path = Path.cwd() / "pyproject.toml"

found, _ = config_found(pyproject_toml_path)
if not found:
return True, None

valid, config, _message = is_valid_pyproject_toml(pyproject_toml_path)
if not valid:
# needs to be re-configured
Expand Down
5 changes: 5 additions & 0 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from codeflash.cli_cmds.cmd_init import (
CommonSections,
VsCodeSetupInfo,
config_found,
configure_pyproject_toml,
create_empty_pyproject_toml,
create_find_common_tags_file,
Expand Down Expand Up @@ -263,6 +264,10 @@ def init_project(params: ValidateProjectParams) -> dict[str, str]:
"root": root,
}

found, message = config_found(pyproject_toml_path)
if not found:
return {"status": "error", "message": message}

valid, config, reason = is_valid_pyproject_toml(pyproject_toml_path)
if not valid:
return {
Expand Down
Loading