diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 9af373921..77a193a20 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -32,7 +32,7 @@ from codeflash.code_utils.git_utils import get_git_remotes, get_repo_owner_and_name from codeflash.code_utils.github_utils import get_github_secrets_page_url from codeflash.code_utils.oauth_handler import perform_oauth_signin -from codeflash.code_utils.shell_utils import get_shell_rc_path, save_api_key_to_rc +from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell, save_api_key_to_rc from codeflash.either import is_successful from codeflash.lsp.helpers import is_LSP_enabled from codeflash.telemetry.posthog_cf import ph @@ -136,7 +136,10 @@ def init_codeflash() -> None: completion_message += ( "\n\n🐚 Don't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!" ) - reload_cmd = f"call {get_shell_rc_path()}" if os.name == "nt" else f"source {get_shell_rc_path()}" + if os.name == "nt": + reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}" + else: + reload_cmd = f"source {get_shell_rc_path()}" completion_message += f"\nOr run: {reload_cmd}" completion_panel = Panel( @@ -1087,7 +1090,7 @@ def configure_pyproject_toml( with toml_path.open("w", encoding="utf8") as pyproject_file: pyproject_file.write(tomlkit.dumps(pyproject_data)) - click.echo(f"āœ… Added Codeflash configuration to {toml_path}") + click.echo(f"Added Codeflash configuration to {toml_path}") click.echo() return True @@ -1264,7 +1267,8 @@ def enter_api_key_and_save_to_rc() -> None: browser_launched = True # This does not work on remote consoles shell_rc_path = get_shell_rc_path() if not shell_rc_path.exists() and os.name == "nt": - # On Windows, create a batch file in the user's home directory (not auto-run, just used to store api key) + # On Windows, create the appropriate file (PowerShell .ps1 or CMD .bat) in the user's home directory + shell_rc_path.parent.mkdir(parents=True, exist_ok=True) shell_rc_path.touch() click.echo(f"āœ… Created {shell_rc_path}") get_user_id(api_key=api_key) # Used to verify whether the API key is valid. diff --git a/codeflash/code_utils/env_utils.py b/codeflash/code_utils/env_utils.py index d74c99408..4987e6d8d 100644 --- a/codeflash/code_utils/env_utils.py +++ b/codeflash/code_utils/env_utils.py @@ -59,17 +59,26 @@ def get_codeflash_api_key() -> str: # Check environment variable first env_api_key = os.environ.get("CODEFLASH_API_KEY") shell_api_key = read_api_key_from_shell_config() - + logger.debug( + f"env_utils.py:get_codeflash_api_key - env_api_key: {'***' + env_api_key[-4:] if env_api_key else None}, shell_api_key: {'***' + shell_api_key[-4:] if shell_api_key else None}" + ) # If we have an env var but it's not in shell config, save it for persistence if env_api_key and not shell_api_key: try: from codeflash.either import is_successful + logger.debug("env_utils.py:get_codeflash_api_key - Saving API key from environment to shell config") result = save_api_key_to_rc(env_api_key) if is_successful(result): - logger.debug(f"Automatically saved API key from environment to shell config: {result.unwrap()}") + logger.debug( + f"env_utils.py:get_codeflash_api_key - Automatically saved API key from environment to shell config: {result.unwrap()}" + ) + else: + logger.debug(f"env_utils.py:get_codeflash_api_key - Failed to save API key: {result.failure()}") except Exception as e: - logger.debug(f"Failed to automatically save API key to shell config: {e}") + logger.debug( + f"env_utils.py:get_codeflash_api_key - Failed to automatically save API key to shell config: {e}" + ) # Prefer the shell configuration over environment variables for lsp, # as the API key may change in the RC file during lsp runtime. Since the LSP client (extension) can restart diff --git a/codeflash/code_utils/shell_utils.py b/codeflash/code_utils/shell_utils.py index 79b211111..60da8e3ba 100644 --- a/codeflash/code_utils/shell_utils.py +++ b/codeflash/code_utils/shell_utils.py @@ -1,40 +1,107 @@ from __future__ import annotations +import contextlib import os import re from pathlib import Path from typing import TYPE_CHECKING, Optional +from codeflash.cli_cmds.console import logger from codeflash.code_utils.compat import LF from codeflash.either import Failure, Success if TYPE_CHECKING: from codeflash.either import Result -if os.name == "nt": # Windows - SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE) - SHELL_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY=" -else: - SHELL_RC_EXPORT_PATTERN = re.compile( - r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE - ) - SHELL_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY=" +# PowerShell patterns and prefixes +POWERSHELL_RC_EXPORT_PATTERN = re.compile( + r'^\$env:CODEFLASH_API_KEY\s*=\s*(?:"|\')?(cf-[^\s"\']+)(?:"|\')?\s*$', re.MULTILINE +) +POWERSHELL_RC_EXPORT_PREFIX = "$env:CODEFLASH_API_KEY = " + +# CMD/Batch patterns and prefixes +CMD_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE) +CMD_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY=" + +# Unix shell patterns and prefixes +UNIX_RC_EXPORT_PATTERN = re.compile(r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE) +UNIX_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY=" + + +def is_powershell() -> bool: + """Detect if we're running in PowerShell on Windows. + + Uses multiple heuristics: + 1. PSModulePath environment variable (PowerShell always sets this) + 2. COMSPEC pointing to powershell.exe + 3. TERM_PROGRAM indicating Windows Terminal (often uses PowerShell) + """ + if os.name != "nt": + return False + + # Primary check: PSMODULEPATH is set by PowerShell + # This is the most reliable indicator as PowerShell always sets this + ps_module_path = os.environ.get("PSMODULEPATH") + if ps_module_path: + logger.debug("shell_utils.py:is_powershell - Detected PowerShell via PSModulePath") + return True + + # Secondary check: COMSPEC points to PowerShell + comspec = os.environ.get("COMSPEC", "").lower() + if "powershell" in comspec: + logger.debug(f"shell_utils.py:is_powershell - Detected PowerShell via COMSPEC: {comspec}") + return True + + # Tertiary check: Windows Terminal often uses PowerShell by default + # But we only use this if other indicators are ambiguous + term_program = os.environ.get("TERM_PROGRAM", "").lower() + # Check if we can find evidence of CMD (cmd.exe in COMSPEC) + # If not, assume PowerShell for Windows Terminal + if "windows" in term_program and "terminal" in term_program and "cmd.exe" not in comspec: + logger.debug(f"shell_utils.py:is_powershell - Detected PowerShell via Windows Terminal (COMSPEC: {comspec})") + return True + + logger.debug(f"shell_utils.py:is_powershell - Not PowerShell (COMSPEC: {comspec}, TERM_PROGRAM: {term_program})") + return False def read_api_key_from_shell_config() -> Optional[str]: + """Read API key from shell configuration file.""" + shell_rc_path = get_shell_rc_path() + # Ensure shell_rc_path is a Path object for consistent handling + if not isinstance(shell_rc_path, Path): + shell_rc_path = Path(shell_rc_path) + + # Determine the correct pattern to use based on the file extension and platform + if os.name == "nt": # Windows + pattern = POWERSHELL_RC_EXPORT_PATTERN if shell_rc_path.suffix == ".ps1" else CMD_RC_EXPORT_PATTERN + else: # Unix-like + pattern = UNIX_RC_EXPORT_PATTERN + try: - shell_rc_path = get_shell_rc_path() - with open(shell_rc_path, encoding="utf8") as shell_rc: # noqa: PTH123 + # Convert Path to string using as_posix() for cross-platform path compatibility + shell_rc_path_str = shell_rc_path.as_posix() if isinstance(shell_rc_path, Path) else str(shell_rc_path) + with open(shell_rc_path_str, encoding="utf8") as shell_rc: # noqa: PTH123 shell_contents = shell_rc.read() - matches = SHELL_RC_EXPORT_PATTERN.findall(shell_contents) - return matches[-1] if matches else None + matches = pattern.findall(shell_contents) + if matches: + logger.debug(f"shell_utils.py:read_api_key_from_shell_config - Found API key in file: {shell_rc_path}") + return matches[-1] + logger.debug(f"shell_utils.py:read_api_key_from_shell_config - No API key found in file: {shell_rc_path}") + return None except FileNotFoundError: + logger.debug(f"shell_utils.py:read_api_key_from_shell_config - File not found: {shell_rc_path}") + return None + except Exception as e: + logger.debug(f"shell_utils.py:read_api_key_from_shell_config - Error reading file: {e}") return None def get_shell_rc_path() -> Path: """Get the path to the user's shell configuration file.""" - if os.name == "nt": # on Windows, we use a batch file in the user's home directory + if os.name == "nt": # Windows + if is_powershell(): + return Path.home() / "codeflash_env.ps1" return Path.home() / "codeflash_env.bat" shell = os.environ.get("SHELL", "/bin/bash").split("/")[-1] shell_rc_filename = {"zsh": ".zshrc", "ksh": ".kshrc", "csh": ".cshrc", "tcsh": ".cshrc", "dash": ".profile"}.get( @@ -44,40 +111,123 @@ def get_shell_rc_path() -> Path: def get_api_key_export_line(api_key: str) -> str: - return f'{SHELL_RC_EXPORT_PREFIX}"{api_key}"' + """Get the appropriate export line based on the shell type.""" + if os.name == "nt": # Windows + if is_powershell(): + return f'{POWERSHELL_RC_EXPORT_PREFIX}"{api_key}"' + return f'{CMD_RC_EXPORT_PREFIX}"{api_key}"' + # Unix-like + return f'{UNIX_RC_EXPORT_PREFIX}"{api_key}"' def save_api_key_to_rc(api_key: str) -> Result[str, str]: + """Save API key to the appropriate shell configuration file.""" shell_rc_path = get_shell_rc_path() + # Ensure shell_rc_path is a Path object for consistent handling + if not isinstance(shell_rc_path, Path): + shell_rc_path = Path(shell_rc_path) api_key_line = get_api_key_export_line(api_key) + + logger.debug(f"shell_utils.py:save_api_key_to_rc - Saving API key to: {shell_rc_path}") + logger.debug(f"shell_utils.py:save_api_key_to_rc - API key line format: {api_key_line[:30]}...") + + # Determine the correct pattern to use for replacement + if os.name == "nt": # Windows + if is_powershell(): + pattern = POWERSHELL_RC_EXPORT_PATTERN + logger.debug("shell_utils.py:save_api_key_to_rc - Using PowerShell pattern") + else: + pattern = CMD_RC_EXPORT_PATTERN + logger.debug("shell_utils.py:save_api_key_to_rc - Using CMD pattern") + else: # Unix-like + pattern = UNIX_RC_EXPORT_PATTERN + logger.debug("shell_utils.py:save_api_key_to_rc - Using Unix pattern") + try: - with open(shell_rc_path, "r+", encoding="utf8") as shell_file: # noqa: PTH123 - shell_contents = shell_file.read() - if os.name == "nt" and not shell_contents: # on windows we're writing to a batch file + # Create directory if it doesn't exist (ignore errors - file operation will fail if needed) + # Directory creation failed, but we'll still try to open the file + # The file operation itself will raise the appropriate exception if there are permission issues + with contextlib.suppress(OSError, PermissionError): + shell_rc_path.parent.mkdir(parents=True, exist_ok=True) + + # Convert Path to string using as_posix() for cross-platform path compatibility + shell_rc_path_str = shell_rc_path.as_posix() if isinstance(shell_rc_path, Path) else str(shell_rc_path) + + # Try to open in r+ mode (read and write in single operation) + # Handle FileNotFoundError if file doesn't exist (r+ requires file to exist) + try: + with open(shell_rc_path_str, "r+", encoding="utf8") as shell_file: # noqa: PTH123 + shell_contents = shell_file.read() + logger.debug(f"shell_utils.py:save_api_key_to_rc - Read existing file, length: {len(shell_contents)}") + + # Initialize empty file with header for batch files if needed + if not shell_contents: + logger.debug("shell_utils.py:save_api_key_to_rc - File is empty, initializing") + if os.name == "nt" and not is_powershell(): + shell_contents = "@echo off" + logger.debug("shell_utils.py:save_api_key_to_rc - Added @echo off header for batch file") + + # Check if API key already exists in the current file + matches = pattern.findall(shell_contents) + existing_in_file = bool(matches) + logger.debug(f"shell_utils.py:save_api_key_to_rc - Existing key in file: {existing_in_file}") + + if existing_in_file: + # Replace the existing API key line in this file + updated_shell_contents = re.sub(pattern, api_key_line, shell_contents) + action = "Updated CODEFLASH_API_KEY in" + logger.debug("shell_utils.py:save_api_key_to_rc - Replaced existing API key") + else: + # Append the new API key line + if shell_contents and not shell_contents.endswith(LF): + updated_shell_contents = shell_contents + LF + api_key_line + LF + else: + updated_shell_contents = shell_contents.rstrip() + f"{LF}{api_key_line}{LF}" + action = "Added CODEFLASH_API_KEY to" + logger.debug("shell_utils.py:save_api_key_to_rc - Appended new API key") + + # Write the updated contents + shell_file.seek(0) + shell_file.write(updated_shell_contents) + shell_file.truncate() + except FileNotFoundError: + # File doesn't exist, create it first with initial content + logger.debug("shell_utils.py:save_api_key_to_rc - File does not exist, creating new") + shell_contents = "" + # Initialize with header for batch files if needed + if os.name == "nt" and not is_powershell(): shell_contents = "@echo off" - existing_api_key = read_api_key_from_shell_config() + logger.debug("shell_utils.py:save_api_key_to_rc - Added @echo off header for batch file") + + # Create the file by opening in write mode + with open(shell_rc_path_str, "w", encoding="utf8") as shell_file: # noqa: PTH123 + shell_file.write(shell_contents) - if existing_api_key: - # Replace the existing API key line - updated_shell_contents = re.sub(SHELL_RC_EXPORT_PATTERN, api_key_line, shell_contents) - action = "Updated CODEFLASH_API_KEY in" - else: + # Re-open in r+ mode to add the API key (r+ allows both read and write) + with open(shell_rc_path_str, "r+", encoding="utf8") as shell_file: # noqa: PTH123 # Append the new API key line updated_shell_contents = shell_contents.rstrip() + f"{LF}{api_key_line}{LF}" action = "Added CODEFLASH_API_KEY to" + logger.debug("shell_utils.py:save_api_key_to_rc - Appended new API key to new file") + + # Write the updated contents + shell_file.seek(0) + shell_file.write(updated_shell_contents) + shell_file.truncate() + + logger.debug(f"shell_utils.py:save_api_key_to_rc - Successfully wrote to {shell_rc_path}") - shell_file.seek(0) - shell_file.write(updated_shell_contents) - shell_file.truncate() return Success(f"āœ… {action} {shell_rc_path}") - except PermissionError: + except PermissionError as e: + logger.debug(f"shell_utils.py:save_api_key_to_rc - Permission error: {e}") return Failure( f"šŸ’” I tried adding your Codeflash API key to {shell_rc_path} - but seems like I don't have permissions to do so.{LF}" f"You'll need to open it yourself and add the following line:{LF}{LF}{api_key_line}{LF}" ) - except FileNotFoundError: + except Exception as e: + logger.debug(f"shell_utils.py:save_api_key_to_rc - Error: {e}") return Failure( - f"šŸ’” I went to save your Codeflash API key to {shell_rc_path}, but noticed that it doesn't exist.{LF}" + f"šŸ’” I went to save your Codeflash API key to {shell_rc_path}, but encountered an error: {e}{LF}" f"To ensure your Codeflash API key is automatically loaded into your environment at startup, you can create {shell_rc_path} and add the following line:{LF}" f"{LF}{api_key_line}{LF}" ) diff --git a/tests/test_shell_utils.py b/tests/test_shell_utils.py index 0ceeba5d8..d7ee8de5d 100644 --- a/tests/test_shell_utils.py +++ b/tests/test_shell_utils.py @@ -15,7 +15,7 @@ def test_save_api_key_to_rc_success(self, mock_get_shell_rc_path, mock_file): api_key = "cf-12345" result = save_api_key_to_rc(api_key) self.assertTrue(isinstance(result, Success)) - mock_file.assert_called_with("/fake/path/.bashrc", encoding="utf8") + mock_file.assert_called_with("/fake/path/.bashrc", "r+", encoding="utf8") handle = mock_file() handle.write.assert_called_once() handle.truncate.assert_called_once()