Skip to content

Commit

Permalink
Merge pull request #719 from yuvipanda/required_scopes
Browse files Browse the repository at this point in the history
Add `allowed_scopes` to all authenticators to allow some users based on granted scopes
  • Loading branch information
minrk committed Apr 26, 2024
2 parents 1f0cbc0 + 00360fa commit ed6b97e
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 1 deletion.
41 changes: 41 additions & 0 deletions oauthenticator/oauth2.py
Expand Up @@ -483,6 +483,36 @@ def _logout_redirect_url_default(self):
""",
)

allowed_scopes = List(
Unicode(),
config=True,
help="""
Allow users who have been granted *all* these scopes to log in.
We request all the scopes listed in the 'scope' config, but only a
subset of these may be granted by the authorization server. This may
happen if the user does not have permissions to access a requested
scope, or has chosen to not give consent for a particular scope. If the
scopes listed in this config are not granted, the user will not be
allowed to log in.
The granted scopes will be part of the access token (fetched from self.token_url).
See https://datatracker.ietf.org/doc/html/rfc6749#section-3.3 for more
information.
See the OAuth documentation of your OAuth provider for various options.
""",
)

@validate('allowed_scopes')
def _allowed_scopes_validation(self, proposal):
# allowed scopes must be a subset of requested scopes
if set(proposal.value) - set(self.scope):
raise ValueError(
f"Allowed scopes must be a subset of requested scopes. {self.scope} is requested but {proposal.value} is allowed"
)
return proposal.value

extra_authorize_params = Dict(
config=True,
help="""
Expand Down Expand Up @@ -1060,6 +1090,8 @@ async def check_allowed(self, username, auth_model):
"""
Returns True for users allowed to be authorized
If a user must be *disallowed*, raises a 403 exception.
Overrides Authenticator.check_allowed that is called from
`Authenticator.get_authenticated_user` after
`OAuthenticator.authenticate` has been called, and therefore also after
Expand All @@ -1074,6 +1106,15 @@ async def check_allowed(self, username, auth_model):
if auth_model is None:
return True

# Allow users who have been granted specific scopes that grant them entry
if self.allowed_scopes:
granted_scopes = auth_model.get('auth_state', {}).get('scope', [])
missing_scopes = set(self.allowed_scopes) - set(granted_scopes)
if not missing_scopes:
message = f"Granting access to user {username}, as they had {self.allowed_scopes}"
self.log.info(message)
return True

if self.allow_all:
return True

Expand Down
6 changes: 6 additions & 0 deletions oauthenticator/tests/mocks.py
Expand Up @@ -104,6 +104,8 @@ def setup_oauth_mock(
access_token_path,
user_path=None,
token_type='Bearer',
token_request_style='post',
scope="",
):
"""setup the mock client for OAuth
Expand All @@ -125,6 +127,7 @@ def setup_oauth_mock(
access_token_path (str): The path for the access token request (e.g. /access_token)
user_path (str): The path for requesting (e.g. /user)
token_type (str): the token_type field for the provider
scope (str): The scope field returned by the provider
"""

client.oauth_codes = oauth_codes = {}
Expand Down Expand Up @@ -161,6 +164,8 @@ def access_token(request):
'access_token': token,
'token_type': token_type,
}
if scope:
model['scope'] = scope
if 'id_token' in user:
model['id_token'] = user['id_token']
return model
Expand All @@ -172,6 +177,7 @@ def get_user(request):
token = auth_header.split(None, 1)[1]
else:
query = parse_qs(urlparse(request.url).query)

if 'access_token' in query:
token = query['access_token'][0]
else:
Expand Down
35 changes: 34 additions & 1 deletion oauthenticator/tests/test_generic.py
@@ -1,8 +1,9 @@
import json
import re
from functools import partial

import jwt
from pytest import fixture, mark
from pytest import fixture, mark, raises
from traitlets.config import Config

from ..generic import GenericOAuthenticator
Expand Down Expand Up @@ -35,6 +36,7 @@ def generic_client(client):
host='generic.horse',
access_token_path='/oauth/access_token',
user_path='/oauth/userinfo',
scope='basic',
)
return client

Expand Down Expand Up @@ -293,6 +295,37 @@ async def test_generic_data(get_authenticator, generic_client):
assert auth_model


@mark.parametrize(
["allowed_scopes", "allowed"], [(["advanced"], False), (["basic"], True)]
)
async def test_allowed_scopes(
get_authenticator, generic_client, allowed_scopes, allowed
):
c = Config()
c.GenericOAuthenticator.allowed_scopes = allowed_scopes
c.GenericOAuthenticator.scope = list(allowed_scopes)
authenticator = get_authenticator(config=c)

handled_user_model = user_model("user1")
handler = generic_client.handler_for_user(handled_user_model)
auth_model = await authenticator.authenticate(handler)
assert allowed == await authenticator.check_allowed(auth_model["name"], auth_model)


async def test_allowed_scopes_validation_scope_subset(get_authenticator):
c = Config()
# Test that if we require more scopes than we request, validation fails
c.GenericOAuthenticator.allowed_scopes = ["a", "b"]
c.GenericOAuthenticator.scope = ["a"]
with raises(
ValueError,
match=re.escape(
"Allowed scopes must be a subset of requested scopes. ['a'] is requested but ['a', 'b'] is allowed"
),
):
get_authenticator(config=c)


async def test_generic_callable_username_key(get_authenticator, generic_client):
c = Config()
c.GenericOAuthenticator.allow_all = True
Expand Down

0 comments on commit ed6b97e

Please sign in to comment.