From a43ed5d5b0a2ae1be67935a701b48ba83b4d9893 Mon Sep 17 00:00:00 2001 From: Brendan McCollam Date: Fri, 18 Nov 2016 15:39:04 +0000 Subject: [PATCH 01/10] Move custom validator registration onto GrantTypeBase --- .../rfc6749/grant_types/authorization_code.py | 43 +++++------------- oauthlib/oauth2/rfc6749/grant_types/base.py | 44 +++++++++++++++++++ .../rfc6749/grant_types/client_credentials.py | 13 +++--- .../oauth2/rfc6749/grant_types/implicit.py | 43 ++++++++++-------- .../rfc6749/grant_types/refresh_token.py | 20 ++++++--- .../resource_owner_password_credentials.py | 18 +++----- 6 files changed, 105 insertions(+), 76 deletions(-) diff --git a/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py b/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py index 09756a1b..1ac40f3b 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py +++ b/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py @@ -13,7 +13,6 @@ from .base import GrantTypeBase from .. import errors -from ..request_validator import RequestValidator log = logging.getLogger(__name__) @@ -96,31 +95,7 @@ class AuthorizationCodeGrant(GrantTypeBase): """ default_response_mode = 'query' - - def __init__(self, request_validator=None, refresh_token=True): - self.request_validator = request_validator or RequestValidator() - self.refresh_token = refresh_token - - self._authorization_validators = [] - self._token_validators = [] - self._code_modifiers = [] - self._token_modifiers = [] - self.response_types = ['code'] - - def register_response_type(self, response_type): - self.response_types.append(response_type) - - def register_authorization_validator(self, validator): - self._authorization_validators.append(validator) - - def register_token_validator(self, validator): - self._token_validators.append(validator) - - def register_code_modifier(self, modifier): - self._code_modifiers.append(modifier) - - def register_token_modifier(self, modifier): - self._token_modifiers.append(modifier) + response_types = ['code'] def create_authorization_code(self, request): """Generates an authorization grant represented as a dictionary.""" @@ -347,6 +322,10 @@ def validate_authorization_request(self, request): # Note that the correct parameters to be added are automatically # populated through the use of specific exceptions. + request_info = {} + for validator in self._auth_validators_run_before_standard_ones: + request_info.update(validator(request)) + # REQUIRED. if request.response_type is None: raise errors.MissingResponseTypeError(request=request) @@ -367,15 +346,15 @@ def validate_authorization_request(self, request): # http://tools.ietf.org/html/rfc6749#section-3.3 self.validate_scopes(request) - request_info = { + request_info.update({ 'client_id': request.client_id, 'redirect_uri': request.redirect_uri, 'response_type': request.response_type, 'state': request.state, 'request': request - } + }) - for validator in self._authorization_validators: + for validator in self._auth_validators_run_after_standard_ones: request_info.update(validator(request)) return request.scopes, request_info @@ -385,6 +364,9 @@ def validate_token_request(self, request): if request.grant_type not in ('authorization_code', 'openid'): raise errors.UnsupportedGrantTypeError(request=request) + for validator in self._token_validators_run_before_standard_ones: + validator(request) + if request.code is None: raise errors.InvalidRequestError( description='Missing code parameter.', request=request) @@ -441,6 +423,5 @@ def validate_token_request(self, request): request.redirect_uri, request.client_id, request.client) raise errors.AccessDeniedError(request=request) - for validator in self._token_validators: + for validator in self._token_validators_run_after_standard_ones: validator(request) - diff --git a/oauthlib/oauth2/rfc6749/grant_types/base.py b/oauthlib/oauth2/rfc6749/grant_types/base.py index 36b06eb8..5a3888ad 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/base.py +++ b/oauthlib/oauth2/rfc6749/grant_types/base.py @@ -9,6 +9,7 @@ from oauthlib.common import add_params_to_uri from oauthlib.oauth2.rfc6749 import errors, utils +from ..request_validator import RequestValidator log = logging.getLogger(__name__) @@ -17,6 +18,49 @@ class GrantTypeBase(object): error_uri = None request_validator = None default_response_mode = 'fragment' + refresh_token = True + response_types = ['code'] + + def __init__(self, request_validator=None, **kwargs): + self.request_validator = request_validator or RequestValidator() + + # Transforms class variables into instance variables: + self.response_types = self.response_types + self.refresh_token = self.refresh_token + + self._setup_validator_hooks() + for kw, val in kwargs.items(): + setattr(self, kw, val) + + def _setup_validator_hooks(self): + self._auth_validators_run_before_standard_ones = [] + self._auth_validators_run_after_standard_ones = [] + self._token_validators_run_before_standard_ones = [] + self._token_validators_run_after_standard_ones = [] + self._code_modifiers = [] + self._token_modifiers = [] + + def register_response_type(self, response_type): + self.response_types.append(response_type) + + def register_authorization_validator(self, validator, after_standard=True): + if after_standard: + self._auth_validators_run_after_standard_ones.append(validator) + else: + self._auth_validators_run_before_standard_ones.append(validator) + + def register_token_validator(self, validator, after_standard=True): + if after_standard: + self._token_validators_run_after_standard_ones.append(validator) + else: + self._token_validators_run_before_standard_ones.append(validator) + + def register_code_modifier(self, modifier): + self._code_modifiers.append(modifier) + + def register_token_modifier(self, modifier): + self._token_modifiers.append(modifier) + def create_authorization_response(self, request, token_handler): raise NotImplementedError('Subclasses must implement this method.') diff --git a/oauthlib/oauth2/rfc6749/grant_types/client_credentials.py b/oauthlib/oauth2/rfc6749/grant_types/client_credentials.py index 91c17a69..1306dafd 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/client_credentials.py +++ b/oauthlib/oauth2/rfc6749/grant_types/client_credentials.py @@ -50,13 +50,6 @@ class ClientCredentialsGrant(GrantTypeBase): .. _`Client Credentials Grant`: http://tools.ietf.org/html/rfc6749#section-4.4 """ - def __init__(self, request_validator=None): - self.request_validator = request_validator or RequestValidator() - self._token_modifiers = [] - - def register_token_modifier(self, modifier): - self._token_modifiers.append(modifier) - def create_token_response(self, request, token_handler): """Return token or error in JSON format. @@ -92,6 +85,9 @@ def create_token_response(self, request, token_handler): return headers, json.dumps(token), 200 def validate_token_request(self, request): + for validator in self._token_validators_run_before_standard_ones: + validator(request) + if not getattr(request, 'grant_type', None): raise errors.InvalidRequestError('Request is missing grant type.', request=request) @@ -119,3 +115,6 @@ def validate_token_request(self, request): log.debug('Authorizing access to user %r.', request.user) request.client_id = request.client_id or request.client.client_id self.validate_scopes(request) + + for validator in self._token_validators_run_after_standard_ones: + validator(request) diff --git a/oauthlib/oauth2/rfc6749/grant_types/implicit.py b/oauthlib/oauth2/rfc6749/grant_types/implicit.py index 7366b947..72bb94ca 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/implicit.py +++ b/oauthlib/oauth2/rfc6749/grant_types/implicit.py @@ -5,6 +5,7 @@ """ from __future__ import unicode_literals, absolute_import +from itertools import chain import logging from oauthlib import common @@ -117,20 +118,8 @@ class ImplicitGrant(GrantTypeBase): .. _`Section 10.16`: http://tools.ietf.org/html/rfc6749#section-10.16 """ - def __init__(self, request_validator=None): - self.request_validator = request_validator or RequestValidator() - self._authorization_validators = [] - self._token_modifiers = [] - self.response_types = ['token'] - - def register_response_type(self, response_type): - self.response_types.append(response_type) - - def register_authorization_validator(self, validator): - self._authorization_validators.append(validator) - - def register_token_modifier(self, modifier): - self._token_modifiers.append(modifier) + response_types = ['token'] + grant_allows_refresh_token = False def create_authorization_response(self, request, token_handler): """Create an authorization response. @@ -328,6 +317,16 @@ def validate_token_request(self, request): # Then check for normal errors. + request_info = {} + # For implicit grant, auth_validators and token_validators are + # basically equivalent since the token is returned from the + # authorization endpoint. + for validator in chain(self._token_validators_run_before_standard_ones, + self._auth_validators_run_before_standard_ones): + result = validator(request) + if result is not None: + request_info.update(result) + # If the resource owner denies the access request or if the request # fails for reasons other than a missing or invalid redirection URI, # the authorization server informs the client by adding the following @@ -359,15 +358,21 @@ def validate_token_request(self, request): # http://tools.ietf.org/html/rfc6749#section-3.3 self.validate_scopes(request) - request_info = { + request_info.update({ 'client_id': request.client_id, 'redirect_uri': request.redirect_uri, 'response_type': request.response_type, 'state': request.state, 'request': request, - } - - for validator in self._authorization_validators: - request_info.update(validator(request)) + }) + + # For implicit grant, auth_validators and token_validators are + # basically equivalent since the token is returned from the + # authorization endpoint. + for validator in chain(self._auth_validators_run_after_standard_ones, + self._token_validators_run_after_standard_ones): + result = validator(request) + if result is not None: + request_info.update(result) return request.scopes, request_info diff --git a/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py b/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py index cb26880f..def630ce 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py +++ b/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py @@ -22,13 +22,13 @@ class RefreshTokenGrant(GrantTypeBase): .. _`Refresh token grant`: http://tools.ietf.org/html/rfc6749#section-6 """ - def __init__(self, request_validator=None, issue_new_refresh_tokens=True): - self.request_validator = request_validator or RequestValidator() - self.issue_new_refresh_tokens = issue_new_refresh_tokens - self._token_modifiers = [] - - def register_token_modifier(self, modifier): - self._token_modifiers.append(modifier) + def __init__(self, request_validator=None, + issue_new_refresh_tokens=True, + **kwargs): + super(RefreshTokenGrant, self).__init__( + request_validator, + issue_new_refresh_tokens=issue_new_refresh_tokens, + **kwargs) def create_token_response(self, request, token_handler): """Create a new access token from a refresh_token. @@ -76,6 +76,9 @@ def validate_token_request(self, request): if request.grant_type != 'refresh_token': raise errors.UnsupportedGrantTypeError(request=request) + for validator in self._token_validators_run_before_standard_ones: + validator(request) + if request.refresh_token is None: raise errors.InvalidRequestError( description='Missing refresh token parameter.', @@ -123,3 +126,6 @@ def validate_token_request(self, request): raise errors.InvalidScopeError(request=request) else: request.scopes = original_scopes + + for validator in self._token_validators_run_after_standard_ones: + validator(request) diff --git a/oauthlib/oauth2/rfc6749/grant_types/resource_owner_password_credentials.py b/oauthlib/oauth2/rfc6749/grant_types/resource_owner_password_credentials.py index 0f4d65e4..2ef6e163 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/resource_owner_password_credentials.py +++ b/oauthlib/oauth2/rfc6749/grant_types/resource_owner_password_credentials.py @@ -70,18 +70,6 @@ class ResourceOwnerPasswordCredentialsGrant(GrantTypeBase): .. _`Resource Owner Password Credentials Grant`: http://tools.ietf.org/html/rfc6749#section-4.3 """ - def __init__(self, request_validator=None, refresh_token=True): - """ - If the refresh_token keyword argument is False, do not return - a refresh token in the response. - """ - self.request_validator = request_validator or RequestValidator() - self.refresh_token = refresh_token - self._token_modifiers = [] - - def register_token_modifier(self, modifier): - self._token_modifiers.append(modifier) - def create_token_response(self, request, token_handler): """Return token or error in json format. @@ -168,6 +156,9 @@ def validate_token_request(self, request): .. _`Section 3.3`: http://tools.ietf.org/html/rfc6749#section-3.3 .. _`Section 3.2.1`: http://tools.ietf.org/html/rfc6749#section-3.2.1 """ + for validator in self._token_validators_run_before_standard_ones: + validator(request) + for param in ('grant_type', 'username', 'password'): if not getattr(request, param, None): raise errors.InvalidRequestError( @@ -201,3 +192,6 @@ def validate_token_request(self, request): if request.client: request.client_id = request.client_id or request.client.client_id self.validate_scopes(request) + + for validator in self._token_validators_run_after_standard_ones: + validator(request) From 28cf20b3ad64b568bff8507ea68a231651bd132e Mon Sep 17 00:00:00 2001 From: Brendan McCollam Date: Mon, 21 Nov 2016 14:19:30 +0000 Subject: [PATCH 02/10] OIDC grants are proxies --- .../rfc6749/grant_types/openid_connect.py | 174 +++++------------- 1 file changed, 51 insertions(+), 123 deletions(-) diff --git a/oauthlib/oauth2/rfc6749/grant_types/openid_connect.py b/oauthlib/oauth2/rfc6749/grant_types/openid_connect.py index 6cc37729..4dfe934a 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/openid_connect.py +++ b/oauthlib/oauth2/rfc6749/grant_types/openid_connect.py @@ -77,10 +77,33 @@ def validate_authorization_request(self, request): return self._handler_for_request(request).validate_authorization_request(request) -class OpenIDConnectBase(GrantTypeBase): +class OpenIDConnectBase(object): + + # Just proxy the majority of method calls through to the + # proxy_target grant type handler, which will usually be either + # the standard OAuth2 AuthCode or Implicit grant types. + def __getattr__(self, attr): + return getattr(self.proxy_target, attr) + + def __setattr__(self, attr, value): + proxied_attrs = set(('refresh_token', 'response_types')) + if attr in proxied_attrs: + setattr(self.proxy_target, attr, value) + else: + super(OpenIDConnectBase, self).__setattr__(attr, value) - def __init__(self, request_validator=None): - self.request_validator = request_validator or RequestValidator() + def validate_authorization_request(self, request): + """Validates the OpenID Connect authorization request parameters. + + :returns: (list of scopes, dict of request info) + """ + # If request.prompt is 'none' then no login/authorization form should + # be presented to the user. Instead, a silent login/authorization + # should be performed. + if request.prompt == 'none': + raise OIDCNoPrompt() + else: + return self.proxy_target.validate_authorization_request(request) def _inflate_claims(self, request): # this may be called multiple times in a single request so make sure we only de-serialize the claims once @@ -309,135 +332,40 @@ def openid_implicit_authorization_validator(self, request): class OpenIDConnectAuthCode(OpenIDConnectBase): - def __init__(self, request_validator=None): - self.request_validator = request_validator or RequestValidator() - super(OpenIDConnectAuthCode, self).__init__( - request_validator=self.request_validator) - self.auth_code = AuthorizationCodeGrant( - request_validator=self.request_validator) - self.auth_code.register_authorization_validator( + def __init__(self, request_validator=None, **kwargs): + self.proxy_target = AuthorizationCodeGrant( + request_validator=request_validator, **kwargs) + self.register_authorization_validator( self.openid_authorization_validator) - self.auth_code.register_token_modifier(self.add_id_token) - - @property - def refresh_token(self): - return self.auth_code.refresh_token - - @refresh_token.setter - def refresh_token(self, value): - self.auth_code.refresh_token = value - - def create_authorization_code(self, request): - return self.auth_code.create_authorization_code(request) - - def create_authorization_response(self, request, token_handler): - return self.auth_code.create_authorization_response( - request, token_handler) - - def create_token_response(self, request, token_handler): - return self.auth_code.create_token_response(request, token_handler) - - def validate_authorization_request(self, request): - """Validates the OpenID Connect authorization request parameters. - - :returns: (list of scopes, dict of request info) - """ - # If request.prompt is 'none' then no login/authorization form should - # be presented to the user. Instead, a silent login/authorization - # should be performed. - if request.prompt == 'none': - raise OIDCNoPrompt() - else: - return self.auth_code.validate_authorization_request(request) - - def validate_token_request(self, request): - return self.auth_code.validate_token_request(request) - + self.register_token_modifier(self.add_id_token) class OpenIDConnectImplicit(OpenIDConnectBase): - def __init__(self, request_validator=None): - self.request_validator = request_validator or RequestValidator() - super(OpenIDConnectImplicit, self).__init__( - request_validator=self.request_validator) - self.implicit = ImplicitGrant( - request_validator=request_validator) - self.implicit.register_response_type('id_token') - self.implicit.register_response_type('id_token token') - self.implicit.register_authorization_validator( + def __init__(self, request_validator=None, **kwargs): + self.proxy_target = ImplicitGrant( + request_validator=request_validator, **kwargs) + self.register_response_type('id_token') + self.register_response_type('id_token token') + self.register_authorization_validator( self.openid_authorization_validator) - self.implicit.register_authorization_validator( + self.register_authorization_validator( self.openid_implicit_authorization_validator) - self.implicit.register_token_modifier(self.add_id_token) - - def create_authorization_response(self, request, token_handler): - return self.create_token_response(request, token_handler) - - def create_token_response(self, request, token_handler): - return self.implicit.create_authorization_response( - request, token_handler) - - def validate_authorization_request(self, request): - """Validates the OpenID Connect authorization request parameters. - - :returns: (list of scopes, dict of request info) - """ - # If request.prompt is 'none' then no login/authorization form should - # be presented to the user. Instead, a silent login/authorization - # should be performed. - if request.prompt == 'none': - raise OIDCNoPrompt() - else: - return self.implicit.validate_authorization_request(request) - + self.register_token_modifier(self.add_id_token) class OpenIDConnectHybrid(OpenIDConnectBase): - def __init__(self, request_validator=None): + def __init__(self, request_validator=None, **kwargs): self.request_validator = request_validator or RequestValidator() - self.auth_code = AuthorizationCodeGrant( - request_validator=request_validator) - self.auth_code.register_response_type('code id_token') - self.auth_code.register_response_type('code token') - self.auth_code.register_response_type('code id_token token') - self.auth_code.register_authorization_validator( + self.proxy_target = AuthorizationCodeGrant( + request_validator=request_validator, **kwargs) + self.register_response_type('code id_token') + self.register_response_type('code token') + self.register_response_type('code id_token token') + self.register_authorization_validator( self.openid_authorization_validator) - self.auth_code.register_code_modifier(self.add_token) - self.auth_code.register_code_modifier(self.add_id_token) - self.auth_code.register_token_modifier(self.add_id_token) - - @property - def refresh_token(self): - return self.auth_code.refresh_token - - @refresh_token.setter - def refresh_token(self, value): - self.auth_code.refresh_token = value - - def create_authorization_code(self, request): - return self.auth_code.create_authorization_code(request) - - def create_authorization_response(self, request, token_handler): - return self.auth_code.create_authorization_response( - request, token_handler) - - def create_token_response(self, request, token_handler): - return self.auth_code.create_token_response(request, token_handler) - - def validate_authorization_request(self, request): - """Validates the OpenID Connect authorization request parameters. - - :returns: (list of scopes, dict of request info) - """ - # If request.prompt is 'none' then no login/authorization form should - # be presented to the user. Instead, a silent login/authorization - # should be performed. - if request.prompt == 'none': - raise OIDCNoPrompt() - else: - return self.auth_code.validate_authorization_request(request) - - def validate_token_request(self, request): - return self.auth_code.validate_token_request(request) - + # Hybrid flows can return the id_token from the authorization + # endpoint as part of the 'code' response + self.register_code_modifier(self.add_token) + self.register_code_modifier(self.add_id_token) + self.register_token_modifier(self.add_id_token) From 41f853f9d56bfb403b40f00054a56242e1be52ed Mon Sep 17 00:00:00 2001 From: Brendan McCollam Date: Mon, 21 Nov 2016 14:47:59 +0000 Subject: [PATCH 03/10] Adds tests for custom grant validators --- .../grant_types/test_authorization_code.py | 32 +++++++++++++++++++ .../grant_types/test_client_credentials.py | 14 ++++++++ .../rfc6749/grant_types/test_implicit.py | 19 +++++++++++ .../rfc6749/grant_types/test_refresh_token.py | 15 +++++++++ .../test_resource_owner_password.py | 15 +++++++++ 5 files changed, 95 insertions(+) diff --git a/tests/oauth2/rfc6749/grant_types/test_authorization_code.py b/tests/oauth2/rfc6749/grant_types/test_authorization_code.py index 18cd3f2a..216d1cd4 100644 --- a/tests/oauth2/rfc6749/grant_types/test_authorization_code.py +++ b/tests/oauth2/rfc6749/grant_types/test_authorization_code.py @@ -32,6 +32,38 @@ def set_client(self, request): request.client.client_id = 'mocked' return True + def test_custom_auth_validators(self): + self.authval1, self.authval2 = mock.Mock(), mock.Mock() + self.authval1.return_value = {} + self.authval2.return_value = {} + self.tknval1, self.tknval2 = mock.Mock(), mock.Mock() + self.auth.register_token_validator(self.tknval1, after_standard=False) + self.auth.register_token_validator(self.tknval2, after_standard=True) + self.auth.register_authorization_validator(self.authval1, after_standard=False) + self.auth.register_authorization_validator(self.authval2, after_standard=True) + + bearer = BearerToken(self.mock_validator) + self.auth.create_authorization_response(self.request, bearer) + self.assertTrue(self.authval1.called) + self.assertTrue(self.authval2.called) + self.assertFalse(self.tknval1.called) + self.assertFalse(self.tknval2.called) + + def test_custom_token_validators(self): + self.authval1, self.authval2 = mock.Mock(), mock.Mock() + self.tknval1, self.tknval2 = mock.Mock(), mock.Mock() + self.auth.register_token_validator(self.tknval1, after_standard=False) + self.auth.register_token_validator(self.tknval2, after_standard=True) + self.auth.register_authorization_validator(self.authval1, after_standard=False) + self.auth.register_authorization_validator(self.authval2, after_standard=True) + + bearer = BearerToken(self.mock_validator) + self.auth.create_token_response(self.request, bearer) + self.assertTrue(self.tknval1.called) + self.assertTrue(self.tknval2.called) + self.assertFalse(self.authval1.called) + self.assertFalse(self.authval2.called) + def test_create_authorization_grant(self): bearer = BearerToken(self.mock_validator) h, b, s = self.auth.create_authorization_response(self.request, bearer) diff --git a/tests/oauth2/rfc6749/grant_types/test_client_credentials.py b/tests/oauth2/rfc6749/grant_types/test_client_credentials.py index 0865c7ef..b52e265b 100644 --- a/tests/oauth2/rfc6749/grant_types/test_client_credentials.py +++ b/tests/oauth2/rfc6749/grant_types/test_client_credentials.py @@ -22,6 +22,20 @@ def setUp(self): self.auth = ClientCredentialsGrant( request_validator=self.mock_validator) + def test_custom_token_validators(self): + self.authval1, self.authval2 = mock.Mock(), mock.Mock() + self.tknval1, self.tknval2 = mock.Mock(), mock.Mock() + self.auth.register_token_validator(self.tknval1, after_standard=False) + self.auth.register_token_validator(self.tknval2, after_standard=True) + self.auth.register_authorization_validator(self.authval1, after_standard=False) + self.auth.register_authorization_validator(self.authval2, after_standard=True) + bearer = BearerToken(self.mock_validator) + self.auth.create_token_response(self.request, bearer) + self.assertTrue(self.tknval1.called) + self.assertTrue(self.tknval2.called) + self.assertFalse(self.authval1.called) + self.assertFalse(self.authval2.called) + def test_create_token_response(self): bearer = BearerToken(self.mock_validator) headers, body, status_code = self.auth.create_token_response( diff --git a/tests/oauth2/rfc6749/grant_types/test_implicit.py b/tests/oauth2/rfc6749/grant_types/test_implicit.py index cdeecb75..35862e49 100644 --- a/tests/oauth2/rfc6749/grant_types/test_implicit.py +++ b/tests/oauth2/rfc6749/grant_types/test_implicit.py @@ -40,5 +40,24 @@ def test_create_token_response(self, generate_token): h, b, s = self.auth.create_token_response(self.request, bearer) self.assertURLEqual(h['Location'], correct_uri) + def test_custom_validators(self): + self.authval1, self.authval2 = mock.Mock(), mock.Mock() + self.tknval1, self.tknval2 = mock.Mock(), mock.Mock() + for val in (self.authval1, self.authval2): + val.return_value = {} + for val in (self.tknval1, self.tknval2): + val.return_value = None + self.auth.register_token_validator(self.tknval1, after_standard=False) + self.auth.register_token_validator(self.tknval2, after_standard=True) + self.auth.register_authorization_validator(self.authval1, after_standard=False) + self.auth.register_authorization_validator(self.authval2, after_standard=True) + + bearer = BearerToken(self.mock_validator) + self.auth.create_token_response(self.request, bearer) + self.assertTrue(self.tknval1.called) + self.assertTrue(self.tknval2.called) + self.assertTrue(self.authval1.called) + self.assertTrue(self.authval2.called) + def test_error_response(self): pass diff --git a/tests/oauth2/rfc6749/grant_types/test_refresh_token.py b/tests/oauth2/rfc6749/grant_types/test_refresh_token.py index 125dc2b9..99e05d6d 100644 --- a/tests/oauth2/rfc6749/grant_types/test_refresh_token.py +++ b/tests/oauth2/rfc6749/grant_types/test_refresh_token.py @@ -36,6 +36,21 @@ def test_create_token_response(self): self.assertIn('expires_in', token) self.assertEqual(token['scope'], 'foo') + def test_custom_token_validators(self): + self.authval1, self.authval2 = mock.Mock(), mock.Mock() + self.tknval1, self.tknval2 = mock.Mock(), mock.Mock() + self.auth.register_token_validator(self.tknval1, after_standard=False) + self.auth.register_token_validator(self.tknval2, after_standard=True) + self.auth.register_authorization_validator(self.authval1, after_standard=False) + self.auth.register_authorization_validator(self.authval2, after_standard=True) + + bearer = BearerToken(self.mock_validator) + self.auth.create_token_response(self.request, bearer) + self.assertTrue(self.tknval1.called) + self.assertTrue(self.tknval2.called) + self.assertFalse(self.authval1.called) + self.assertFalse(self.authval2.called) + def test_create_token_inherit_scope(self): self.request.scope = None self.mock_validator.get_original_scopes.return_value = ['foo', 'bar'] diff --git a/tests/oauth2/rfc6749/grant_types/test_resource_owner_password.py b/tests/oauth2/rfc6749/grant_types/test_resource_owner_password.py index c6377534..4747a690 100644 --- a/tests/oauth2/rfc6749/grant_types/test_resource_owner_password.py +++ b/tests/oauth2/rfc6749/grant_types/test_resource_owner_password.py @@ -89,6 +89,21 @@ def test_create_token_response_without_refresh_token(self): self.assertEqual(status_code, 401) self.assertEqual(self.mock_validator.save_token.call_count, 0) + def test_custom_token_validators(self): + self.authval1, self.authval2 = mock.Mock(), mock.Mock() + self.tknval1, self.tknval2 = mock.Mock(), mock.Mock() + self.auth.register_token_validator(self.tknval1, after_standard=False) + self.auth.register_token_validator(self.tknval2, after_standard=True) + self.auth.register_authorization_validator(self.authval1, after_standard=False) + self.auth.register_authorization_validator(self.authval2, after_standard=True) + + bearer = BearerToken(self.mock_validator) + self.auth.create_token_response(self.request, bearer) + self.assertTrue(self.tknval1.called) + self.assertTrue(self.tknval2.called) + self.assertFalse(self.authval1.called) + self.assertFalse(self.authval2.called) + def test_error_response(self): pass From ff0e40884f07fde16b542b883a3c5703092d6b86 Mon Sep 17 00:00:00 2001 From: Brendan McCollam Date: Mon, 19 Dec 2016 15:53:55 +0000 Subject: [PATCH 04/10] Adds Brendan McCollam to AUTHORS --- AUTHORS | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 329eb97e..811679e0 100644 --- a/AUTHORS +++ b/AUTHORS @@ -24,4 +24,5 @@ Josh Turmel David Baumgold Juan Fabio GarcĂ­a Solero Omer Katz -Joel Stevenson \ No newline at end of file +Joel Stevenson +Brendan McCollam From b63590317fc353da5e39ec9eef9a1494eddb925e Mon Sep 17 00:00:00 2001 From: Brendan McCollam Date: Mon, 19 Dec 2016 16:03:22 +0000 Subject: [PATCH 05/10] Docs for custom validator registration --- docs/oauth2/grants/authcode.rst | 1 + docs/oauth2/grants/credentials.rst | 1 + docs/oauth2/grants/grants.rst | 6 +++++ docs/oauth2/grants/implicit.rst | 1 + docs/oauth2/grants/password.rst | 1 + oauthlib/oauth2/rfc6749/grant_types/base.py | 26 +++++++++++++++++++++ 6 files changed, 36 insertions(+) diff --git a/docs/oauth2/grants/authcode.rst b/docs/oauth2/grants/authcode.rst index c83e5a90..abf5d322 100644 --- a/docs/oauth2/grants/authcode.rst +++ b/docs/oauth2/grants/authcode.rst @@ -3,3 +3,4 @@ Authorization Code Grant .. autoclass:: oauthlib.oauth2.AuthorizationCodeGrant :members: + :inherited-members: diff --git a/docs/oauth2/grants/credentials.rst b/docs/oauth2/grants/credentials.rst index 790bce82..179c0890 100644 --- a/docs/oauth2/grants/credentials.rst +++ b/docs/oauth2/grants/credentials.rst @@ -3,3 +3,4 @@ Client Credentials Grant .. autoclass:: oauthlib.oauth2.ClientCredentialsGrant :members: + :inherited-members: diff --git a/docs/oauth2/grants/grants.rst b/docs/oauth2/grants/grants.rst index 95d3ffef..f4fcb56b 100644 --- a/docs/oauth2/grants/grants.rst +++ b/docs/oauth2/grants/grants.rst @@ -24,6 +24,12 @@ resources in various ways with different security credentials. Naturally, OAuth 2 allows for extension grant types to be defined and OAuthLib attempts to cater for easy inclusion of this as much as possible. +OAuthlib also offers hooks for registering your own custom validations for use +with the existing grant type handlers +(:py:meth:`oauthlib.oauth2.AuthorizationCodeGrant.register_authorization_validator`). +In some situations, this may be more convenient than subclassing or writing +your own extension grant type. + Certain grant types allow the issuing of refresh tokens which will allow a client to request new tokens for as long as you as provider allow them too. In general, OAuth 2 tokens should expire quickly and rather than annoying the user diff --git a/docs/oauth2/grants/implicit.rst b/docs/oauth2/grants/implicit.rst index 90490e34..6ffdc6e2 100644 --- a/docs/oauth2/grants/implicit.rst +++ b/docs/oauth2/grants/implicit.rst @@ -3,3 +3,4 @@ Implicit Grant .. autoclass:: oauthlib.oauth2.ImplicitGrant :members: + :inherited-members: diff --git a/docs/oauth2/grants/password.rst b/docs/oauth2/grants/password.rst index 0230c099..48b42f6b 100644 --- a/docs/oauth2/grants/password.rst +++ b/docs/oauth2/grants/password.rst @@ -3,3 +3,4 @@ Resource Owner Password Credentials Grant .. autoclass:: oauthlib.oauth2.ResourceOwnerPasswordCredentialsGrant :members: + :inherited-members: diff --git a/oauthlib/oauth2/rfc6749/grant_types/base.py b/oauthlib/oauth2/rfc6749/grant_types/base.py index 5a3888ad..98b75e7a 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/base.py +++ b/oauthlib/oauth2/rfc6749/grant_types/base.py @@ -44,12 +44,38 @@ def register_response_type(self, response_type): self.response_types.append(response_type) def register_authorization_validator(self, validator, after_standard=True): + """ + Register a validator callable to be invoked during calls to the + authorization endpoint. + + :param callable validator: callable that takes a request object and returns a + mapping of items to be included with the authorization validation + response. + :param bool after_standard: default: True, If True, the custom + validator is called after the standard authorization + validations. If False, the custom valdator is called before + the standard validations. + :returns: None + """ if after_standard: self._auth_validators_run_after_standard_ones.append(validator) else: self._auth_validators_run_before_standard_ones.append(validator) def register_token_validator(self, validator, after_standard=True): + """ + Register a validator callable to be invoked during calls to the + token endpoint (or the authorization endpoint during the implicit grant, + flow which returns tokens directly from the authorization endpoint). + + + :param callable validator: callable that takes a request object and returns None or + raises an exception if appropriate. + :param bool after_standard: default: True, If True, the custom + validator is called after the standard token validations. If False, + the custom valdator is called before the standard validations. + :returns: None + """ if after_standard: self._token_validators_run_after_standard_ones.append(validator) else: From a33c82e7be21429b5c376f2912a6a553cf87018b Mon Sep 17 00:00:00 2001 From: Brendan McCollam Date: Tue, 20 Dec 2016 11:15:44 +0000 Subject: [PATCH 06/10] Helper for implicit grant custom validators --- .../oauth2/rfc6749/grant_types/implicit.py | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/oauthlib/oauth2/rfc6749/grant_types/implicit.py b/oauthlib/oauth2/rfc6749/grant_types/implicit.py index 72bb94ca..59f02bdc 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/implicit.py +++ b/oauthlib/oauth2/rfc6749/grant_types/implicit.py @@ -318,14 +318,11 @@ def validate_token_request(self, request): # Then check for normal errors. request_info = {} - # For implicit grant, auth_validators and token_validators are - # basically equivalent since the token is returned from the - # authorization endpoint. - for validator in chain(self._token_validators_run_before_standard_ones, - self._auth_validators_run_before_standard_ones): - result = validator(request) - if result is not None: - request_info.update(result) + + self._run_custom_validators(request, request_info, + self._auth_validators_run_before_standard_ones, + self._token_validators_run_before_standard_ones) + # If the resource owner denies the access request or if the request # fails for reasons other than a missing or invalid redirection URI, @@ -366,13 +363,23 @@ def validate_token_request(self, request): 'request': request, }) + self._run_custom_validators(request, request_info, + self._auth_validators_run_after_standard_ones, + self._token_validators_run_after_standard_ones) + + return request.scopes, request_info + + + def _run_custom_validators(self, + request, + request_info, + auth_validators, + token_validators): # For implicit grant, auth_validators and token_validators are # basically equivalent since the token is returned from the # authorization endpoint. - for validator in chain(self._auth_validators_run_after_standard_ones, - self._token_validators_run_after_standard_ones): + for validator in chain(auth_validators, token_validators): result = validator(request) if result is not None: request_info.update(result) - - return request.scopes, request_info + return request_info From c6d84da649789151a274e091df3f23e99e92dd0c Mon Sep 17 00:00:00 2001 From: Brendan McCollam Date: Tue, 20 Dec 2016 14:55:04 +0000 Subject: [PATCH 07/10] Per code review, _run_custom_validators() doesn't mutate request_info --- oauthlib/oauth2/rfc6749/grant_types/implicit.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/oauthlib/oauth2/rfc6749/grant_types/implicit.py b/oauthlib/oauth2/rfc6749/grant_types/implicit.py index 59f02bdc..5d72367b 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/implicit.py +++ b/oauthlib/oauth2/rfc6749/grant_types/implicit.py @@ -317,9 +317,7 @@ def validate_token_request(self, request): # Then check for normal errors. - request_info = {} - - self._run_custom_validators(request, request_info, + request_info = self._run_custom_validators(request, self._auth_validators_run_before_standard_ones, self._token_validators_run_before_standard_ones) @@ -363,18 +361,21 @@ def validate_token_request(self, request): 'request': request, }) - self._run_custom_validators(request, request_info, + request_info = self._run_custom_validators(request, self._auth_validators_run_after_standard_ones, - self._token_validators_run_after_standard_ones) + self._token_validators_run_after_standard_ones, + request_info) return request.scopes, request_info def _run_custom_validators(self, request, - request_info, auth_validators, - token_validators): + token_validators, + request_info=None): + # Make a copy so we don't modify the existing request_info dict + request_info = {} if request_info is None else request_info.copy() # For implicit grant, auth_validators and token_validators are # basically equivalent since the token is returned from the # authorization endpoint. From d0a59bd8a4fbb44f2ad0218becfbccf49d91cb4d Mon Sep 17 00:00:00 2001 From: Brendan McCollam Date: Thu, 22 Dec 2016 13:27:55 +0000 Subject: [PATCH 08/10] Refactor custom validators registration --- .../rfc6749/grant_types/authorization_code.py | 8 +- oauthlib/oauth2/rfc6749/grant_types/base.py | 85 +++++++++---------- .../rfc6749/grant_types/client_credentials.py | 4 +- .../oauth2/rfc6749/grant_types/implicit.py | 12 +-- .../rfc6749/grant_types/openid_connect.py | 8 +- .../rfc6749/grant_types/refresh_token.py | 4 +- .../resource_owner_password_credentials.py | 4 +- 7 files changed, 57 insertions(+), 68 deletions(-) diff --git a/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py b/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py index 1ac40f3b..3d427ab8 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py +++ b/oauthlib/oauth2/rfc6749/grant_types/authorization_code.py @@ -323,7 +323,7 @@ def validate_authorization_request(self, request): # populated through the use of specific exceptions. request_info = {} - for validator in self._auth_validators_run_before_standard_ones: + for validator in self.custom_validators.pre_auth: request_info.update(validator(request)) # REQUIRED. @@ -354,7 +354,7 @@ def validate_authorization_request(self, request): 'request': request }) - for validator in self._auth_validators_run_after_standard_ones: + for validator in self.custom_validators.post_auth: request_info.update(validator(request)) return request.scopes, request_info @@ -364,7 +364,7 @@ def validate_token_request(self, request): if request.grant_type not in ('authorization_code', 'openid'): raise errors.UnsupportedGrantTypeError(request=request) - for validator in self._token_validators_run_before_standard_ones: + for validator in self.custom_validators.pre_token: validator(request) if request.code is None: @@ -423,5 +423,5 @@ def validate_token_request(self, request): request.redirect_uri, request.client_id, request.client) raise errors.AccessDeniedError(request=request) - for validator in self._token_validators_run_after_standard_ones: + for validator in self.custom_validators.post_token: validator(request) diff --git a/oauthlib/oauth2/rfc6749/grant_types/base.py b/oauthlib/oauth2/rfc6749/grant_types/base.py index 98b75e7a..7d3befd0 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/base.py +++ b/oauthlib/oauth2/rfc6749/grant_types/base.py @@ -4,6 +4,7 @@ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ """ from __future__ import unicode_literals, absolute_import +from itertools import chain import logging @@ -13,6 +14,28 @@ log = logging.getLogger(__name__) +class ValidatorsContainer(object): + + """ + Container object for holding validator callables to be invoked. + """ + + def __init__(self, + post_auth=None, post_token=None, + pre_auth=None, pre_token=None): + self.pre_auth = pre_auth + self.post_auth = post_auth + self.pre_token = pre_token + self.post_token = post_token + + @property + def all_pre(self): + return chain(self.pre_auth, self.pre_token) + + @property + def all_post(self): + return chain(self.post_auth, self.post_token) + class GrantTypeBase(object): error_uri = None @@ -27,60 +50,30 @@ def __init__(self, request_validator=None, **kwargs): # Transforms class variables into instance variables: self.response_types = self.response_types self.refresh_token = self.refresh_token + self._setup_custom_validators(kwargs) + self._code_modifiers = [] + self._token_modifiers = [] - self._setup_validator_hooks() for kw, val in kwargs.items(): setattr(self, kw, val) - def _setup_validator_hooks(self): - self._auth_validators_run_before_standard_ones = [] - self._auth_validators_run_after_standard_ones = [] - self._token_validators_run_before_standard_ones = [] - self._token_validators_run_after_standard_ones = [] - self._code_modifiers = [] - self._token_modifiers = [] + def _setup_custom_validators(self, kwargs): + post_auth = kwargs.get('post_auth', []) + post_token = kwargs.get('post_token', []) + pre_auth = kwargs.get('pre_auth', []) + pre_token = kwargs.get('pre_token', []) + if not hasattr(self, 'validate_authorization_request'): + if post_auth or pre_auth: + msg = ("{} does not support authorization validators. Use " + "token validators instead.").format(self.__class__.__name__) + raise ValueError(msg) + post_auth, pre_auth = (), () + self.custom_validators = ValidatorsContainer(post_auth, post_token, + pre_auth, pre_token) def register_response_type(self, response_type): self.response_types.append(response_type) - def register_authorization_validator(self, validator, after_standard=True): - """ - Register a validator callable to be invoked during calls to the - authorization endpoint. - - :param callable validator: callable that takes a request object and returns a - mapping of items to be included with the authorization validation - response. - :param bool after_standard: default: True, If True, the custom - validator is called after the standard authorization - validations. If False, the custom valdator is called before - the standard validations. - :returns: None - """ - if after_standard: - self._auth_validators_run_after_standard_ones.append(validator) - else: - self._auth_validators_run_before_standard_ones.append(validator) - - def register_token_validator(self, validator, after_standard=True): - """ - Register a validator callable to be invoked during calls to the - token endpoint (or the authorization endpoint during the implicit grant, - flow which returns tokens directly from the authorization endpoint). - - - :param callable validator: callable that takes a request object and returns None or - raises an exception if appropriate. - :param bool after_standard: default: True, If True, the custom - validator is called after the standard token validations. If False, - the custom valdator is called before the standard validations. - :returns: None - """ - if after_standard: - self._token_validators_run_after_standard_ones.append(validator) - else: - self._token_validators_run_before_standard_ones.append(validator) - def register_code_modifier(self, modifier): self._code_modifiers.append(modifier) diff --git a/oauthlib/oauth2/rfc6749/grant_types/client_credentials.py b/oauthlib/oauth2/rfc6749/grant_types/client_credentials.py index 1306dafd..dd0a39a9 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/client_credentials.py +++ b/oauthlib/oauth2/rfc6749/grant_types/client_credentials.py @@ -85,7 +85,7 @@ def create_token_response(self, request, token_handler): return headers, json.dumps(token), 200 def validate_token_request(self, request): - for validator in self._token_validators_run_before_standard_ones: + for validator in self.custom_validators.pre_token: validator(request) if not getattr(request, 'grant_type', None): @@ -116,5 +116,5 @@ def validate_token_request(self, request): request.client_id = request.client_id or request.client.client_id self.validate_scopes(request) - for validator in self._token_validators_run_after_standard_ones: + for validator in self.custom_validators.post_token: validator(request) diff --git a/oauthlib/oauth2/rfc6749/grant_types/implicit.py b/oauthlib/oauth2/rfc6749/grant_types/implicit.py index 5d72367b..51e95af6 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/implicit.py +++ b/oauthlib/oauth2/rfc6749/grant_types/implicit.py @@ -5,7 +5,6 @@ """ from __future__ import unicode_literals, absolute_import -from itertools import chain import logging from oauthlib import common @@ -318,8 +317,7 @@ def validate_token_request(self, request): # Then check for normal errors. request_info = self._run_custom_validators(request, - self._auth_validators_run_before_standard_ones, - self._token_validators_run_before_standard_ones) + self.custom_validators.all_pre) # If the resource owner denies the access request or if the request @@ -362,8 +360,7 @@ def validate_token_request(self, request): }) request_info = self._run_custom_validators(request, - self._auth_validators_run_after_standard_ones, - self._token_validators_run_after_standard_ones, + self.custom_validators.all_post, request_info) return request.scopes, request_info @@ -371,15 +368,14 @@ def validate_token_request(self, request): def _run_custom_validators(self, request, - auth_validators, - token_validators, + validations, request_info=None): # Make a copy so we don't modify the existing request_info dict request_info = {} if request_info is None else request_info.copy() # For implicit grant, auth_validators and token_validators are # basically equivalent since the token is returned from the # authorization endpoint. - for validator in chain(auth_validators, token_validators): + for validator in validations: result = validator(request) if result is not None: request_info.update(result) diff --git a/oauthlib/oauth2/rfc6749/grant_types/openid_connect.py b/oauthlib/oauth2/rfc6749/grant_types/openid_connect.py index 4dfe934a..e59b8a04 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/openid_connect.py +++ b/oauthlib/oauth2/rfc6749/grant_types/openid_connect.py @@ -335,7 +335,7 @@ class OpenIDConnectAuthCode(OpenIDConnectBase): def __init__(self, request_validator=None, **kwargs): self.proxy_target = AuthorizationCodeGrant( request_validator=request_validator, **kwargs) - self.register_authorization_validator( + self.custom_validators.post_auth.append( self.openid_authorization_validator) self.register_token_modifier(self.add_id_token) @@ -346,9 +346,9 @@ def __init__(self, request_validator=None, **kwargs): request_validator=request_validator, **kwargs) self.register_response_type('id_token') self.register_response_type('id_token token') - self.register_authorization_validator( + self.custom_validators.post_auth.append( self.openid_authorization_validator) - self.register_authorization_validator( + self.custom_validators.post_auth.append( self.openid_implicit_authorization_validator) self.register_token_modifier(self.add_id_token) @@ -362,7 +362,7 @@ def __init__(self, request_validator=None, **kwargs): self.register_response_type('code id_token') self.register_response_type('code token') self.register_response_type('code id_token token') - self.register_authorization_validator( + self.custom_validators.post_auth.append( self.openid_authorization_validator) # Hybrid flows can return the id_token from the authorization # endpoint as part of the 'code' response diff --git a/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py b/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py index def630ce..396668b7 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py +++ b/oauthlib/oauth2/rfc6749/grant_types/refresh_token.py @@ -76,7 +76,7 @@ def validate_token_request(self, request): if request.grant_type != 'refresh_token': raise errors.UnsupportedGrantTypeError(request=request) - for validator in self._token_validators_run_before_standard_ones: + for validator in self.custom_validators.pre_token: validator(request) if request.refresh_token is None: @@ -127,5 +127,5 @@ def validate_token_request(self, request): else: request.scopes = original_scopes - for validator in self._token_validators_run_after_standard_ones: + for validator in self.custom_validators.post_token: validator(request) diff --git a/oauthlib/oauth2/rfc6749/grant_types/resource_owner_password_credentials.py b/oauthlib/oauth2/rfc6749/grant_types/resource_owner_password_credentials.py index 2ef6e163..f7552409 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/resource_owner_password_credentials.py +++ b/oauthlib/oauth2/rfc6749/grant_types/resource_owner_password_credentials.py @@ -156,7 +156,7 @@ def validate_token_request(self, request): .. _`Section 3.3`: http://tools.ietf.org/html/rfc6749#section-3.3 .. _`Section 3.2.1`: http://tools.ietf.org/html/rfc6749#section-3.2.1 """ - for validator in self._token_validators_run_before_standard_ones: + for validator in self.custom_validators.pre_token: validator(request) for param in ('grant_type', 'username', 'password'): @@ -193,5 +193,5 @@ def validate_token_request(self, request): request.client_id = request.client_id or request.client.client_id self.validate_scopes(request) - for validator in self._token_validators_run_after_standard_ones: + for validator in self.custom_validators.post_token: validator(request) From 790310a6e1e7e1ce434ef478507d9a1c9464e298 Mon Sep 17 00:00:00 2001 From: Brendan McCollam Date: Thu, 22 Dec 2016 15:00:15 +0000 Subject: [PATCH 09/10] Update custom validator tests --- .../grant_types/test_authorization_code.py | 22 ++++++------- .../grant_types/test_client_credentials.py | 31 +++++++++++------ .../rfc6749/grant_types/test_implicit.py | 8 ++--- .../rfc6749/grant_types/test_refresh_token.py | 30 +++++++++++------ .../test_resource_owner_password.py | 33 +++++++++++++------ 5 files changed, 79 insertions(+), 45 deletions(-) diff --git a/tests/oauth2/rfc6749/grant_types/test_authorization_code.py b/tests/oauth2/rfc6749/grant_types/test_authorization_code.py index 216d1cd4..c5e68699 100644 --- a/tests/oauth2/rfc6749/grant_types/test_authorization_code.py +++ b/tests/oauth2/rfc6749/grant_types/test_authorization_code.py @@ -32,15 +32,20 @@ def set_client(self, request): request.client.client_id = 'mocked' return True - def test_custom_auth_validators(self): + def setup_validators(self): self.authval1, self.authval2 = mock.Mock(), mock.Mock() self.authval1.return_value = {} self.authval2.return_value = {} self.tknval1, self.tknval2 = mock.Mock(), mock.Mock() - self.auth.register_token_validator(self.tknval1, after_standard=False) - self.auth.register_token_validator(self.tknval2, after_standard=True) - self.auth.register_authorization_validator(self.authval1, after_standard=False) - self.auth.register_authorization_validator(self.authval2, after_standard=True) + self.tknval1.return_value = None + self.tknval2.return_value = None + self.auth.custom_validators.pre_token.append(self.tknval1) + self.auth.custom_validators.post_token.append(self.tknval2) + self.auth.custom_validators.pre_auth.append(self.authval1) + self.auth.custom_validators.post_auth.append(self.authval2) + + def test_custom_auth_validators(self): + self.setup_validators() bearer = BearerToken(self.mock_validator) self.auth.create_authorization_response(self.request, bearer) @@ -50,12 +55,7 @@ def test_custom_auth_validators(self): self.assertFalse(self.tknval2.called) def test_custom_token_validators(self): - self.authval1, self.authval2 = mock.Mock(), mock.Mock() - self.tknval1, self.tknval2 = mock.Mock(), mock.Mock() - self.auth.register_token_validator(self.tknval1, after_standard=False) - self.auth.register_token_validator(self.tknval2, after_standard=True) - self.auth.register_authorization_validator(self.authval1, after_standard=False) - self.auth.register_authorization_validator(self.authval2, after_standard=True) + self.setup_validators() bearer = BearerToken(self.mock_validator) self.auth.create_token_response(self.request, bearer) diff --git a/tests/oauth2/rfc6749/grant_types/test_client_credentials.py b/tests/oauth2/rfc6749/grant_types/test_client_credentials.py index b52e265b..1698abf1 100644 --- a/tests/oauth2/rfc6749/grant_types/test_client_credentials.py +++ b/tests/oauth2/rfc6749/grant_types/test_client_credentials.py @@ -22,19 +22,30 @@ def setUp(self): self.auth = ClientCredentialsGrant( request_validator=self.mock_validator) + def test_custom_auth_validators_unsupported(self): + authval1, authval2 = mock.Mock(), mock.Mock() + expected = ('ClientCredentialsGrant does not support authorization ' + 'validators. Use token validators instead.') + with self.assertRaises(ValueError) as caught: + ClientCredentialsGrant(self.mock_validator, pre_auth=[authval1]) + self.assertEqual(caught.exception.args[0], expected) + with self.assertRaises(ValueError) as caught: + ClientCredentialsGrant(self.mock_validator, post_auth=[authval2]) + self.assertEqual(caught.exception.args[0], expected) + with self.assertRaises(AttributeError): + self.auth.custom_validators.pre_auth.append(authval1) + with self.assertRaises(AttributeError): + self.auth.custom_validators.pre_auth.append(authval2) + def test_custom_token_validators(self): - self.authval1, self.authval2 = mock.Mock(), mock.Mock() - self.tknval1, self.tknval2 = mock.Mock(), mock.Mock() - self.auth.register_token_validator(self.tknval1, after_standard=False) - self.auth.register_token_validator(self.tknval2, after_standard=True) - self.auth.register_authorization_validator(self.authval1, after_standard=False) - self.auth.register_authorization_validator(self.authval2, after_standard=True) + tknval1, tknval2 = mock.Mock(), mock.Mock() + self.auth.custom_validators.pre_token.append(tknval1) + self.auth.custom_validators.post_token.append(tknval2) + bearer = BearerToken(self.mock_validator) self.auth.create_token_response(self.request, bearer) - self.assertTrue(self.tknval1.called) - self.assertTrue(self.tknval2.called) - self.assertFalse(self.authval1.called) - self.assertFalse(self.authval2.called) + self.assertTrue(tknval1.called) + self.assertTrue(tknval2.called) def test_create_token_response(self): bearer = BearerToken(self.mock_validator) diff --git a/tests/oauth2/rfc6749/grant_types/test_implicit.py b/tests/oauth2/rfc6749/grant_types/test_implicit.py index 35862e49..53f4ac34 100644 --- a/tests/oauth2/rfc6749/grant_types/test_implicit.py +++ b/tests/oauth2/rfc6749/grant_types/test_implicit.py @@ -47,10 +47,10 @@ def test_custom_validators(self): val.return_value = {} for val in (self.tknval1, self.tknval2): val.return_value = None - self.auth.register_token_validator(self.tknval1, after_standard=False) - self.auth.register_token_validator(self.tknval2, after_standard=True) - self.auth.register_authorization_validator(self.authval1, after_standard=False) - self.auth.register_authorization_validator(self.authval2, after_standard=True) + self.auth.custom_validators.pre_token.append(self.tknval1) + self.auth.custom_validators.post_token.append(self.tknval2) + self.auth.custom_validators.pre_auth.append(self.authval1) + self.auth.custom_validators.post_auth.append(self.authval2) bearer = BearerToken(self.mock_validator) self.auth.create_token_response(self.request, bearer) diff --git a/tests/oauth2/rfc6749/grant_types/test_refresh_token.py b/tests/oauth2/rfc6749/grant_types/test_refresh_token.py index 99e05d6d..56711559 100644 --- a/tests/oauth2/rfc6749/grant_types/test_refresh_token.py +++ b/tests/oauth2/rfc6749/grant_types/test_refresh_token.py @@ -36,20 +36,30 @@ def test_create_token_response(self): self.assertIn('expires_in', token) self.assertEqual(token['scope'], 'foo') + def test_custom_auth_validators_unsupported(self): + authval1, authval2 = mock.Mock(), mock.Mock() + expected = ('RefreshTokenGrant does not support authorization ' + 'validators. Use token validators instead.') + with self.assertRaises(ValueError) as caught: + RefreshTokenGrant(self.mock_validator, pre_auth=[authval1]) + self.assertEqual(caught.exception.args[0], expected) + with self.assertRaises(ValueError) as caught: + RefreshTokenGrant(self.mock_validator, post_auth=[authval2]) + self.assertEqual(caught.exception.args[0], expected) + with self.assertRaises(AttributeError): + self.auth.custom_validators.pre_auth.append(authval1) + with self.assertRaises(AttributeError): + self.auth.custom_validators.pre_auth.append(authval2) + def test_custom_token_validators(self): - self.authval1, self.authval2 = mock.Mock(), mock.Mock() - self.tknval1, self.tknval2 = mock.Mock(), mock.Mock() - self.auth.register_token_validator(self.tknval1, after_standard=False) - self.auth.register_token_validator(self.tknval2, after_standard=True) - self.auth.register_authorization_validator(self.authval1, after_standard=False) - self.auth.register_authorization_validator(self.authval2, after_standard=True) + tknval1, tknval2 = mock.Mock(), mock.Mock() + self.auth.custom_validators.pre_token.append(tknval1) + self.auth.custom_validators.post_token.append(tknval2) bearer = BearerToken(self.mock_validator) self.auth.create_token_response(self.request, bearer) - self.assertTrue(self.tknval1.called) - self.assertTrue(self.tknval2.called) - self.assertFalse(self.authval1.called) - self.assertFalse(self.authval2.called) + self.assertTrue(tknval1.called) + self.assertTrue(tknval2.called) def test_create_token_inherit_scope(self): self.request.scope = None diff --git a/tests/oauth2/rfc6749/grant_types/test_resource_owner_password.py b/tests/oauth2/rfc6749/grant_types/test_resource_owner_password.py index 4747a690..006ed7cb 100644 --- a/tests/oauth2/rfc6749/grant_types/test_resource_owner_password.py +++ b/tests/oauth2/rfc6749/grant_types/test_resource_owner_password.py @@ -89,20 +89,33 @@ def test_create_token_response_without_refresh_token(self): self.assertEqual(status_code, 401) self.assertEqual(self.mock_validator.save_token.call_count, 0) + def test_custom_auth_validators_unsupported(self): + authval1, authval2 = mock.Mock(), mock.Mock() + expected = ('ResourceOwnerPasswordCredentialsGrant does not ' + 'support authorization validators. Use token ' + 'validators instead.') + with self.assertRaises(ValueError) as caught: + ResourceOwnerPasswordCredentialsGrant(self.mock_validator, + pre_auth=[authval1]) + self.assertEqual(caught.exception.args[0], expected) + with self.assertRaises(ValueError) as caught: + ResourceOwnerPasswordCredentialsGrant(self.mock_validator, + post_auth=[authval2]) + self.assertEqual(caught.exception.args[0], expected) + with self.assertRaises(AttributeError): + self.auth.custom_validators.pre_auth.append(authval1) + with self.assertRaises(AttributeError): + self.auth.custom_validators.pre_auth.append(authval2) + def test_custom_token_validators(self): - self.authval1, self.authval2 = mock.Mock(), mock.Mock() - self.tknval1, self.tknval2 = mock.Mock(), mock.Mock() - self.auth.register_token_validator(self.tknval1, after_standard=False) - self.auth.register_token_validator(self.tknval2, after_standard=True) - self.auth.register_authorization_validator(self.authval1, after_standard=False) - self.auth.register_authorization_validator(self.authval2, after_standard=True) + tknval1, tknval2 = mock.Mock(), mock.Mock() + self.auth.custom_validators.pre_token.append(tknval1) + self.auth.custom_validators.post_token.append(tknval2) bearer = BearerToken(self.mock_validator) self.auth.create_token_response(self.request, bearer) - self.assertTrue(self.tknval1.called) - self.assertTrue(self.tknval2.called) - self.assertFalse(self.authval1.called) - self.assertFalse(self.authval2.called) + self.assertTrue(tknval1.called) + self.assertTrue(tknval2.called) def test_error_response(self): pass From cf414259b2b03664cf51e12532035e2a4ef130f3 Mon Sep 17 00:00:00 2001 From: Brendan McCollam Date: Thu, 22 Dec 2016 16:39:50 +0000 Subject: [PATCH 10/10] Updated docs for custom validators --- docs/oauth2/grants/custom_validators.rst | 5 +++ docs/oauth2/grants/grants.rst | 3 +- oauthlib/oauth2/rfc6749/grant_types/base.py | 45 ++++++++++++++++++--- 3 files changed, 47 insertions(+), 6 deletions(-) create mode 100644 docs/oauth2/grants/custom_validators.rst diff --git a/docs/oauth2/grants/custom_validators.rst b/docs/oauth2/grants/custom_validators.rst new file mode 100644 index 00000000..4629e6f4 --- /dev/null +++ b/docs/oauth2/grants/custom_validators.rst @@ -0,0 +1,5 @@ +Custom Validators +----------------- + +.. autoclass:: oauthlib.oauth2.rfc6749.grant_types.base.ValidatorsContainer + :members: diff --git a/docs/oauth2/grants/grants.rst b/docs/oauth2/grants/grants.rst index f4fcb56b..16b17be5 100644 --- a/docs/oauth2/grants/grants.rst +++ b/docs/oauth2/grants/grants.rst @@ -9,6 +9,7 @@ Grant types implicit password credentials + custom_validators jwt Grant types are what make OAuth 2 so flexible. The Authorization Code grant is @@ -26,7 +27,7 @@ attempts to cater for easy inclusion of this as much as possible. OAuthlib also offers hooks for registering your own custom validations for use with the existing grant type handlers -(:py:meth:`oauthlib.oauth2.AuthorizationCodeGrant.register_authorization_validator`). +(:py:class:`oauthlib.oauth2.rfc6749.grant_types.base.ValidatorsContainer`). In some situations, this may be more convenient than subclassing or writing your own extension grant type. diff --git a/oauthlib/oauth2/rfc6749/grant_types/base.py b/oauthlib/oauth2/rfc6749/grant_types/base.py index 7d3befd0..1128388d 100644 --- a/oauthlib/oauth2/rfc6749/grant_types/base.py +++ b/oauthlib/oauth2/rfc6749/grant_types/base.py @@ -15,14 +15,48 @@ log = logging.getLogger(__name__) class ValidatorsContainer(object): - """ - Container object for holding validator callables to be invoked. + Container object for holding custom validator callables to be invoked + as part of the grant type `validate_authorization_request()` or + `validate_authorization_request()` methods on the various grant types. + + Authorization validators must be callables that take a request object and + return a dict, which may contain items to be added to the `request_info` + returned from the grant_type after validation. + + Token validators must be callables that take a request object and + return None. + + Both authorization validators and token validators may raise OAuth2 + exceptions if validation conditions fail. + + Authorization validators added to `pre_auth` will be run BEFORE + the standard validations (but after the critical ones that raise + fatal errors) as part of `validate_authorization_request()` + + Authorization validators added to `post_auth` will be run AFTER + the standard validations as part of `validate_authorization_request()` + + Token validators added to `pre_token` will be run BEFORE + the standard validations as part of `validate_token_request()` + + Token validators added to `post_token` will be run AFTER + the standard validations as part of `validate_token_request()` + + For example: + + >>> def my_auth_validator(request): + ... return {'myval': True} + >>> auth_code_grant = AuthorizationCodeGrant(request_validator) + >>> auth_code_grant.custom_validators.pre_auth.append(my_auth_validator) + >>> def my_token_validator(request): + ... if not request.everything_okay: + ... raise errors.OAuth2Error("uh-oh") + >>> auth_code_grant.custom_validators.post_token.append(my_token_validator) """ - def __init__(self, - post_auth=None, post_token=None, - pre_auth=None, pre_token=None): + def __init__(self, post_auth, post_token, + pre_auth, pre_token): self.pre_auth = pre_auth self.post_auth = post_auth self.pre_token = pre_token @@ -67,6 +101,7 @@ def _setup_custom_validators(self, kwargs): msg = ("{} does not support authorization validators. Use " "token validators instead.").format(self.__class__.__name__) raise ValueError(msg) + # Using tuples here because they can't be appended to: post_auth, pre_auth = (), () self.custom_validators = ValidatorsContainer(post_auth, post_token, pre_auth, pre_token)