Skip to content
Merged
1 change: 1 addition & 0 deletions codex-rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions codex-rs/cli/src/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use codex_core::auth::logout;
use codex_core::config::Config;
use codex_core::config::ConfigOverrides;
use codex_login::ServerOptions;
use codex_login::run_device_code_login;
use codex_login::run_login_server;
use codex_protocol::mcp_protocol::AuthMode;
use std::path::PathBuf;
Expand Down Expand Up @@ -55,6 +56,32 @@ pub async fn run_login_with_api_key(
}
}

/// Login using the OAuth device code flow.
pub async fn run_login_with_device_code(
cli_config_overrides: CliConfigOverrides,
issuer_base_url: Option<String>,
client_id: Option<String>,
) -> ! {
let config = load_config_or_exit(cli_config_overrides);
let mut opts = ServerOptions::new(
config.codex_home,
client_id.unwrap_or(CLIENT_ID.to_string()),
);
if let Some(iss) = issuer_base_url {
opts.issuer = iss;
}
match run_device_code_login(opts).await {
Ok(()) => {
eprintln!("Successfully logged in");
std::process::exit(0);
}
Err(e) => {
eprintln!("Error logging in with device code: {e}");
std::process::exit(1);
}
}
}

pub async fn run_login_status(cli_config_overrides: CliConfigOverrides) -> ! {
let config = load_config_or_exit(cli_config_overrides);

Expand Down
24 changes: 23 additions & 1 deletion codex-rs/cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use codex_cli::SeatbeltCommand;
use codex_cli::login::run_login_status;
use codex_cli::login::run_login_with_api_key;
use codex_cli::login::run_login_with_chatgpt;
use codex_cli::login::run_login_with_device_code;
use codex_cli::login::run_logout;
use codex_cli::proto;
use codex_common::CliConfigOverrides;
Expand Down Expand Up @@ -133,6 +134,20 @@ struct LoginCommand {
#[arg(long = "api-key", value_name = "API_KEY")]
api_key: Option<String>,

/// EXPERIMENTAL: Use device code flow (not yet supported)
/// This feature is experimental and may changed in future releases.
#[arg(long = "experimental_use-device-code", hide = true)]
use_device_code: bool,

/// EXPERIMENTAL: Use custom OAuth issuer base URL (advanced)
/// Override the OAuth issuer base URL (advanced)
#[arg(long = "experimental_issuer", value_name = "URL", hide = true)]
issuer_base_url: Option<String>,

/// EXPERIMENTAL: Use custom OAuth client ID (advanced)
#[arg(long = "experimental_client-id", value_name = "CLIENT_ID", hide = true)]
client_id: Option<String>,

#[command(subcommand)]
action: Option<LoginSubcommand>,
}
Expand Down Expand Up @@ -282,7 +297,14 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
run_login_status(login_cli.config_overrides).await;
}
None => {
if let Some(api_key) = login_cli.api_key {
if login_cli.use_device_code {
run_login_with_device_code(
login_cli.config_overrides,
login_cli.issuer_base_url,
login_cli.client_id,
)
.await;
} else if let Some(api_key) = login_cli.api_key {
run_login_with_api_key(login_cli.config_overrides, api_key).await;
} else {
run_login_with_chatgpt(login_cli.config_overrides).await;
Expand Down
1 change: 1 addition & 0 deletions codex-rs/login/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ webbrowser = { workspace = true }
anyhow = { workspace = true }
core_test_support = { workspace = true }
tempfile = { workspace = true }
wiremock = { workspace = true }
196 changes: 196 additions & 0 deletions codex-rs/login/src/device_code_auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
use reqwest::StatusCode;
use serde::Deserialize;
use serde::Serialize;
use serde::de::Deserializer;
use serde::de::{self};
use std::time::Duration;
use std::time::Instant;

use crate::pkce::PkceCodes;
use crate::server::ServerOptions;
use std::io::Write;
use std::io::{self};

#[derive(Deserialize)]
struct UserCodeResp {
device_auth_id: String,
#[serde(alias = "user_code", alias = "usercode")]
user_code: String,
#[serde(default, deserialize_with = "deserialize_interval")]
interval: u64,
}

#[derive(Serialize)]
struct UserCodeReq {
client_id: String,
}

#[derive(Serialize)]
struct TokenPollReq {
device_auth_id: String,
user_code: String,
}

fn deserialize_interval<'de, D>(deserializer: D) -> Result<u64, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.trim()
.parse::<u64>()
.map_err(|e| de::Error::custom(format!("invalid u64 string: {e}")))
}

#[derive(Deserialize)]
struct CodeSuccessResp {
authorization_code: String,
code_challenge: String,
code_verifier: String,
}

/// Request the user code and polling interval.
async fn request_user_code(
client: &reqwest::Client,
auth_base_url: &str,
client_id: &str,
) -> std::io::Result<UserCodeResp> {
let url = format!("{auth_base_url}/deviceauth/usercode");
let body = serde_json::to_string(&UserCodeReq {
client_id: client_id.to_string(),
})
.map_err(std::io::Error::other)?;
let resp = client
.post(url)
.header("Content-Type", "application/json")
.body(body)
.send()
.await
.map_err(std::io::Error::other)?;

if !resp.status().is_success() {
return Err(std::io::Error::other(format!(
"device code request failed with status {}",
resp.status()
)));
}

let body = resp.text().await.map_err(std::io::Error::other)?;
serde_json::from_str(&body).map_err(std::io::Error::other)
}

/// Poll token endpoint until a code is issued or timeout occurs.
async fn poll_for_token(
client: &reqwest::Client,
auth_base_url: &str,
device_auth_id: &str,
user_code: &str,
interval: u64,
) -> std::io::Result<CodeSuccessResp> {
let url = format!("{auth_base_url}/deviceauth/token");
let max_wait = Duration::from_secs(15 * 60);
let start = Instant::now();

loop {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this method can be split up

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I split the function into multiple smaller functions.

let body = serde_json::to_string(&TokenPollReq {
device_auth_id: device_auth_id.to_string(),
user_code: user_code.to_string(),
})
.map_err(std::io::Error::other)?;
let resp = client
.post(&url)
.header("Content-Type", "application/json")
.body(body)
.send()
.await
.map_err(std::io::Error::other)?;

let status = resp.status();

if status.is_success() {
return resp.json().await.map_err(std::io::Error::other);
}

if status == StatusCode::FORBIDDEN || status == StatusCode::NOT_FOUND {
if start.elapsed() >= max_wait {
return Err(std::io::Error::other(
"device auth timed out after 15 minutes",
));
}
let sleep_for = Duration::from_secs(interval).min(max_wait - start.elapsed());
tokio::time::sleep(sleep_for).await;
continue;
}

return Err(std::io::Error::other(format!(
"device auth failed with status {}",
resp.status()
)));
}
}

// Helper to print colored text if terminal supports ANSI
fn print_colored_warning_device_code() {
// ANSI escape code for bright yellow
const YELLOW: &str = "\x1b[93m";
const RESET: &str = "\x1b[0m";
let warning = "WARN!!! device code authentication has potential risks and\n\
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we might want to improve this copy

should be used with caution only in cases where browser support \n\
is missing. This is prone to attacks.\n\
\n\
- This code is valid for 15 minutes.\n\
- Do not share this code with anyone.\n\
";
let mut stdout = io::stdout().lock();
let _ = write!(stdout, "{YELLOW}{warning}{RESET}");
let _ = stdout.flush();
}

/// Full device code login flow.
pub async fn run_device_code_login(opts: ServerOptions) -> std::io::Result<()> {
let client = reqwest::Client::new();
let auth_base_url = opts.issuer.trim_end_matches('/').to_owned();
print_colored_warning_device_code();
println!("⏳ Generating a new 9-digit device code for authentication...\n");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this message be before the call?

let uc = request_user_code(&client, &auth_base_url, &opts.client_id).await?;

println!(
"To authenticate, visit: {}/deviceauth/authorize and enter code: {}",
opts.issuer.trim_end_matches('/'),
uc.user_code
);

let code_resp = poll_for_token(
&client,
&auth_base_url,
&uc.device_auth_id,
&uc.user_code,
uc.interval,
)
.await?;

let pkce = PkceCodes {
code_verifier: code_resp.code_verifier,
code_challenge: code_resp.code_challenge,
};
println!("authorization code received");
let redirect_uri = format!("{}/deviceauth/callback", opts.issuer.trim_end_matches('/'));

let tokens = crate::server::exchange_code_for_tokens(
&opts.issuer,
&opts.client_id,
&redirect_uri,
&pkce,
&code_resp.authorization_code,
)
.await
.map_err(|err| std::io::Error::other(format!("device code exchange failed: {err}")))?;

crate::server::persist_tokens_async(
&opts.codex_home,
None,
tokens.id_token,
tokens.access_token,
tokens.refresh_token,
)
.await
}
2 changes: 2 additions & 0 deletions codex-rs/login/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod device_code_auth;
mod pkce;
mod server;

pub use device_code_auth::run_device_code_login;
pub use server::LoginServer;
pub use server::ServerOptions;
pub use server::ShutdownHandle;
Expand Down
18 changes: 11 additions & 7 deletions codex-rs/login/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,13 @@ fn bind_server(port: u16) -> io::Result<Server> {
}
}

struct ExchangedTokens {
id_token: String,
access_token: String,
refresh_token: String,
pub(crate) struct ExchangedTokens {
pub id_token: String,
pub access_token: String,
pub refresh_token: String,
}

async fn exchange_code_for_tokens(
pub(crate) async fn exchange_code_for_tokens(
issuer: &str,
client_id: &str,
redirect_uri: &str,
Expand Down Expand Up @@ -443,7 +443,7 @@ async fn exchange_code_for_tokens(
})
}

async fn persist_tokens_async(
pub(crate) async fn persist_tokens_async(
codex_home: &Path,
api_key: Option<String>,
id_token: String,
Expand Down Expand Up @@ -562,7 +562,11 @@ fn jwt_auth_claims(jwt: &str) -> serde_json::Map<String, serde_json::Value> {
serde_json::Map::new()
}

async fn obtain_api_key(issuer: &str, client_id: &str, id_token: &str) -> io::Result<String> {
pub(crate) async fn obtain_api_key(
issuer: &str,
client_id: &str,
id_token: &str,
) -> io::Result<String> {
// Token exchange for an API key access token
#[derive(serde::Deserialize)]
struct ExchangeResp {
Expand Down
Loading
Loading