Skip to content
This repository was archived by the owner on Apr 19, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions endpoints/test/users_id_token_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import unittest

import mock
import pytest

import endpoints.api_config as api_config

Expand Down Expand Up @@ -369,8 +370,9 @@ def testEmptyAudience(self):
parsed_token, users_id_token._ISSUERS, [], self._SAMPLE_ALLOWED_CLIENT_IDS)
self.assertEqual(False, result)

@mock.patch.object(oauth, 'get_authorized_scopes')
@mock.patch.object(oauth, 'get_client_id')
def AttemptOauth(self, client_id, mock_get_client_id, allowed_client_ids=None):
def AttemptOauth(self, client_id, mock_get_client_id, mock_get_authorized_scopes, allowed_client_ids=None):
if allowed_client_ids is None:
allowed_client_ids = self._SAMPLE_ALLOWED_CLIENT_IDS
# We have four cases:
Expand All @@ -381,20 +383,20 @@ def AttemptOauth(self, client_id, mock_get_client_id, allowed_client_ids=None):
# mock call for every scope.
if client_id is None:
mock_get_client_id.side_effect = oauth.Error
mock_get_authorized_scopes.side_effect = oauth.Error
else:
mock_get_client_id.return_value = client_id
mock_get_authorized_scopes.return_value = self._SAMPLE_OAUTH_SCOPES
users_id_token._set_bearer_user_vars(allowed_client_ids,
self._SAMPLE_OAUTH_SCOPES)
if client_id is None:
for scope in self._SAMPLE_OAUTH_SCOPES:
mock_get_client_id.assert_called_with(scope)
mock_get_authorized_scopes.assert_called_with(self._SAMPLE_OAUTH_SCOPES)
elif (list(allowed_client_ids) == users_id_token.SKIP_CLIENT_ID_CHECK or
client_id in allowed_client_ids):
scope = self._SAMPLE_OAUTH_SCOPES[0]
mock_get_client_id.assert_called_with(scope)
mock_get_client_id.assert_called_with([scope])
else:
for scope in self._SAMPLE_OAUTH_SCOPES:
mock_get_client_id.assert_called_with(scope)
mock_get_client_id.assert_called_with(self._SAMPLE_OAUTH_SCOPES)


def assertOauthSucceeded(self, client_id):
Expand Down Expand Up @@ -487,10 +489,10 @@ def testGetCurrentUserEmailAndAuth(self):
def testGetCurrentUserOauth(self, mock_get_current_user):
mock_get_current_user.return_value = users.User('test@gmail.com')

os.environ['ENDPOINTS_USE_OAUTH_SCOPE'] = 'scope'
os.environ['ENDPOINTS_USE_OAUTH_SCOPE'] = 'scope1 scope2'
user = users_id_token.get_current_user()
self.assertEqual(user.email(), 'test@gmail.com')
mock_get_current_user.assert_called_once_with('scope')
mock_get_current_user.assert_called_once_with(['scope1', 'scope2'])

def testGetTokenQueryParamOauthHeader(self):
os.environ['HTTP_AUTHORIZATION'] = 'OAuth ' + self._SAMPLE_TOKEN
Expand Down Expand Up @@ -631,9 +633,10 @@ def testMethodCallParsesIdToken(self):
self.VerifyIdToken(self.TestApiAnnotatedAtApi(),
message_types.VoidMessage())

@mock.patch.object(oauth, 'get_authorized_scopes')
@mock.patch.object(oauth, 'get_client_id')
@mock.patch.object(users_id_token, '_is_local_dev')
def testMaybeSetVarsWithActualRequestAccessToken(self, mock_local, mock_get_client_id):
def testMaybeSetVarsWithActualRequestAccessToken(self, mock_local, mock_get_client_id, mock_get_authorized_scopes):
dummy_scope = 'scope'
dummy_token = 'dummy_token'
dummy_email = 'test@gmail.com'
Expand All @@ -656,13 +659,15 @@ def method(self, request):

mock_local.return_value = False
mock_get_client_id.return_value = dummy_client_id
mock_get_authorized_scopes.return_value = [dummy_scope]

api_instance = TestApiScopes()
os.environ['HTTP_AUTHORIZATION'] = 'Bearer ' + dummy_token
api_instance.method(message_types.VoidMessage())
self.assertEqual(os.getenv('ENDPOINTS_USE_OAUTH_SCOPE'), dummy_scope)
assert os.getenv('ENDPOINTS_USE_OAUTH_SCOPE') == dummy_scope
mock_local.assert_has_calls([mock.call(), mock.call()])
mock_get_client_id.assert_called_once_with(dummy_scope)
mock_get_client_id.assert_called_once_with([dummy_scope])
mock_get_authorized_scopes.assert_called_once_with([dummy_scope])

@mock.patch.object(users_id_token, '_get_id_token_user')
@mock.patch.object(time, 'time')
Expand Down Expand Up @@ -891,5 +896,28 @@ def testBadBase64(self):
self._SAMPLE_CERT_URI, self.cache)
self.assertIsNone(parsed_token)


@pytest.mark.parametrize(('scopelist', 'all_scopes', 'sufficient_scopes'), [
(('scope1', 'scope2'), {'scope1', 'scope2'}, {frozenset(['scope1']), frozenset(['scope2'])}),
(('scope1', 'scope2 scope3'), {'scope1', 'scope2', 'scope3'}, {frozenset(['scope1']), frozenset(['scope2', 'scope3'])}),
(('scope1 scope2', 'scope1 scope3'), {'scope1', 'scope2', 'scope3'}, {frozenset(['scope1', 'scope2']), frozenset(['scope1', 'scope3'])}),
])
def test_process_scopes(scopelist, all_scopes, sufficient_scopes):
result = users_id_token._process_scopes(scopelist)
assert result == (all_scopes, sufficient_scopes)

@pytest.mark.parametrize(('authorized_scopes', 'sufficient_scopes', 'is_valid'), [
(['scope1'], {frozenset(['scope1'])}, True),
(['scope1'], {frozenset(['scope1', 'scope2'])}, False),
(['scope1', 'scope2'], {frozenset(['scope1'])}, True),
(['scope1', 'scope2'], {frozenset(['scope1']), frozenset(['scope2'])}, True),
(['scope1', 'scope2'], {frozenset(['scope1', 'scope2'])}, True),
(['scope1'], {frozenset(['scope1']), frozenset(['scope2', 'scope3'])}, True),
(['scope2'], {frozenset(['scope1']), frozenset(['scope2', 'scope3'])}, False),
(['scope2', 'scope3'], {frozenset(['scope1']), frozenset(['scope2', 'scope3'])}, True),
])
def test_are_scopes_sufficient(authorized_scopes, sufficient_scopes, is_valid):
assert users_id_token._are_scopes_sufficient(authorized_scopes, sufficient_scopes) is is_valid

if __name__ == '__main__':
unittest.main()
81 changes: 59 additions & 22 deletions endpoints/users_id_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_current_user():
# We can get more information from the oauth.get_current_user function,
# as long as we know what scope to use. Since that scope has been
# cached, we can just return this:
return oauth.get_current_user(os.environ[_ENV_USE_OAUTH_SCOPE])
return oauth.get_current_user(os.environ[_ENV_USE_OAUTH_SCOPE].split())

if (_ENV_AUTH_EMAIL in os.environ and
_ENV_AUTH_DOMAIN in os.environ):
Expand Down Expand Up @@ -316,6 +316,43 @@ def _set_oauth_user_vars(token_info, audiences, allowed_client_ids, scopes,
# pylint: enable=unused-argument


def _process_scopes(scopes):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe document what scopes is?
I believe scopes should be a list of string, and each string is a space-separated list of scope.
Is this correct?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, whoops!

"""Parse a scopes list into a set of all scopes and a set of sufficient scope sets.

scopes: A list of strings, each of which is a space-separated list of scopes.
Examples: ['scope1']
['scope1', 'scope2']
['scope1', 'scope2 scope3']

Returns:
all_scopes: a set of strings, each of which is one scope to check for
sufficient_scopes: a set of sets of strings; each inner set is
a set of scopes which are sufficient for access.
Example: {{'scope1'}, {'scope2', 'scope3'}}
"""
all_scopes = set()
sufficient_scopes = set()
for scope_set in scopes:
scope_set_scopes = frozenset(scope_set.split())
all_scopes.update(scope_set_scopes)
sufficient_scopes.add(scope_set_scopes)
return all_scopes, sufficient_scopes


def _are_scopes_sufficient(authorized_scopes, sufficient_scopes):
"""Check if a list of authorized scopes satisfies any set of sufficient scopes.

Args:
authorized_scopes: a list of strings, return value from oauth.get_authorized_scopes
sufficient_scopes: a set of sets of strings, return value from _process_scopes
"""
for sufficient_scope_set in sufficient_scopes:
if sufficient_scope_set.issubset(authorized_scopes):
return True
return False



def _set_bearer_user_vars(allowed_client_ids, scopes):
"""Validate the oauth bearer token and set endpoints auth user variables.

Expand All @@ -327,27 +364,27 @@ def _set_bearer_user_vars(allowed_client_ids, scopes):
allowed_client_ids: List of client IDs that are acceptable.
scopes: List of acceptable scopes.
"""
for scope in scopes:
try:
client_id = oauth.get_client_id(scope)
except oauth.Error:
# This scope failed. Try the next.
continue

# The client ID must be in allowed_client_ids. If allowed_client_ids is
# empty, don't allow any client ID. If allowed_client_ids is set to
# SKIP_CLIENT_ID_CHECK, all client IDs will be allowed.
if (list(allowed_client_ids) != SKIP_CLIENT_ID_CHECK and
client_id not in allowed_client_ids):
_logger.warning('Client ID is not allowed: %s', client_id)
return
all_scopes, sufficient_scopes = _process_scopes(scopes)
try:
authorized_scopes = oauth.get_authorized_scopes(sorted(all_scopes))
except oauth.Error:
_logger.debug('Unable to get authorized scopes.', exc_info=True)
return
if not _are_scopes_sufficient(authorized_scopes, sufficient_scopes):
_logger.debug('Authorized scopes did not satisfy scope requirements.')
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we return or fail here?

Also, would it be useful to include the actual authorized/sufficient scopes in the log message?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we should. Nice catch.

Honestly, I'm considering removing these lines completely, or attaching them to a separate logger which is normally disabled. These log lines will come up a lot for unauthorized users, so people probably won't want to see them unless they're debugging auth issues.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough

return
client_id = oauth.get_client_id(authorized_scopes)

os.environ[_ENV_USE_OAUTH_SCOPE] = scope
_logger.debug('Returning user from matched oauth_user.')
# The client ID must be in allowed_client_ids. If allowed_client_ids is
# empty, don't allow any client ID. If allowed_client_ids is set to
# SKIP_CLIENT_ID_CHECK, all client IDs will be allowed.
if (list(allowed_client_ids) != SKIP_CLIENT_ID_CHECK and
client_id not in allowed_client_ids):
_logger.warning('Client ID is not allowed: %s', client_id)
return

_logger.debug('Oauth framework user didn\'t match oauth token user.')
return None
os.environ[_ENV_USE_OAUTH_SCOPE] = ' '.join(authorized_scopes)
_logger.debug('get_current_user() will return user from matched oauth_user.')


def _set_bearer_user_vars_local(token, allowed_client_ids, scopes):
Expand Down Expand Up @@ -392,15 +429,15 @@ def _set_bearer_user_vars_local(token, allowed_client_ids, scopes):
return

# Verify at least one of the scopes matches.
token_scopes = token_info.get('scope', '').split(' ')
if not any(scope in scopes for scope in token_scopes):
_, sufficient_scopes = _process_scopes(scopes)
authorized_scopes = token_info.get('scope', '').split(' ')
if not _are_scopes_sufficient(authorized_scopes, sufficient_scopes):
_logger.warning('Oauth token scopes don\'t match any acceptable scopes.')
return

os.environ[_ENV_AUTH_EMAIL] = token_info['email']
os.environ[_ENV_AUTH_DOMAIN] = ''
_logger.debug('Local dev returning user from token.')
return


def _is_local_dev():
Expand Down