diff --git a/oauthlib/oauth2/rfc6749/clients/base.py b/oauthlib/oauth2/rfc6749/clients/base.py index d5eb0cc1..1d12638e 100644 --- a/oauthlib/oauth2/rfc6749/clients/base.py +++ b/oauthlib/oauth2/rfc6749/clients/base.py @@ -589,11 +589,11 @@ def populate_token_attributes(self, response): if 'expires_in' in response: self.expires_in = response.get('expires_in') - self._expires_at = time.time() + int(self.expires_in) + self._expires_at = round(time.time()) + int(self.expires_in) if 'expires_at' in response: try: - self._expires_at = int(response.get('expires_at')) + self._expires_at = round(float(response.get('expires_at'))) except: self._expires_at = None diff --git a/oauthlib/oauth2/rfc6749/parameters.py b/oauthlib/oauth2/rfc6749/parameters.py index 8f6ce2c7..0f0f423a 100644 --- a/oauthlib/oauth2/rfc6749/parameters.py +++ b/oauthlib/oauth2/rfc6749/parameters.py @@ -345,7 +345,7 @@ def parse_implicit_response(uri, state=None, scope=None): params['scope'] = scope_to_list(params['scope']) if 'expires_in' in params: - params['expires_at'] = time.time() + int(params['expires_in']) + params['expires_at'] = round(time.time()) + int(params['expires_in']) if state and params.get('state', None) != state: raise ValueError("Mismatching or missing state in params.") @@ -437,6 +437,9 @@ def parse_token_response(body, scope=None): else: params['expires_at'] = time.time() + int(params['expires_in']) + if isinstance(params.get('expires_at'), float): + params['expires_at'] = round(params['expires_at']) + params = OAuth2Token(params, old_scope=scope) validate_token_parameters(params) return params diff --git a/tests/oauth2/rfc6749/clients/test_base.py b/tests/oauth2/rfc6749/clients/test_base.py index 70a22834..7286b991 100644 --- a/tests/oauth2/rfc6749/clients/test_base.py +++ b/tests/oauth2/rfc6749/clients/test_base.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import datetime +from unittest.mock import patch from oauthlib import common from oauthlib.oauth2 import Client, InsecureTransportError, TokenExpiredError @@ -353,3 +354,35 @@ def test_create_code_challenge_s256(self): code_verifier = client.create_code_verifier(length=128) code_challenge_s256 = client.create_code_challenge(code_verifier=code_verifier, code_challenge_method='S256') self.assertEqual(code_challenge_s256, client.code_challenge) + + def test_parse_token_response_expires_at_is_int(self): + expected_expires_at = 1661185149 + token_json = ('{ "access_token":"2YotnFZFEjr1zCsicMWpAA",' + ' "token_type":"example",' + ' "expires_at":1661185148.6437678,' + ' "scope":"/profile",' + ' "example_parameter":"example_value"}') + + client = Client(self.client_id) + + response = client.parse_request_body_response(token_json, scope=["/profile"]) + + self.assertEqual(response['expires_at'], expected_expires_at) + self.assertEqual(client._expires_at, expected_expires_at) + + @patch('time.time') + def test_parse_token_response_generated_expires_at_is_int(self, t): + t.return_value = 1661185148.6437678 + expected_expires_at = round(t.return_value) + 3600 + token_json = ('{ "access_token":"2YotnFZFEjr1zCsicMWpAA",' + ' "token_type":"example",' + ' "expires_in":3600,' + ' "scope":"/profile",' + ' "example_parameter":"example_value"}') + + client = Client(self.client_id) + + response = client.parse_request_body_response(token_json, scope=["/profile"]) + + self.assertEqual(response['expires_at'], expected_expires_at) + self.assertEqual(client._expires_at, expected_expires_at) diff --git a/tests/oauth2/rfc6749/clients/test_service_application.py b/tests/oauth2/rfc6749/clients/test_service_application.py index b97d8554..84361d8b 100644 --- a/tests/oauth2/rfc6749/clients/test_service_application.py +++ b/tests/oauth2/rfc6749/clients/test_service_application.py @@ -166,7 +166,7 @@ def test_request_body_no_initial_private_key(self, t): @patch('time.time') def test_parse_token_response(self, t): t.return_value = time() - self.token['expires_at'] = self.token['expires_in'] + t.return_value + self.token['expires_at'] = self.token['expires_in'] + round(t.return_value) client = ServiceApplicationClient(self.client_id)