From 271f43b0cf97783071d1719c54b394c0f71c2fb0 Mon Sep 17 00:00:00 2001 From: Bassam Ojeil Date: Tue, 20 Jul 2021 22:17:34 -0700 Subject: [PATCH 1/2] feat: support refresh handler callable on google.oauth2.credentials.Credentials This is an optional parameter that can be set via the constructor. It is used to provide the credentials with new tokens and their expiration time on `refresh()` call. ``` def refresh_handler(request, scopes): # Generate a new token for the requested scopes by calling # an external process. return ( "ACCESS_TOKEN", _helpers.utcnow() + datetime.timedelta(seconds=3600)) creds = google.oauth2.credentials.Credentials( scopes=scopes, refresh_handler=refresh_handler) creds.refresh(request) ``` It is useful in the following cases: - Useful in general when tokens are obtained by calling some external process on demand. - Useful in particular for retrieving downscoped tokens from a token broker. This should have no impact on existing behavior. Refresh tokens will still have higher priority over refresh handlers. A getter and setter is exposed to make it easy to set the callable on unpickled credentials as the callable cannot be easily serialized. ``` unpickled = pickle.loads(pickle.dumps(oauth_creds)) unpickled.refresh_handler = refresh_handler ``` --- google/oauth2/credentials.py | 69 +++++++- tests/oauth2/test_credentials.py | 285 +++++++++++++++++++++++++++++++ 2 files changed, 351 insertions(+), 3 deletions(-) diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py index dcfa5f912..a8caf8eaf 100644 --- a/google/oauth2/credentials.py +++ b/google/oauth2/credentials.py @@ -74,6 +74,7 @@ def __init__( quota_project_id=None, expiry=None, rapt_token=None, + refresh_handler=None, ): """ Args: @@ -103,6 +104,13 @@ def __init__( This project may be different from the project used to create the credentials. rapt_token (Optional[str]): The reauth Proof Token. + refresh_handler (Optional[Callable[[google.auth.transport.Request, Sequence[str]], [str, datetime]]]): + A callable which takes in the HTTP request callable and the list of + OAuth scopes and when called returns an access token string for the + requested scopes and its expiry datetime. This is useful when no + refresh tokens are provided and tokens are obtained by calling + some external process on demand. It is particularly useful for + retrieving downscoped tokens from a token broker. """ super(Credentials, self).__init__() self.token = token @@ -116,13 +124,18 @@ def __init__( self._client_secret = client_secret self._quota_project_id = quota_project_id self._rapt_token = rapt_token + self.refresh_handler = refresh_handler def __getstate__(self): """A __getstate__ method must exist for the __setstate__ to be called This is identical to the default implementation. See https://docs.python.org/3.7/library/pickle.html#object.__setstate__ """ - return self.__dict__ + state_dict = self.__dict__.copy() + # Remove _refresh_handler function as functions can't be pickled. + # The refresh_handler setter should be used to repopulate this. + del state_dict["_refresh_handler"] + return state_dict def __setstate__(self, d): """Credentials pickled with older versions of the class do not have @@ -138,6 +151,8 @@ def __setstate__(self, d): self._client_secret = d.get("_client_secret") self._quota_project_id = d.get("_quota_project_id") self._rapt_token = d.get("_rapt_token") + # The refresh_handler setter should be used to repopulate this. + self._refresh_handler = None @property def refresh_token(self): @@ -187,6 +202,31 @@ def rapt_token(self): """Optional[str]: The reauth Proof Token.""" return self._rapt_token + @property + def refresh_handler(self): + """Returns the refresh handler if available. + + Returns: + Optional[Callable[[google.auth.transport.Request, Sequence[str]], [str, datetime]]]: + The current refresh handler. + """ + return self._refresh_handler + + @refresh_handler.setter + def refresh_handler(self, value): + """Updates the current refresh handler. + + Args: + value (Optional[Callable[[google.auth.transport.Request, Sequence[str]], [str, datetime]]]): + The updated value of the refresh handler. + + Raises: + TypeError: If the value is not a callable or None. + """ + if not callable(value) and value is not None: + raise TypeError("The provided refresh_handler is not a callable or None.") + self._refresh_handler = value + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): @@ -205,6 +245,31 @@ def with_quota_project(self, quota_project_id): @_helpers.copy_docstring(credentials.Credentials) def refresh(self, request): + scopes = self._scopes if self._scopes is not None else self._default_scopes + # Use refresh handler if available and no refresh token is + # available. This is useful in general when tokens are obtained by calling + # some external process on demand. It is particularly useful for retrieving + # downscoped tokens from a token broker. + if self._refresh_token is None and self.refresh_handler: + token, expiry = self.refresh_handler(request, scopes=scopes) + # Validate returned data. + if not isinstance(token, str): + raise exceptions.RefreshError( + "The refresh_handler returned token is not a string." + ) + if not isinstance(expiry, datetime): + raise exceptions.RefreshError( + "The refresh_handler returned expiry is not a datetime object." + ) + if _helpers.utcnow() >= expiry - _helpers.CLOCK_SKEW: + raise exceptions.RefreshError( + "The credentials returned by the refresh_handler are " + "already expired." + ) + self.token = token + self.expiry = expiry + return + if ( self._refresh_token is None or self._token_uri is None @@ -217,8 +282,6 @@ def refresh(self, request): "token_uri, client_id, and client_secret." ) - scopes = self._scopes if self._scopes is not None else self._default_scopes - ( access_token, refresh_token, diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py index 4a387a58e..8c2b84ed8 100644 --- a/tests/oauth2/test_credentials.py +++ b/tests/oauth2/test_credentials.py @@ -66,6 +66,50 @@ def test_default_state(self): assert credentials.client_id == self.CLIENT_ID assert credentials.client_secret == self.CLIENT_SECRET assert credentials.rapt_token == self.RAPT_TOKEN + assert credentials.refresh_handler is None + + def test_refresh_handler_setter_and_getter(self): + scopes = ["email", "profile"] + original_refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN_1", None)) + updated_refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN_2", None)) + creds = credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + scopes=scopes, + default_scopes=None, + refresh_handler=original_refresh_handler, + ) + + assert creds.refresh_handler == original_refresh_handler + + creds.refresh_handler = updated_refresh_handler + + assert creds.refresh_handler == updated_refresh_handler + + creds.refresh_handler = None + + assert creds.refresh_handler is None + + def test_invalid_refresh_handler(self): + scopes = ["email", "profile"] + with pytest.raises(TypeError) as excinfo: + credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + scopes=scopes, + default_scopes=None, + refresh_handler=[mock.Mock()], + ) + + assert excinfo.match("The provided refresh_handler is not a callable or None.") @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True) @mock.patch( @@ -126,6 +170,221 @@ def test_refresh_no_refresh_token(self): request.assert_not_called() + @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True) + @mock.patch( + "google.auth._helpers.utcnow", + return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, + ) + def test_refresh_with_refresh_token_and_refresh_handler( + self, unused_utcnow, refresh_grant + ): + token = "token" + new_rapt_token = "new_rapt_token" + expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) + grant_response = {"id_token": mock.sentinel.id_token} + refresh_grant.return_value = ( + # Access token + token, + # New refresh token + None, + # Expiry, + expiry, + # Extra data + grant_response, + # rapt_token + new_rapt_token, + ) + + refresh_handler = mock.Mock() + request = mock.create_autospec(transport.Request) + creds = credentials.Credentials( + token=None, + refresh_token=self.REFRESH_TOKEN, + token_uri=self.TOKEN_URI, + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + rapt_token=self.RAPT_TOKEN, + refresh_handler=refresh_handler, + ) + + # Refresh credentials + creds.refresh(request) + + # Check jwt grant call. + refresh_grant.assert_called_with( + request, + self.TOKEN_URI, + self.REFRESH_TOKEN, + self.CLIENT_ID, + self.CLIENT_SECRET, + None, + self.RAPT_TOKEN, + ) + + # Check that the credentials have the token and expiry + assert creds.token == token + assert creds.expiry == expiry + assert creds.id_token == mock.sentinel.id_token + assert creds.rapt_token == new_rapt_token + + # Check that the credentials are valid (have a token and are not + # expired) + assert creds.valid + + # Assert refresh handler not called as the refresh token has + # higher priority. + refresh_handler.assert_not_called() + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_with_refresh_handler_success_scopes(self, unused_utcnow): + expected_expiry = datetime.datetime.min + datetime.timedelta(seconds=2800) + refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN", expected_expiry)) + scopes = ["email", "profile"] + default_scopes = ["https://www.googleapis.com/auth/cloud-platform"] + request = mock.create_autospec(transport.Request) + creds = credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + scopes=scopes, + default_scopes=default_scopes, + refresh_handler=refresh_handler, + ) + + creds.refresh(request) + + assert creds.token == "ACCESS_TOKEN" + assert creds.expiry == expected_expiry + assert creds.valid + assert not creds.expired + # Confirm refresh handler called with the expected arguments. + refresh_handler.assert_called_with(request, scopes=scopes) + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_with_refresh_handler_success_default_scopes(self, unused_utcnow): + expected_expiry = datetime.datetime.min + datetime.timedelta(seconds=2800) + original_refresh_handler = mock.Mock( + return_value=("UNUSED_TOKEN", expected_expiry) + ) + refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN", expected_expiry)) + default_scopes = ["https://www.googleapis.com/auth/cloud-platform"] + request = mock.create_autospec(transport.Request) + creds = credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + scopes=None, + default_scopes=default_scopes, + refresh_handler=original_refresh_handler, + ) + + # Test newly set refresh_handler is used instead of the original one. + creds.refresh_handler = refresh_handler + creds.refresh(request) + + assert creds.token == "ACCESS_TOKEN" + assert creds.expiry == expected_expiry + assert creds.valid + assert not creds.expired + # default_scopes should be used since no developer provided scopes + # are provided. + refresh_handler.assert_called_with(request, scopes=default_scopes) + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_with_refresh_handler_invalid_token(self, unused_utcnow): + expected_expiry = datetime.datetime.min + datetime.timedelta(seconds=2800) + # Simulate refresh handler does not return a valid token. + refresh_handler = mock.Mock(return_value=(None, expected_expiry)) + scopes = ["email", "profile"] + default_scopes = ["https://www.googleapis.com/auth/cloud-platform"] + request = mock.create_autospec(transport.Request) + creds = credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + scopes=scopes, + default_scopes=default_scopes, + refresh_handler=refresh_handler, + ) + + with pytest.raises( + exceptions.RefreshError, match="returned token is not a string" + ): + creds.refresh(request) + + assert creds.token is None + assert creds.expiry is None + assert not creds.valid + # Confirm refresh handler called with the expected arguments. + refresh_handler.assert_called_with(request, scopes=scopes) + + def test_refresh_with_refresh_handler_invalid_expiry(self): + # Simulate refresh handler returns expiration time in an invalid unit. + refresh_handler = mock.Mock(return_value=("TOKEN", 2800)) + scopes = ["email", "profile"] + default_scopes = ["https://www.googleapis.com/auth/cloud-platform"] + request = mock.create_autospec(transport.Request) + creds = credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + scopes=scopes, + default_scopes=default_scopes, + refresh_handler=refresh_handler, + ) + + with pytest.raises( + exceptions.RefreshError, match="returned expiry is not a datetime object" + ): + creds.refresh(request) + + assert creds.token is None + assert creds.expiry is None + assert not creds.valid + # Confirm refresh handler called with the expected arguments. + refresh_handler.assert_called_with(request, scopes=scopes) + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_with_refresh_handler_expired_token(self, unused_utcnow): + expected_expiry = datetime.datetime.min + _helpers.CLOCK_SKEW + # Simulate refresh handler returns an expired token. + refresh_handler = mock.Mock(return_value=("TOKEN", expected_expiry)) + scopes = ["email", "profile"] + default_scopes = ["https://www.googleapis.com/auth/cloud-platform"] + request = mock.create_autospec(transport.Request) + creds = credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + scopes=scopes, + default_scopes=default_scopes, + refresh_handler=refresh_handler, + ) + + with pytest.raises(exceptions.RefreshError, match="already expired"): + creds.refresh(request) + + assert creds.token is None + assert creds.expiry is None + assert not creds.valid + # Confirm refresh handler called with the expected arguments. + refresh_handler.assert_called_with(request, scopes=scopes) + @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True) @mock.patch( "google.auth._helpers.utcnow", @@ -527,6 +786,32 @@ def test_pickle_and_unpickle(self): for attr in list(creds.__dict__): assert getattr(creds, attr) == getattr(unpickled, attr) + def test_pickle_and_unpickle_with_refresh_handler(self): + expected_expiry = _helpers.utcnow() + datetime.timedelta(seconds=2800) + refresh_handler = mock.Mock(return_value=("TOKEN", expected_expiry)) + + creds = credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + refresh_handler=refresh_handler, + ) + unpickled = pickle.loads(pickle.dumps(creds)) + + # make sure attributes aren't lost during pickling + assert list(creds.__dict__).sort() == list(unpickled.__dict__).sort() + + for attr in list(creds.__dict__): + # For the _refresh_handler property, the unpickled creds should be + # set to None. + if attr == "_refresh_handler": + assert getattr(unpickled, attr) is None + else: + assert getattr(creds, attr) == getattr(unpickled, attr) + def test_pickle_with_missing_attribute(self): creds = self.make_credentials() From a76a365dbf4d156c7dff13c17528f443b4fea2eb Mon Sep 17 00:00:00 2001 From: Bassam Ojeil Date: Wed, 21 Jul 2021 12:08:18 -0700 Subject: [PATCH 2/2] Addresses review comments. --- google/oauth2/credentials.py | 6 ++++-- tests/oauth2/test_credentials.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py index a8caf8eaf..158249ed5 100644 --- a/google/oauth2/credentials.py +++ b/google/oauth2/credentials.py @@ -132,8 +132,10 @@ def __getstate__(self): See https://docs.python.org/3.7/library/pickle.html#object.__setstate__ """ state_dict = self.__dict__.copy() - # Remove _refresh_handler function as functions can't be pickled. - # The refresh_handler setter should be used to repopulate this. + # Remove _refresh_handler function as there are limitations pickling and + # unpickling certain callables (lambda, functools.partial instances) + # because they need to be importable. + # Instead, the refresh_handler setter should be used to repopulate this. del state_dict["_refresh_handler"] return state_dict diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py index 8c2b84ed8..4a7f66e7f 100644 --- a/tests/oauth2/test_credentials.py +++ b/tests/oauth2/test_credentials.py @@ -84,11 +84,11 @@ def test_refresh_handler_setter_and_getter(self): refresh_handler=original_refresh_handler, ) - assert creds.refresh_handler == original_refresh_handler + assert creds.refresh_handler is original_refresh_handler creds.refresh_handler = updated_refresh_handler - assert creds.refresh_handler == updated_refresh_handler + assert creds.refresh_handler is updated_refresh_handler creds.refresh_handler = None @@ -106,7 +106,7 @@ def test_invalid_refresh_handler(self): rapt_token=None, scopes=scopes, default_scopes=None, - refresh_handler=[mock.Mock()], + refresh_handler=object(), ) assert excinfo.match("The provided refresh_handler is not a callable or None.")