diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 7940e4cbeb..ff7bc42626 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1064,6 +1064,7 @@ dependencies = [ "codex-async-utils", "codex-file-search", "codex-git-tooling", + "codex-keyring-store", "codex-otel", "codex-protocol", "codex-rmcp-client", @@ -1082,6 +1083,7 @@ dependencies = [ "http", "image", "indexmap 2.10.0", + "keyring", "landlock", "libc", "maplit", @@ -1098,6 +1100,7 @@ dependencies = [ "serde_json", "serial_test", "sha1", + "sha2", "shlex", "similar", "strum_macros 0.27.2", @@ -1221,6 +1224,14 @@ dependencies = [ "walkdir", ] +[[package]] +name = "codex-keyring-store" +version = "0.0.0" +dependencies = [ + "keyring", + "tracing", +] + [[package]] name = "codex-linux-sandbox" version = "0.0.0" @@ -1386,6 +1397,7 @@ version = "0.0.0" dependencies = [ "anyhow", "axum", + "codex-keyring-store", "codex-protocol", "dirs", "escargot", diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 1d6c937e8a..f9d0865f32 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -16,6 +16,7 @@ members = [ "core", "exec", "execpolicy", + "keyring-store", "file-search", "git-tooling", "linux-sandbox", @@ -67,6 +68,7 @@ codex-exec = { path = "exec" } codex-feedback = { path = "feedback" } codex-file-search = { path = "file-search" } codex-git-tooling = { path = "git-tooling" } +codex-keyring-store = { path = "keyring-store" } codex-linux-sandbox = { path = "linux-sandbox" } codex-login = { path = "login" } codex-mcp-server = { path = "mcp-server" } diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index 8ab7ffe827..1962411b6e 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -21,15 +21,16 @@ bytes = { workspace = true } chrono = { workspace = true, features = ["serde"] } codex-app-server-protocol = { workspace = true } codex-apply-patch = { workspace = true } +codex-async-utils = { workspace = true } codex-file-search = { workspace = true } +codex-git-tooling = { workspace = true } +codex-keyring-store = { workspace = true } codex-otel = { workspace = true, features = ["otel"] } codex-protocol = { workspace = true } -codex-git-tooling = { workspace = true } codex-rmcp-client = { workspace = true } -codex-async-utils = { workspace = true } -codex-utils-string = { workspace = true } codex-utils-pty = { workspace = true } codex-utils-readiness = { workspace = true } +codex-utils-string = { workspace = true } codex-utils-tokenizer = { workspace = true } dirs = { workspace = true } dunce = { workspace = true } @@ -38,6 +39,7 @@ eventsource-stream = { workspace = true } futures = { workspace = true } http = { workspace = true } indexmap = { workspace = true } +keyring = { workspace = true } libc = { workspace = true } mcp-types = { workspace = true } os_info = { workspace = true } @@ -47,6 +49,7 @@ reqwest = { workspace = true, features = ["json", "stream"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } sha1 = { workspace = true } +sha2 = { workspace = true } shlex = { workspace = true } similar = { workspace = true } strum_macros = { workspace = true } diff --git a/codex-rs/core/src/auth/storage.rs b/codex-rs/core/src/auth/storage.rs index 508adc89fe..a238eb9c38 100644 --- a/codex-rs/core/src/auth/storage.rs +++ b/codex-rs/core/src/auth/storage.rs @@ -2,6 +2,8 @@ use chrono::DateTime; use chrono::Utc; use serde::Deserialize; use serde::Serialize; +use sha2::Digest; +use sha2::Sha256; use std::fmt::Debug; use std::fs::File; use std::fs::OpenOptions; @@ -12,8 +14,11 @@ use std::os::unix::fs::OpenOptionsExt; use std::path::Path; use std::path::PathBuf; use std::sync::Arc; +use tracing::warn; use crate::token_data::TokenData; +use codex_keyring_store::DefaultKeyringStore; +use codex_keyring_store::KeyringStore; /// Determine where Codex should store CLI auth credentials. #[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -22,7 +27,10 @@ pub enum AuthCredentialsStoreMode { #[default] /// Persist credentials in CODEX_HOME/auth.json. File, - // TODO: Implement keyring support. + /// Persist credentials in the keyring. Fail if unavailable. + Keyring, + /// Use keyring when available; otherwise, fall back to a file in CODEX_HOME. + Auto, } /// Expected structure for $CODEX_HOME/auth.json. @@ -114,25 +122,177 @@ impl AuthStorageBackend for FileAuthStorage { } } +const KEYRING_SERVICE: &str = "Codex Auth"; + +// turns codex_home path into a stable, short key string +fn compute_store_key(codex_home: &Path) -> std::io::Result { + let canonical = codex_home + .canonicalize() + .unwrap_or_else(|_| codex_home.to_path_buf()); + let path_str = canonical.to_string_lossy(); + let mut hasher = Sha256::new(); + hasher.update(path_str.as_bytes()); + let digest = hasher.finalize(); + let hex = format!("{digest:x}"); + let truncated = hex.get(..16).unwrap_or(&hex); + Ok(format!("cli|{truncated}")) +} + +#[derive(Clone, Debug)] +struct KeyringAuthStorage { + codex_home: PathBuf, + keyring_store: Arc, +} + +impl KeyringAuthStorage { + fn new(codex_home: PathBuf, keyring_store: Arc) -> Self { + Self { + codex_home, + keyring_store, + } + } + + fn load_from_keyring(&self, key: &str) -> std::io::Result> { + match self.keyring_store.load(KEYRING_SERVICE, key) { + Ok(Some(serialized)) => serde_json::from_str(&serialized).map(Some).map_err(|err| { + std::io::Error::other(format!( + "failed to deserialize CLI auth from keyring: {err}" + )) + }), + Ok(None) => Ok(None), + Err(error) => Err(std::io::Error::other(format!( + "failed to load CLI auth from keyring: {}", + error.message() + ))), + } + } + + fn save_to_keyring(&self, key: &str, value: &str) -> std::io::Result<()> { + match self.keyring_store.save(KEYRING_SERVICE, key, value) { + Ok(()) => Ok(()), + Err(error) => { + let message = format!( + "failed to write OAuth tokens to keyring: {}", + error.message() + ); + warn!("{message}"); + Err(std::io::Error::other(message)) + } + } + } +} + +impl AuthStorageBackend for KeyringAuthStorage { + fn load(&self) -> std::io::Result> { + let key = compute_store_key(&self.codex_home)?; + self.load_from_keyring(&key) + } + + fn save(&self, auth: &AuthDotJson) -> std::io::Result<()> { + let key = compute_store_key(&self.codex_home)?; + // Simpler error mapping per style: prefer method reference over closure + let serialized = serde_json::to_string(auth).map_err(std::io::Error::other)?; + self.save_to_keyring(&key, &serialized)?; + if let Err(err) = delete_file_if_exists(&self.codex_home) { + warn!("failed to remove CLI auth fallback file: {err}"); + } + Ok(()) + } + + fn delete(&self) -> std::io::Result { + let key = compute_store_key(&self.codex_home)?; + let keyring_removed = self + .keyring_store + .delete(KEYRING_SERVICE, &key) + .map_err(|err| { + std::io::Error::other(format!("failed to delete auth from keyring: {err}")) + })?; + let file_removed = delete_file_if_exists(&self.codex_home)?; + Ok(keyring_removed || file_removed) + } +} + +#[derive(Clone, Debug)] +struct AutoAuthStorage { + keyring_storage: Arc, + file_storage: Arc, +} + +impl AutoAuthStorage { + fn new(codex_home: PathBuf, keyring_store: Arc) -> Self { + Self { + keyring_storage: Arc::new(KeyringAuthStorage::new(codex_home.clone(), keyring_store)), + file_storage: Arc::new(FileAuthStorage::new(codex_home)), + } + } +} + +impl AuthStorageBackend for AutoAuthStorage { + fn load(&self) -> std::io::Result> { + match self.keyring_storage.load() { + Ok(Some(auth)) => Ok(Some(auth)), + Ok(None) => self.file_storage.load(), + Err(err) => { + warn!("failed to load CLI auth from keyring, falling back to file storage: {err}"); + self.file_storage.load() + } + } + } + + fn save(&self, auth: &AuthDotJson) -> std::io::Result<()> { + match self.keyring_storage.save(auth) { + Ok(()) => Ok(()), + Err(err) => { + warn!("failed to save auth to keyring, falling back to file storage: {err}"); + self.file_storage.save(auth) + } + } + } + + fn delete(&self) -> std::io::Result { + // Keyring storage will delete from disk as well + self.keyring_storage.delete() + } +} + pub(super) fn create_auth_storage( codex_home: PathBuf, mode: AuthCredentialsStoreMode, +) -> Arc { + let keyring_store: Arc = Arc::new(DefaultKeyringStore); + create_auth_storage_with_keyring_store(codex_home, mode, keyring_store) +} + +fn create_auth_storage_with_keyring_store( + codex_home: PathBuf, + mode: AuthCredentialsStoreMode, + keyring_store: Arc, ) -> Arc { match mode { AuthCredentialsStoreMode::File => Arc::new(FileAuthStorage::new(codex_home)), + AuthCredentialsStoreMode::Keyring => { + Arc::new(KeyringAuthStorage::new(codex_home, keyring_store)) + } + AuthCredentialsStoreMode::Auto => Arc::new(AutoAuthStorage::new(codex_home, keyring_store)), } } #[cfg(test)] mod tests { use super::*; + use crate::token_data::IdTokenInfo; use anyhow::Context; + use base64::Engine; use pretty_assertions::assert_eq; + use serde_json::json; use tempfile::tempdir; + use codex_keyring_store::tests::MockKeyringStore; + use keyring::Error as KeyringError; + #[tokio::test] async fn file_storage_load_returns_auth_dot_json() -> anyhow::Result<()> { - let codex_home = tempdir().unwrap(); + let codex_home = tempdir()?; let storage = FileAuthStorage::new(codex_home.path().to_path_buf()); let auth_dot_json = AuthDotJson { openai_api_key: Some("test-key".to_string()), @@ -151,7 +311,7 @@ mod tests { #[tokio::test] async fn file_storage_save_persists_auth_dot_json() -> anyhow::Result<()> { - let codex_home = tempdir().unwrap(); + let codex_home = tempdir()?; let storage = FileAuthStorage::new(codex_home.path().to_path_buf()); let auth_dot_json = AuthDotJson { openai_api_key: Some("test-key".to_string()), @@ -188,4 +348,325 @@ mod tests { assert!(!dir.path().join("auth.json").exists()); Ok(()) } + + fn seed_keyring_and_fallback_auth_file_for_delete( + mock_keyring: &MockKeyringStore, + codex_home: &Path, + compute_key: F, + ) -> anyhow::Result<(String, PathBuf)> + where + F: FnOnce() -> std::io::Result, + { + let key = compute_key()?; + mock_keyring.save(KEYRING_SERVICE, &key, "{}")?; + let auth_file = get_auth_file(codex_home); + std::fs::write(&auth_file, "stale")?; + Ok((key, auth_file)) + } + + fn seed_keyring_with_auth( + mock_keyring: &MockKeyringStore, + compute_key: F, + auth: &AuthDotJson, + ) -> anyhow::Result<()> + where + F: FnOnce() -> std::io::Result, + { + let key = compute_key()?; + let serialized = serde_json::to_string(auth)?; + mock_keyring.save(KEYRING_SERVICE, &key, &serialized)?; + Ok(()) + } + + fn assert_keyring_saved_auth_and_removed_fallback( + mock_keyring: &MockKeyringStore, + key: &str, + codex_home: &Path, + expected: &AuthDotJson, + ) { + let saved_value = mock_keyring + .saved_value(key) + .expect("keyring entry should exist"); + let expected_serialized = serde_json::to_string(expected).expect("serialize expected auth"); + assert_eq!(saved_value, expected_serialized); + let auth_file = get_auth_file(codex_home); + assert!( + !auth_file.exists(), + "fallback auth.json should be removed after keyring save" + ); + } + + fn id_token_with_prefix(prefix: &str) -> IdTokenInfo { + #[derive(Serialize)] + struct Header { + alg: &'static str, + typ: &'static str, + } + + let header = Header { + alg: "none", + typ: "JWT", + }; + let payload = json!({ + "email": format!("{prefix}@example.com"), + "https://api.openai.com/auth": { + "chatgpt_account_id": format!("{prefix}-account"), + }, + }); + let encode = |bytes: &[u8]| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes); + let header_b64 = encode(&serde_json::to_vec(&header).expect("serialize header")); + let payload_b64 = encode(&serde_json::to_vec(&payload).expect("serialize payload")); + let signature_b64 = encode(b"sig"); + let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}"); + + crate::token_data::parse_id_token(&fake_jwt).expect("fake JWT should parse") + } + + fn auth_with_prefix(prefix: &str) -> AuthDotJson { + AuthDotJson { + openai_api_key: Some(format!("{prefix}-api-key")), + tokens: Some(TokenData { + id_token: id_token_with_prefix(prefix), + access_token: format!("{prefix}-access"), + refresh_token: format!("{prefix}-refresh"), + account_id: Some(format!("{prefix}-account-id")), + }), + last_refresh: None, + } + } + + #[test] + fn keyring_auth_storage_load_returns_deserialized_auth() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = KeyringAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let expected = AuthDotJson { + openai_api_key: Some("sk-test".to_string()), + tokens: None, + last_refresh: None, + }; + seed_keyring_with_auth( + &mock_keyring, + || compute_store_key(codex_home.path()), + &expected, + )?; + + let loaded = storage.load()?; + assert_eq!(Some(expected), loaded); + Ok(()) + } + + #[test] + fn keyring_auth_storage_compute_store_key_for_home_directory() -> anyhow::Result<()> { + let codex_home = PathBuf::from("~/.codex"); + + let key = compute_store_key(codex_home.as_path())?; + + assert_eq!(key, "cli|940db7b1d0e4eb40"); + Ok(()) + } + + #[test] + fn keyring_auth_storage_save_persists_and_removes_fallback_file() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = KeyringAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let auth_file = get_auth_file(codex_home.path()); + std::fs::write(&auth_file, "stale")?; + let auth = AuthDotJson { + openai_api_key: None, + tokens: Some(TokenData { + id_token: Default::default(), + access_token: "access".to_string(), + refresh_token: "refresh".to_string(), + account_id: Some("account".to_string()), + }), + last_refresh: Some(Utc::now()), + }; + + storage.save(&auth)?; + + let key = compute_store_key(codex_home.path())?; + assert_keyring_saved_auth_and_removed_fallback( + &mock_keyring, + &key, + codex_home.path(), + &auth, + ); + Ok(()) + } + + #[test] + fn keyring_auth_storage_delete_removes_keyring_and_file() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = KeyringAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let (key, auth_file) = seed_keyring_and_fallback_auth_file_for_delete( + &mock_keyring, + codex_home.path(), + || compute_store_key(codex_home.path()), + )?; + + let removed = storage.delete()?; + + assert!(removed, "delete should report removal"); + assert!( + !mock_keyring.contains(&key), + "keyring entry should be removed" + ); + assert!( + !auth_file.exists(), + "fallback auth.json should be removed after keyring delete" + ); + Ok(()) + } + + #[test] + fn auto_auth_storage_load_prefers_keyring_value() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = AutoAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let keyring_auth = auth_with_prefix("keyring"); + seed_keyring_with_auth( + &mock_keyring, + || compute_store_key(codex_home.path()), + &keyring_auth, + )?; + + let file_auth = auth_with_prefix("file"); + storage.file_storage.save(&file_auth)?; + + let loaded = storage.load()?; + assert_eq!(loaded, Some(keyring_auth)); + Ok(()) + } + + #[test] + fn auto_auth_storage_load_uses_file_when_keyring_empty() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = AutoAuthStorage::new(codex_home.path().to_path_buf(), Arc::new(mock_keyring)); + + let expected = auth_with_prefix("file-only"); + storage.file_storage.save(&expected)?; + + let loaded = storage.load()?; + assert_eq!(loaded, Some(expected)); + Ok(()) + } + + #[test] + fn auto_auth_storage_load_falls_back_when_keyring_errors() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = AutoAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let key = compute_store_key(codex_home.path())?; + mock_keyring.set_error(&key, KeyringError::Invalid("error".into(), "load".into())); + + let expected = auth_with_prefix("fallback"); + storage.file_storage.save(&expected)?; + + let loaded = storage.load()?; + assert_eq!(loaded, Some(expected)); + Ok(()) + } + + #[test] + fn auto_auth_storage_save_prefers_keyring() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = AutoAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let key = compute_store_key(codex_home.path())?; + + let stale = auth_with_prefix("stale"); + storage.file_storage.save(&stale)?; + + let expected = auth_with_prefix("to-save"); + storage.save(&expected)?; + + assert_keyring_saved_auth_and_removed_fallback( + &mock_keyring, + &key, + codex_home.path(), + &expected, + ); + Ok(()) + } + + #[test] + fn auto_auth_storage_save_falls_back_when_keyring_errors() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = AutoAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let key = compute_store_key(codex_home.path())?; + mock_keyring.set_error(&key, KeyringError::Invalid("error".into(), "save".into())); + + let auth = auth_with_prefix("fallback"); + storage.save(&auth)?; + + let auth_file = get_auth_file(codex_home.path()); + assert!( + auth_file.exists(), + "fallback auth.json should be created when keyring save fails" + ); + let saved = storage + .file_storage + .load()? + .context("fallback auth should exist")?; + assert_eq!(saved, auth); + assert!( + mock_keyring.saved_value(&key).is_none(), + "keyring should not contain value when save fails" + ); + Ok(()) + } + + #[test] + fn auto_auth_storage_delete_removes_keyring_and_file() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = AutoAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let (key, auth_file) = seed_keyring_and_fallback_auth_file_for_delete( + &mock_keyring, + codex_home.path(), + || compute_store_key(codex_home.path()), + )?; + + let removed = storage.delete()?; + + assert!(removed, "delete should report removal"); + assert!( + !mock_keyring.contains(&key), + "keyring entry should be removed" + ); + assert!( + !auth_file.exists(), + "fallback auth.json should be removed after delete" + ); + Ok(()) + } } diff --git a/codex-rs/keyring-store/Cargo.toml b/codex-rs/keyring-store/Cargo.toml new file mode 100644 index 0000000000..94d3d54493 --- /dev/null +++ b/codex-rs/keyring-store/Cargo.toml @@ -0,0 +1,11 @@ +[package] +edition = "2024" +name = "codex-keyring-store" +version = { workspace = true } + +[lints] +workspace = true + +[dependencies] +keyring = { workspace = true } +tracing = { workspace = true } diff --git a/codex-rs/keyring-store/src/lib.rs b/codex-rs/keyring-store/src/lib.rs new file mode 100644 index 0000000000..10dad3a98a --- /dev/null +++ b/codex-rs/keyring-store/src/lib.rs @@ -0,0 +1,226 @@ +use keyring::Entry; +use keyring::Error as KeyringError; +use std::error::Error; +use std::fmt; +use std::fmt::Debug; +use tracing::trace; + +#[derive(Debug)] +pub enum CredentialStoreError { + Other(KeyringError), +} + +impl CredentialStoreError { + pub fn new(error: KeyringError) -> Self { + Self::Other(error) + } + + pub fn message(&self) -> String { + match self { + Self::Other(error) => error.to_string(), + } + } + + pub fn into_error(self) -> KeyringError { + match self { + Self::Other(error) => error, + } + } +} + +impl fmt::Display for CredentialStoreError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Other(error) => write!(f, "{error}"), + } + } +} + +impl Error for CredentialStoreError {} + +/// Shared credential store abstraction for keyring-backed implementations. +pub trait KeyringStore: Debug + Send + Sync { + fn load(&self, service: &str, account: &str) -> Result, CredentialStoreError>; + fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError>; + fn delete(&self, service: &str, account: &str) -> Result; +} + +#[derive(Debug)] +pub struct DefaultKeyringStore; + +impl KeyringStore for DefaultKeyringStore { + fn load(&self, service: &str, account: &str) -> Result, CredentialStoreError> { + trace!("keyring.load start, service={service}, account={account}"); + let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?; + match entry.get_password() { + Ok(password) => { + trace!("keyring.load success, service={service}, account={account}"); + Ok(Some(password)) + } + Err(keyring::Error::NoEntry) => { + trace!("keyring.load no entry, service={service}, account={account}"); + Ok(None) + } + Err(error) => { + trace!("keyring.load error, service={service}, account={account}, error={error}"); + Err(CredentialStoreError::new(error)) + } + } + } + + fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError> { + trace!( + "keyring.save start, service={service}, account={account}, value_len={}", + value.len() + ); + let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?; + match entry.set_password(value) { + Ok(()) => { + trace!("keyring.save success, service={service}, account={account}"); + Ok(()) + } + Err(error) => { + trace!("keyring.save error, service={service}, account={account}, error={error}"); + Err(CredentialStoreError::new(error)) + } + } + } + + fn delete(&self, service: &str, account: &str) -> Result { + trace!("keyring.delete start, service={service}, account={account}"); + let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?; + match entry.delete_credential() { + Ok(()) => { + trace!("keyring.delete success, service={service}, account={account}"); + Ok(true) + } + Err(keyring::Error::NoEntry) => { + trace!("keyring.delete no entry, service={service}, account={account}"); + Ok(false) + } + Err(error) => { + trace!("keyring.delete error, service={service}, account={account}, error={error}"); + Err(CredentialStoreError::new(error)) + } + } + } +} + +pub mod tests { + use super::CredentialStoreError; + use super::KeyringStore; + use keyring::Error as KeyringError; + use keyring::credential::CredentialApi as _; + use keyring::mock::MockCredential; + use std::collections::HashMap; + use std::sync::Arc; + use std::sync::Mutex; + use std::sync::PoisonError; + + #[derive(Default, Clone, Debug)] + pub struct MockKeyringStore { + credentials: Arc>>>, + } + + impl MockKeyringStore { + pub fn credential(&self, account: &str) -> Arc { + let mut guard = self + .credentials + .lock() + .unwrap_or_else(PoisonError::into_inner); + guard + .entry(account.to_string()) + .or_insert_with(|| Arc::new(MockCredential::default())) + .clone() + } + + pub fn saved_value(&self, account: &str) -> Option { + let credential = { + let guard = self + .credentials + .lock() + .unwrap_or_else(PoisonError::into_inner); + guard.get(account).cloned() + }?; + credential.get_password().ok() + } + + pub fn set_error(&self, account: &str, error: KeyringError) { + let credential = self.credential(account); + credential.set_error(error); + } + + pub fn contains(&self, account: &str) -> bool { + let guard = self + .credentials + .lock() + .unwrap_or_else(PoisonError::into_inner); + guard.contains_key(account) + } + } + + impl KeyringStore for MockKeyringStore { + fn load( + &self, + _service: &str, + account: &str, + ) -> Result, CredentialStoreError> { + let credential = { + let guard = self + .credentials + .lock() + .unwrap_or_else(PoisonError::into_inner); + guard.get(account).cloned() + }; + + let Some(credential) = credential else { + return Ok(None); + }; + + match credential.get_password() { + Ok(password) => Ok(Some(password)), + Err(KeyringError::NoEntry) => Ok(None), + Err(error) => Err(CredentialStoreError::new(error)), + } + } + + fn save( + &self, + _service: &str, + account: &str, + value: &str, + ) -> Result<(), CredentialStoreError> { + let credential = self.credential(account); + credential + .set_password(value) + .map_err(CredentialStoreError::new) + } + + fn delete(&self, _service: &str, account: &str) -> Result { + let credential = { + let guard = self + .credentials + .lock() + .unwrap_or_else(PoisonError::into_inner); + guard.get(account).cloned() + }; + + let Some(credential) = credential else { + return Ok(false); + }; + + let removed = match credential.delete_credential() { + Ok(()) => Ok(true), + Err(KeyringError::NoEntry) => Ok(false), + Err(error) => Err(CredentialStoreError::new(error)), + }?; + + let mut guard = self + .credentials + .lock() + .unwrap_or_else(PoisonError::into_inner); + guard.remove(account); + Ok(removed) + } + } +} diff --git a/codex-rs/rmcp-client/Cargo.toml b/codex-rs/rmcp-client/Cargo.toml index c515cdbfa3..e9f832e655 100644 --- a/codex-rs/rmcp-client/Cargo.toml +++ b/codex-rs/rmcp-client/Cargo.toml @@ -12,7 +12,10 @@ axum = { workspace = true, default-features = false, features = [ "http1", "tokio", ] } +codex-keyring-store = { workspace = true } codex-protocol = { workspace = true } +dirs = { workspace = true } +futures = { workspace = true, default-features = false, features = ["std"] } keyring = { workspace = true, features = [ "apple-native", "crypto-rust", @@ -20,6 +23,12 @@ keyring = { workspace = true, features = [ "windows-native", ] } mcp-types = { path = "../mcp-types" } +oauth2 = "5" +reqwest = { version = "0.12", default-features = false, features = [ + "json", + "stream", + "rustls-tls", +] } rmcp = { workspace = true, default-features = false, features = [ "auth", "base64", @@ -31,17 +40,9 @@ rmcp = { workspace = true, default-features = false, features = [ "transport-streamable-http-client-reqwest", "transport-streamable-http-server", ] } -futures = { workspace = true, default-features = false, features = ["std"] } -reqwest = { version = "0.12", default-features = false, features = [ - "json", - "stream", - "rustls-tls", -] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } sha2 = { workspace = true } -dirs = { workspace = true } -oauth2 = "5" tiny_http = { workspace = true } tokio = { workspace = true, features = [ "io-util", diff --git a/codex-rs/rmcp-client/src/oauth.rs b/codex-rs/rmcp-client/src/oauth.rs index afa0e907b1..bd6833fca4 100644 --- a/codex-rs/rmcp-client/src/oauth.rs +++ b/codex-rs/rmcp-client/src/oauth.rs @@ -17,8 +17,8 @@ //! If the keyring is not available or fails, we fall back to CODEX_HOME/.credentials.json which is consistent with other coding CLI agents. use anyhow::Context; +use anyhow::Error; use anyhow::Result; -use keyring::Entry; use oauth2::AccessToken; use oauth2::EmptyExtraTokenFields; use oauth2::RefreshToken; @@ -33,7 +33,6 @@ use serde_json::map::Map as JsonMap; use sha2::Digest; use sha2::Sha256; use std::collections::BTreeMap; -use std::fmt; use std::fs; use std::io::ErrorKind; use std::path::PathBuf; @@ -43,6 +42,8 @@ use std::time::SystemTime; use std::time::UNIX_EPOCH; use tracing::warn; +use codex_keyring_store::DefaultKeyringStore; +use codex_keyring_store::KeyringStore; use rmcp::transport::auth::AuthorizationManager; use tokio::sync::Mutex; @@ -73,64 +74,6 @@ pub enum OAuthCredentialsStoreMode { Keyring, } -#[derive(Debug)] -struct CredentialStoreError(anyhow::Error); - -impl CredentialStoreError { - fn new(error: impl Into) -> Self { - Self(error.into()) - } - - fn message(&self) -> String { - self.0.to_string() - } - - fn into_error(self) -> anyhow::Error { - self.0 - } -} - -impl fmt::Display for CredentialStoreError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl std::error::Error for CredentialStoreError {} - -trait KeyringStore { - fn load(&self, service: &str, account: &str) -> Result, CredentialStoreError>; - fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError>; - fn delete(&self, service: &str, account: &str) -> Result; -} - -struct DefaultKeyringStore; - -impl KeyringStore for DefaultKeyringStore { - fn load(&self, service: &str, account: &str) -> Result, CredentialStoreError> { - let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?; - match entry.get_password() { - Ok(password) => Ok(Some(password)), - Err(keyring::Error::NoEntry) => Ok(None), - Err(error) => Err(CredentialStoreError::new(error)), - } - } - - fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError> { - let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?; - entry.set_password(value).map_err(CredentialStoreError::new) - } - - fn delete(&self, service: &str, account: &str) -> Result { - let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?; - match entry.delete_credential() { - Ok(()) => Ok(true), - Err(keyring::Error::NoEntry) => Ok(false), - Err(error) => Err(CredentialStoreError::new(error)), - } - } -} - /// Wrap OAuthTokenResponse to allow for partial equality comparison. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WrappedOAuthTokenResponse(pub OAuthTokenResponse); @@ -199,7 +142,7 @@ fn load_oauth_tokens_from_keyring( Ok(Some(tokens)) } Ok(None) => Ok(None), - Err(error) => Err(error.into_error()), + Err(error) => Err(Error::new(error.into_error())), } } @@ -243,7 +186,7 @@ fn save_oauth_tokens_with_keyring( error.message() ); warn!("{message}"); - Err(error.into_error().context(message)) + Err(Error::new(error.into_error()).context(message)) } } } @@ -595,109 +538,14 @@ mod tests { use super::*; use anyhow::Result; use keyring::Error as KeyringError; - use keyring::credential::CredentialApi as _; - use keyring::mock::MockCredential; use pretty_assertions::assert_eq; - use std::collections::HashMap; - use std::sync::Arc; use std::sync::Mutex; use std::sync::MutexGuard; use std::sync::OnceLock; use std::sync::PoisonError; use tempfile::tempdir; - #[derive(Default, Clone)] - struct MockCredentialStore { - credentials: Arc>>>, - } - - impl MockCredentialStore { - fn credential(&self, account: &str) -> Arc { - let mut guard = self.credentials.lock().unwrap(); - guard - .entry(account.to_string()) - .or_insert_with(|| Arc::new(MockCredential::default())) - .clone() - } - - fn saved_value(&self, account: &str) -> Option { - let credential = { - let guard = self.credentials.lock().unwrap(); - guard.get(account).cloned() - }?; - credential.get_password().ok() - } - - fn set_error(&self, account: &str, error: KeyringError) { - let credential = self.credential(account); - credential.set_error(error); - } - - fn contains(&self, account: &str) -> bool { - let guard = self.credentials.lock().unwrap(); - guard.contains_key(account) - } - } - - impl KeyringStore for MockCredentialStore { - fn load( - &self, - _service: &str, - account: &str, - ) -> Result, CredentialStoreError> { - let credential = { - let guard = self.credentials.lock().unwrap(); - guard.get(account).cloned() - }; - - let Some(credential) = credential else { - return Ok(None); - }; - - match credential.get_password() { - Ok(password) => Ok(Some(password)), - Err(KeyringError::NoEntry) => Ok(None), - Err(error) => Err(CredentialStoreError::new(error)), - } - } - - fn save( - &self, - _service: &str, - account: &str, - value: &str, - ) -> Result<(), CredentialStoreError> { - let credential = self.credential(account); - credential - .set_password(value) - .map_err(CredentialStoreError::new) - } - - fn delete(&self, _service: &str, account: &str) -> Result { - let credential = { - let guard = self.credentials.lock().unwrap(); - guard.get(account).cloned() - }; - - let Some(credential) = credential else { - return Ok(false); - }; - - match credential.delete_credential() { - Ok(()) => { - let mut guard = self.credentials.lock().unwrap(); - guard.remove(account); - Ok(true) - } - Err(KeyringError::NoEntry) => { - let mut guard = self.credentials.lock().unwrap(); - guard.remove(account); - Ok(false) - } - Err(error) => Err(CredentialStoreError::new(error)), - } - } - } + use codex_keyring_store::tests::MockKeyringStore; struct TempCodexHome { _guard: MutexGuard<'static, ()>, @@ -733,7 +581,7 @@ mod tests { #[test] fn load_oauth_tokens_reads_from_keyring_when_available() -> Result<()> { let _env = TempCodexHome::new(); - let store = MockCredentialStore::default(); + let store = MockKeyringStore::default(); let tokens = sample_tokens(); let expected = tokens.clone(); let serialized = serde_json::to_string(&tokens)?; @@ -749,7 +597,7 @@ mod tests { #[test] fn load_oauth_tokens_falls_back_when_missing_in_keyring() -> Result<()> { let _env = TempCodexHome::new(); - let store = MockCredentialStore::default(); + let store = MockKeyringStore::default(); let tokens = sample_tokens(); let expected = tokens.clone(); @@ -768,7 +616,7 @@ mod tests { #[test] fn load_oauth_tokens_falls_back_when_keyring_errors() -> Result<()> { let _env = TempCodexHome::new(); - let store = MockCredentialStore::default(); + let store = MockKeyringStore::default(); let tokens = sample_tokens(); let expected = tokens.clone(); let key = super::compute_store_key(&tokens.server_name, &tokens.url)?; @@ -789,7 +637,7 @@ mod tests { #[test] fn save_oauth_tokens_prefers_keyring_when_available() -> Result<()> { let _env = TempCodexHome::new(); - let store = MockCredentialStore::default(); + let store = MockKeyringStore::default(); let tokens = sample_tokens(); let key = super::compute_store_key(&tokens.server_name, &tokens.url)?; @@ -811,7 +659,7 @@ mod tests { #[test] fn save_oauth_tokens_writes_fallback_when_keyring_fails() -> Result<()> { let _env = TempCodexHome::new(); - let store = MockCredentialStore::default(); + let store = MockKeyringStore::default(); let tokens = sample_tokens(); let key = super::compute_store_key(&tokens.server_name, &tokens.url)?; store.set_error(&key, KeyringError::Invalid("error".into(), "save".into())); @@ -841,7 +689,7 @@ mod tests { #[test] fn delete_oauth_tokens_removes_all_storage() -> Result<()> { let _env = TempCodexHome::new(); - let store = MockCredentialStore::default(); + let store = MockKeyringStore::default(); let tokens = sample_tokens(); let serialized = serde_json::to_string(&tokens)?; let key = super::compute_store_key(&tokens.server_name, &tokens.url)?; @@ -863,7 +711,7 @@ mod tests { #[test] fn delete_oauth_tokens_file_mode_removes_keyring_only_entry() -> Result<()> { let _env = TempCodexHome::new(); - let store = MockCredentialStore::default(); + let store = MockKeyringStore::default(); let tokens = sample_tokens(); let serialized = serde_json::to_string(&tokens)?; let key = super::compute_store_key(&tokens.server_name, &tokens.url)?; @@ -885,7 +733,7 @@ mod tests { #[test] fn delete_oauth_tokens_propagates_keyring_errors() -> Result<()> { let _env = TempCodexHome::new(); - let store = MockCredentialStore::default(); + let store = MockKeyringStore::default(); let tokens = sample_tokens(); let key = super::compute_store_key(&tokens.server_name, &tokens.url)?; store.set_error(&key, KeyringError::Invalid("error".into(), "delete".into()));