diff --git a/google/oauth2/_credentials_async.py b/google/oauth2/_credentials_async.py index b4878c543..e7b9637c8 100644 --- a/google/oauth2/_credentials_async.py +++ b/google/oauth2/_credentials_async.py @@ -75,6 +75,7 @@ async def refresh(self, request): self._client_secret, scopes=self._scopes, rapt_token=self._rapt_token, + enable_reauth_refresh=self._enable_reauth_refresh, ) self.token = access_token diff --git a/google/oauth2/_reauth_async.py b/google/oauth2/_reauth_async.py index 09e076090..e20c11f3e 100644 --- a/google/oauth2/_reauth_async.py +++ b/google/oauth2/_reauth_async.py @@ -250,6 +250,7 @@ async def refresh_grant( client_secret, scopes=None, rapt_token=None, + enable_reauth_refresh=False, ): """Implements the reauthentication flow. @@ -267,6 +268,8 @@ async def refresh_grant( token has a wild card scope (e.g. 'https://www.googleapis.com/auth/any-api'). rapt_token (Optional(str)): The rapt token for reauth. + enable_reauth_refresh (Optional[bool]): Whether reauth refresh flow should + be used. The default value is False. Returns: Tuple[str, Optional[str], Optional[datetime], Mapping[str, str], str]: The @@ -301,6 +304,11 @@ async def refresh_grant( == reauth._REAUTH_NEEDED_ERROR_RAPT_REQUIRED ) ): + if not enable_reauth_refresh: + raise exceptions.RefreshError( + "Reauthenticatio is needed. Please run `gcloud auth application-default login` to reauthentciate." + ) + rapt_token = await get_rapt_token( request, client_id, client_secret, refresh_token, token_uri, scopes=scopes ) diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py index 1bea6570c..7f5e6b8ea 100644 --- a/google/oauth2/credentials.py +++ b/google/oauth2/credentials.py @@ -75,6 +75,7 @@ def __init__( expiry=None, rapt_token=None, refresh_handler=None, + enable_reauth_refresh=False, ): """ Args: @@ -111,6 +112,8 @@ def __init__( refresh tokens are provided and tokens are obtained by calling some external process on demand. It is particularly useful for retrieving downscoped tokens from a token broker. + enable_reauth_refresh (Optional[bool]): Whether reauth refresh flow should + be used. The default value is False. """ super(Credentials, self).__init__() self.token = token @@ -125,6 +128,7 @@ def __init__( self._quota_project_id = quota_project_id self._rapt_token = rapt_token self.refresh_handler = refresh_handler + self._enable_reauth_refresh = enable_reauth_refresh def __getstate__(self): """A __getstate__ method must exist for the __setstate__ to be called @@ -153,6 +157,7 @@ def __setstate__(self, d): self._client_secret = d.get("_client_secret") self._quota_project_id = d.get("_quota_project_id") self._rapt_token = d.get("_rapt_token") + self._enable_reauth_refresh = d.get("_enable_reauth_refresh") # The refresh_handler setter should be used to repopulate this. self._refresh_handler = None @@ -243,6 +248,7 @@ def with_quota_project(self, quota_project_id): default_scopes=self.default_scopes, quota_project_id=quota_project_id, rapt_token=self.rapt_token, + enable_reauth_refresh=self._enable_reauth_refresh, ) @_helpers.copy_docstring(credentials.Credentials) @@ -298,6 +304,7 @@ def refresh(self, request): self._client_secret, scopes=scopes, rapt_token=self._rapt_token, + enable_reauth_refresh=self._enable_reauth_refresh, ) self.token = access_token diff --git a/google/oauth2/reauth.py b/google/oauth2/reauth.py index d914fe9a7..16cea670a 100644 --- a/google/oauth2/reauth.py +++ b/google/oauth2/reauth.py @@ -277,6 +277,7 @@ def refresh_grant( client_secret, scopes=None, rapt_token=None, + enable_reauth_refresh=False, ): """Implements the reauthentication flow. @@ -294,6 +295,8 @@ def refresh_grant( token has a wild card scope (e.g. 'https://www.googleapis.com/auth/any-api'). rapt_token (Optional(str)): The rapt token for reauth. + enable_reauth_refresh (Optional[bool]): Whether reauth refresh flow should + be used. The default value is False. Returns: Tuple[str, Optional[str], Optional[datetime], Mapping[str, str], str]: The @@ -312,7 +315,7 @@ def refresh_grant( } if scopes: body["scope"] = " ".join(scopes) - if rapt_token: + if rapt_token and enable_reauth_refresh: body["rapt"] = rapt_token response_status_ok, response_data = _client._token_endpoint_request_no_throw( @@ -326,6 +329,11 @@ def refresh_grant( or response_data.get("error_subtype") == _REAUTH_NEEDED_ERROR_RAPT_REQUIRED ) ): + if not enable_reauth_refresh: + raise exceptions.RefreshError( + "Reauthenticatio is needed. Please run `gcloud auth application-default login` to reauthentciate." + ) + rapt_token = get_rapt_token( request, client_id, client_secret, refresh_token, token_uri, scopes=scopes ) diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py index c9dafebf8..b6a80e3d0 100644 --- a/tests/oauth2/test_credentials.py +++ b/tests/oauth2/test_credentials.py @@ -51,6 +51,7 @@ def make_credentials(cls): client_id=cls.CLIENT_ID, client_secret=cls.CLIENT_SECRET, rapt_token=cls.RAPT_TOKEN, + enable_reauth_refresh=True, ) def test_default_state(self): @@ -149,6 +150,7 @@ def test_refresh_success(self, unused_utcnow, refresh_grant): self.CLIENT_SECRET, None, self.RAPT_TOKEN, + True, ) # Check that the credentials have the token and expiry @@ -219,6 +221,7 @@ def test_refresh_with_refresh_token_and_refresh_handler( self.CLIENT_SECRET, None, self.RAPT_TOKEN, + False, ) # Check that the credentials have the token and expiry @@ -422,6 +425,7 @@ def test_credentials_with_scopes_requested_refresh_success( scopes=scopes, default_scopes=default_scopes, rapt_token=self.RAPT_TOKEN, + enable_reauth_refresh=True, ) # Refresh credentials @@ -436,6 +440,7 @@ def test_credentials_with_scopes_requested_refresh_success( self.CLIENT_SECRET, scopes, self.RAPT_TOKEN, + True, ) # Check that the credentials have the token and expiry @@ -484,6 +489,7 @@ def test_credentials_with_only_default_scopes_requested( client_secret=self.CLIENT_SECRET, default_scopes=default_scopes, rapt_token=self.RAPT_TOKEN, + enable_reauth_refresh=True, ) # Refresh credentials @@ -498,6 +504,7 @@ def test_credentials_with_only_default_scopes_requested( self.CLIENT_SECRET, default_scopes, self.RAPT_TOKEN, + True, ) # Check that the credentials have the token and expiry @@ -549,6 +556,7 @@ def test_credentials_with_scopes_returned_refresh_success( client_secret=self.CLIENT_SECRET, scopes=scopes, rapt_token=self.RAPT_TOKEN, + enable_reauth_refresh=True, ) # Refresh credentials @@ -563,6 +571,7 @@ def test_credentials_with_scopes_returned_refresh_success( self.CLIENT_SECRET, scopes, self.RAPT_TOKEN, + True, ) # Check that the credentials have the token and expiry @@ -615,6 +624,7 @@ def test_credentials_with_scopes_refresh_failure_raises_refresh_error( client_secret=self.CLIENT_SECRET, scopes=scopes, rapt_token=self.RAPT_TOKEN, + enable_reauth_refresh=True, ) # Refresh credentials @@ -632,6 +642,7 @@ def test_credentials_with_scopes_refresh_failure_raises_refresh_error( self.CLIENT_SECRET, scopes, self.RAPT_TOKEN, + True, ) # Check that the credentials have the token and expiry diff --git a/tests/oauth2/test_reauth.py b/tests/oauth2/test_reauth.py index e9ffa8a79..5f44b0ebc 100644 --- a/tests/oauth2/test_reauth.py +++ b/tests/oauth2/test_reauth.py @@ -270,6 +270,7 @@ def test_refresh_grant_failed(): "client_secret", scopes=["foo", "bar"], rapt_token="rapt_token", + enable_reauth_refresh=True, ) assert excinfo.match(r"Bad request") mock_token_request.assert_called_with( @@ -298,7 +299,12 @@ def test_refresh_grant_success(): "google.oauth2.reauth.get_rapt_token", return_value="new_rapt_token" ): assert reauth.refresh_grant( - MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret" + MOCK_REQUEST, + "token_uri", + "refresh_token", + "client_id", + "client_secret", + enable_reauth_refresh=True, ) == ( "access_token", "refresh_token", @@ -306,3 +312,18 @@ def test_refresh_grant_success(): {"access_token": "access_token"}, "new_rapt_token", ) + + +def test_refresh_grant_reauth_refresh_disabled(): + with mock.patch( + "google.oauth2._client._token_endpoint_request_no_throw" + ) as mock_token_request: + mock_token_request.side_effect = [ + (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}), + (True, {"access_token": "access_token"}), + ] + with pytest.raises(exceptions.RefreshError) as excinfo: + reauth.refresh_grant( + MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret" + ) + assert excinfo.match(r"Reauthenticatio is needed") diff --git a/tests_async/oauth2/test_credentials_async.py b/tests_async/oauth2/test_credentials_async.py index 99cf16f80..bc89392ad 100644 --- a/tests_async/oauth2/test_credentials_async.py +++ b/tests_async/oauth2/test_credentials_async.py @@ -43,6 +43,7 @@ def make_credentials(cls): token_uri=cls.TOKEN_URI, client_id=cls.CLIENT_ID, client_secret=cls.CLIENT_SECRET, + enable_reauth_refresh=True, ) def test_default_state(self): @@ -97,6 +98,7 @@ async def test_refresh_success(self, unused_utcnow, refresh_grant): self.CLIENT_SECRET, None, None, + True, ) # Check that the credentials have the token and expiry @@ -169,6 +171,7 @@ async def test_credentials_with_scopes_requested_refresh_success( self.CLIENT_SECRET, scopes, "old_rapt_token", + False, ) # Check that the credentials have the token and expiry @@ -231,6 +234,7 @@ async def test_credentials_with_scopes_returned_refresh_success( self.CLIENT_SECRET, scopes, None, + False, ) # Check that the credentials have the token and expiry @@ -301,6 +305,7 @@ async def test_credentials_with_scopes_refresh_failure_raises_refresh_error( self.CLIENT_SECRET, scopes, None, + False, ) # Check that the credentials have the token and expiry diff --git a/tests_async/oauth2/test_reauth_async.py b/tests_async/oauth2/test_reauth_async.py index f144d89f5..2aef86e39 100644 --- a/tests_async/oauth2/test_reauth_async.py +++ b/tests_async/oauth2/test_reauth_async.py @@ -318,7 +318,12 @@ async def test_refresh_grant_success(): "google.oauth2._reauth_async.get_rapt_token", return_value="new_rapt_token" ): assert await _reauth_async.refresh_grant( - MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret" + MOCK_REQUEST, + "token_uri", + "refresh_token", + "client_id", + "client_secret", + enable_reauth_refresh=True, ) == ( "access_token", "refresh_token", @@ -326,3 +331,19 @@ async def test_refresh_grant_success(): {"access_token": "access_token"}, "new_rapt_token", ) + + +@pytest.mark.asyncio +async def test_refresh_grant_reauth_refresh_disabled(): + with mock.patch( + "google.oauth2._client_async._token_endpoint_request_no_throw" + ) as mock_token_request: + mock_token_request.side_effect = [ + (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}), + (True, {"access_token": "access_token"}), + ] + with pytest.raises(exceptions.RefreshError) as excinfo: + assert await _reauth_async.refresh_grant( + MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret" + ) + assert excinfo.match(r"Reauthenticatio is needed")