Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions codeflash/cli_cmds/cmd_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from codeflash.code_utils.git_utils import get_git_remotes, get_repo_owner_and_name
from codeflash.code_utils.github_utils import get_github_secrets_page_url
from codeflash.code_utils.oauth_handler import perform_oauth_signin
from codeflash.code_utils.shell_utils import get_shell_rc_path, save_api_key_to_rc
from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell, save_api_key_to_rc
from codeflash.either import is_successful
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.telemetry.posthog_cf import ph
Expand Down Expand Up @@ -136,7 +136,10 @@ def init_codeflash() -> None:
completion_message += (
"\n\n🐚 Don't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
)
reload_cmd = f"call {get_shell_rc_path()}" if os.name == "nt" else f"source {get_shell_rc_path()}"
if os.name == "nt":
reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}"
else:
reload_cmd = f"source {get_shell_rc_path()}"
completion_message += f"\nOr run: {reload_cmd}"

completion_panel = Panel(
Expand Down Expand Up @@ -1087,7 +1090,7 @@ def configure_pyproject_toml(

with toml_path.open("w", encoding="utf8") as pyproject_file:
pyproject_file.write(tomlkit.dumps(pyproject_data))
click.echo(f"Added Codeflash configuration to {toml_path}")
click.echo(f"Added Codeflash configuration to {toml_path}")
click.echo()
return True

Expand Down Expand Up @@ -1264,7 +1267,8 @@ def enter_api_key_and_save_to_rc() -> None:
browser_launched = True # This does not work on remote consoles
shell_rc_path = get_shell_rc_path()
if not shell_rc_path.exists() and os.name == "nt":
# On Windows, create a batch file in the user's home directory (not auto-run, just used to store api key)
# On Windows, create the appropriate file (PowerShell .ps1 or CMD .bat) in the user's home directory
shell_rc_path.parent.mkdir(parents=True, exist_ok=True)
shell_rc_path.touch()
click.echo(f"✅ Created {shell_rc_path}")
get_user_id(api_key=api_key) # Used to verify whether the API key is valid.
Expand Down
15 changes: 12 additions & 3 deletions codeflash/code_utils/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,26 @@ def get_codeflash_api_key() -> str:
# Check environment variable first
env_api_key = os.environ.get("CODEFLASH_API_KEY")
shell_api_key = read_api_key_from_shell_config()

logger.debug(
f"env_utils.py:get_codeflash_api_key - env_api_key: {'***' + env_api_key[-4:] if env_api_key else None}, shell_api_key: {'***' + shell_api_key[-4:] if shell_api_key else None}"
)
# If we have an env var but it's not in shell config, save it for persistence
if env_api_key and not shell_api_key:
try:
from codeflash.either import is_successful

logger.debug("env_utils.py:get_codeflash_api_key - Saving API key from environment to shell config")
result = save_api_key_to_rc(env_api_key)
if is_successful(result):
logger.debug(f"Automatically saved API key from environment to shell config: {result.unwrap()}")
logger.debug(
f"env_utils.py:get_codeflash_api_key - Automatically saved API key from environment to shell config: {result.unwrap()}"
)
else:
logger.debug(f"env_utils.py:get_codeflash_api_key - Failed to save API key: {result.failure()}")
except Exception as e:
logger.debug(f"Failed to automatically save API key to shell config: {e}")
logger.debug(
f"env_utils.py:get_codeflash_api_key - Failed to automatically save API key to shell config: {e}"
)

# Prefer the shell configuration over environment variables for lsp,
# as the API key may change in the RC file during lsp runtime. Since the LSP client (extension) can restart
Expand Down
208 changes: 179 additions & 29 deletions codeflash/code_utils/shell_utils.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,107 @@
from __future__ import annotations

import contextlib
import os
import re
from pathlib import Path
from typing import TYPE_CHECKING, Optional

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.compat import LF
from codeflash.either import Failure, Success

if TYPE_CHECKING:
from codeflash.either import Result

if os.name == "nt": # Windows
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
SHELL_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
else:
SHELL_RC_EXPORT_PATTERN = re.compile(
r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE
)
SHELL_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
# PowerShell patterns and prefixes
POWERSHELL_RC_EXPORT_PATTERN = re.compile(
r'^\$env:CODEFLASH_API_KEY\s*=\s*(?:"|\')?(cf-[^\s"\']+)(?:"|\')?\s*$', re.MULTILINE
)
POWERSHELL_RC_EXPORT_PREFIX = "$env:CODEFLASH_API_KEY = "

# CMD/Batch patterns and prefixes
CMD_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
CMD_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="

# Unix shell patterns and prefixes
UNIX_RC_EXPORT_PATTERN = re.compile(r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE)
UNIX_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="


def is_powershell() -> bool:
"""Detect if we're running in PowerShell on Windows.

Uses multiple heuristics:
1. PSModulePath environment variable (PowerShell always sets this)
2. COMSPEC pointing to powershell.exe
3. TERM_PROGRAM indicating Windows Terminal (often uses PowerShell)
"""
if os.name != "nt":
return False

# Primary check: PSMODULEPATH is set by PowerShell
# This is the most reliable indicator as PowerShell always sets this
ps_module_path = os.environ.get("PSMODULEPATH")
if ps_module_path:
logger.debug("shell_utils.py:is_powershell - Detected PowerShell via PSModulePath")
return True

# Secondary check: COMSPEC points to PowerShell
comspec = os.environ.get("COMSPEC", "").lower()
if "powershell" in comspec:
logger.debug(f"shell_utils.py:is_powershell - Detected PowerShell via COMSPEC: {comspec}")
return True

# Tertiary check: Windows Terminal often uses PowerShell by default
# But we only use this if other indicators are ambiguous
term_program = os.environ.get("TERM_PROGRAM", "").lower()
# Check if we can find evidence of CMD (cmd.exe in COMSPEC)
# If not, assume PowerShell for Windows Terminal
if "windows" in term_program and "terminal" in term_program and "cmd.exe" not in comspec:
logger.debug(f"shell_utils.py:is_powershell - Detected PowerShell via Windows Terminal (COMSPEC: {comspec})")
return True

logger.debug(f"shell_utils.py:is_powershell - Not PowerShell (COMSPEC: {comspec}, TERM_PROGRAM: {term_program})")
return False


def read_api_key_from_shell_config() -> Optional[str]:
"""Read API key from shell configuration file."""
shell_rc_path = get_shell_rc_path()
# Ensure shell_rc_path is a Path object for consistent handling
if not isinstance(shell_rc_path, Path):
shell_rc_path = Path(shell_rc_path)

# Determine the correct pattern to use based on the file extension and platform
if os.name == "nt": # Windows
pattern = POWERSHELL_RC_EXPORT_PATTERN if shell_rc_path.suffix == ".ps1" else CMD_RC_EXPORT_PATTERN
else: # Unix-like
pattern = UNIX_RC_EXPORT_PATTERN

try:
shell_rc_path = get_shell_rc_path()
with open(shell_rc_path, encoding="utf8") as shell_rc: # noqa: PTH123
# Convert Path to string using as_posix() for cross-platform path compatibility
shell_rc_path_str = shell_rc_path.as_posix() if isinstance(shell_rc_path, Path) else str(shell_rc_path)
with open(shell_rc_path_str, encoding="utf8") as shell_rc: # noqa: PTH123
shell_contents = shell_rc.read()
matches = SHELL_RC_EXPORT_PATTERN.findall(shell_contents)
return matches[-1] if matches else None
matches = pattern.findall(shell_contents)
if matches:
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - Found API key in file: {shell_rc_path}")
return matches[-1]
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - No API key found in file: {shell_rc_path}")
return None
except FileNotFoundError:
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - File not found: {shell_rc_path}")
return None
except Exception as e:
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - Error reading file: {e}")
return None


def get_shell_rc_path() -> Path:
"""Get the path to the user's shell configuration file."""
if os.name == "nt": # on Windows, we use a batch file in the user's home directory
if os.name == "nt": # Windows
if is_powershell():
return Path.home() / "codeflash_env.ps1"
return Path.home() / "codeflash_env.bat"
shell = os.environ.get("SHELL", "/bin/bash").split("/")[-1]
shell_rc_filename = {"zsh": ".zshrc", "ksh": ".kshrc", "csh": ".cshrc", "tcsh": ".cshrc", "dash": ".profile"}.get(
Expand All @@ -44,40 +111,123 @@ def get_shell_rc_path() -> Path:


def get_api_key_export_line(api_key: str) -> str:
return f'{SHELL_RC_EXPORT_PREFIX}"{api_key}"'
"""Get the appropriate export line based on the shell type."""
if os.name == "nt": # Windows
if is_powershell():
return f'{POWERSHELL_RC_EXPORT_PREFIX}"{api_key}"'
return f'{CMD_RC_EXPORT_PREFIX}"{api_key}"'
# Unix-like
return f'{UNIX_RC_EXPORT_PREFIX}"{api_key}"'


def save_api_key_to_rc(api_key: str) -> Result[str, str]:
"""Save API key to the appropriate shell configuration file."""
shell_rc_path = get_shell_rc_path()
# Ensure shell_rc_path is a Path object for consistent handling
if not isinstance(shell_rc_path, Path):
shell_rc_path = Path(shell_rc_path)
api_key_line = get_api_key_export_line(api_key)

logger.debug(f"shell_utils.py:save_api_key_to_rc - Saving API key to: {shell_rc_path}")
logger.debug(f"shell_utils.py:save_api_key_to_rc - API key line format: {api_key_line[:30]}...")

# Determine the correct pattern to use for replacement
if os.name == "nt": # Windows
if is_powershell():
pattern = POWERSHELL_RC_EXPORT_PATTERN
logger.debug("shell_utils.py:save_api_key_to_rc - Using PowerShell pattern")
else:
pattern = CMD_RC_EXPORT_PATTERN
logger.debug("shell_utils.py:save_api_key_to_rc - Using CMD pattern")
else: # Unix-like
pattern = UNIX_RC_EXPORT_PATTERN
logger.debug("shell_utils.py:save_api_key_to_rc - Using Unix pattern")

try:
with open(shell_rc_path, "r+", encoding="utf8") as shell_file: # noqa: PTH123
shell_contents = shell_file.read()
if os.name == "nt" and not shell_contents: # on windows we're writing to a batch file
# Create directory if it doesn't exist (ignore errors - file operation will fail if needed)
# Directory creation failed, but we'll still try to open the file
# The file operation itself will raise the appropriate exception if there are permission issues
with contextlib.suppress(OSError, PermissionError):
shell_rc_path.parent.mkdir(parents=True, exist_ok=True)

# Convert Path to string using as_posix() for cross-platform path compatibility
shell_rc_path_str = shell_rc_path.as_posix() if isinstance(shell_rc_path, Path) else str(shell_rc_path)

# Try to open in r+ mode (read and write in single operation)
# Handle FileNotFoundError if file doesn't exist (r+ requires file to exist)
try:
with open(shell_rc_path_str, "r+", encoding="utf8") as shell_file: # noqa: PTH123
shell_contents = shell_file.read()
logger.debug(f"shell_utils.py:save_api_key_to_rc - Read existing file, length: {len(shell_contents)}")

# Initialize empty file with header for batch files if needed
if not shell_contents:
logger.debug("shell_utils.py:save_api_key_to_rc - File is empty, initializing")
if os.name == "nt" and not is_powershell():
shell_contents = "@echo off"
logger.debug("shell_utils.py:save_api_key_to_rc - Added @echo off header for batch file")

# Check if API key already exists in the current file
matches = pattern.findall(shell_contents)
existing_in_file = bool(matches)
logger.debug(f"shell_utils.py:save_api_key_to_rc - Existing key in file: {existing_in_file}")

if existing_in_file:
# Replace the existing API key line in this file
updated_shell_contents = re.sub(pattern, api_key_line, shell_contents)
action = "Updated CODEFLASH_API_KEY in"
logger.debug("shell_utils.py:save_api_key_to_rc - Replaced existing API key")
else:
# Append the new API key line
if shell_contents and not shell_contents.endswith(LF):
updated_shell_contents = shell_contents + LF + api_key_line + LF
else:
updated_shell_contents = shell_contents.rstrip() + f"{LF}{api_key_line}{LF}"
action = "Added CODEFLASH_API_KEY to"
logger.debug("shell_utils.py:save_api_key_to_rc - Appended new API key")

# Write the updated contents
shell_file.seek(0)
shell_file.write(updated_shell_contents)
shell_file.truncate()
except FileNotFoundError:
# File doesn't exist, create it first with initial content
logger.debug("shell_utils.py:save_api_key_to_rc - File does not exist, creating new")
shell_contents = ""
# Initialize with header for batch files if needed
if os.name == "nt" and not is_powershell():
shell_contents = "@echo off"
existing_api_key = read_api_key_from_shell_config()
logger.debug("shell_utils.py:save_api_key_to_rc - Added @echo off header for batch file")

# Create the file by opening in write mode
with open(shell_rc_path_str, "w", encoding="utf8") as shell_file: # noqa: PTH123
shell_file.write(shell_contents)

if existing_api_key:
# Replace the existing API key line
updated_shell_contents = re.sub(SHELL_RC_EXPORT_PATTERN, api_key_line, shell_contents)
action = "Updated CODEFLASH_API_KEY in"
else:
# Re-open in r+ mode to add the API key (r+ allows both read and write)
with open(shell_rc_path_str, "r+", encoding="utf8") as shell_file: # noqa: PTH123
# Append the new API key line
updated_shell_contents = shell_contents.rstrip() + f"{LF}{api_key_line}{LF}"
action = "Added CODEFLASH_API_KEY to"
logger.debug("shell_utils.py:save_api_key_to_rc - Appended new API key to new file")

# Write the updated contents
shell_file.seek(0)
shell_file.write(updated_shell_contents)
shell_file.truncate()

logger.debug(f"shell_utils.py:save_api_key_to_rc - Successfully wrote to {shell_rc_path}")

shell_file.seek(0)
shell_file.write(updated_shell_contents)
shell_file.truncate()
return Success(f"✅ {action} {shell_rc_path}")
except PermissionError:
except PermissionError as e:
logger.debug(f"shell_utils.py:save_api_key_to_rc - Permission error: {e}")
return Failure(
f"💡 I tried adding your Codeflash API key to {shell_rc_path} - but seems like I don't have permissions to do so.{LF}"
f"You'll need to open it yourself and add the following line:{LF}{LF}{api_key_line}{LF}"
)
except FileNotFoundError:
except Exception as e:
logger.debug(f"shell_utils.py:save_api_key_to_rc - Error: {e}")
return Failure(
f"💡 I went to save your Codeflash API key to {shell_rc_path}, but noticed that it doesn't exist.{LF}"
f"💡 I went to save your Codeflash API key to {shell_rc_path}, but encountered an error: {e}{LF}"
f"To ensure your Codeflash API key is automatically loaded into your environment at startup, you can create {shell_rc_path} and add the following line:{LF}"
f"{LF}{api_key_line}{LF}"
)
2 changes: 1 addition & 1 deletion tests/test_shell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_save_api_key_to_rc_success(self, mock_get_shell_rc_path, mock_file):
api_key = "cf-12345"
result = save_api_key_to_rc(api_key)
self.assertTrue(isinstance(result, Success))
mock_file.assert_called_with("/fake/path/.bashrc", encoding="utf8")
mock_file.assert_called_with("/fake/path/.bashrc", "r+", encoding="utf8")
handle = mock_file()
handle.write.assert_called_once()
handle.truncate.assert_called_once()
Expand Down
Loading