From e4a5901d8e730118f25eaaf7f72980fb2c241182 Mon Sep 17 00:00:00 2001 From: LaelLuo Date: Sun, 9 Nov 2025 16:59:50 +0800 Subject: [PATCH 1/2] :sparkles: feat(rmcp-client): persist OAuth credential expiry timestamps --- codex-rs/rmcp-client/src/oauth.rs | 81 ++++++++++++++++--- .../rmcp-client/src/perform_oauth_login.rs | 3 + 2 files changed, 73 insertions(+), 11 deletions(-) diff --git a/codex-rs/rmcp-client/src/oauth.rs b/codex-rs/rmcp-client/src/oauth.rs index bd6833fca412..91343e3c099b 100644 --- a/codex-rs/rmcp-client/src/oauth.rs +++ b/codex-rs/rmcp-client/src/oauth.rs @@ -57,6 +57,8 @@ pub struct StoredOAuthTokens { pub url: String, pub client_id: String, pub token_response: WrappedOAuthTokenResponse, + #[serde(default)] + pub expires_at: Option, } /// Determine where Codex should store and read MCP credentials. @@ -113,6 +115,22 @@ pub(crate) fn has_oauth_tokens( Ok(load_oauth_tokens(server_name, url, store_mode)?.is_some()) } +fn refresh_expires_in_from_timestamp(tokens: &mut StoredOAuthTokens) { + let Some(expires_at) = tokens.expires_at else { + return; + }; + + match expires_in_from_timestamp(expires_at) { + Some(seconds) => { + let duration = Duration::from_secs(seconds); + tokens.token_response.0.set_expires_in(Some(&duration)); + } + None => { + tokens.token_response.0.set_expires_in(None); + } + } +} + fn load_oauth_tokens_from_keyring_with_fallback_to_file( keyring_store: &K, server_name: &str, @@ -137,8 +155,9 @@ fn load_oauth_tokens_from_keyring( let key = compute_store_key(server_name, url)?; match keyring_store.load(KEYRING_SERVICE, &key) { Ok(Some(serialized)) => { - let tokens: StoredOAuthTokens = serde_json::from_str(&serialized) + let mut tokens: StoredOAuthTokens = serde_json::from_str(&serialized) .context("failed to deserialize OAuth tokens from keyring")?; + refresh_expires_in_from_timestamp(&mut tokens); Ok(Some(tokens)) } Ok(None) => Ok(None), @@ -286,11 +305,13 @@ impl OAuthPersistor { match maybe_credentials { Some(credentials) => { + let expires_at = compute_expires_at_millis(&credentials); let stored = StoredOAuthTokens { server_name: self.inner.server_name.clone(), url: self.inner.url.clone(), client_id, token_response: WrappedOAuthTokenResponse(credentials.clone()), + expires_at, }; let mut last_credentials = self.inner.last_credentials.lock().await; if last_credentials.as_ref() != Some(&stored) { @@ -366,19 +387,14 @@ fn load_oauth_tokens_from_file(server_name: &str, url: &str) -> Result Result<()> { let mut store = read_fallback_file()?.unwrap_or_default(); let token_response = &tokens.token_response.0; + let expires_at = tokens + .expires_at + .or_else(|| compute_expires_at_millis(token_response)); let refresh_token = token_response .refresh_token() .map(|token| token.secret().to_string()); @@ -403,7 +422,7 @@ fn save_oauth_tokens_to_file(tokens: &StoredOAuthTokens) -> Result<()> { server_url: tokens.url.clone(), client_id: tokens.client_id.clone(), access_token: token_response.access_token().secret().to_string(), - expires_at: compute_expires_at_millis(token_response), + expires_at, refresh_token, scopes, }; @@ -427,7 +446,7 @@ fn delete_oauth_tokens_from_file(key: &str) -> Result { Ok(removed) } -fn compute_expires_at_millis(response: &OAuthTokenResponse) -> Option { +pub(crate) fn compute_expires_at_millis(response: &OAuthTokenResponse) -> Option { let expires_in = response.expires_in()?; let now = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -750,6 +769,43 @@ mod tests { Ok(()) } + #[test] + fn refresh_expires_in_from_timestamp_restores_future_durations() { + let mut tokens = sample_tokens(); + let expires_at = tokens.expires_at.expect("expires_at should be set"); + + tokens.token_response.0.set_expires_in(None); + super::refresh_expires_in_from_timestamp(&mut tokens); + + let actual = tokens + .token_response + .0 + .expires_in() + .expect("expires_in should be restored") + .as_secs(); + let expected = super::expires_in_from_timestamp(expires_at) + .expect("expires_at should still be in the future"); + let diff = actual.abs_diff(expected); + assert!(diff <= 1, "expires_in drift too large: diff={diff}"); + } + + #[test] + fn refresh_expires_in_from_timestamp_clears_expired_tokens() { + let mut tokens = sample_tokens(); + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_else(|_| Duration::from_secs(0)); + let expired_at = now.as_millis() as u64; + tokens.expires_at = Some(expired_at.saturating_sub(1000)); + + let duration = Duration::from_secs(600); + tokens.token_response.0.set_expires_in(Some(&duration)); + + super::refresh_expires_in_from_timestamp(&mut tokens); + + assert!(tokens.token_response.0.expires_in().is_none()); + } + fn assert_tokens_match_without_expiry( actual: &StoredOAuthTokens, expected: &StoredOAuthTokens, @@ -757,6 +813,7 @@ mod tests { assert_eq!(actual.server_name, expected.server_name); assert_eq!(actual.url, expected.url); assert_eq!(actual.client_id, expected.client_id); + assert_eq!(actual.expires_at, expected.expires_at); assert_token_response_match_without_expiry( &actual.token_response, &expected.token_response, @@ -803,12 +860,14 @@ mod tests { ])); let expires_in = Duration::from_secs(3600); response.set_expires_in(Some(&expires_in)); + let expires_at = super::compute_expires_at_millis(&response); StoredOAuthTokens { server_name: "test-server".to_string(), url: "https://example.test".to_string(), client_id: "client-id".to_string(), token_response: WrappedOAuthTokenResponse(response), + expires_at, } } } diff --git a/codex-rs/rmcp-client/src/perform_oauth_login.rs b/codex-rs/rmcp-client/src/perform_oauth_login.rs index 425e124d7dab..d8ffdd3949a2 100644 --- a/codex-rs/rmcp-client/src/perform_oauth_login.rs +++ b/codex-rs/rmcp-client/src/perform_oauth_login.rs @@ -17,6 +17,7 @@ use urlencoding::decode; use crate::OAuthCredentialsStoreMode; use crate::StoredOAuthTokens; use crate::WrappedOAuthTokenResponse; +use crate::oauth::compute_expires_at_millis; use crate::save_oauth_tokens; use crate::utils::apply_default_headers; use crate::utils::build_default_headers; @@ -91,11 +92,13 @@ pub async fn perform_oauth_login( let credentials = credentials_opt.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?; + let expires_at = compute_expires_at_millis(&credentials); let stored = StoredOAuthTokens { server_name: server_name.to_string(), url: server_url.to_string(), client_id, token_response: WrappedOAuthTokenResponse(credentials), + expires_at, }; save_oauth_tokens(server_name, &stored, store_mode)?; From f9365c41960abb693649f31a6a0e920d6fdee4a0 Mon Sep 17 00:00:00 2001 From: LaelLuo Date: Thu, 13 Nov 2025 10:51:14 +0800 Subject: [PATCH 2/2] :sparkles: feat(rmcp-client): auto refresh OAuth tokens using expires_at --- codex-rs/rmcp-client/src/oauth.rs | 58 ++++++++++++++++++++++--- codex-rs/rmcp-client/src/rmcp_client.rs | 13 ++++++ 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/codex-rs/rmcp-client/src/oauth.rs b/codex-rs/rmcp-client/src/oauth.rs index 91343e3c099b..f8eafaf23e1f 100644 --- a/codex-rs/rmcp-client/src/oauth.rs +++ b/codex-rs/rmcp-client/src/oauth.rs @@ -50,6 +50,7 @@ use tokio::sync::Mutex; use crate::find_codex_home::find_codex_home; const KEYRING_SERVICE: &str = "Codex MCP Credentials"; +const REFRESH_SKEW_MILLIS: u64 = 30_000; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct StoredOAuthTokens { @@ -305,15 +306,24 @@ impl OAuthPersistor { match maybe_credentials { Some(credentials) => { - let expires_at = compute_expires_at_millis(&credentials); + let mut last_credentials = self.inner.last_credentials.lock().await; + let new_token_response = WrappedOAuthTokenResponse(credentials.clone()); + let same_token = last_credentials + .as_ref() + .map(|prev| prev.token_response == new_token_response) + .unwrap_or(false); + let expires_at = if same_token { + last_credentials.as_ref().and_then(|prev| prev.expires_at) + } else { + compute_expires_at_millis(&credentials) + }; let stored = StoredOAuthTokens { server_name: self.inner.server_name.clone(), url: self.inner.url.clone(), client_id, - token_response: WrappedOAuthTokenResponse(credentials.clone()), + token_response: new_token_response, expires_at, }; - let mut last_credentials = self.inner.last_credentials.lock().await; if last_credentials.as_ref() != Some(&stored) { save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?; *last_credentials = Some(stored); @@ -338,6 +348,30 @@ impl OAuthPersistor { Ok(()) } + + pub(crate) async fn refresh_if_needed(&self) -> Result<()> { + let expires_at = { + let guard = self.inner.last_credentials.lock().await; + guard.as_ref().and_then(|tokens| tokens.expires_at) + }; + + if !token_needs_refresh(expires_at) { + return Ok(()); + } + + { + let manager = self.inner.authorization_manager.clone(); + let guard = manager.lock().await; + guard.refresh_token().await.with_context(|| { + format!( + "failed to refresh OAuth tokens for server {}", + self.inner.server_name + ) + })?; + } + + self.persist_if_needed().await + } } const FALLBACK_FILENAME: &str = ".credentials.json"; @@ -473,6 +507,19 @@ fn expires_in_from_timestamp(expires_at: u64) -> Option { } } +fn token_needs_refresh(expires_at: Option) -> bool { + let Some(expires_at) = expires_at else { + return false; + }; + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_else(|_| Duration::from_secs(0)) + .as_millis() as u64; + + now.saturating_add(REFRESH_SKEW_MILLIS) >= expires_at +} + fn compute_store_key(server_name: &str, server_url: &str) -> Result { let mut payload = JsonMap::new(); payload.insert( @@ -608,8 +655,9 @@ mod tests { store.save(KEYRING_SERVICE, &key, &serialized)?; let loaded = - super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?; - assert_eq!(loaded, Some(expected)); + super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)? + .expect("tokens should load from keyring"); + assert_tokens_match_without_expiry(&loaded, &expected); Ok(()) } diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index bc1980f1f5e4..b8ccef27d19a 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -259,6 +259,7 @@ impl RmcpClient { params: Option, timeout: Option, ) -> Result { + self.refresh_oauth_if_needed().await; let service = self.service().await?; let rmcp_params = params .map(convert_to_rmcp::<_, PaginatedRequestParam>) @@ -276,6 +277,7 @@ impl RmcpClient { params: Option, timeout: Option, ) -> Result { + self.refresh_oauth_if_needed().await; let service = self.service().await?; let rmcp_params = params .map(convert_to_rmcp::<_, PaginatedRequestParam>) @@ -293,6 +295,7 @@ impl RmcpClient { params: Option, timeout: Option, ) -> Result { + self.refresh_oauth_if_needed().await; let service = self.service().await?; let rmcp_params = params .map(convert_to_rmcp::<_, PaginatedRequestParam>) @@ -310,6 +313,7 @@ impl RmcpClient { params: ReadResourceRequestParams, timeout: Option, ) -> Result { + self.refresh_oauth_if_needed().await; let service = self.service().await?; let rmcp_params: ReadResourceRequestParam = convert_to_rmcp(params)?; let fut = service.read_resource(rmcp_params); @@ -325,6 +329,7 @@ impl RmcpClient { arguments: Option, timeout: Option, ) -> Result { + self.refresh_oauth_if_needed().await; let service = self.service().await?; let params = CallToolRequestParams { arguments, name }; let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?; @@ -363,6 +368,14 @@ impl RmcpClient { warn!("failed to persist OAuth tokens: {error}"); } } + + async fn refresh_oauth_if_needed(&self) { + if let Some(runtime) = self.oauth_persistor().await + && let Err(error) = runtime.refresh_if_needed().await + { + warn!("failed to refresh OAuth tokens: {error}"); + } + } } async fn create_oauth_transport_and_runtime(