Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 140 additions & 25 deletions code-rs/core/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ impl CodexAuth {
if !access_token_is_still_valid(&tokens.access_token, Utc::now()) {
return Err(err);
}
self.record_proactive_refresh_fallback(Utc::now());
}
}
}
Expand Down Expand Up @@ -282,6 +283,23 @@ impl CodexAuth {
self.get_current_auth_json().and_then(|t| t.tokens.clone())
}

fn record_proactive_refresh_fallback(&self, timestamp: DateTime<Utc>) {
let updated = {
let mut guard = self.auth_dot_json.lock().unwrap();
let Some(auth_dot_json) = guard.as_mut() else {
return;
};
auth_dot_json.last_refresh = Some(timestamp);
auth_dot_json.clone()
};

if !self.auth_file.as_os_str().is_empty() {
if let Err(err) = write_auth_json(&self.auth_file, &updated) {
tracing::warn!("failed to persist proactive refresh fallback cooldown: {err}");
}
}
}

/// Consider this private to integration tests.
pub fn create_dummy_chatgpt_auth_for_testing() -> Self {
let auth_dot_json = AuthDotJson {
Expand Down Expand Up @@ -363,16 +381,29 @@ fn should_proactively_refresh_auth(
last_refresh: Option<DateTime<Utc>>,
access_token: Option<&str>,
) -> bool {
let now = Utc::now();
if let Some(access_token) = access_token
&& let Ok(Some(expires_at)) = parse_jwt_expiration(access_token)
{
return expires_at
<= Utc::now()
+ chrono::Duration::minutes(CHATGPT_ACCESS_TOKEN_REFRESH_WINDOW_MINUTES);
if expires_at <= now {
return true;
}
if expires_at
<= now + chrono::Duration::minutes(CHATGPT_ACCESS_TOKEN_REFRESH_WINDOW_MINUTES)
{
return last_refresh.is_none_or(|last_refresh| {
last_refresh
< now
- chrono::Duration::minutes(
CHATGPT_ACCESS_TOKEN_REFRESH_RETRY_COOLDOWN_MINUTES,
)
});
}
return false;
}

last_refresh.is_some_and(|last_refresh| {
last_refresh < Utc::now() - chrono::Duration::days(28)
last_refresh < now - chrono::Duration::days(28)
})
}

Expand Down Expand Up @@ -503,21 +534,8 @@ pub async fn auth_for_stored_account(
})?;
let mut last_refresh = account.last_refresh;
let now = Utc::now();
let refresh_needed = if account.mode == AuthMode::ChatGPT {
if let Ok(Some(expires_at)) = parse_jwt_expiration(&tokens.access_token) {
expires_at
<= now
+ chrono::Duration::minutes(
CHATGPT_ACCESS_TOKEN_REFRESH_WINDOW_MINUTES,
)
} else {
last_refresh
.map(|last| last < now - chrono::Duration::days(28))
.unwrap_or(true)
}
} else {
false
};
let refresh_needed = account.mode == AuthMode::ChatGPT
&& should_refresh_stored_account_auth(last_refresh, &tokens.access_token);

if refresh_needed {
let client = crate::default_client::create_client(originator);
Expand All @@ -534,6 +552,10 @@ pub async fn auth_for_stored_account(
Ok(response) => response,
Err(err) => {
if access_token_is_still_valid(&tokens.access_token, Utc::now()) {
last_refresh = Some(record_stored_account_proactive_refresh_fallback(
code_home,
&account.id,
));
return Ok(CodexAuth::from_tokens_with_originator_and_mode(
tokens,
last_refresh,
Expand Down Expand Up @@ -563,6 +585,10 @@ pub async fn auth_for_stored_account(
}
}
if access_token_is_still_valid(&tokens.access_token, Utc::now()) {
last_refresh = Some(record_stored_account_proactive_refresh_fallback(
code_home,
&account.id,
));
return Ok(CodexAuth::from_tokens_with_originator_and_mode(
tokens,
last_refresh,
Expand Down Expand Up @@ -602,6 +628,30 @@ pub async fn auth_for_stored_account(
}
}

fn record_stored_account_proactive_refresh_fallback(
code_home: &Path,
account_id: &str,
) -> DateTime<Utc> {
let now = Utc::now();
if let Err(err) = crate::auth_accounts::update_account_last_refresh(code_home, account_id, now) {
tracing::warn!("failed to persist proactive refresh fallback cooldown: {err}");
}
now
}

fn should_refresh_stored_account_auth(
last_refresh: Option<DateTime<Utc>>,
access_token: &str,
) -> bool {
if let Ok(Some(_)) = parse_jwt_expiration(access_token) {
return should_proactively_refresh_auth(last_refresh, Some(access_token));
}

last_refresh
.map(|last| last < Utc::now() - chrono::Duration::days(28))
.unwrap_or(true)
}

/// Activate a stored account by writing its credentials to auth.json and
/// marking it active in the account store.
pub fn activate_account(code_home: &Path, account_id: &str) -> std::io::Result<()> {
Expand Down Expand Up @@ -1022,6 +1072,7 @@ pub struct AuthDotJson {
// Shared constant for token refresh (client id used for oauth token refresh flow)
pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
const CHATGPT_ACCESS_TOKEN_REFRESH_WINDOW_MINUTES: i64 = 5;
const CHATGPT_ACCESS_TOKEN_REFRESH_RETRY_COOLDOWN_MINUTES: i64 = 5;

use std::sync::RwLock;

Expand Down Expand Up @@ -1432,6 +1483,16 @@ mod tests {
assert!(!should_proactively_refresh_auth(Some(stale), Some(&future_access)));
assert!(should_proactively_refresh_auth(Some(fresh), Some(&expiring_access)));
assert!(should_proactively_refresh_auth(Some(fresh), Some(&expired_access)));

let just_attempted = Utc::now();
assert!(!should_proactively_refresh_auth(
Some(just_attempted),
Some(&expiring_access)
));
assert!(should_proactively_refresh_auth(
Some(just_attempted),
Some(&expired_access)
));
}

#[test]
Expand All @@ -1455,13 +1516,44 @@ mod tests {
let _guard = EnvVarGuard::set(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR, server.uri());
let code_home = tempdir().unwrap();
let access_token = build_jwt(serde_json::json!({ "exp": Utc::now().timestamp() + 240 }));
let account = stored_chatgpt_account(access_token.clone());
let account = stored_chatgpt_account(
access_token.clone(),
Some(Utc::now() - chrono::Duration::minutes(10)),
);
let account = crate::auth_accounts::upsert_chatgpt_account(
code_home.path(),
account.tokens.clone().expect("account has tokens"),
account.last_refresh.expect("account has refresh time"),
account.label.clone(),
false,
)
.expect("seed stored account");

let auth = auth_for_stored_account(code_home.path(), &account, "test")
.await
.expect("valid cached token should survive proactive refresh failure");

assert_eq!(auth.get_token().await.unwrap(), access_token);
let returned_last_refresh = auth
.get_current_auth_json()
.and_then(|auth| auth.last_refresh)
.expect("fallback should record refresh cooldown");
assert!(returned_last_refresh > account.last_refresh.unwrap());

let accounts = crate::auth_accounts::list_accounts(code_home.path())
.expect("list stored accounts");
assert_eq!(accounts.len(), 1, "fallback should not duplicate account");
let stored = crate::auth_accounts::find_account(code_home.path(), &account.id)
.expect("read stored account")
.expect("original account should remain stored");
assert_eq!(stored.id, account.id);
let stored_tokens = stored.tokens.expect("stored account keeps tokens");
let account_tokens = account.tokens.expect("seeded account has tokens");
assert_eq!(stored_tokens.id_token.raw_jwt, account_tokens.id_token.raw_jwt);
assert_eq!(stored_tokens.access_token, account_tokens.access_token);
assert_eq!(stored_tokens.refresh_token, account_tokens.refresh_token);
assert_eq!(stored_tokens.account_id, account_tokens.account_id);
assert!(stored.last_refresh.unwrap() > account.last_refresh.unwrap());
}

#[tokio::test]
Expand All @@ -1477,7 +1569,7 @@ mod tests {
let tokens = token_data_for_access(access_token.clone());
let auth = CodexAuth::from_tokens_with_originator_and_mode(
tokens,
Some(Utc::now()),
Some(Utc::now() - chrono::Duration::minutes(10)),
"test",
AuthMode::ChatGPT,
);
Expand All @@ -1488,6 +1580,16 @@ mod tests {
.expect("valid cached token should survive proactive refresh failure");

assert_eq!(token_data.access_token, access_token);
let requests_after_fallback = server.received_requests().await.unwrap().len();
assert_eq!(requests_after_fallback, 4);

let token_data_again = auth
.get_token_data()
.await
.expect("fallback should suppress immediate retry");

assert_eq!(token_data_again.access_token, access_token);
assert_eq!(server.received_requests().await.unwrap().len(), requests_after_fallback);
}

#[tokio::test]
Expand All @@ -1501,7 +1603,7 @@ mod tests {
let _guard = EnvVarGuard::set(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR, server.uri());
let code_home = tempdir().unwrap();
let access_token = build_jwt(serde_json::json!({ "exp": Utc::now().timestamp() - 60 }));
let account = stored_chatgpt_account(access_token);
let account = stored_chatgpt_account(access_token, Some(Utc::now()));

let err = auth_for_stored_account(code_home.path(), &account, "test")
.await
Expand Down Expand Up @@ -1555,22 +1657,35 @@ mod tests {
format!("{header_b64}.{payload_b64}.{signature_b64}")
}

fn stored_chatgpt_account(access_token: String) -> crate::auth_accounts::StoredAccount {
fn stored_chatgpt_account(
access_token: String,
last_refresh: Option<DateTime<Utc>>,
) -> crate::auth_accounts::StoredAccount {
crate::auth_accounts::StoredAccount {
id: "account-id".to_string(),
mode: AuthMode::ChatGPT,
label: None,
openai_api_key: None,
tokens: Some(token_data_for_access(access_token)),
last_refresh: Some(Utc::now()),
last_refresh,
created_at: None,
last_used_at: None,
}
}

fn token_data_for_access(access_token: String) -> TokenData {
let raw_jwt = build_jwt(serde_json::json!({
"email": "user@example.com",
"https://api.openai.com/auth": {
"chatgpt_plan_type": "plus"
}
}));
TokenData {
id_token: IdTokenInfo::default(),
id_token: IdTokenInfo {
raw_jwt,
email: None,
..Default::default()
},
access_token,
refresh_token: "refresh-token".to_string(),
account_id: Some("account-id".to_string()),
Expand Down
17 changes: 17 additions & 0 deletions code-rs/core/src/auth_accounts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,23 @@ pub fn find_account(code_home: &Path, account_id: &str) -> io::Result<Option<Sto
.find(|acc| acc.id == account_id))
}

pub fn update_account_last_refresh(
code_home: &Path,
account_id: &str,
last_refresh: DateTime<Utc>,
) -> io::Result<Option<StoredAccount>> {
let path = accounts_file_path(code_home);
let mut data = read_accounts_file(&path)?;

let Some(account) = data.accounts.iter_mut().find(|acc| acc.id == account_id) else {
return Ok(None);
};
account.last_refresh = Some(last_refresh);
let updated = account.clone();
write_accounts_file(&path, &data)?;
Ok(Some(updated))
}

pub fn set_active_account_id(
code_home: &Path,
account_id: Option<String>,
Expand Down