diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index f2cd3e969..286d3d19e 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -31,6 +31,7 @@ from codeflash.code_utils.env_utils import check_formatter_installed, get_codeflash_api_key from codeflash.code_utils.git_utils import get_git_remotes, get_repo_owner_and_name from codeflash.code_utils.github_utils import get_github_secrets_page_url +from codeflash.code_utils.oauth_handler import perform_oauth_signin from codeflash.code_utils.shell_utils import get_shell_rc_path, save_api_key_to_rc from codeflash.either import is_successful from codeflash.lsp.helpers import is_LSP_enabled @@ -1166,10 +1167,13 @@ def convert(self, value: str, param: click.Parameter | None, ctx: click.Context # Returns True if the user entered a new API key, False if they used an existing one def prompt_api_key() -> bool: + """Prompt user for API key via OAuth or manual entry.""" + # Check for existing API key try: existing_api_key = get_codeflash_api_key() except OSError: existing_api_key = None + if existing_api_key: display_key = f"{existing_api_key[:3]}****{existing_api_key[-4:]}" api_key_panel = Panel( @@ -1186,8 +1190,52 @@ def prompt_api_key() -> bool: console.print() return False - enter_api_key_and_save_to_rc() - ph("cli-new-api-key-entered") + # Prompt for authentication method + auth_choices = ["πŸ” Login in with Codeflash", "πŸ”‘ Use Codeflash API key"] + + questions = [ + inquirer.List( + "auth_method", + message="How would you like to authenticate?", + choices=auth_choices, + default=auth_choices[0], + carousel=True, + ) + ] + + answers = inquirer.prompt(questions, theme=CodeflashTheme()) + if not answers: + apologize_and_exit() + + method = answers["auth_method"] + + if method == auth_choices[1]: + enter_api_key_and_save_to_rc() + ph("cli-new-api-key-entered") + return True + + # Perform OAuth sign-in + api_key = perform_oauth_signin() + + if not api_key: + apologize_and_exit() + + # Save API key + shell_rc_path = get_shell_rc_path() + if not shell_rc_path.exists() and os.name == "nt": + shell_rc_path.touch() + click.echo(f"βœ… Created {shell_rc_path}") + + result = save_api_key_to_rc(api_key) + if is_successful(result): + click.echo(result.unwrap()) + click.echo("βœ… Signed in successfully and API key saved!") + else: + click.echo(result.failure()) + click.pause() + + os.environ["CODEFLASH_API_KEY"] = api_key + ph("cli-oauth-signin-completed") return True diff --git a/codeflash/code_utils/oauth_handler.py b/codeflash/code_utils/oauth_handler.py new file mode 100644 index 000000000..65e9f1341 --- /dev/null +++ b/codeflash/code_utils/oauth_handler.py @@ -0,0 +1,791 @@ +from __future__ import annotations + +import base64 +import contextlib +import hashlib +import http.server +import json +import os +import secrets +import socket +import sys +import threading +import time +import urllib.parse +import webbrowser + +import click +import requests + +from codeflash.api.cfapi import get_cfapi_base_urls + + +class OAuthHandler: + """Handle OAuth PKCE flow for CodeFlash authentication.""" + + def __init__(self) -> None: + self.code: str | None = None + self.state: str | None = None + self.error: str | None = None + self.theme: str | None = None + self.is_complete = False + self.token_error: str | None = None + self.manual_code: str | None = None + self.lock = threading.Lock() + + def create_callback_handler(self) -> type[http.server.BaseHTTPRequestHandler]: + """Create HTTP handler for OAuth callback.""" + oauth_handler = self + + class CallbackHandler(http.server.BaseHTTPRequestHandler): + server_version = "CFHTTP" + + def do_GET(self) -> None: + parsed = urllib.parse.urlparse(self.path) + + if parsed.path == "/status": + self.send_response(200) + self.send_header("Content-type", "application/json") + self.send_header("Access-Control-Allow-Origin", "*") + self.end_headers() + + status = { + "success": oauth_handler.token_error is None and oauth_handler.code is not None, + "error": oauth_handler.token_error, + } + self.wfile.write(json.dumps(status).encode()) + return + + if parsed.path != "/callback": + self.send_response(404) + self.end_headers() + return + + params = urllib.parse.parse_qs(parsed.query) + + with oauth_handler.lock: + if not oauth_handler.is_complete: + oauth_handler.code = params.get("code", [None])[0] + oauth_handler.state = params.get("state", [None])[0] + oauth_handler.error = params.get("error", [None])[0] + oauth_handler.theme = params.get("theme", ["light"])[0] + oauth_handler.is_complete = True + + # Send HTML response + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + + html_content = self._get_html_response() + self.wfile.write(html_content.encode()) + + def _get_html_response(self) -> str: + """Return simple HTML response.""" + theme = oauth_handler.theme or "light" + if oauth_handler.error: + return self._get_error_html(oauth_handler.error, theme) + if oauth_handler.code: + return self._get_loading_html(theme) + return self._get_error_html("unauthorized", theme) + + @staticmethod + def _get_loading_html(theme: str = "light") -> str: + """Return loading state while exchanging token.""" + theme_class = "dark" if theme == "dark" else "" + return f""" + + + + + + CodeFlash Authentication + + + +
+
+ + +
+
+
+
+
+

Authenticating

+

Please wait while we verify your credentials...

+
+
+ + + + + """ + + @staticmethod + def _get_error_html(error_message: str, theme: str = "light") -> str: + """Return error state HTML.""" + theme_class = "dark" if theme == "dark" else "" + return f""" + + + + + + CodeFlash Authentication + + + +
+
+ + +
+
+
+ + + + + +
+

Authentication Failed

+
{error_message}
+
+
+ + + """ + + def log_message(self, fmt: str, *args: object) -> None: + """Suppress log messages.""" + + return CallbackHandler + + @staticmethod + def get_free_port() -> int: + """Find an available port.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + @staticmethod + def generate_pkce_pair() -> tuple[str, str]: + """Generate PKCE code verifier and challenge.""" + code_verifier = "".join( + secrets.choice("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~") for _ in range(64) + ) + code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=").decode() + return code_verifier, code_challenge + + def start_local_server(self, port: int) -> http.server.HTTPServer: + """Start local HTTP server for OAuth callback.""" + handler_class = self.create_callback_handler() + httpd = http.server.HTTPServer(("localhost", port), handler_class) + + def serve_forever_wrapper() -> None: + httpd.serve_forever() + + server_thread = threading.Thread(target=serve_forever_wrapper) + server_thread.daemon = True + server_thread.start() + + return httpd + + def exchange_code_for_token(self, code: str, code_verifier: str, redirect_uri: str) -> str | None: + """Exchange authorization code for API token.""" + token_url = f"{get_cfapi_base_urls().cfwebapp_base_url}/codeflash/auth/oauth/token" + data = { + "grant_type": "authorization_code", + "code": code, + "code_verifier": code_verifier, + "redirect_uri": redirect_uri, + "client_id": "cf-cli-app", + } + + try: + resp = requests.post(token_url, json=data, timeout=10) + resp.raise_for_status() + token_json = resp.json() + api_key = token_json.get("access_token") + + if not api_key: + self.token_error = "No access token in response" # noqa: S105 + return None + + except requests.exceptions.HTTPError: + self.token_error = "Unauthorized" # noqa: S105 + return None + else: + return api_key + + +def get_browser_name_fallback() -> str | None: + try: + controller = webbrowser.get() + # controller.name exists for most browser controllers + return getattr(controller, "name", None) + except Exception: + return None + + +def should_attempt_browser_launch() -> bool: + # A list of browser names that indicate we should not attempt to open a + # web browser for the user. + browser_blocklist = ["www-browser", "lynx", "links", "w3m", "elinks", "links2"] + browser_env = os.environ.get("BROWSER") or get_browser_name_fallback() + if browser_env and browser_env in browser_blocklist: + return False + + # Common environment variables used in CI/CD or other non-interactive shells. + if os.environ.get("CI") or os.environ.get("DEBIAN_FRONTEND") == "noninteractive": + return False + + # The presence of SSH_CONNECTION indicates a remote session. + # We should not attempt to launch a browser unless a display is explicitly available + # (checked below for Linux). + is_ssh = bool(os.environ.get("SSH_CONNECTION")) + + # On Linux, the presence of a display server is a strong indicator of a GUI. + if sys.platform == "linux": + # These are environment variables that can indicate a running compositor on + # Linux. + display_variables = ["DISPLAY", "WAYLAND_DISPLAY", "MIR_SOCKET"] + has_display = any(os.environ.get(v) for v in display_variables) + if not has_display: + return False + + # If in an SSH session on a non-Linux OS (e.g., macOS), don't launch browser. + # The Linux case is handled above (it's allowed if DISPLAY is set). + if is_ssh and sys.platform != "linux": + return False + + # For non-Linux OSes, we generally assume a GUI is available + # unless other signals (like SSH) suggest otherwise. + # The `open` command's error handling will catch final edge cases. + return True + + +def _wait_for_manual_code_input(oauth: OAuthHandler) -> None: + """Thread function to wait for manual code input.""" + try: + code = input() + with oauth.lock: + if not oauth.is_complete: + oauth.manual_code = code.strip() + oauth.is_complete = True + except Exception: # noqa: S110 + pass + + +def perform_oauth_signin() -> str | None: + """Perform OAuth PKCE flow and return API key if successful. + + Returns None if failed. + """ + oauth = OAuthHandler() + + # Setup PKCE + port = oauth.get_free_port() + code_verifier, code_challenge = oauth.generate_pkce_pair() + state = "".join(secrets.choice("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") for _ in range(16)) + + # Build authorization URLs for both local and remote + local_redirect_uri = f"http://localhost:{port}/callback" + remote_redirect_uri = f"{get_cfapi_base_urls().cfwebapp_base_url}/codeflash/auth/callback" + + base_url = f"{get_cfapi_base_urls().cfwebapp_base_url}/codeflash/auth" + params = ( + f"response_type=code" + f"&client_id=cf-cli-app" + f"&code_challenge={code_challenge}" + f"&code_challenge_method=sha256" + f"&state={state}" + ) + local_auth_url = f"{base_url}?{params}&redirect_uri={urllib.parse.quote(local_redirect_uri)}" + remote_auth_url = f"{base_url}?{params}&redirect_uri={urllib.parse.quote(remote_redirect_uri)}" + + # Start local server + try: + httpd = oauth.start_local_server(port) + except Exception: + click.echo("❌ Failed to start local server.") + return None + + if should_attempt_browser_launch(): + # Try to open browser + click.echo("🌐 Opening browser to sign in to CodeFlash…") + with contextlib.suppress(Exception): + webbrowser.open(local_auth_url) + + # Show remote URL and start input thread + click.echo("\nπŸ“‹ If browser didn't open, visit this URL:") + click.echo(f"\n{remote_auth_url}\n") + click.echo("Paste code here if prompted > ", nl=False) + + # Start thread to wait for manual input + input_thread = threading.Thread(target=_wait_for_manual_code_input, args=(oauth,)) + input_thread.daemon = True + input_thread.start() + + waited = 0 + while not oauth.is_complete and waited < 180: + time.sleep(0.5) + waited += 0.5 + + if not oauth.is_complete: + httpd.shutdown() + click.echo("\n❌ Authentication timed out.") + return None + + # Check which method completed + api_key = None + + if oauth.manual_code: + # Manual code was entered + api_key = oauth.exchange_code_for_token(oauth.manual_code, code_verifier, remote_redirect_uri) + elif oauth.code: + # Browser callback received + if oauth.error or not oauth.state or oauth.state != state: + httpd.shutdown() + click.echo("\n❌ Unauthorized.") + return None + + api_key = oauth.exchange_code_for_token(oauth.code, code_verifier, local_redirect_uri) + + # Cleanup + time.sleep(3) + httpd.shutdown() + + if not api_key: + click.echo("\n❌ Authentication failed.") + click.echo("\n") + return api_key