Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose leeway in clients #643

Merged
merged 1 commit into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions authlib/integrations/httpx_client/oauth2_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, client_id=None, client_secret=None,
revocation_endpoint_auth_method=None,
scope=None, redirect_uri=None,
token=None, token_placement='header',
update_token=None, **kwargs):
update_token=None, leeway=60, **kwargs):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why default to 60?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because that's the default we've used at the lower-level function as well:

def is_expired(self, leeway=60):

Other than that, I don't have a strong opinion.


# extract httpx.Client kwargs
client_kwargs = self._extract_session_request_params(kwargs)
Expand All @@ -75,7 +75,7 @@ def __init__(self, client_id=None, client_secret=None,
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
scope=scope, redirect_uri=redirect_uri,
token=token, token_placement=token_placement,
update_token=update_token, **kwargs
update_token=update_token, leeway=leeway, **kwargs
)

async def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
Expand Down Expand Up @@ -106,7 +106,7 @@ async def stream(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAUL

async def ensure_active_token(self, token):
async with self._token_refresh_lock:
if self.token.is_expired():
if self.token.is_expired(leeway=self.leeway):
refresh_token = token.get('refresh_token')
url = self.metadata.get('token_endpoint')
if refresh_token and url:
Expand Down
7 changes: 4 additions & 3 deletions authlib/integrations/requests_client/assertion_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class AssertionAuth(OAuth2Auth):
def ensure_active_token(self):
if not self.token or self.token.is_expired() and self.client:
if self.client and (not self.token or self.token.is_expired(self.client.leeway)):
return self.client.refresh_token()


Expand All @@ -25,15 +25,16 @@ class AssertionSession(AssertionClient, Session):
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE

def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
claims=None, token_placement='header', scope=None, default_timeout=None, **kwargs):
claims=None, token_placement='header', scope=None, default_timeout=None,
leeway=60, **kwargs):
Session.__init__(self)
self.default_timeout = default_timeout
update_session_configure(self, kwargs)
AssertionClient.__init__(
self, session=self,
token_endpoint=token_endpoint, issuer=issuer, subject=subject,
audience=audience, grant_type=grant_type, claims=claims,
token_placement=token_placement, scope=scope, **kwargs
token_placement=token_placement, scope=scope, leeway=leeway, **kwargs
)

def request(self, method, url, withhold_token=False, auth=None, **kwargs):
Expand Down
7 changes: 5 additions & 2 deletions authlib/integrations/requests_client/oauth2_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class OAuth2Session(OAuth2Client, Session):
values: "header", "body", "uri".
:param update_token: A function for you to update token. It accept a
:class:`OAuth2Token` as parameter.
:param leeway: Time window in seconds before the actual expiration of the
authentication token, that the token is considered expired and will
be refreshed.
:param default_timeout: If settled, every requests will have a default timeout.
"""
client_auth_class = OAuth2ClientAuth
Expand All @@ -79,7 +82,7 @@ def __init__(self, client_id=None, client_secret=None,
revocation_endpoint_auth_method=None,
scope=None, state=None, redirect_uri=None,
token=None, token_placement='header',
update_token=None, default_timeout=None, **kwargs):
update_token=None, leeway=60, default_timeout=None, **kwargs):
Session.__init__(self)
self.default_timeout = default_timeout
update_session_configure(self, kwargs)
Expand All @@ -91,7 +94,7 @@ def __init__(self, client_id=None, client_secret=None,
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
scope=scope, state=state, redirect_uri=redirect_uri,
token=token, token_placement=token_placement,
update_token=update_token, **kwargs
update_token=update_token, leeway=leeway, **kwargs
)

def fetch_access_token(self, url=None, **kwargs):
Expand Down
10 changes: 8 additions & 2 deletions authlib/oauth2/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class OAuth2Client:
values: "header", "body", "uri".
:param update_token: A function for you to update token. It accept a
:class:`OAuth2Token` as parameter.
:param leeway: Time window in seconds before the actual expiration of the
authentication token, that the token is considered expired and will
be refreshed.
"""
client_auth_class = ClientAuth
token_auth_class = TokenAuth
Expand All @@ -52,7 +55,8 @@ def __init__(self, session, client_id=None, client_secret=None,
token_endpoint_auth_method=None,
revocation_endpoint_auth_method=None,
scope=None, state=None, redirect_uri=None, code_challenge_method=None,
token=None, token_placement='header', update_token=None, **metadata):
token=None, token_placement='header', update_token=None, leeway=60,
**metadata):

self.session = session
self.client_id = client_id
Expand Down Expand Up @@ -97,6 +101,8 @@ def __init__(self, session, client_id=None, client_secret=None,
}
self._auth_methods = {}

self.leeway = leeway

def register_client_auth_method(self, auth):
"""Extend client authenticate for token endpoint.

Expand Down Expand Up @@ -263,7 +269,7 @@ def refresh_token(self, url=None, refresh_token=None, body='',
def ensure_active_token(self, token=None):
if token is None:
token = self.token
if not token.is_expired():
if not token.is_expired(leeway=self.leeway):
return True
refresh_token = token.get('refresh_token')
url = self.metadata.get('token_endpoint')
Expand Down
3 changes: 2 additions & 1 deletion authlib/oauth2/rfc7521/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AssertionClient:

def __init__(self, session, token_endpoint, issuer, subject,
audience=None, grant_type=None, claims=None,
token_placement='header', scope=None, **kwargs):
token_placement='header', scope=None, leeway=60, **kwargs):

self.session = session

Expand All @@ -38,6 +38,7 @@ def __init__(self, session, token_endpoint, issuer, subject,
if self.token_auth_class is not None:
self.token_auth = self.token_auth_class(None, token_placement, self)
self._kwargs = kwargs
self.leeway = leeway

@property
def token(self):
Expand Down
4 changes: 4 additions & 0 deletions docs/client/oauth2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ it has expired::
>>> openid_configuration = requests.get("https://example.org/.well-known/openid-configuration").json()
>>> session = OAuth2Session(…, token_endpoint=openid_configuration["token_endpoint"])

By default, the token will be refreshed 60 seconds before its actual expiry time, to avoid clock skew issues.
You can control this behaviour by setting the ``leeway`` parameter of the :class:`~requests_client.OAuth2Session`
class.

Manually refreshing tokens
~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
12 changes: 12 additions & 0 deletions tests/clients/test_requests/test_oauth2_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,18 @@ def test_token_status(self):

self.assertTrue(sess.token.is_expired)

def test_token_status2(self):
token = dict(access_token='a', token_type='bearer', expires_in=10)
sess = OAuth2Session('foo', token=token, leeway=15)

self.assertTrue(sess.token.is_expired(sess.leeway))

def test_token_status3(self):
token = dict(access_token='a', token_type='bearer', expires_in=10)
sess = OAuth2Session('foo', token=token, leeway=5)

self.assertFalse(sess.token.is_expired(sess.leeway))

def test_token_expired(self):
token = dict(access_token='a', token_type='bearer', expires_at=100)
sess = OAuth2Session('foo', token=token)
Expand Down