Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add .fetch(req) method to base OAuthenticator #415

Merged
merged 1 commit into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions oauthenticator/auth0.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,11 @@
import json
import os

from tornado.auth import OAuth2Mixin
from tornado import web
from tornado.httpclient import HTTPRequest, AsyncHTTPClient

from traitlets import Unicode, default

from jupyterhub.auth import LocalAuthenticator
from tornado.httpclient import HTTPRequest
from traitlets import Unicode, default

from .oauth2 import OAuthLoginHandler, OAuthenticator
from .oauth2 import OAuthenticator


class Auth0OAuthenticator(OAuthenticator):
Expand Down Expand Up @@ -68,8 +64,6 @@ def _token_url_default(self):

async def authenticate(self, handler, data=None):
code = handler.get_argument("code")
# TODO: Configure the curl_httpclient for tornado
http_client = AsyncHTTPClient()

params = {
'grant_type': 'authorization_code',
Expand All @@ -87,8 +81,7 @@ async def authenticate(self, handler, data=None):
body=json.dumps(params),
)

resp = await http_client.fetch(req)
resp_json = json.loads(resp.body.decode('utf8', 'replace'))
resp_json = await self.fetch(req)

access_token = resp_json['access_token']

Expand All @@ -106,8 +99,7 @@ async def authenticate(self, handler, data=None):
method="GET",
headers=headers,
)
resp = await http_client.fetch(req)
resp_json = json.loads(resp.body.decode('utf8', 'replace'))
resp_json = await self.fetch(req)

return {
'name': resp_json["email"],
Expand Down
16 changes: 4 additions & 12 deletions oauthenticator/azuread.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,15 @@
Custom Authenticator to use Azure AD with JupyterHub
"""

import json
import jwt
import os
import urllib

from tornado.auth import OAuth2Mixin
from tornado.log import app_log
from tornado.httpclient import HTTPRequest, AsyncHTTPClient

import jwt
from jupyterhub.auth import LocalAuthenticator

from tornado.httpclient import HTTPRequest
from traitlets import Unicode, default

from .oauth2 import OAuthLoginHandler, OAuthenticator
from .oauth2 import OAuthenticator


class AzureAdOAuthenticator(OAuthenticator):
Expand Down Expand Up @@ -47,7 +42,6 @@ def _token_url_default(self):

async def authenticate(self, handler, data=None):
code = handler.get_argument("code")
http_client = AsyncHTTPClient()

params = dict(
client_id=self.client_id,
Expand All @@ -72,10 +66,8 @@ async def authenticate(self, handler, data=None):
body=data # Body is required for a POST...
)

resp = await http_client.fetch(req)
resp_json = json.loads(resp.body.decode('utf8', 'replace'))
resp_json = await self.fetch(req)

# app_log.info("Response %s", resp_json)
access_token = resp_json['access_token']

id_token = resp_json['id_token']
Expand Down
26 changes: 7 additions & 19 deletions oauthenticator/bitbucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,14 @@
Custom Authenticator to use Bitbucket OAuth with JupyterHub
"""

import json
import urllib

from tornado.auth import OAuth2Mixin
from tornado import web

from tornado.httputil import url_concat
from tornado.httpclient import HTTPRequest, AsyncHTTPClient

from jupyterhub.auth import LocalAuthenticator
from tornado.httpclient import HTTPRequest
from tornado.httputil import url_concat
from traitlets import Set, default

from traitlets import Set, default, observe

from .oauth2 import OAuthLoginHandler, OAuthenticator
from .oauth2 import OAuthenticator


def _api_headers(access_token):
Expand Down Expand Up @@ -60,8 +54,6 @@ def _token_url_default(self):

async def authenticate(self, handler, data=None):
code = handler.get_argument("code")
# TODO: Configure the curl_httpclient for tornado
http_client = AsyncHTTPClient()

params = dict(
client_id=self.client_id,
Expand All @@ -83,8 +75,7 @@ async def authenticate(self, handler, data=None):
headers=bb_header,
)

resp = await http_client.fetch(req)
resp_json = json.loads(resp.body.decode('utf8', 'replace'))
resp_json = await self.fetch(req)

access_token = resp_json['access_token']

Expand All @@ -94,8 +85,7 @@ async def authenticate(self, handler, data=None):
method="GET",
headers=_api_headers(access_token),
)
resp = await http_client.fetch(req)
resp_json = json.loads(resp.body.decode('utf8', 'replace'))
resp_json = await self.fetch(req)

username = resp_json["username"]

Expand All @@ -113,7 +103,6 @@ async def authenticate(self, handler, data=None):
}

async def _check_membership_allowed_teams(self, username, access_token):
http_client = AsyncHTTPClient()

headers = _api_headers(access_token)
# We verify the team membership by calling teams endpoint.
Expand All @@ -122,8 +111,7 @@ async def _check_membership_allowed_teams(self, username, access_token):
)
while next_page:
req = HTTPRequest(next_page, method="GET", headers=headers)
resp = await http_client.fetch(req)
resp_json = json.loads(resp.body.decode('utf8', 'replace'))
resp_json = await self.fetch(req)
next_page = resp_json.get('next', None)

user_teams = set([entry["username"] for entry in resp_json["values"]])
Expand Down
22 changes: 6 additions & 16 deletions oauthenticator/cilogon.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,15 @@
email instead of ePPN as the JupyterHub username.
"""

import json
import os

from tornado.auth import OAuth2Mixin
from jupyterhub.auth import LocalAuthenticator
from tornado import web

from tornado.httpclient import HTTPRequest
from tornado.httputil import url_concat
from tornado.httpclient import HTTPRequest, AsyncHTTPClient

from traitlets import Unicode, List, Bool, default, validate, observe

from jupyterhub.auth import LocalAuthenticator
from traitlets import Bool, List, Unicode, default, validate

from .oauth2 import OAuthLoginHandler, OAuthenticator
from .oauth2 import OAuthenticator, OAuthLoginHandler


class CILogonLoginHandler(OAuthLoginHandler):
Expand Down Expand Up @@ -138,8 +133,6 @@ async def authenticate(self, handler, data=None):
receive it.
"""
code = handler.get_argument("code")
# TODO: Configure the curl_httpclient for tornado
http_client = AsyncHTTPClient()

# Exchange the OAuth code for a CILogon Access Token
# See: http://www.cilogon.org/oidc
Expand All @@ -157,18 +150,15 @@ async def authenticate(self, handler, data=None):

req = HTTPRequest(url, headers=headers, method="POST", body='')

resp = await http_client.fetch(req)
token_response = json.loads(resp.body.decode('utf8', 'replace'))
token_response = await self.fetch(req)
access_token = token_response['access_token']
self.log.info("Access token acquired.")
# Determine who the logged in user is
params = dict(access_token=access_token)
req = HTTPRequest(
url_concat("https://%s/oauth2/userinfo" % self.cilogon_host, params),
headers=headers,
)
resp = await http_client.fetch(req)
resp_json = json.loads(resp.body.decode('utf8', 'replace'))
resp_json = await self.fetch(req)

claimlist = [self.username_claim]
if self.additional_username_claims:
Expand Down
40 changes: 15 additions & 25 deletions oauthenticator/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,17 @@
Custom Authenticator to use generic OAuth2 with JupyterHub
"""

import json
import os
import base64
import urllib

from tornado.httputil import url_concat
from tornado.httpclient import HTTPRequest, AsyncHTTPClient
import os
from urllib.parse import urlencode

from jupyterhub.auth import LocalAuthenticator

from traitlets import Unicode, Dict, Bool, Union, List
from .traitlets import Callable
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
from tornado.httputil import url_concat
from traitlets import Bool, Dict, List, Unicode, Union, default

from .oauth2 import OAuthenticator
from .traitlets import Callable


class GenericOAuthenticator(OAuthenticator):
Expand Down Expand Up @@ -88,7 +85,8 @@ class GenericOAuthenticator(OAuthenticator):
help="Disable basic authentication for access token request",
)

def http_client(self):
@default("http_client")
def _default_http_client(self):
return AsyncHTTPClient(force_instance=True, defaults=dict(validate_cert=self.tls_verify))

def _get_headers(self):
Expand All @@ -101,7 +99,7 @@ def _get_headers(self):
headers.update({"Authorization": "Basic {}".format(b64key.decode("utf8"))})
return headers

async def _get_token(self, http_client, headers, params):
def _get_token(self, headers, params):
if self.token_url:
url = self.token_url
else:
Expand All @@ -111,13 +109,11 @@ async def _get_token(self, http_client, headers, params):
url,
method="POST",
headers=headers,
body=urllib.parse.urlencode(params),
body=urlencode(params),
)
return self.fetch(req, "fetching access token")

resp = await http_client.fetch(req)
return json.loads(resp.body.decode('utf8', 'replace'))

async def _get_user_data(self, http_client, token_response):
def _get_user_data(self, token_response):
access_token = token_response['access_token']
token_type = token_response['token_type']

Expand All @@ -140,9 +136,7 @@ async def _get_user_data(self, http_client, token_response):
method=self.userdata_method,
headers=headers,
)
resp = await http_client.fetch(req)

return json.loads(resp.body.decode('utf8', 'replace'))
return self.fetch(req, "fetching user data")

@staticmethod
def _create_auth_state(token_response, user_data_response):
Expand All @@ -165,8 +159,6 @@ def check_user_in_groups(member_groups, allowed_groups):

async def authenticate(self, handler, data=None):
code = handler.get_argument("code")
# TODO: Configure the curl_httpclient for tornado
http_client = self.http_client()

params = dict(
redirect_uri=self.get_callback_url(handler),
Expand All @@ -177,9 +169,9 @@ async def authenticate(self, handler, data=None):

headers = self._get_headers()

token_resp_json = await self._get_token(http_client, headers, params)
token_resp_json = await self._get_token(headers, params)

user_data_resp_json = await self._get_user_data(http_client, token_resp_json)
user_data_resp_json = await self._get_user_data(token_resp_json)

if callable(self.username_key):
name = self.username_key(user_data_resp_json)
Expand All @@ -197,7 +189,6 @@ async def authenticate(self, handler, data=None):
}

if self.allowed_groups:

self.log.info('Validating if user claim groups match any of {}'.format(self.allowed_groups))

if callable(self.claim_groups_key):
Expand All @@ -223,5 +214,4 @@ async def authenticate(self, handler, data=None):

class LocalGenericOAuthenticator(LocalAuthenticator, GenericOAuthenticator):
"""A version that mixes in local system user creation"""

pass
Loading