From d6b269bdc8b87b07236571d994a6138086213f73 Mon Sep 17 00:00:00 2001 From: Min RK Date: Wed, 10 Mar 2021 14:32:21 +0100 Subject: [PATCH] Add `.fetch()` utility on base OAuthenticator consolidates: - httpclient instantiation - error logging - json deserialization reduces duplicate code across implementations --- oauthenticator/auth0.py | 18 +- oauthenticator/azuread.py | 16 +- oauthenticator/bitbucket.py | 26 +-- oauthenticator/cilogon.py | 22 +-- oauthenticator/generic.py | 40 ++--- oauthenticator/github.py | 33 ++-- oauthenticator/gitlab.py | 39 ++--- oauthenticator/globus.py | 25 ++- oauthenticator/google.py | 36 ++-- oauthenticator/mediawiki.py | 11 +- oauthenticator/oauth2.py | 67 +++++++- oauthenticator/okpy.py | 19 +-- oauthenticator/openshift.py | 19 +-- oauthenticator/tests/test_generic.py | 245 ++++++++++++++------------- 14 files changed, 278 insertions(+), 338 deletions(-) diff --git a/oauthenticator/auth0.py b/oauthenticator/auth0.py index 12c982c8..9a619de2 100644 --- a/oauthenticator/auth0.py +++ b/oauthenticator/auth0.py @@ -32,15 +32,11 @@ import json import os -from tornado.auth import OAuth2Mixin -from tornado import web -from tornado.httpclient import HTTPRequest, AsyncHTTPClient - -from traitlets import Unicode, default - from jupyterhub.auth import LocalAuthenticator +from tornado.httpclient import HTTPRequest +from traitlets import Unicode, default -from .oauth2 import OAuthLoginHandler, OAuthenticator +from .oauth2 import OAuthenticator class Auth0OAuthenticator(OAuthenticator): @@ -68,8 +64,6 @@ def _token_url_default(self): async def authenticate(self, handler, data=None): code = handler.get_argument("code") - # TODO: Configure the curl_httpclient for tornado - http_client = AsyncHTTPClient() params = { 'grant_type': 'authorization_code', @@ -87,8 +81,7 @@ async def authenticate(self, handler, data=None): body=json.dumps(params), ) - resp = await http_client.fetch(req) - resp_json = json.loads(resp.body.decode('utf8', 'replace')) + resp_json = await self.fetch(req) access_token = resp_json['access_token'] @@ -106,8 +99,7 @@ async def authenticate(self, handler, data=None): method="GET", headers=headers, ) - resp = await http_client.fetch(req) - resp_json = json.loads(resp.body.decode('utf8', 'replace')) + resp_json = await self.fetch(req) return { 'name': resp_json["email"], diff --git a/oauthenticator/azuread.py b/oauthenticator/azuread.py index 325e2997..247f1b91 100644 --- a/oauthenticator/azuread.py +++ b/oauthenticator/azuread.py @@ -2,20 +2,15 @@ Custom Authenticator to use Azure AD with JupyterHub """ -import json -import jwt import os import urllib -from tornado.auth import OAuth2Mixin -from tornado.log import app_log -from tornado.httpclient import HTTPRequest, AsyncHTTPClient - +import jwt from jupyterhub.auth import LocalAuthenticator - +from tornado.httpclient import HTTPRequest from traitlets import Unicode, default -from .oauth2 import OAuthLoginHandler, OAuthenticator +from .oauth2 import OAuthenticator class AzureAdOAuthenticator(OAuthenticator): @@ -47,7 +42,6 @@ def _token_url_default(self): async def authenticate(self, handler, data=None): code = handler.get_argument("code") - http_client = AsyncHTTPClient() params = dict( client_id=self.client_id, @@ -72,10 +66,8 @@ async def authenticate(self, handler, data=None): body=data # Body is required for a POST... ) - resp = await http_client.fetch(req) - resp_json = json.loads(resp.body.decode('utf8', 'replace')) + resp_json = await self.fetch(req) - # app_log.info("Response %s", resp_json) access_token = resp_json['access_token'] id_token = resp_json['id_token'] diff --git a/oauthenticator/bitbucket.py b/oauthenticator/bitbucket.py index 709d0c3f..aa248f58 100644 --- a/oauthenticator/bitbucket.py +++ b/oauthenticator/bitbucket.py @@ -2,20 +2,14 @@ Custom Authenticator to use Bitbucket OAuth with JupyterHub """ -import json import urllib -from tornado.auth import OAuth2Mixin -from tornado import web - -from tornado.httputil import url_concat -from tornado.httpclient import HTTPRequest, AsyncHTTPClient - from jupyterhub.auth import LocalAuthenticator +from tornado.httpclient import HTTPRequest +from tornado.httputil import url_concat +from traitlets import Set, default -from traitlets import Set, default, observe - -from .oauth2 import OAuthLoginHandler, OAuthenticator +from .oauth2 import OAuthenticator def _api_headers(access_token): @@ -60,8 +54,6 @@ def _token_url_default(self): async def authenticate(self, handler, data=None): code = handler.get_argument("code") - # TODO: Configure the curl_httpclient for tornado - http_client = AsyncHTTPClient() params = dict( client_id=self.client_id, @@ -83,8 +75,7 @@ async def authenticate(self, handler, data=None): headers=bb_header, ) - resp = await http_client.fetch(req) - resp_json = json.loads(resp.body.decode('utf8', 'replace')) + resp_json = await self.fetch(req) access_token = resp_json['access_token'] @@ -94,8 +85,7 @@ async def authenticate(self, handler, data=None): method="GET", headers=_api_headers(access_token), ) - resp = await http_client.fetch(req) - resp_json = json.loads(resp.body.decode('utf8', 'replace')) + resp_json = await self.fetch(req) username = resp_json["username"] @@ -113,7 +103,6 @@ async def authenticate(self, handler, data=None): } async def _check_membership_allowed_teams(self, username, access_token): - http_client = AsyncHTTPClient() headers = _api_headers(access_token) # We verify the team membership by calling teams endpoint. @@ -122,8 +111,7 @@ async def _check_membership_allowed_teams(self, username, access_token): ) while next_page: req = HTTPRequest(next_page, method="GET", headers=headers) - resp = await http_client.fetch(req) - resp_json = json.loads(resp.body.decode('utf8', 'replace')) + resp_json = await self.fetch(req) next_page = resp_json.get('next', None) user_teams = set([entry["username"] for entry in resp_json["values"]]) diff --git a/oauthenticator/cilogon.py b/oauthenticator/cilogon.py index 127e9ecf..2e2562fc 100644 --- a/oauthenticator/cilogon.py +++ b/oauthenticator/cilogon.py @@ -13,20 +13,15 @@ email instead of ePPN as the JupyterHub username. """ -import json import os -from tornado.auth import OAuth2Mixin +from jupyterhub.auth import LocalAuthenticator from tornado import web - +from tornado.httpclient import HTTPRequest from tornado.httputil import url_concat -from tornado.httpclient import HTTPRequest, AsyncHTTPClient - -from traitlets import Unicode, List, Bool, default, validate, observe - -from jupyterhub.auth import LocalAuthenticator +from traitlets import Bool, List, Unicode, default, validate -from .oauth2 import OAuthLoginHandler, OAuthenticator +from .oauth2 import OAuthenticator, OAuthLoginHandler class CILogonLoginHandler(OAuthLoginHandler): @@ -138,8 +133,6 @@ async def authenticate(self, handler, data=None): receive it. """ code = handler.get_argument("code") - # TODO: Configure the curl_httpclient for tornado - http_client = AsyncHTTPClient() # Exchange the OAuth code for a CILogon Access Token # See: http://www.cilogon.org/oidc @@ -157,18 +150,15 @@ async def authenticate(self, handler, data=None): req = HTTPRequest(url, headers=headers, method="POST", body='') - resp = await http_client.fetch(req) - token_response = json.loads(resp.body.decode('utf8', 'replace')) + token_response = await self.fetch(req) access_token = token_response['access_token'] - self.log.info("Access token acquired.") # Determine who the logged in user is params = dict(access_token=access_token) req = HTTPRequest( url_concat("https://%s/oauth2/userinfo" % self.cilogon_host, params), headers=headers, ) - resp = await http_client.fetch(req) - resp_json = json.loads(resp.body.decode('utf8', 'replace')) + resp_json = await self.fetch(req) claimlist = [self.username_claim] if self.additional_username_claims: diff --git a/oauthenticator/generic.py b/oauthenticator/generic.py index c0fc89c3..3310d28b 100644 --- a/oauthenticator/generic.py +++ b/oauthenticator/generic.py @@ -2,20 +2,17 @@ Custom Authenticator to use generic OAuth2 with JupyterHub """ -import json -import os import base64 -import urllib - -from tornado.httputil import url_concat -from tornado.httpclient import HTTPRequest, AsyncHTTPClient +import os +from urllib.parse import urlencode from jupyterhub.auth import LocalAuthenticator - -from traitlets import Unicode, Dict, Bool, Union, List -from .traitlets import Callable +from tornado.httpclient import AsyncHTTPClient, HTTPRequest +from tornado.httputil import url_concat +from traitlets import Bool, Dict, List, Unicode, Union, default from .oauth2 import OAuthenticator +from .traitlets import Callable class GenericOAuthenticator(OAuthenticator): @@ -88,7 +85,8 @@ class GenericOAuthenticator(OAuthenticator): help="Disable basic authentication for access token request", ) - def http_client(self): + @default("http_client") + def _default_http_client(self): return AsyncHTTPClient(force_instance=True, defaults=dict(validate_cert=self.tls_verify)) def _get_headers(self): @@ -101,7 +99,7 @@ def _get_headers(self): headers.update({"Authorization": "Basic {}".format(b64key.decode("utf8"))}) return headers - async def _get_token(self, http_client, headers, params): + def _get_token(self, headers, params): if self.token_url: url = self.token_url else: @@ -111,13 +109,11 @@ async def _get_token(self, http_client, headers, params): url, method="POST", headers=headers, - body=urllib.parse.urlencode(params), + body=urlencode(params), ) + return self.fetch(req, "fetching access token") - resp = await http_client.fetch(req) - return json.loads(resp.body.decode('utf8', 'replace')) - - async def _get_user_data(self, http_client, token_response): + def _get_user_data(self, token_response): access_token = token_response['access_token'] token_type = token_response['token_type'] @@ -140,9 +136,7 @@ async def _get_user_data(self, http_client, token_response): method=self.userdata_method, headers=headers, ) - resp = await http_client.fetch(req) - - return json.loads(resp.body.decode('utf8', 'replace')) + return self.fetch(req, "fetching user data") @staticmethod def _create_auth_state(token_response, user_data_response): @@ -165,8 +159,6 @@ def check_user_in_groups(member_groups, allowed_groups): async def authenticate(self, handler, data=None): code = handler.get_argument("code") - # TODO: Configure the curl_httpclient for tornado - http_client = self.http_client() params = dict( redirect_uri=self.get_callback_url(handler), @@ -177,9 +169,9 @@ async def authenticate(self, handler, data=None): headers = self._get_headers() - token_resp_json = await self._get_token(http_client, headers, params) + token_resp_json = await self._get_token(headers, params) - user_data_resp_json = await self._get_user_data(http_client, token_resp_json) + user_data_resp_json = await self._get_user_data(token_resp_json) if callable(self.username_key): name = self.username_key(user_data_resp_json) @@ -197,7 +189,6 @@ async def authenticate(self, handler, data=None): } if self.allowed_groups: - self.log.info('Validating if user claim groups match any of {}'.format(self.allowed_groups)) if callable(self.claim_groups_key): @@ -223,5 +214,4 @@ async def authenticate(self, handler, data=None): class LocalGenericOAuthenticator(LocalAuthenticator, GenericOAuthenticator): """A version that mixes in local system user creation""" - pass diff --git a/oauthenticator/github.py b/oauthenticator/github.py index 8708329a..93bba2de 100644 --- a/oauthenticator/github.py +++ b/oauthenticator/github.py @@ -5,22 +5,15 @@ import json import os -import re -import string import warnings -from tornado.auth import OAuth2Mixin +from jupyterhub.auth import LocalAuthenticator from tornado import web - +from tornado.httpclient import HTTPRequest from tornado.httputil import url_concat -from tornado.httpclient import HTTPRequest, AsyncHTTPClient, HTTPError - -from jupyterhub.auth import LocalAuthenticator - -from traitlets import List, Set, Unicode, default, observe +from traitlets import Set, Unicode, default -from .common import next_page_from_links -from .oauth2 import OAuthLoginHandler, OAuthenticator +from .oauth2 import OAuthenticator def _api_headers(access_token): @@ -124,8 +117,6 @@ async def authenticate(self, handler, data=None): receive it. """ code = handler.get_argument("code") - # TODO: Configure the curl_httpclient for tornado - http_client = AsyncHTTPClient() # Exchange the OAuth code for a GitHub Access Token # @@ -146,30 +137,28 @@ async def authenticate(self, handler, data=None): validate_cert=self.validate_server_cert, ) - resp = await http_client.fetch(req) - resp_json = json.loads(resp.body.decode('utf8', 'replace')) + resp_json = await self.fetch(req) if 'access_token' in resp_json: access_token = resp_json['access_token'] elif 'error_description' in resp_json: - raise HTTPError( + raise web.HTTPError( 403, "An access token was not returned: {}".format( resp_json['error_description'] ), ) else: - raise HTTPError(500, "Bad response: {}".format(resp)) + raise web.HTTPError(500, "Bad response: {}".format(resp_json)) - # Determine who the logged in user is + # Determine who the logged-in user is req = HTTPRequest( self.github_api + "/user", method="GET", headers=_api_headers(access_token), validate_cert=self.validate_server_cert, ) - resp = await http_client.fetch(req) - resp_json = json.loads(resp.body.decode('utf8', 'replace')) + resp_json = await self.fetch(req, "fetching user info") username = resp_json["login"] # username is now the GitHub userid. @@ -205,7 +194,6 @@ async def authenticate(self, handler, data=None): return userdict async def _check_membership_allowed_organizations(self, org, username, access_token): - http_client = AsyncHTTPClient() headers = _api_headers(access_token) # Check membership of user `username` for organization `org` via api [check-membership](https://developer.github.com/v3/orgs/members/#check-membership) # With empty scope (even if authenticated by an org member), this @@ -225,8 +213,7 @@ async def _check_membership_allowed_organizations(self, org, username, access_to self.log.debug( "Checking GitHub organization membership: %s in %s?", username, org ) - resp = await http_client.fetch(req, raise_error=False) - print(resp) + resp = await self.fetch(req, raise_error=False, parse_json=False) if resp.code == 204: self.log.info("Allowing %s as member of %s", username, org) return True diff --git a/oauthenticator/gitlab.py b/oauthenticator/gitlab.py index cf8fd44e..85818c09 100644 --- a/oauthenticator/gitlab.py +++ b/oauthenticator/gitlab.py @@ -5,23 +5,16 @@ import json import os -import re -import sys import warnings from urllib.parse import quote -from tornado.auth import OAuth2Mixin -from tornado import web - +from jupyterhub.auth import LocalAuthenticator from tornado.escape import url_escape +from tornado.httpclient import HTTPRequest from tornado.httputil import url_concat -from tornado.httpclient import HTTPRequest, AsyncHTTPClient - -from jupyterhub.auth import LocalAuthenticator +from traitlets import CUnicode, Set, Unicode, default -from traitlets import Set, CUnicode, Unicode, default, observe - -from .oauth2 import OAuthLoginHandler, OAuthenticator +from .oauth2 import OAuthenticator def _api_headers(access_token): @@ -116,8 +109,6 @@ def _token_url_default(self): async def authenticate(self, handler, data=None): code = handler.get_argument("code") - # TODO: Configure the curl_httpclient for tornado - http_client = AsyncHTTPClient() # Exchange the OAuth code for a GitLab Access Token # @@ -144,9 +135,7 @@ async def authenticate(self, handler, data=None): body='', # Body is required for a POST... ) - resp = await http_client.fetch(req) - resp_json = json.loads(resp.body.decode('utf8', 'replace')) - + resp_json = await self.fetch(req, label="getting access token") access_token = resp_json['access_token'] # memoize gitlab version for class lifetime @@ -161,8 +150,7 @@ async def authenticate(self, handler, data=None): validate_cert=validate_server_cert, headers=_api_headers(access_token), ) - resp = await http_client.fetch(req) - resp_json = json.loads(resp.body.decode('utf8', 'replace')) + resp_json = await self.fetch(req, label="getting gitlab user") username = resp_json["username"] user_id = resp_json["id"] @@ -207,14 +195,12 @@ async def _get_gitlab_version(self, access_token): headers=_api_headers(access_token), validate_cert=self.validate_server_cert, ) - resp = await AsyncHTTPClient().fetch(req, raise_error=True) - resp_json = json.loads(resp.body.decode('utf8', 'replace')) + resp_json = await self.fetch(req) version_strings = resp_json['version'].split('-')[0].split('.')[:3] version_ints = list(map(int, version_strings)) return version_ints async def _check_membership_allowed_groups(self, user_id, access_token): - http_client = AsyncHTTPClient() headers = _api_headers(access_token) # Check if user is a member of any group in the allowed list for group in map(url_escape, self.allowed_gitlab_groups): @@ -225,13 +211,12 @@ async def _check_membership_allowed_groups(self, user_id, access_token): user_id, ) req = HTTPRequest(url, method="GET", headers=headers) - resp = await http_client.fetch(req, raise_error=False) + resp = await self.fetch(req, raise_error=False, parse_json=False) if resp.code == 200: return True # user _is_ in group return False async def _check_membership_allowed_project_ids(self, user_id, access_token): - http_client = AsyncHTTPClient() headers = _api_headers(access_token) # Check if user has developer access to any project in the allowed list for project in self.allowed_project_ids: @@ -242,15 +227,13 @@ async def _check_membership_allowed_project_ids(self, user_id, access_token): user_id, ) req = HTTPRequest(url, method="GET", headers=headers) - resp = await http_client.fetch(req, raise_error=False) - - if resp.body: - resp_json = json.loads(resp.body.decode('utf8', 'replace')) + resp_json = await self.fetch(req, raise_error=False) + if resp_json: access_level = resp_json.get('access_level', 0) # We only allow access level Developer and above # Reference: https://docs.gitlab.com/ee/api/members.html - if resp.code == 200 and access_level >= 30: + if access_level >= 30: return True return False diff --git a/oauthenticator/globus.py b/oauthenticator/globus.py index 31970c24..51b96df1 100644 --- a/oauthenticator/globus.py +++ b/oauthenticator/globus.py @@ -1,20 +1,17 @@ """ Custom Authenticator to use Globus OAuth2 with JupyterHub """ +import base64 import os import pickle -import json -import base64 import urllib -from tornado.web import HTTPError -from tornado.httpclient import HTTPRequest, AsyncHTTPClient - -from traitlets import List, Unicode, Bool, default - +from jupyterhub.auth import LocalAuthenticator from jupyterhub.handlers import LogoutHandler from jupyterhub.utils import url_path_join -from jupyterhub.auth import LocalAuthenticator +from tornado.httpclient import HTTPRequest +from tornado.web import HTTPError +from traitlets import Bool, List, Unicode, default from .oauth2 import OAuthenticator @@ -153,7 +150,6 @@ async def authenticate(self, handler, data=None): will have the 'foouser' account in Jupyterhub. """ # Complete login and exchange the code for tokens. - http_client = AsyncHTTPClient() params = dict( redirect_uri=self.get_callback_url(handler), code=handler.get_argument("code"), @@ -163,15 +159,14 @@ async def authenticate(self, handler, data=None): headers=self.get_client_credential_headers(), body=urllib.parse.urlencode(params), ) - token_response = await http_client.fetch(req) - token_json = json.loads(token_response.body.decode('utf8', 'replace')) + token_json = await self.fetch(req) # Fetch user info at Globus's oauth2/userinfo/ HTTP endpoint to get the username user_headers = self.get_default_headers() user_headers['Authorization'] = 'Bearer {}'.format(token_json['access_token']) req = HTTPRequest(self.userdata_url, method='GET', headers=user_headers) - user_resp = await http_client.fetch(req) - username = self.get_username(json.loads(user_resp.body.decode('utf8', 'replace'))) + user_resp = await self.fetch(req) + username = self.get_username(user_resp) # Each token should have these attributes. Resource server is optional, # and likely won't be present. @@ -241,14 +236,14 @@ async def revoke_service_tokens(self, services): access_tokens = [token_dict.get('access_token') for token_dict in services.values()] refresh_tokens = [token_dict.get('refresh_token') for token_dict in services.values()] all_tokens = [tok for tok in access_tokens + refresh_tokens if tok is not None] - http_client = AsyncHTTPClient() + for token in all_tokens: req = HTTPRequest(self.revocation_url, method="POST", headers=self.get_client_credential_headers(), body=urllib.parse.urlencode({'token': token}), ) - await http_client.fetch(req) + await self.fetch(req) def logout_url(self, base_url): return url_path_join(base_url, 'logout') diff --git a/oauthenticator/google.py b/oauthenticator/google.py index 8631070c..4b3a8cbf 100644 --- a/oauthenticator/google.py +++ b/oauthenticator/google.py @@ -5,21 +5,18 @@ """ import os -import json import urllib.parse -from tornado import gen -from tornado.httpclient import HTTPRequest, AsyncHTTPClient +from jupyterhub.auth import LocalAuthenticator +from jupyterhub.crypto import EncryptionUnavailable, InvalidToken, decrypt from tornado.auth import GoogleOAuth2Mixin +from tornado.httpclient import HTTPRequest +from tornado.httputil import url_concat from tornado.web import HTTPError +from traitlets import Dict, List, Unicode, default, validate -from traitlets import Dict, Unicode, List, default, validate, observe - -from jupyterhub.crypto import decrypt, EncryptionUnavailable, InvalidToken -from jupyterhub.auth import LocalAuthenticator -from jupyterhub.utils import url_path_join +from .oauth2 import OAuthenticator -from .oauth2 import OAuthLoginHandler, OAuthCallbackHandler, OAuthenticator def check_user_in_groups(member_groups, allowed_groups): # Check if user is a member of any group in the allowed groups @@ -132,28 +129,23 @@ async def authenticate(self, handler, data=None, google_groups=None): ) ) - http_client = AsyncHTTPClient() - - response = await http_client.fetch( + req = HTTPRequest( self.token_url, method="POST", headers={"Content-Type": "application/x-www-form-urlencoded"}, body=body, ) - - user = json.loads(response.body.decode("utf-8", "replace")) + user = await self.fetch(req, "completing oauth") access_token = str(user['access_token']) refresh_token = user.get('refresh_token', None) - response = await http_client.fetch( - self.user_info_url + '?access_token=' + access_token + req = HTTPRequest( + url_concat( + self.user_info_url, + {'access_token': access_token}, + ) ) - - if not response: - handler.clear_all_cookies() - raise HTTPError(500, 'Google authentication failed') - - bodyjs = json.loads(response.body.decode()) + bodyjs = await self.fetch(req, "fetching user info") user_email = username = bodyjs['email'] user_email_domain = user_email.split('@')[1] diff --git a/oauthenticator/mediawiki.py b/oauthenticator/mediawiki.py index d7728d91..a71548c2 100644 --- a/oauthenticator/mediawiki.py +++ b/oauthenticator/mediawiki.py @@ -4,24 +4,18 @@ Requires `mwoauth` package. """ -import os import json +import os from asyncio import wrap_future from concurrent.futures import ThreadPoolExecutor -from tornado import web - from jupyterhub.handlers import BaseHandler from jupyterhub.utils import url_path_join -from jupyterhub import orm - from mwoauth import ConsumerToken, Handshaker from mwoauth.tokens import RequestToken - from traitlets import Any, Integer, Unicode -from oauthenticator import OAuthenticator, OAuthCallbackHandler - +from oauthenticator import OAuthCallbackHandler, OAuthenticator # Name of cookie used to pass auth token between the oauth # login and authentication phase @@ -141,4 +135,3 @@ async def authenticate(self, handler, data=None): } else: self.log.error("No username found in %s", identity) - diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 40b2087c..35338574 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -7,18 +7,17 @@ import base64 import json import os -from urllib.parse import quote, urlparse import uuid +from urllib.parse import quote, urlparse, urlunparse +from jupyterhub.auth import Authenticator +from jupyterhub.handlers import BaseHandler +from jupyterhub.utils import url_path_join from tornado import web from tornado.auth import OAuth2Mixin +from tornado.httpclient import AsyncHTTPClient, HTTPClientError from tornado.log import app_log - -from jupyterhub.handlers import BaseHandler -from jupyterhub.auth import Authenticator -from jupyterhub.utils import url_path_join - -from traitlets import Unicode, Bool, List, Dict, default, observe +from traitlets import Any, Bool, Dict, List, Unicode, default def guess_callback_uri(protocol, host, hub_server_url): @@ -316,10 +315,62 @@ def _validate_server_cert_default(self): else: return True + http_client = Any() + + @default("http_client") + def _default_http_client(self): + return AsyncHTTPClient() + + async def fetch(self, req, label="fetching", parse_json=True, **kwargs): + """Wrapper for http requests + + logs error responses, parses successful JSON responses + + Args: + req: tornado HTTPRequest + label (str): label describing what is happening, + used in log message when the request fails. + **kwargs: remaining keyword args + passed to underlying `client.fetch(req, **kwargs)` + Returns: + r: parsed JSON response + """ + try: + resp = await self.http_client.fetch(req, **kwargs) + except HTTPClientError as e: + if e.response: + # Log failed response message for debugging purposes + message = e.response.body.decode("utf8", "replace") + try: + # guess json, reformat for readability + json_message = json.loads(message) + except ValueError: + # not json + pass + else: + # reformat json log message for readability + message = json.dumps(json_message, sort_keys=True, indent=1) + else: + # didn't get a response, e.g. connection error + message = str(e) + + # log url without query params + url = urlunparse(urlparse(req.url)._replace(query="")) + app_log.error(f"Error {label} {e.code} {req.method} {url}: {message}") + raise e + else: + if parse_json: + if resp.body: + return json.loads(resp.body.decode('utf8', 'replace')) + else: + # empty body is None + return None + else: + return resp + def login_url(self, base_url): return url_path_join(base_url, 'oauth_login') - def get_callback_url(self, handler=None): """Get my OAuth redirect URL diff --git a/oauthenticator/okpy.py b/oauthenticator/okpy.py index 0127d8b2..eb2d3da1 100644 --- a/oauthenticator/okpy.py +++ b/oauthenticator/okpy.py @@ -1,17 +1,16 @@ """ Custom Authenticator to use okpy OAuth with JupyterHub """ -import json + from binascii import a2b_base64 -from tornado.auth import OAuth2Mixin +from jupyterhub.auth import LocalAuthenticator from tornado import web -from tornado.httpclient import HTTPRequest, AsyncHTTPClient +from tornado.auth import OAuth2Mixin +from tornado.httpclient import HTTPRequest from tornado.httputil import url_concat from traitlets import default -from jupyterhub.auth import LocalAuthenticator - from .oauth2 import OAuthenticator @@ -71,17 +70,13 @@ async def authenticate(self, handler, data=None): code = handler.get_argument("code", False) if not code: raise web.HTTPError(400, "Authentication Cancelled.") - http_client = AsyncHTTPClient() auth_request = self.get_auth_request(code) - response = await http_client.fetch(auth_request) - if not response: + state = await self.fetch(auth_request) + if not state: raise web.HTTPError(500, 'Authentication Failed: Token Not Acquired') - state = json.loads(response.body.decode('utf8', 'replace')) access_token = state['access_token'] info_request = self.get_user_info_request(access_token) - response = await http_client.fetch(info_request) - user = json.loads(response.body.decode('utf8', 'replace')) - # TODO: preserve state in auth_state when JupyterHub supports encrypted auth_state + user = await self.fetch(info_request) return { 'name': user['email'], 'auth_state': {'access_token': access_token, 'okpy_user': user}, diff --git a/oauthenticator/openshift.py b/oauthenticator/openshift.py index a8397589..6bb0479b 100644 --- a/oauthenticator/openshift.py +++ b/oauthenticator/openshift.py @@ -4,19 +4,15 @@ Derived from the GitHub OAuth authenticator. """ - -import json import os import requests from jupyterhub.auth import LocalAuthenticator -from tornado import web -from tornado.auth import OAuth2Mixin -from tornado.httpclient import AsyncHTTPClient, HTTPClient, HTTPRequest +from tornado.httpclient import HTTPRequest from tornado.httputil import url_concat +from traitlets import Bool, Set, Unicode, default from oauthenticator.oauth2 import OAuthenticator -from traitlets import Bool, Set, Unicode, default class OpenShiftOAuthenticator(OAuthenticator): @@ -94,9 +90,6 @@ def user_in_groups(user_groups: set, allowed_groups: set): async def authenticate(self, handler, data=None): code = handler.get_argument("code") - # TODO: Configure the curl_httpclient for tornado - - http_client = AsyncHTTPClient() # Exchange the OAuth code for a OpenShift Access Token # @@ -120,10 +113,7 @@ async def authenticate(self, handler, data=None): body='', # Body is required for a POST... ) - resp = await http_client.fetch(req) - - resp_json = json.loads(resp.body.decode('utf8', 'replace')) - + resp_json = await self.fetch(req) access_token = resp_json['access_token'] # Determine who the logged in user is @@ -141,8 +131,7 @@ async def authenticate(self, handler, data=None): headers=headers, ) - resp = await http_client.fetch(req) - ocp_user = json.loads(resp.body.decode('utf8', 'replace')) + ocp_user = await self.fetch(req) username = ocp_user['metadata']['name'] diff --git a/oauthenticator/tests/test_generic.py b/oauthenticator/tests/test_generic.py index 99a0ce54..4116d214 100644 --- a/oauthenticator/tests/test_generic.py +++ b/oauthenticator/tests/test_generic.py @@ -1,9 +1,9 @@ +from functools import partial + from pytest import fixture from ..generic import GenericOAuthenticator - from .mocks import setup_oauth_mock -from unittest import mock def user_model(username, **kwargs): @@ -16,7 +16,7 @@ def user_model(username, **kwargs): return user -def get_authenticator(**kwargs): +def _get_authenticator(**kwargs): return GenericOAuthenticator( token_url='https://generic.horse/oauth/access_token', userdata_url='https://generic.horse/oauth/userinfo', @@ -39,128 +39,131 @@ def generic_client(client): return client -async def test_generic(generic_client): - with mock.patch.object(GenericOAuthenticator, 'http_client') as fake_client: - fake_client.return_value = generic_client - authenticator = get_authenticator() - - handler = get_simple_handler(generic_client) - user_info = await authenticator.authenticate(handler) - assert sorted(user_info) == ['auth_state', 'name'] - name = user_info['name'] - assert name == 'wash' - auth_state = user_info['auth_state'] - assert 'access_token' in auth_state - assert 'oauth_user' in auth_state - assert 'refresh_token' in auth_state - assert 'scope' in auth_state - - -async def test_generic_callable_username_key(generic_client): - with mock.patch.object(GenericOAuthenticator, 'http_client') as fake_client: - fake_client.return_value = generic_client - authenticator = get_authenticator( - username_key=lambda r: r['alternate_username'] - ) - handler = generic_client.handler_for_user( - user_model('wash', alternate_username='zoe') - ) - user_info = await authenticator.authenticate(handler) - assert user_info['name'] == 'zoe' +@fixture +def get_authenticator(generic_client, **kwargs): + return partial(_get_authenticator, http_client=generic_client) -async def test_generic_callable_groups_claim_key_with_allowed_groups(generic_client): - with mock.patch.object(GenericOAuthenticator, 'http_client') as fake_client: - fake_client.return_value = generic_client - authenticator = get_authenticator( - scope=['openid', 'profile', 'roles'], - claim_groups_key=lambda r: r.get('policies').get('roles'), - allowed_groups=['super_user'] - ) - handler = generic_client.handler_for_user( - user_model('wash', alternate_username='zoe', policies={'roles': ['super_user']}) - ) - user_info = await authenticator.authenticate(handler) - assert user_info['name'] == 'wash' +async def test_generic(get_authenticator, generic_client): + authenticator = get_authenticator() + handler = get_simple_handler(generic_client) + user_info = await authenticator.authenticate(handler) + assert sorted(user_info) == ['auth_state', 'name'] + name = user_info['name'] + assert name == 'wash' + auth_state = user_info['auth_state'] + assert 'access_token' in auth_state + assert 'oauth_user' in auth_state + assert 'refresh_token' in auth_state + assert 'scope' in auth_state -async def test_generic_groups_claim_key_with_allowed_groups(generic_client): - with mock.patch.object(GenericOAuthenticator, 'http_client') as fake_client: - fake_client.return_value = generic_client - authenticator = get_authenticator( - scope=['openid', 'profile', 'roles'], - claim_groups_key='groups', - allowed_groups=['super_user'] - ) - handler = generic_client.handler_for_user( - user_model('wash', alternate_username='zoe', groups=['super_user']) - ) - user_info = await authenticator.authenticate(handler) - assert user_info['name'] == 'wash' +async def test_generic_callable_username_key(get_authenticator, generic_client): + authenticator = get_authenticator(username_key=lambda r: r['alternate_username']) + handler = generic_client.handler_for_user( + user_model('wash', alternate_username='zoe') + ) + user_info = await authenticator.authenticate(handler) + assert user_info['name'] == 'zoe' -async def test_generic_groups_claim_key_with_allowed_groups_unauthorized(generic_client): - with mock.patch.object(GenericOAuthenticator, 'http_client') as fake_client: - fake_client.return_value = generic_client - authenticator = get_authenticator( - scope=['openid', 'profile', 'roles'], - claim_groups_key='groups', - allowed_groups=['user'] - ) - handler = generic_client.handler_for_user( - user_model('wash', alternate_username='zoe', groups=['public']) - ) - user_info = await authenticator.authenticate(handler) - assert user_info is None - - -async def test_generic_groups_claim_key_with_allowed_groups_and_admin_groups(generic_client): - with mock.patch.object(GenericOAuthenticator, 'http_client') as fake_client: - fake_client.return_value = generic_client - authenticator = get_authenticator( - scope=['openid', 'profile', 'roles'], - claim_groups_key='groups', - allowed_groups=['user'], - admin_groups=['administrator'], - ) - handler = generic_client.handler_for_user( - user_model('wash', alternate_username='zoe', groups=['user', 'administrator']) - ) - user_info = await authenticator.authenticate(handler) - assert user_info['name'] == 'wash' - assert user_info['admin'] is True - - -async def test_generic_groups_claim_key_with_allowed_groups_and_admin_groups_not_admin(generic_client): - with mock.patch.object(GenericOAuthenticator, 'http_client') as fake_client: - fake_client.return_value = generic_client - authenticator = get_authenticator( - scope=['openid', 'profile', 'roles'], - claim_groups_key='groups', - allowed_groups=['user'], - admin_groups=['administrator'], - ) - handler = generic_client.handler_for_user( - user_model('wash', alternate_username='zoe', groups=['user']) - ) - user_info = await authenticator.authenticate(handler) - assert user_info['name'] == 'wash' - assert user_info['admin'] is False - - -async def test_generic_callable_groups_claim_key_with_allowed_groups_and_admin_groups(generic_client): - with mock.patch.object(GenericOAuthenticator, 'http_client') as fake_client: - fake_client.return_value = generic_client - authenticator = get_authenticator( - username_key=lambda r: r['alternate_username'], - scope=['openid', 'profile', 'roles'], - claim_groups_key=lambda r: r.get('policies').get('roles'), - allowed_groups=['user', 'public'], - admin_groups=['administrator'], - ) - handler = generic_client.handler_for_user( - user_model('wash', alternate_username='zoe', policies={'roles': ['user', 'administrator']}) + +async def test_generic_callable_groups_claim_key_with_allowed_groups( + get_authenticator, generic_client +): + authenticator = get_authenticator( + scope=['openid', 'profile', 'roles'], + claim_groups_key=lambda r: r.get('policies').get('roles'), + allowed_groups=['super_user'], + ) + handler = generic_client.handler_for_user( + user_model('wash', alternate_username='zoe', policies={'roles': ['super_user']}) + ) + user_info = await authenticator.authenticate(handler) + assert user_info['name'] == 'wash' + + +async def test_generic_groups_claim_key_with_allowed_groups( + get_authenticator, generic_client +): + authenticator = get_authenticator( + scope=['openid', 'profile', 'roles'], + claim_groups_key='groups', + allowed_groups=['super_user'], + ) + handler = generic_client.handler_for_user( + user_model('wash', alternate_username='zoe', groups=['super_user']) + ) + user_info = await authenticator.authenticate(handler) + assert user_info['name'] == 'wash' + + +async def test_generic_groups_claim_key_with_allowed_groups_unauthorized( + get_authenticator, generic_client +): + authenticator = get_authenticator( + scope=['openid', 'profile', 'roles'], + claim_groups_key='groups', + allowed_groups=['user'], + ) + handler = generic_client.handler_for_user( + user_model('wash', alternate_username='zoe', groups=['public']) + ) + user_info = await authenticator.authenticate(handler) + assert user_info is None + + +async def test_generic_groups_claim_key_with_allowed_groups_and_admin_groups( + get_authenticator, generic_client +): + authenticator = get_authenticator( + scope=['openid', 'profile', 'roles'], + claim_groups_key='groups', + allowed_groups=['user'], + admin_groups=['administrator'], + ) + handler = generic_client.handler_for_user( + user_model('wash', alternate_username='zoe', groups=['user', 'administrator']) + ) + user_info = await authenticator.authenticate(handler) + assert user_info['name'] == 'wash' + assert user_info['admin'] is True + + +async def test_generic_groups_claim_key_with_allowed_groups_and_admin_groups_not_admin( + get_authenticator, generic_client +): + authenticator = get_authenticator( + scope=['openid', 'profile', 'roles'], + claim_groups_key='groups', + allowed_groups=['user'], + admin_groups=['administrator'], + ) + handler = generic_client.handler_for_user( + user_model('wash', alternate_username='zoe', groups=['user']) + ) + user_info = await authenticator.authenticate(handler) + assert user_info['name'] == 'wash' + assert user_info['admin'] is False + + +async def test_generic_callable_groups_claim_key_with_allowed_groups_and_admin_groups( + get_authenticator, generic_client +): + authenticator = get_authenticator( + username_key=lambda r: r['alternate_username'], + scope=['openid', 'profile', 'roles'], + claim_groups_key=lambda r: r.get('policies').get('roles'), + allowed_groups=['user', 'public'], + admin_groups=['administrator'], + ) + handler = generic_client.handler_for_user( + user_model( + 'wash', + alternate_username='zoe', + policies={'roles': ['user', 'administrator']}, ) - user_info = await authenticator.authenticate(handler) - assert user_info['name'] == 'zoe' - assert user_info['admin'] is True + ) + user_info = await authenticator.authenticate(handler) + assert user_info['name'] == 'zoe' + assert user_info['admin'] is True