From 64655bf3f572975aa769199181113bacab47fb11 Mon Sep 17 00:00:00 2001 From: Michalis Mengisoglou Date: Wed, 13 Mar 2024 15:00:04 +0200 Subject: [PATCH] Expose leeway in clients Commit 3da1fdc introduced a "leeway" parameter for proactive token refreshing. Expose this parameter in clients (e.g., requests client) to allow configuring it by the library's users. --- authlib/integrations/httpx_client/oauth2_client.py | 6 +++--- .../requests_client/assertion_session.py | 7 ++++--- .../integrations/requests_client/oauth2_session.py | 7 +++++-- authlib/oauth2/client.py | 10 ++++++++-- authlib/oauth2/rfc7521/client.py | 3 ++- docs/client/oauth2.rst | 4 ++++ tests/clients/test_requests/test_oauth2_session.py | 12 ++++++++++++ 7 files changed, 38 insertions(+), 11 deletions(-) diff --git a/authlib/integrations/httpx_client/oauth2_client.py b/authlib/integrations/httpx_client/oauth2_client.py index d4ee0f58..5b2d3fdd 100644 --- a/authlib/integrations/httpx_client/oauth2_client.py +++ b/authlib/integrations/httpx_client/oauth2_client.py @@ -58,7 +58,7 @@ def __init__(self, client_id=None, client_secret=None, revocation_endpoint_auth_method=None, scope=None, redirect_uri=None, token=None, token_placement='header', - update_token=None, **kwargs): + update_token=None, leeway=60, **kwargs): # extract httpx.Client kwargs client_kwargs = self._extract_session_request_params(kwargs) @@ -75,7 +75,7 @@ def __init__(self, client_id=None, client_secret=None, revocation_endpoint_auth_method=revocation_endpoint_auth_method, scope=scope, redirect_uri=redirect_uri, token=token, token_placement=token_placement, - update_token=update_token, **kwargs + update_token=update_token, leeway=leeway, **kwargs ) async def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs): @@ -106,7 +106,7 @@ async def stream(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAUL async def ensure_active_token(self, token): async with self._token_refresh_lock: - if self.token.is_expired(): + if self.token.is_expired(leeway=self.leeway): refresh_token = token.get('refresh_token') url = self.metadata.get('token_endpoint') if refresh_token and url: diff --git a/authlib/integrations/requests_client/assertion_session.py b/authlib/integrations/requests_client/assertion_session.py index d07c0016..de41dceb 100644 --- a/authlib/integrations/requests_client/assertion_session.py +++ b/authlib/integrations/requests_client/assertion_session.py @@ -7,7 +7,7 @@ class AssertionAuth(OAuth2Auth): def ensure_active_token(self): - if not self.token or self.token.is_expired() and self.client: + if self.client and (not self.token or self.token.is_expired(self.client.leeway)): return self.client.refresh_token() @@ -25,7 +25,8 @@ class AssertionSession(AssertionClient, Session): DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None, - claims=None, token_placement='header', scope=None, default_timeout=None, **kwargs): + claims=None, token_placement='header', scope=None, default_timeout=None, + leeway=60, **kwargs): Session.__init__(self) self.default_timeout = default_timeout update_session_configure(self, kwargs) @@ -33,7 +34,7 @@ def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=No self, session=self, token_endpoint=token_endpoint, issuer=issuer, subject=subject, audience=audience, grant_type=grant_type, claims=claims, - token_placement=token_placement, scope=scope, **kwargs + token_placement=token_placement, scope=scope, leeway=leeway, **kwargs ) def request(self, method, url, withhold_token=False, auth=None, **kwargs): diff --git a/authlib/integrations/requests_client/oauth2_session.py b/authlib/integrations/requests_client/oauth2_session.py index 9e2426a2..93586568 100644 --- a/authlib/integrations/requests_client/oauth2_session.py +++ b/authlib/integrations/requests_client/oauth2_session.py @@ -64,6 +64,9 @@ class OAuth2Session(OAuth2Client, Session): values: "header", "body", "uri". :param update_token: A function for you to update token. It accept a :class:`OAuth2Token` as parameter. + :param leeway: Time window in seconds before the actual expiration of the + authentication token, that the token is considered expired and will + be refreshed. :param default_timeout: If settled, every requests will have a default timeout. """ client_auth_class = OAuth2ClientAuth @@ -79,7 +82,7 @@ def __init__(self, client_id=None, client_secret=None, revocation_endpoint_auth_method=None, scope=None, state=None, redirect_uri=None, token=None, token_placement='header', - update_token=None, default_timeout=None, **kwargs): + update_token=None, leeway=60, default_timeout=None, **kwargs): Session.__init__(self) self.default_timeout = default_timeout update_session_configure(self, kwargs) @@ -91,7 +94,7 @@ def __init__(self, client_id=None, client_secret=None, revocation_endpoint_auth_method=revocation_endpoint_auth_method, scope=scope, state=state, redirect_uri=redirect_uri, token=token, token_placement=token_placement, - update_token=update_token, **kwargs + update_token=update_token, leeway=leeway, **kwargs ) def fetch_access_token(self, url=None, **kwargs): diff --git a/authlib/oauth2/client.py b/authlib/oauth2/client.py index e3fd1355..d36d93f0 100644 --- a/authlib/oauth2/client.py +++ b/authlib/oauth2/client.py @@ -38,6 +38,9 @@ class OAuth2Client: values: "header", "body", "uri". :param update_token: A function for you to update token. It accept a :class:`OAuth2Token` as parameter. + :param leeway: Time window in seconds before the actual expiration of the + authentication token, that the token is considered expired and will + be refreshed. """ client_auth_class = ClientAuth token_auth_class = TokenAuth @@ -52,7 +55,8 @@ def __init__(self, session, client_id=None, client_secret=None, token_endpoint_auth_method=None, revocation_endpoint_auth_method=None, scope=None, state=None, redirect_uri=None, code_challenge_method=None, - token=None, token_placement='header', update_token=None, **metadata): + token=None, token_placement='header', update_token=None, leeway=60, + **metadata): self.session = session self.client_id = client_id @@ -97,6 +101,8 @@ def __init__(self, session, client_id=None, client_secret=None, } self._auth_methods = {} + self.leeway = leeway + def register_client_auth_method(self, auth): """Extend client authenticate for token endpoint. @@ -263,7 +269,7 @@ def refresh_token(self, url=None, refresh_token=None, body='', def ensure_active_token(self, token=None): if token is None: token = self.token - if not token.is_expired(): + if not token.is_expired(leeway=self.leeway): return True refresh_token = token.get('refresh_token') url = self.metadata.get('token_endpoint') diff --git a/authlib/oauth2/rfc7521/client.py b/authlib/oauth2/rfc7521/client.py index e7ce2c3c..cf431047 100644 --- a/authlib/oauth2/rfc7521/client.py +++ b/authlib/oauth2/rfc7521/client.py @@ -15,7 +15,7 @@ class AssertionClient: def __init__(self, session, token_endpoint, issuer, subject, audience=None, grant_type=None, claims=None, - token_placement='header', scope=None, **kwargs): + token_placement='header', scope=None, leeway=60, **kwargs): self.session = session @@ -38,6 +38,7 @@ def __init__(self, session, token_endpoint, issuer, subject, if self.token_auth_class is not None: self.token_auth = self.token_auth_class(None, token_placement, self) self._kwargs = kwargs + self.leeway = leeway @property def token(self): diff --git a/docs/client/oauth2.rst b/docs/client/oauth2.rst index a4623ccf..c53f10f7 100644 --- a/docs/client/oauth2.rst +++ b/docs/client/oauth2.rst @@ -280,6 +280,10 @@ it has expired:: >>> openid_configuration = requests.get("https://example.org/.well-known/openid-configuration").json() >>> session = OAuth2Session(…, token_endpoint=openid_configuration["token_endpoint"]) +By default, the token will be refreshed 60 seconds before its actual expiry time, to avoid clock skew issues. +You can control this behaviour by setting the ``leeway`` parameter of the :class:`~requests_client.OAuth2Session` +class. + Manually refreshing tokens ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/clients/test_requests/test_oauth2_session.py b/tests/clients/test_requests/test_oauth2_session.py index 8afc8dea..c6c51c34 100644 --- a/tests/clients/test_requests/test_oauth2_session.py +++ b/tests/clients/test_requests/test_oauth2_session.py @@ -295,6 +295,18 @@ def test_token_status(self): self.assertTrue(sess.token.is_expired) + def test_token_status2(self): + token = dict(access_token='a', token_type='bearer', expires_in=10) + sess = OAuth2Session('foo', token=token, leeway=15) + + self.assertTrue(sess.token.is_expired(sess.leeway)) + + def test_token_status3(self): + token = dict(access_token='a', token_type='bearer', expires_in=10) + sess = OAuth2Session('foo', token=token, leeway=5) + + self.assertFalse(sess.token.is_expired(sess.leeway)) + def test_token_expired(self): token = dict(access_token='a', token_type='bearer', expires_at=100) sess = OAuth2Session('foo', token=token)