diff --git a/code-rs/core/src/auth.rs b/code-rs/core/src/auth.rs index 5244fff76a4..12c617d633f 100644 --- a/code-rs/core/src/auth.rs +++ b/code-rs/core/src/auth.rs @@ -227,6 +227,7 @@ impl CodexAuth { if !access_token_is_still_valid(&tokens.access_token, Utc::now()) { return Err(err); } + self.record_proactive_refresh_fallback(Utc::now()); } } } @@ -282,6 +283,23 @@ impl CodexAuth { self.get_current_auth_json().and_then(|t| t.tokens.clone()) } + fn record_proactive_refresh_fallback(&self, timestamp: DateTime) { + let updated = { + let mut guard = self.auth_dot_json.lock().unwrap(); + let Some(auth_dot_json) = guard.as_mut() else { + return; + }; + auth_dot_json.last_refresh = Some(timestamp); + auth_dot_json.clone() + }; + + if !self.auth_file.as_os_str().is_empty() { + if let Err(err) = write_auth_json(&self.auth_file, &updated) { + tracing::warn!("failed to persist proactive refresh fallback cooldown: {err}"); + } + } + } + /// Consider this private to integration tests. pub fn create_dummy_chatgpt_auth_for_testing() -> Self { let auth_dot_json = AuthDotJson { @@ -363,16 +381,29 @@ fn should_proactively_refresh_auth( last_refresh: Option>, access_token: Option<&str>, ) -> bool { + let now = Utc::now(); if let Some(access_token) = access_token && let Ok(Some(expires_at)) = parse_jwt_expiration(access_token) { - return expires_at - <= Utc::now() - + chrono::Duration::minutes(CHATGPT_ACCESS_TOKEN_REFRESH_WINDOW_MINUTES); + if expires_at <= now { + return true; + } + if expires_at + <= now + chrono::Duration::minutes(CHATGPT_ACCESS_TOKEN_REFRESH_WINDOW_MINUTES) + { + return last_refresh.is_none_or(|last_refresh| { + last_refresh + < now + - chrono::Duration::minutes( + CHATGPT_ACCESS_TOKEN_REFRESH_RETRY_COOLDOWN_MINUTES, + ) + }); + } + return false; } last_refresh.is_some_and(|last_refresh| { - last_refresh < Utc::now() - chrono::Duration::days(28) + last_refresh < now - chrono::Duration::days(28) }) } @@ -503,21 +534,8 @@ pub async fn auth_for_stored_account( })?; let mut last_refresh = account.last_refresh; let now = Utc::now(); - let refresh_needed = if account.mode == AuthMode::ChatGPT { - if let Ok(Some(expires_at)) = parse_jwt_expiration(&tokens.access_token) { - expires_at - <= now - + chrono::Duration::minutes( - CHATGPT_ACCESS_TOKEN_REFRESH_WINDOW_MINUTES, - ) - } else { - last_refresh - .map(|last| last < now - chrono::Duration::days(28)) - .unwrap_or(true) - } - } else { - false - }; + let refresh_needed = account.mode == AuthMode::ChatGPT + && should_refresh_stored_account_auth(last_refresh, &tokens.access_token); if refresh_needed { let client = crate::default_client::create_client(originator); @@ -534,6 +552,10 @@ pub async fn auth_for_stored_account( Ok(response) => response, Err(err) => { if access_token_is_still_valid(&tokens.access_token, Utc::now()) { + last_refresh = Some(record_stored_account_proactive_refresh_fallback( + code_home, + &account.id, + )); return Ok(CodexAuth::from_tokens_with_originator_and_mode( tokens, last_refresh, @@ -563,6 +585,10 @@ pub async fn auth_for_stored_account( } } if access_token_is_still_valid(&tokens.access_token, Utc::now()) { + last_refresh = Some(record_stored_account_proactive_refresh_fallback( + code_home, + &account.id, + )); return Ok(CodexAuth::from_tokens_with_originator_and_mode( tokens, last_refresh, @@ -602,6 +628,30 @@ pub async fn auth_for_stored_account( } } +fn record_stored_account_proactive_refresh_fallback( + code_home: &Path, + account_id: &str, +) -> DateTime { + let now = Utc::now(); + if let Err(err) = crate::auth_accounts::update_account_last_refresh(code_home, account_id, now) { + tracing::warn!("failed to persist proactive refresh fallback cooldown: {err}"); + } + now +} + +fn should_refresh_stored_account_auth( + last_refresh: Option>, + access_token: &str, +) -> bool { + if let Ok(Some(_)) = parse_jwt_expiration(access_token) { + return should_proactively_refresh_auth(last_refresh, Some(access_token)); + } + + last_refresh + .map(|last| last < Utc::now() - chrono::Duration::days(28)) + .unwrap_or(true) +} + /// Activate a stored account by writing its credentials to auth.json and /// marking it active in the account store. pub fn activate_account(code_home: &Path, account_id: &str) -> std::io::Result<()> { @@ -1022,6 +1072,7 @@ pub struct AuthDotJson { // Shared constant for token refresh (client id used for oauth token refresh flow) pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; const CHATGPT_ACCESS_TOKEN_REFRESH_WINDOW_MINUTES: i64 = 5; +const CHATGPT_ACCESS_TOKEN_REFRESH_RETRY_COOLDOWN_MINUTES: i64 = 5; use std::sync::RwLock; @@ -1432,6 +1483,16 @@ mod tests { assert!(!should_proactively_refresh_auth(Some(stale), Some(&future_access))); assert!(should_proactively_refresh_auth(Some(fresh), Some(&expiring_access))); assert!(should_proactively_refresh_auth(Some(fresh), Some(&expired_access))); + + let just_attempted = Utc::now(); + assert!(!should_proactively_refresh_auth( + Some(just_attempted), + Some(&expiring_access) + )); + assert!(should_proactively_refresh_auth( + Some(just_attempted), + Some(&expired_access) + )); } #[test] @@ -1455,13 +1516,44 @@ mod tests { let _guard = EnvVarGuard::set(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR, server.uri()); let code_home = tempdir().unwrap(); let access_token = build_jwt(serde_json::json!({ "exp": Utc::now().timestamp() + 240 })); - let account = stored_chatgpt_account(access_token.clone()); + let account = stored_chatgpt_account( + access_token.clone(), + Some(Utc::now() - chrono::Duration::minutes(10)), + ); + let account = crate::auth_accounts::upsert_chatgpt_account( + code_home.path(), + account.tokens.clone().expect("account has tokens"), + account.last_refresh.expect("account has refresh time"), + account.label.clone(), + false, + ) + .expect("seed stored account"); let auth = auth_for_stored_account(code_home.path(), &account, "test") .await .expect("valid cached token should survive proactive refresh failure"); assert_eq!(auth.get_token().await.unwrap(), access_token); + let returned_last_refresh = auth + .get_current_auth_json() + .and_then(|auth| auth.last_refresh) + .expect("fallback should record refresh cooldown"); + assert!(returned_last_refresh > account.last_refresh.unwrap()); + + let accounts = crate::auth_accounts::list_accounts(code_home.path()) + .expect("list stored accounts"); + assert_eq!(accounts.len(), 1, "fallback should not duplicate account"); + let stored = crate::auth_accounts::find_account(code_home.path(), &account.id) + .expect("read stored account") + .expect("original account should remain stored"); + assert_eq!(stored.id, account.id); + let stored_tokens = stored.tokens.expect("stored account keeps tokens"); + let account_tokens = account.tokens.expect("seeded account has tokens"); + assert_eq!(stored_tokens.id_token.raw_jwt, account_tokens.id_token.raw_jwt); + assert_eq!(stored_tokens.access_token, account_tokens.access_token); + assert_eq!(stored_tokens.refresh_token, account_tokens.refresh_token); + assert_eq!(stored_tokens.account_id, account_tokens.account_id); + assert!(stored.last_refresh.unwrap() > account.last_refresh.unwrap()); } #[tokio::test] @@ -1477,7 +1569,7 @@ mod tests { let tokens = token_data_for_access(access_token.clone()); let auth = CodexAuth::from_tokens_with_originator_and_mode( tokens, - Some(Utc::now()), + Some(Utc::now() - chrono::Duration::minutes(10)), "test", AuthMode::ChatGPT, ); @@ -1488,6 +1580,16 @@ mod tests { .expect("valid cached token should survive proactive refresh failure"); assert_eq!(token_data.access_token, access_token); + let requests_after_fallback = server.received_requests().await.unwrap().len(); + assert_eq!(requests_after_fallback, 4); + + let token_data_again = auth + .get_token_data() + .await + .expect("fallback should suppress immediate retry"); + + assert_eq!(token_data_again.access_token, access_token); + assert_eq!(server.received_requests().await.unwrap().len(), requests_after_fallback); } #[tokio::test] @@ -1501,7 +1603,7 @@ mod tests { let _guard = EnvVarGuard::set(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR, server.uri()); let code_home = tempdir().unwrap(); let access_token = build_jwt(serde_json::json!({ "exp": Utc::now().timestamp() - 60 })); - let account = stored_chatgpt_account(access_token); + let account = stored_chatgpt_account(access_token, Some(Utc::now())); let err = auth_for_stored_account(code_home.path(), &account, "test") .await @@ -1555,22 +1657,35 @@ mod tests { format!("{header_b64}.{payload_b64}.{signature_b64}") } - fn stored_chatgpt_account(access_token: String) -> crate::auth_accounts::StoredAccount { + fn stored_chatgpt_account( + access_token: String, + last_refresh: Option>, + ) -> crate::auth_accounts::StoredAccount { crate::auth_accounts::StoredAccount { id: "account-id".to_string(), mode: AuthMode::ChatGPT, label: None, openai_api_key: None, tokens: Some(token_data_for_access(access_token)), - last_refresh: Some(Utc::now()), + last_refresh, created_at: None, last_used_at: None, } } fn token_data_for_access(access_token: String) -> TokenData { + let raw_jwt = build_jwt(serde_json::json!({ + "email": "user@example.com", + "https://api.openai.com/auth": { + "chatgpt_plan_type": "plus" + } + })); TokenData { - id_token: IdTokenInfo::default(), + id_token: IdTokenInfo { + raw_jwt, + email: None, + ..Default::default() + }, access_token, refresh_token: "refresh-token".to_string(), account_id: Some("account-id".to_string()), diff --git a/code-rs/core/src/auth_accounts.rs b/code-rs/core/src/auth_accounts.rs index 0c9f4afab24..044c4e7e425 100644 --- a/code-rs/core/src/auth_accounts.rs +++ b/code-rs/core/src/auth_accounts.rs @@ -247,6 +247,23 @@ pub fn find_account(code_home: &Path, account_id: &str) -> io::Result, +) -> io::Result> { + let path = accounts_file_path(code_home); + let mut data = read_accounts_file(&path)?; + + let Some(account) = data.accounts.iter_mut().find(|acc| acc.id == account_id) else { + return Ok(None); + }; + account.last_refresh = Some(last_refresh); + let updated = account.clone(); + write_accounts_file(&path, &data)?; + Ok(Some(updated)) +} + pub fn set_active_account_id( code_home: &Path, account_id: Option,