Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support callable for GenericOAuthenticator username_key #305

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions oauthenticator/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from jupyterhub.auth import LocalAuthenticator

from traitlets import Unicode, Dict, Bool
from traitlets import Unicode, Dict, Bool, Union
from .traitlets import Callable

from .oauth2 import OAuthLoginHandler, OAuthenticator

Expand Down Expand Up @@ -53,11 +54,22 @@ class GenericOAuthenticator(OAuthenticator):
help="Extra parameters for first POST request"
).tag(config=True)

username_key = Unicode(
os.environ.get('OAUTH2_USERNAME_KEY', 'username'),
username_key = Union(
[
Unicode(os.environ.get('OAUTH2_USERNAME_KEY', 'username')),
Callable()
],
config=True,
help="Userdata username key from returned json for USERDATA_URL"
help="""
Userdata username key from returned json for USERDATA_URL.

Can be a string key name or a callable that accepts the returned
json (as a dict) and returns the username. The callable is useful
e.g. for extracting the username from a nested object in the
response.
"""
)

userdata_params = Dict(
help="Userdata params to get user data login information"
).tag(config=True)
Expand All @@ -70,7 +82,7 @@ class GenericOAuthenticator(OAuthenticator):
userdata_token_method = Unicode(
os.environ.get('OAUTH2_USERDATA_REQUEST_TYPE', 'header'),
config=True,
help="Method for sending access token in userdata request. Supported methods: header, url. Default: header"
help="Method for sending access token in userdata request. Supported methods: header, url. Default: header"
)

tls_verify = Bool(
Expand Down Expand Up @@ -156,12 +168,16 @@ async def authenticate(self, handler, data=None):
resp = await http_client.fetch(req)
resp_json = json.loads(resp.body.decode('utf8', 'replace'))

if not resp_json.get(self.username_key):
self.log.error("OAuth user contains no key %s: %s", self.username_key, resp_json)
return
if callable(self.username_key):
name = self.username_key(resp_json)
else:
name = resp_json.get(self.username_key)
if not name:
self.log.error("OAuth user contains no key %s: %s", self.username_key, resp_json)
return

return {
'name': resp_json.get(self.username_key),
'name': name,
'auth_state': {
'access_token': access_token,
'refresh_token': refresh_token,
Expand Down
23 changes: 19 additions & 4 deletions oauthenticator/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@
from .mocks import setup_oauth_mock


def user_model(username):
def user_model(username, **kwargs):
"""Return a user model"""
return {
user = {
'username': username,
'scope': 'basic',
}
user.update(kwargs)
return user

def Authenticator():
def Authenticator(**kwargs):
return GenericOAuthenticator(
token_url='https://generic.horse/oauth/access_token',
userdata_url='https://generic.horse/oauth/userinfo'
userdata_url='https://generic.horse/oauth/userinfo',
**kwargs
)

@fixture
def generic_client(client):
setup_oauth_mock(client,
Expand All @@ -39,3 +43,14 @@ async def test_generic(generic_client):
assert 'oauth_user' in auth_state
assert 'refresh_token' in auth_state
assert 'scope' in auth_state


async def test_generic_callable_username_key(generic_client):
authenticator = Authenticator(
username_key=lambda r: r['alternate_username']
)
handler = generic_client.handler_for_user(
user_model('wash', alternate_username='zoe')
)
user_info = await authenticator.authenticate(handler)
assert user_info['name'] == 'zoe'
16 changes: 16 additions & 0 deletions oauthenticator/traitlets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from traitlets import TraitType
yuvipanda marked this conversation as resolved.
Show resolved Hide resolved

class Callable(TraitType):
"""
A trait which is callable.
Classes are callable, as are instances
with a __call__() method.
"""

info_text = 'a callable'

def validate(self, obj, value):
if callable(value):
return value
else:
self.error(obj, value)