diff --git a/endpoints/test/users_id_token_test.py b/endpoints/test/users_id_token_test.py index 861282a..ee667ee 100644 --- a/endpoints/test/users_id_token_test.py +++ b/endpoints/test/users_id_token_test.py @@ -22,6 +22,7 @@ import unittest import mock +import pytest import endpoints.api_config as api_config @@ -369,8 +370,9 @@ def testEmptyAudience(self): parsed_token, users_id_token._ISSUERS, [], self._SAMPLE_ALLOWED_CLIENT_IDS) self.assertEqual(False, result) + @mock.patch.object(oauth, 'get_authorized_scopes') @mock.patch.object(oauth, 'get_client_id') - def AttemptOauth(self, client_id, mock_get_client_id, allowed_client_ids=None): + def AttemptOauth(self, client_id, mock_get_client_id, mock_get_authorized_scopes, allowed_client_ids=None): if allowed_client_ids is None: allowed_client_ids = self._SAMPLE_ALLOWED_CLIENT_IDS # We have four cases: @@ -381,20 +383,20 @@ def AttemptOauth(self, client_id, mock_get_client_id, allowed_client_ids=None): # mock call for every scope. if client_id is None: mock_get_client_id.side_effect = oauth.Error + mock_get_authorized_scopes.side_effect = oauth.Error else: mock_get_client_id.return_value = client_id + mock_get_authorized_scopes.return_value = self._SAMPLE_OAUTH_SCOPES users_id_token._set_bearer_user_vars(allowed_client_ids, self._SAMPLE_OAUTH_SCOPES) if client_id is None: - for scope in self._SAMPLE_OAUTH_SCOPES: - mock_get_client_id.assert_called_with(scope) + mock_get_authorized_scopes.assert_called_with(self._SAMPLE_OAUTH_SCOPES) elif (list(allowed_client_ids) == users_id_token.SKIP_CLIENT_ID_CHECK or client_id in allowed_client_ids): scope = self._SAMPLE_OAUTH_SCOPES[0] - mock_get_client_id.assert_called_with(scope) + mock_get_client_id.assert_called_with([scope]) else: - for scope in self._SAMPLE_OAUTH_SCOPES: - mock_get_client_id.assert_called_with(scope) + mock_get_client_id.assert_called_with(self._SAMPLE_OAUTH_SCOPES) def assertOauthSucceeded(self, client_id): @@ -487,10 +489,10 @@ def testGetCurrentUserEmailAndAuth(self): def testGetCurrentUserOauth(self, mock_get_current_user): mock_get_current_user.return_value = users.User('test@gmail.com') - os.environ['ENDPOINTS_USE_OAUTH_SCOPE'] = 'scope' + os.environ['ENDPOINTS_USE_OAUTH_SCOPE'] = 'scope1 scope2' user = users_id_token.get_current_user() self.assertEqual(user.email(), 'test@gmail.com') - mock_get_current_user.assert_called_once_with('scope') + mock_get_current_user.assert_called_once_with(['scope1', 'scope2']) def testGetTokenQueryParamOauthHeader(self): os.environ['HTTP_AUTHORIZATION'] = 'OAuth ' + self._SAMPLE_TOKEN @@ -631,9 +633,10 @@ def testMethodCallParsesIdToken(self): self.VerifyIdToken(self.TestApiAnnotatedAtApi(), message_types.VoidMessage()) + @mock.patch.object(oauth, 'get_authorized_scopes') @mock.patch.object(oauth, 'get_client_id') @mock.patch.object(users_id_token, '_is_local_dev') - def testMaybeSetVarsWithActualRequestAccessToken(self, mock_local, mock_get_client_id): + def testMaybeSetVarsWithActualRequestAccessToken(self, mock_local, mock_get_client_id, mock_get_authorized_scopes): dummy_scope = 'scope' dummy_token = 'dummy_token' dummy_email = 'test@gmail.com' @@ -656,13 +659,15 @@ def method(self, request): mock_local.return_value = False mock_get_client_id.return_value = dummy_client_id + mock_get_authorized_scopes.return_value = [dummy_scope] api_instance = TestApiScopes() os.environ['HTTP_AUTHORIZATION'] = 'Bearer ' + dummy_token api_instance.method(message_types.VoidMessage()) - self.assertEqual(os.getenv('ENDPOINTS_USE_OAUTH_SCOPE'), dummy_scope) + assert os.getenv('ENDPOINTS_USE_OAUTH_SCOPE') == dummy_scope mock_local.assert_has_calls([mock.call(), mock.call()]) - mock_get_client_id.assert_called_once_with(dummy_scope) + mock_get_client_id.assert_called_once_with([dummy_scope]) + mock_get_authorized_scopes.assert_called_once_with([dummy_scope]) @mock.patch.object(users_id_token, '_get_id_token_user') @mock.patch.object(time, 'time') @@ -891,5 +896,28 @@ def testBadBase64(self): self._SAMPLE_CERT_URI, self.cache) self.assertIsNone(parsed_token) + +@pytest.mark.parametrize(('scopelist', 'all_scopes', 'sufficient_scopes'), [ + (('scope1', 'scope2'), {'scope1', 'scope2'}, {frozenset(['scope1']), frozenset(['scope2'])}), + (('scope1', 'scope2 scope3'), {'scope1', 'scope2', 'scope3'}, {frozenset(['scope1']), frozenset(['scope2', 'scope3'])}), + (('scope1 scope2', 'scope1 scope3'), {'scope1', 'scope2', 'scope3'}, {frozenset(['scope1', 'scope2']), frozenset(['scope1', 'scope3'])}), +]) +def test_process_scopes(scopelist, all_scopes, sufficient_scopes): + result = users_id_token._process_scopes(scopelist) + assert result == (all_scopes, sufficient_scopes) + +@pytest.mark.parametrize(('authorized_scopes', 'sufficient_scopes', 'is_valid'), [ + (['scope1'], {frozenset(['scope1'])}, True), + (['scope1'], {frozenset(['scope1', 'scope2'])}, False), + (['scope1', 'scope2'], {frozenset(['scope1'])}, True), + (['scope1', 'scope2'], {frozenset(['scope1']), frozenset(['scope2'])}, True), + (['scope1', 'scope2'], {frozenset(['scope1', 'scope2'])}, True), + (['scope1'], {frozenset(['scope1']), frozenset(['scope2', 'scope3'])}, True), + (['scope2'], {frozenset(['scope1']), frozenset(['scope2', 'scope3'])}, False), + (['scope2', 'scope3'], {frozenset(['scope1']), frozenset(['scope2', 'scope3'])}, True), +]) +def test_are_scopes_sufficient(authorized_scopes, sufficient_scopes, is_valid): + assert users_id_token._are_scopes_sufficient(authorized_scopes, sufficient_scopes) is is_valid + if __name__ == '__main__': unittest.main() diff --git a/endpoints/users_id_token.py b/endpoints/users_id_token.py index ae25b4d..21cbd34 100644 --- a/endpoints/users_id_token.py +++ b/endpoints/users_id_token.py @@ -125,7 +125,7 @@ def get_current_user(): # We can get more information from the oauth.get_current_user function, # as long as we know what scope to use. Since that scope has been # cached, we can just return this: - return oauth.get_current_user(os.environ[_ENV_USE_OAUTH_SCOPE]) + return oauth.get_current_user(os.environ[_ENV_USE_OAUTH_SCOPE].split()) if (_ENV_AUTH_EMAIL in os.environ and _ENV_AUTH_DOMAIN in os.environ): @@ -316,6 +316,43 @@ def _set_oauth_user_vars(token_info, audiences, allowed_client_ids, scopes, # pylint: enable=unused-argument +def _process_scopes(scopes): + """Parse a scopes list into a set of all scopes and a set of sufficient scope sets. + + scopes: A list of strings, each of which is a space-separated list of scopes. + Examples: ['scope1'] + ['scope1', 'scope2'] + ['scope1', 'scope2 scope3'] + + Returns: + all_scopes: a set of strings, each of which is one scope to check for + sufficient_scopes: a set of sets of strings; each inner set is + a set of scopes which are sufficient for access. + Example: {{'scope1'}, {'scope2', 'scope3'}} + """ + all_scopes = set() + sufficient_scopes = set() + for scope_set in scopes: + scope_set_scopes = frozenset(scope_set.split()) + all_scopes.update(scope_set_scopes) + sufficient_scopes.add(scope_set_scopes) + return all_scopes, sufficient_scopes + + +def _are_scopes_sufficient(authorized_scopes, sufficient_scopes): + """Check if a list of authorized scopes satisfies any set of sufficient scopes. + + Args: + authorized_scopes: a list of strings, return value from oauth.get_authorized_scopes + sufficient_scopes: a set of sets of strings, return value from _process_scopes + """ + for sufficient_scope_set in sufficient_scopes: + if sufficient_scope_set.issubset(authorized_scopes): + return True + return False + + + def _set_bearer_user_vars(allowed_client_ids, scopes): """Validate the oauth bearer token and set endpoints auth user variables. @@ -327,27 +364,27 @@ def _set_bearer_user_vars(allowed_client_ids, scopes): allowed_client_ids: List of client IDs that are acceptable. scopes: List of acceptable scopes. """ - for scope in scopes: - try: - client_id = oauth.get_client_id(scope) - except oauth.Error: - # This scope failed. Try the next. - continue - - # The client ID must be in allowed_client_ids. If allowed_client_ids is - # empty, don't allow any client ID. If allowed_client_ids is set to - # SKIP_CLIENT_ID_CHECK, all client IDs will be allowed. - if (list(allowed_client_ids) != SKIP_CLIENT_ID_CHECK and - client_id not in allowed_client_ids): - _logger.warning('Client ID is not allowed: %s', client_id) - return + all_scopes, sufficient_scopes = _process_scopes(scopes) + try: + authorized_scopes = oauth.get_authorized_scopes(sorted(all_scopes)) + except oauth.Error: + _logger.debug('Unable to get authorized scopes.', exc_info=True) + return + if not _are_scopes_sufficient(authorized_scopes, sufficient_scopes): + _logger.debug('Authorized scopes did not satisfy scope requirements.') + return + client_id = oauth.get_client_id(authorized_scopes) - os.environ[_ENV_USE_OAUTH_SCOPE] = scope - _logger.debug('Returning user from matched oauth_user.') + # The client ID must be in allowed_client_ids. If allowed_client_ids is + # empty, don't allow any client ID. If allowed_client_ids is set to + # SKIP_CLIENT_ID_CHECK, all client IDs will be allowed. + if (list(allowed_client_ids) != SKIP_CLIENT_ID_CHECK and + client_id not in allowed_client_ids): + _logger.warning('Client ID is not allowed: %s', client_id) return - _logger.debug('Oauth framework user didn\'t match oauth token user.') - return None + os.environ[_ENV_USE_OAUTH_SCOPE] = ' '.join(authorized_scopes) + _logger.debug('get_current_user() will return user from matched oauth_user.') def _set_bearer_user_vars_local(token, allowed_client_ids, scopes): @@ -392,15 +429,15 @@ def _set_bearer_user_vars_local(token, allowed_client_ids, scopes): return # Verify at least one of the scopes matches. - token_scopes = token_info.get('scope', '').split(' ') - if not any(scope in scopes for scope in token_scopes): + _, sufficient_scopes = _process_scopes(scopes) + authorized_scopes = token_info.get('scope', '').split(' ') + if not _are_scopes_sufficient(authorized_scopes, sufficient_scopes): _logger.warning('Oauth token scopes don\'t match any acceptable scopes.') return os.environ[_ENV_AUTH_EMAIL] = token_info['email'] os.environ[_ENV_AUTH_DOMAIN] = '' _logger.debug('Local dev returning user from token.') - return def _is_local_dev():