diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 8b71f139b6..02affc7a16 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -830,6 +830,7 @@ dependencies = [ "url", "urlencoding", "webbrowser", + "wiremock", ] [[package]] diff --git a/codex-rs/cli/src/login.rs b/codex-rs/cli/src/login.rs index f0816d0b29..8dd4fb8333 100644 --- a/codex-rs/cli/src/login.rs +++ b/codex-rs/cli/src/login.rs @@ -6,6 +6,7 @@ use codex_core::auth::logout; use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_login::ServerOptions; +use codex_login::run_device_code_login; use codex_login::run_login_server; use codex_protocol::mcp_protocol::AuthMode; use std::path::PathBuf; @@ -55,6 +56,32 @@ pub async fn run_login_with_api_key( } } +/// Login using the OAuth device code flow. +pub async fn run_login_with_device_code( + cli_config_overrides: CliConfigOverrides, + issuer_base_url: Option, + client_id: Option, +) -> ! { + let config = load_config_or_exit(cli_config_overrides); + let mut opts = ServerOptions::new( + config.codex_home, + client_id.unwrap_or(CLIENT_ID.to_string()), + ); + if let Some(iss) = issuer_base_url { + opts.issuer = iss; + } + match run_device_code_login(opts).await { + Ok(()) => { + eprintln!("Successfully logged in"); + std::process::exit(0); + } + Err(e) => { + eprintln!("Error logging in with device code: {e}"); + std::process::exit(1); + } + } +} + pub async fn run_login_status(cli_config_overrides: CliConfigOverrides) -> ! { let config = load_config_or_exit(cli_config_overrides); diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index b1e9601c8c..e9d4e17c45 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -10,6 +10,7 @@ use codex_cli::SeatbeltCommand; use codex_cli::login::run_login_status; use codex_cli::login::run_login_with_api_key; use codex_cli::login::run_login_with_chatgpt; +use codex_cli::login::run_login_with_device_code; use codex_cli::login::run_logout; use codex_cli::proto; use codex_common::CliConfigOverrides; @@ -133,6 +134,20 @@ struct LoginCommand { #[arg(long = "api-key", value_name = "API_KEY")] api_key: Option, + /// EXPERIMENTAL: Use device code flow (not yet supported) + /// This feature is experimental and may changed in future releases. + #[arg(long = "experimental_use-device-code", hide = true)] + use_device_code: bool, + + /// EXPERIMENTAL: Use custom OAuth issuer base URL (advanced) + /// Override the OAuth issuer base URL (advanced) + #[arg(long = "experimental_issuer", value_name = "URL", hide = true)] + issuer_base_url: Option, + + /// EXPERIMENTAL: Use custom OAuth client ID (advanced) + #[arg(long = "experimental_client-id", value_name = "CLIENT_ID", hide = true)] + client_id: Option, + #[command(subcommand)] action: Option, } @@ -282,7 +297,14 @@ async fn cli_main(codex_linux_sandbox_exe: Option) -> anyhow::Result<() run_login_status(login_cli.config_overrides).await; } None => { - if let Some(api_key) = login_cli.api_key { + if login_cli.use_device_code { + run_login_with_device_code( + login_cli.config_overrides, + login_cli.issuer_base_url, + login_cli.client_id, + ) + .await; + } else if let Some(api_key) = login_cli.api_key { run_login_with_api_key(login_cli.config_overrides, api_key).await; } else { run_login_with_chatgpt(login_cli.config_overrides).await; diff --git a/codex-rs/login/Cargo.toml b/codex-rs/login/Cargo.toml index 5d358361c1..e2a693e9ac 100644 --- a/codex-rs/login/Cargo.toml +++ b/codex-rs/login/Cargo.toml @@ -33,3 +33,4 @@ webbrowser = { workspace = true } anyhow = { workspace = true } core_test_support = { workspace = true } tempfile = { workspace = true } +wiremock = { workspace = true } diff --git a/codex-rs/login/src/device_code_auth.rs b/codex-rs/login/src/device_code_auth.rs new file mode 100644 index 0000000000..0d04a61337 --- /dev/null +++ b/codex-rs/login/src/device_code_auth.rs @@ -0,0 +1,196 @@ +use reqwest::StatusCode; +use serde::Deserialize; +use serde::Serialize; +use serde::de::Deserializer; +use serde::de::{self}; +use std::time::Duration; +use std::time::Instant; + +use crate::pkce::PkceCodes; +use crate::server::ServerOptions; +use std::io::Write; +use std::io::{self}; + +#[derive(Deserialize)] +struct UserCodeResp { + device_auth_id: String, + #[serde(alias = "user_code", alias = "usercode")] + user_code: String, + #[serde(default, deserialize_with = "deserialize_interval")] + interval: u64, +} + +#[derive(Serialize)] +struct UserCodeReq { + client_id: String, +} + +#[derive(Serialize)] +struct TokenPollReq { + device_auth_id: String, + user_code: String, +} + +fn deserialize_interval<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + s.trim() + .parse::() + .map_err(|e| de::Error::custom(format!("invalid u64 string: {e}"))) +} + +#[derive(Deserialize)] +struct CodeSuccessResp { + authorization_code: String, + code_challenge: String, + code_verifier: String, +} + +/// Request the user code and polling interval. +async fn request_user_code( + client: &reqwest::Client, + auth_base_url: &str, + client_id: &str, +) -> std::io::Result { + let url = format!("{auth_base_url}/deviceauth/usercode"); + let body = serde_json::to_string(&UserCodeReq { + client_id: client_id.to_string(), + }) + .map_err(std::io::Error::other)?; + let resp = client + .post(url) + .header("Content-Type", "application/json") + .body(body) + .send() + .await + .map_err(std::io::Error::other)?; + + if !resp.status().is_success() { + return Err(std::io::Error::other(format!( + "device code request failed with status {}", + resp.status() + ))); + } + + let body = resp.text().await.map_err(std::io::Error::other)?; + serde_json::from_str(&body).map_err(std::io::Error::other) +} + +/// Poll token endpoint until a code is issued or timeout occurs. +async fn poll_for_token( + client: &reqwest::Client, + auth_base_url: &str, + device_auth_id: &str, + user_code: &str, + interval: u64, +) -> std::io::Result { + let url = format!("{auth_base_url}/deviceauth/token"); + let max_wait = Duration::from_secs(15 * 60); + let start = Instant::now(); + + loop { + let body = serde_json::to_string(&TokenPollReq { + device_auth_id: device_auth_id.to_string(), + user_code: user_code.to_string(), + }) + .map_err(std::io::Error::other)?; + let resp = client + .post(&url) + .header("Content-Type", "application/json") + .body(body) + .send() + .await + .map_err(std::io::Error::other)?; + + let status = resp.status(); + + if status.is_success() { + return resp.json().await.map_err(std::io::Error::other); + } + + if status == StatusCode::FORBIDDEN || status == StatusCode::NOT_FOUND { + if start.elapsed() >= max_wait { + return Err(std::io::Error::other( + "device auth timed out after 15 minutes", + )); + } + let sleep_for = Duration::from_secs(interval).min(max_wait - start.elapsed()); + tokio::time::sleep(sleep_for).await; + continue; + } + + return Err(std::io::Error::other(format!( + "device auth failed with status {}", + resp.status() + ))); + } +} + +// Helper to print colored text if terminal supports ANSI +fn print_colored_warning_device_code() { + // ANSI escape code for bright yellow + const YELLOW: &str = "\x1b[93m"; + const RESET: &str = "\x1b[0m"; + let warning = "WARN!!! device code authentication has potential risks and\n\ + should be used with caution only in cases where browser support \n\ + is missing. This is prone to attacks.\n\ + \n\ + - This code is valid for 15 minutes.\n\ + - Do not share this code with anyone.\n\ + "; + let mut stdout = io::stdout().lock(); + let _ = write!(stdout, "{YELLOW}{warning}{RESET}"); + let _ = stdout.flush(); +} + +/// Full device code login flow. +pub async fn run_device_code_login(opts: ServerOptions) -> std::io::Result<()> { + let client = reqwest::Client::new(); + let auth_base_url = opts.issuer.trim_end_matches('/').to_owned(); + print_colored_warning_device_code(); + println!("⏳ Generating a new 9-digit device code for authentication...\n"); + let uc = request_user_code(&client, &auth_base_url, &opts.client_id).await?; + + println!( + "To authenticate, visit: {}/deviceauth/authorize and enter code: {}", + opts.issuer.trim_end_matches('/'), + uc.user_code + ); + + let code_resp = poll_for_token( + &client, + &auth_base_url, + &uc.device_auth_id, + &uc.user_code, + uc.interval, + ) + .await?; + + let pkce = PkceCodes { + code_verifier: code_resp.code_verifier, + code_challenge: code_resp.code_challenge, + }; + println!("authorization code received"); + let redirect_uri = format!("{}/deviceauth/callback", opts.issuer.trim_end_matches('/')); + + let tokens = crate::server::exchange_code_for_tokens( + &opts.issuer, + &opts.client_id, + &redirect_uri, + &pkce, + &code_resp.authorization_code, + ) + .await + .map_err(|err| std::io::Error::other(format!("device code exchange failed: {err}")))?; + + crate::server::persist_tokens_async( + &opts.codex_home, + None, + tokens.id_token, + tokens.access_token, + tokens.refresh_token, + ) + .await +} diff --git a/codex-rs/login/src/lib.rs b/codex-rs/login/src/lib.rs index a737af22c8..d5e5836f0f 100644 --- a/codex-rs/login/src/lib.rs +++ b/codex-rs/login/src/lib.rs @@ -1,6 +1,8 @@ +mod device_code_auth; mod pkce; mod server; +pub use device_code_auth::run_device_code_login; pub use server::LoginServer; pub use server::ServerOptions; pub use server::ShutdownHandle; diff --git a/codex-rs/login/src/server.rs b/codex-rs/login/src/server.rs index a1e5e6c3b0..7df9038bc5 100644 --- a/codex-rs/login/src/server.rs +++ b/codex-rs/login/src/server.rs @@ -393,13 +393,13 @@ fn bind_server(port: u16) -> io::Result { } } -struct ExchangedTokens { - id_token: String, - access_token: String, - refresh_token: String, +pub(crate) struct ExchangedTokens { + pub id_token: String, + pub access_token: String, + pub refresh_token: String, } -async fn exchange_code_for_tokens( +pub(crate) async fn exchange_code_for_tokens( issuer: &str, client_id: &str, redirect_uri: &str, @@ -443,7 +443,7 @@ async fn exchange_code_for_tokens( }) } -async fn persist_tokens_async( +pub(crate) async fn persist_tokens_async( codex_home: &Path, api_key: Option, id_token: String, @@ -562,7 +562,11 @@ fn jwt_auth_claims(jwt: &str) -> serde_json::Map { serde_json::Map::new() } -async fn obtain_api_key(issuer: &str, client_id: &str, id_token: &str) -> io::Result { +pub(crate) async fn obtain_api_key( + issuer: &str, + client_id: &str, + id_token: &str, +) -> io::Result { // Token exchange for an API key access token #[derive(serde::Deserialize)] struct ExchangeResp { diff --git a/codex-rs/login/tests/suite/device_code_login.rs b/codex-rs/login/tests/suite/device_code_login.rs new file mode 100644 index 0000000000..0b63b98a6d --- /dev/null +++ b/codex-rs/login/tests/suite/device_code_login.rs @@ -0,0 +1,248 @@ +#![allow(clippy::unwrap_used)] + +use base64::Engine; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; +use codex_core::auth::get_auth_file; +use codex_core::auth::try_read_auth_json; +use codex_login::ServerOptions; +use codex_login::run_device_code_login; +use serde_json::json; +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; +use tempfile::tempdir; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::Request; +use wiremock::ResponseTemplate; +use wiremock::matchers::method; +use wiremock::matchers::path; + +use core_test_support::skip_if_no_network; + +// ---------- Small helpers ---------- + +fn make_jwt(payload: serde_json::Value) -> String { + let header = json!({ "alg": "none", "typ": "JWT" }); + let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header).unwrap()); + let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&payload).unwrap()); + let signature_b64 = URL_SAFE_NO_PAD.encode(b"sig"); + format!("{header_b64}.{payload_b64}.{signature_b64}") +} + +async fn mock_usercode_success(server: &MockServer) { + Mock::given(method("POST")) + .and(path("/deviceauth/usercode")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "device_auth_id": "device-auth-123", + "user_code": "CODE-12345", + // NOTE: Interval is kept 0 in order to avoid waiting for the interval to pass + "interval": "0" + }))) + .mount(server) + .await; +} + +async fn mock_usercode_failure(server: &MockServer, status: u16) { + Mock::given(method("POST")) + .and(path("/deviceauth/usercode")) + .respond_with(ResponseTemplate::new(status)) + .mount(server) + .await; +} + +async fn mock_poll_token_two_step( + server: &MockServer, + counter: Arc, + first_response_status: u16, +) { + let c = counter.clone(); + Mock::given(method("POST")) + .and(path("/deviceauth/token")) + .respond_with(move |_: &Request| { + let attempt = c.fetch_add(1, Ordering::SeqCst); + if attempt == 0 { + ResponseTemplate::new(first_response_status) + } else { + ResponseTemplate::new(200).set_body_json(json!({ + "authorization_code": "poll-code-321", + "code_challenge": "code-challenge-321", + "code_verifier": "code-verifier-321" + })) + } + }) + .expect(2) + .mount(server) + .await; +} + +async fn mock_poll_token_single(server: &MockServer, endpoint: &str, response: ResponseTemplate) { + Mock::given(method("POST")) + .and(path(endpoint)) + .respond_with(response) + .mount(server) + .await; +} + +async fn mock_oauth_token_single(server: &MockServer, jwt: String) { + Mock::given(method("POST")) + .and(path("/oauth/token")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "id_token": jwt.clone(), + "access_token": "access-token-123", + "refresh_token": "refresh-token-123" + }))) + .mount(server) + .await; +} + +fn server_opts(codex_home: &tempfile::TempDir, issuer: String) -> ServerOptions { + let mut opts = ServerOptions::new(codex_home.path().to_path_buf(), "client-id".to_string()); + opts.issuer = issuer; + opts.open_browser = false; + opts +} + +#[tokio::test] +async fn device_code_login_integration_succeeds() { + skip_if_no_network!(); + + let codex_home = tempdir().unwrap(); + let mock_server = MockServer::start().await; + + mock_usercode_success(&mock_server).await; + + mock_poll_token_two_step(&mock_server, Arc::new(AtomicUsize::new(0)), 404).await; + + let jwt = make_jwt(json!({ + "https://api.openai.com/auth": { + "chatgpt_account_id": "acct_321" + } + })); + + mock_oauth_token_single(&mock_server, jwt.clone()).await; + + let issuer = mock_server.uri(); + let opts = server_opts(&codex_home, issuer); + + run_device_code_login(opts) + .await + .expect("device code login integration should succeed"); + + let auth_path = get_auth_file(codex_home.path()); + let auth = try_read_auth_json(&auth_path).expect("auth.json written"); + // assert_eq!(auth.openai_api_key.as_deref(), Some("api-key-321")); + let tokens = auth.tokens.expect("tokens persisted"); + assert_eq!(tokens.access_token, "access-token-123"); + assert_eq!(tokens.refresh_token, "refresh-token-123"); + assert_eq!(tokens.id_token.raw_jwt, jwt); + assert_eq!(tokens.account_id.as_deref(), Some("acct_321")); +} + +#[tokio::test] +async fn device_code_login_integration_handles_usercode_http_failure() { + skip_if_no_network!(); + + let codex_home = tempdir().unwrap(); + let mock_server = MockServer::start().await; + + mock_usercode_failure(&mock_server, 503).await; + + let issuer = mock_server.uri(); + + let opts = server_opts(&codex_home, issuer); + + let err = run_device_code_login(opts) + .await + .expect_err("usercode HTTP failure should bubble up"); + assert!( + err.to_string() + .contains("device code request failed with status"), + "unexpected error: {err:?}" + ); + + let auth_path = get_auth_file(codex_home.path()); + assert!(!auth_path.exists()); +} + +#[tokio::test] +async fn device_code_login_integration_persists_without_api_key_on_exchange_failure() { + skip_if_no_network!(); + + let codex_home = tempdir().unwrap(); + + let mock_server = MockServer::start().await; + + mock_usercode_success(&mock_server).await; + + mock_poll_token_two_step(&mock_server, Arc::new(AtomicUsize::new(0)), 404).await; + + let jwt = make_jwt(json!({})); + + mock_oauth_token_single(&mock_server, jwt.clone()).await; + + let issuer = mock_server.uri(); + + let mut opts = ServerOptions::new(codex_home.path().to_path_buf(), "client-id".to_string()); + opts.issuer = issuer; + opts.open_browser = false; + + run_device_code_login(opts) + .await + .expect("device login should succeed without API key exchange"); + + let auth_path = get_auth_file(codex_home.path()); + let auth = try_read_auth_json(&auth_path).expect("auth.json written"); + assert!(auth.openai_api_key.is_none()); + let tokens = auth.tokens.expect("tokens persisted"); + assert_eq!(tokens.access_token, "access-token-123"); + assert_eq!(tokens.refresh_token, "refresh-token-123"); + assert_eq!(tokens.id_token.raw_jwt, jwt); +} + +#[tokio::test] +async fn device_code_login_integration_handles_error_payload() { + skip_if_no_network!(); + + let codex_home = tempdir().unwrap(); + + // Start WireMock + let mock_server = MockServer::start().await; + + mock_usercode_success(&mock_server).await; + + // // /deviceauth/token → returns error payload with status 401 + mock_poll_token_single( + &mock_server, + "/deviceauth/token", + ResponseTemplate::new(401).set_body_json(json!({ + "error": "authorization_declined", + "error_description": "Denied" + })), + ) + .await; + + // (WireMock will automatically 404 for other paths) + + let issuer = mock_server.uri(); + + let mut opts = ServerOptions::new(codex_home.path().to_path_buf(), "client-id".to_string()); + opts.issuer = issuer; + opts.open_browser = false; + + let err = run_device_code_login(opts) + .await + .expect_err("integration failure path should return error"); + + // Accept either the specific error payload, a 400, or a 404 (since the client may return 404 if the flow is incomplete) + assert!( + err.to_string().contains("authorization_declined") || err.to_string().contains("401"), + "Expected an authorization_declined / 400 / 404 error, got {err:?}" + ); + + let auth_path = get_auth_file(codex_home.path()); + assert!( + !auth_path.exists(), + "auth.json should not be created when device auth fails" + ); +} diff --git a/codex-rs/login/tests/suite/mod.rs b/codex-rs/login/tests/suite/mod.rs index 3259e72434..b84b264bec 100644 --- a/codex-rs/login/tests/suite/mod.rs +++ b/codex-rs/login/tests/suite/mod.rs @@ -1,2 +1,3 @@ // Aggregates all former standalone integration tests as modules. +mod device_code_login; mod login_server_e2e;