Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cli: add refresh loop for tunnel, touch GH token #194834

Merged
merged 1 commit into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
163 changes: 132 additions & 31 deletions cli/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use crate::{
constants::{get_default_user_agent, PRODUCT_NAME_LONG},
debug, info, log,
debug, error, info, log,
state::{LauncherPaths, PersistedState},
trace,
util::{
Expand All @@ -18,7 +18,7 @@ use crate::{
warning,
};
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use chrono::{DateTime, Utc};
use gethostname::gethostname;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{cell::Cell, fmt::Display, path::PathBuf, sync::Arc, thread};
Expand Down Expand Up @@ -112,6 +112,20 @@ pub struct StoredCredential {
expires_at: Option<DateTime<Utc>>,
}

const GH_USER_ENDPOINT: &str = "https://api.github.com/user";

async fn get_github_user(
client: &reqwest::Client,
access_token: &str,
) -> Result<reqwest::Response, reqwest::Error> {
client
.get(GH_USER_ENDPOINT)
.header("Authorization", format!("token {}", access_token))
.header("User-Agent", get_default_user_agent())
.send()
.await
}

impl StoredCredential {
pub async fn is_expired(&self, log: &log::Logger, client: &reqwest::Client) -> bool {
match self.provider {
Expand All @@ -124,12 +138,7 @@ impl StoredCredential {
// only on a verifiable 4xx code. We don't error on any failed
// request since then a drop in connection could "require" a refresh
AuthProvider::Github => {
let res = client
.get("https://api.github.com/user")
.header("Authorization", format!("token {}", self.access_token))
.header("User-Agent", get_default_user_agent())
.send()
.await;
let res = get_github_user(client, &self.access_token).await;
let res = match res {
Ok(r) => r,
Err(e) => {
Expand All @@ -154,7 +163,9 @@ impl StoredCredential {
provider,
access_token: auth.access_token,
refresh_token: auth.refresh_token,
expires_at: auth.expires_in.map(|e| Utc::now() + Duration::seconds(e)),
expires_at: auth
.expires_in
.map(|e| Utc::now() + chrono::Duration::seconds(e)),
}
}
}
Expand Down Expand Up @@ -489,7 +500,7 @@ impl Auth {
let entry = match self.get_current_credential() {
Ok(Some(old_creds)) => {
trace!(self.log, "Found token in keyring");
match self.get_refreshed_token(&old_creds).await {
match self.maybe_refresh_token(&old_creds).await {
Ok(Some(new_creds)) => {
self.store_credentials(new_creds.clone());
new_creds
Expand Down Expand Up @@ -555,29 +566,40 @@ impl Auth {

/// Refreshes the token in the credentials if necessary. Returns None if
/// the token is up to date, or Some new token otherwise.
async fn get_refreshed_token(
async fn maybe_refresh_token(
&self,
creds: &StoredCredential,
) -> Result<Option<StoredCredential>, AnyError> {
if !creds.is_expired(&self.log, &self.client).await {
return Ok(None);
}

let refresh_token = match &creds.refresh_token {
Some(t) => t,
None => return Err(AnyError::from(RefreshTokenNotAvailableError())),
};
self.do_refresh_token(creds).await
}

self.do_grant(
creds.provider,
format!(
"client_id={}&grant_type=refresh_token&refresh_token={}",
creds.provider.client_id(),
refresh_token
),
)
.await
.map(Some)
/// Refreshes the token in the credentials. Returns an error if the process failed.
/// Returns None if the token didn't change.
async fn do_refresh_token(
&self,
creds: &StoredCredential,
) -> Result<Option<StoredCredential>, AnyError> {
match &creds.refresh_token {
Some(t) => self
.do_grant(
creds.provider,
format!(
"client_id={}&grant_type=refresh_token&refresh_token={}",
creds.provider.client_id(),
t
),
)
.await
.map(Some),
None => match creds.provider {
AuthProvider::Github => self.touch_github_token(creds).await.map(|_| None),
_ => Err(RefreshTokenNotAvailableError().into()),
},
}
}

/// Does a "grant token" request.
Expand All @@ -600,22 +622,47 @@ impl Auth {
return Ok(StoredCredential::from_response(body, provider));
}

return Err(Auth::handle_grant_error(
provider.grant_uri(),
status_code,
body,
));
}

/// GH doesn't have a refresh token, but does limit to the 10 most recently
/// used tokens per user (#9052), so for the github "refresh" just request
/// the current user.
async fn touch_github_token(&self, credential: &StoredCredential) -> Result<(), AnyError> {
let response = get_github_user(&self.client, &credential.access_token).await?;
if response.status().is_success() {
return Ok(());
}

let status_code = response.status().as_u16();
let body = response.bytes().await?;
Err(Auth::handle_grant_error(
GH_USER_ENDPOINT,
status_code,
body,
))
}

fn handle_grant_error(url: &str, status_code: u16, body: bytes::Bytes) -> AnyError {
if let Ok(res) = serde_json::from_slice::<AuthenticationError>(&body) {
return Err(OAuthError {
return OAuthError {
error: res.error,
error_description: res.error_description,
}
.into());
.into();
}

return Err(StatusError {
return StatusError {
body: String::from_utf8_lossy(&body).to_string(),
status_code,
url: provider.grant_uri().to_string(),
url: url.to_string(),
}
.into());
.into();
}

/// Implements the device code flow, returning the credentials upon success.
async fn do_device_code_flow(&self) -> Result<StoredCredential, AnyError> {
let provider = self.prompt_for_provider().await?;
Expand Down Expand Up @@ -683,13 +730,67 @@ impl Auth {
interval_s += 5; // https://www.rfc-editor.org/rfc/rfc8628#section-3.5
trace!(self.log, "refresh poll failed, slowing down");
}
// Github returns a non-standard 429 to slow down
Err(AnyError::StatusError(e)) if e.status_code == 429 => {
interval_s += 5; // https://www.rfc-editor.org/rfc/rfc8628#section-3.5
trace!(self.log, "refresh poll failed, slowing down");
}
Err(e) => {
trace!(self.log, "refresh poll failed, retrying: {}", e);
}
}
}
}
}

/// Maintains the stored credential by refreshing it against the service
/// to ensure its stays current. Returns a future that should be polled and
/// only errors if a refresh fails in a consistent way.
pub async fn keep_token_alive(self) -> Result<(), AnyError> {
let this = self.clone();
let default_refresh = std::time::Duration::from_secs(60 * 60);
let min_refresh = std::time::Duration::from_secs(10);

let mut credential = this.get_credential().await?;
let mut last_did_error = false;
loop {
let sleep_time = if last_did_error {
min_refresh
} else {
match credential.expires_at {
Some(d) => ((d - Utc::now()) * 2 / 3).to_std().unwrap_or(min_refresh),
None => default_refresh,
}
};

// to_std errors on negative duration, fall back to a 60s refresh
tokio::time::sleep(sleep_time.max(min_refresh)).await;

match this.do_refresh_token(&credential).await {
// 4xx error means this token is probably not good any mode
Err(AnyError::StatusError(e)) if e.status_code >= 400 && e.status_code < 500 => {
error!(this.log, "failed to keep token alive: {:?}", e);
return Err(e.into());
}
Err(e) if matches!(e, AnyError::RefreshTokenNotAvailableError(_)) => {
return Ok(());
}
Err(e) => {
warning!(this.log, "error refreshing token: {:?}", e);
last_did_error = true;
continue;
}
Ok(c) => {
trace!(this.log, "token was successfully refreshed in keepalive");
last_did_error = false;
if let Some(c) = c {
this.store_credentials(c.clone());
credential = c;
}
}
}
}
}
}

#[async_trait]
Expand Down
33 changes: 30 additions & 3 deletions cli/src/tunnels/dev_tunnels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use crate::util::errors::{
use crate::util::input::prompt_placeholder;
use crate::{debug, info, log, spanf, trace, warning};
use async_trait::async_trait;
use futures::TryFutureExt;
use futures::future::BoxFuture;
use futures::{FutureExt, TryFutureExt};
use lazy_static::lazy_static;
use rand::prelude::IteratorRandom;
use regex::Regex;
Expand Down Expand Up @@ -62,6 +63,11 @@ impl PersistedTunnel {
trait AccessTokenProvider: Send + Sync {
/// Gets the current access token.
async fn refresh_token(&self) -> Result<String, WrappedError>;

/// Maintains the stored credential by refreshing it against the service
/// to ensure its stays current. Returns a future that should be polled and
/// only completes if a refresh fails in a consistent way.
fn keep_alive(&self) -> BoxFuture<'static, Result<(), AnyError>>;
}

/// Access token provider that provides a fixed token without refreshing.
Expand All @@ -78,10 +84,15 @@ impl AccessTokenProvider for StaticAccessTokenProvider {
async fn refresh_token(&self) -> Result<String, WrappedError> {
Ok(self.0.clone())
}

fn keep_alive(&self) -> BoxFuture<'static, Result<(), AnyError>> {
futures::future::pending().boxed()
}
}

/// Access token provider that looks up the token from the tunnels API.
struct LookupAccessTokenProvider {
auth: auth::Auth,
client: TunnelManagementClient,
locator: TunnelLocator,
log: log::Logger,
Expand All @@ -90,12 +101,14 @@ struct LookupAccessTokenProvider {

impl LookupAccessTokenProvider {
pub fn new(
auth: auth::Auth,
client: TunnelManagementClient,
locator: TunnelLocator,
log: log::Logger,
initial_token: Option<String>,
) -> Self {
Self {
auth,
client,
locator,
log,
Expand Down Expand Up @@ -130,10 +143,16 @@ impl AccessTokenProvider for LookupAccessTokenProvider {
Err(e) => Err(wrap(e, "failed to lookup tunnel for host token")),
}
}

fn keep_alive(&self) -> BoxFuture<'static, Result<(), AnyError>> {
let auth = self.auth.clone();
auth.keep_token_alive().boxed()
}
}

#[derive(Clone)]
pub struct DevTunnels {
auth: auth::Auth,
log: log::Logger,
launcher_tunnel: PersistedState<Option<PersistedTunnel>>,
client: TunnelManagementClient,
Expand Down Expand Up @@ -276,9 +295,10 @@ impl DevTunnels {
paths: &LauncherPaths,
) -> DevTunnels {
let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
client.authorization_provider(auth);
client.authorization_provider(auth.clone());

DevTunnels {
auth,
log: log.clone(),
client: client.into(),
launcher_tunnel: PersistedState::new(paths.root().join("port_forwarding_tunnel.json")),
Expand All @@ -293,9 +313,10 @@ impl DevTunnels {
paths: &LauncherPaths,
) -> DevTunnels {
let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
client.authorization_provider(auth);
client.authorization_provider(auth.clone());

DevTunnels {
auth,
log: log.clone(),
client: client.into(),
launcher_tunnel: PersistedState::new(paths.root().join("code_tunnel.json")),
Expand Down Expand Up @@ -491,6 +512,7 @@ impl DevTunnels {
&persisted,
self.client.clone(),
LookupAccessTokenProvider::new(
self.auth.clone(),
self.client.clone(),
locator,
self.log.clone(),
Expand Down Expand Up @@ -1013,6 +1035,7 @@ impl ActiveTunnelManager {
access_token_provider: impl AccessTokenProvider + 'static,
status: StatusLock,
) {
let mut token_ka = access_token_provider.keep_alive();
let mut backoff = Backoff::new(Duration::from_secs(5), Duration::from_secs(120));

macro_rules! fail {
Expand Down Expand Up @@ -1069,6 +1092,10 @@ impl ActiveTunnelManager {
backoff.delay().await;
}
},
Err(e) = &mut token_ka => {
error!(log, "access token is no longer valid, exiting: {}", e);
return;
},
_ = close_rx.recv() => {
trace!(log, "Tunnel closing gracefully");
trace!(log, "Tunnel closed with result: {:?}", handle.close().await);
Expand Down