Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[All] Add http_request_kwargs config option #578

Merged
merged 6 commits into from Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 1 addition & 3 deletions oauthenticator/bitbucket.py
Expand Up @@ -2,7 +2,6 @@
Custom Authenticator to use Bitbucket OAuth with JupyterHub
"""
from jupyterhub.auth import LocalAuthenticator
from tornado.httpclient import HTTPRequest
from tornado.httputil import url_concat
from traitlets import Set, default

Expand Down Expand Up @@ -68,8 +67,7 @@ async def _check_membership_allowed_teams(self, username, access_token, token_ty
"https://api.bitbucket.org/2.0/workspaces", {'role': 'member'}
)
while next_page:
req = HTTPRequest(next_page, method="GET", headers=headers)
resp_json = await self.fetch(req)
resp_json = await self.httpfetch(next_page, method="GET", headers=headers)
next_page = resp_json.get('next', None)

user_teams = {entry["name"] for entry in resp_json["values"]}
Expand Down
17 changes: 9 additions & 8 deletions oauthenticator/github.py
Expand Up @@ -7,7 +7,6 @@

from jupyterhub.auth import LocalAuthenticator
from requests.utils import parse_header_links
from tornado.httpclient import HTTPRequest
from traitlets import Bool, Set, Unicode, default

from .oauth2 import OAuthenticator
Expand Down Expand Up @@ -169,13 +168,13 @@ async def update_auth_model(self, auth_model):
if not auth_model["auth_state"]["github_user"]["email"] and (
"user" in granted_scopes or "user:email" in granted_scopes
):
req = HTTPRequest(
resp_json = await self.httpfetch(
self.github_api + "/user/emails",
"fetching user emails",
method="GET",
headers=self.build_userdata_request_headers(access_token, token_type),
validate_cert=self.validate_server_cert,
)
resp_json = await self.fetch(req, "fetching user emails")
for val in resp_json:
if val["primary"]:
auth_model["auth_state"]["github_user"]["email"] = val["email"]
Expand Down Expand Up @@ -210,13 +209,14 @@ async def _paginated_fetch(self, api_url, access_token, token_type):
url = api_url
content = []
while True:
req = HTTPRequest(
resp = await self.httpfetch(
url,
"fetching user teams",
parse_json=False,
method="GET",
headers=self.build_userdata_request_headers(access_token, token_type),
validate_cert=self.validate_server_cert,
)
resp = await self.fetch(req, "fetching user teams", parse_json=False)

resp_json = json.loads(resp.body.decode())
content += resp_json
Expand Down Expand Up @@ -258,14 +258,15 @@ async def _check_membership_allowed_organizations(

check_membership_url = self._build_check_membership_url(org, username)

req = HTTPRequest(
self.log.debug(f"Checking GitHub organization membership: {username} in {org}?")
resp = await self.httpfetch(
check_membership_url,
parse_json=False,
raise_error=False,
method="GET",
headers=headers,
validate_cert=self.validate_server_cert,
)
self.log.debug(f"Checking GitHub organization membership: {username} in {org}?")
resp = await self.fetch(req, raise_error=False, parse_json=False)
if resp.code == 204:
self.log.info(f"Allowing {username} as member of {org}")
return True
Expand Down
13 changes: 8 additions & 5 deletions oauthenticator/gitlab.py
Expand Up @@ -160,13 +160,12 @@ async def user_is_authorized(self, auth_model):

async def _get_gitlab_version(self, access_token):
url = f"{self.gitlab_api}/version"
req = HTTPRequest(
resp_json = await self.httpfetch(
url,
method="GET",
headers=_api_headers(access_token),
validate_cert=self.validate_server_cert,
)
resp_json = await self.fetch(req)
version_strings = resp_json['version'].split('-')[0].split('.')[:3]
version_ints = list(map(int, version_strings))
return version_ints
Expand All @@ -183,11 +182,15 @@ async def _check_membership_allowed_groups(self, user_id, access_token):
)
req = HTTPRequest(
url,
)
resp = await self.httpfetch(
url,
parse_json=False,
raise_error=False,
method="GET",
headers=headers,
validate_cert=self.validate_server_cert,
)
resp = await self.fetch(req, raise_error=False, parse_json=False)
if resp.code == 200:
return True # user _is_ in group
return False
Expand All @@ -202,13 +205,13 @@ async def _check_membership_allowed_project_ids(self, user_id, access_token):
self.member_api_variant,
user_id,
)
req = HTTPRequest(
resp_json = await self.httpfetch(
url,
raise_error=False,
method="GET",
headers=headers,
validate_cert=self.validate_server_cert,
)
resp_json = await self.fetch(req, raise_error=False)
if resp_json:
access_level = resp_json.get('access_level', 0)

Expand Down
9 changes: 4 additions & 5 deletions oauthenticator/globus.py
Expand Up @@ -7,7 +7,6 @@
import urllib

from jupyterhub.auth import LocalAuthenticator
from tornado.httpclient import HTTPRequest
from tornado.web import HTTPError
from traitlets import Bool, List, Set, Unicode, default

Expand Down Expand Up @@ -247,8 +246,9 @@ async def get_users_groups_ids(self, tokens):
# Get list of user's Groups
groups_headers = self.get_default_headers()
groups_headers['Authorization'] = f'Bearer {groups_token}'
req = HTTPRequest(self.globus_groups_url, method='GET', headers=groups_headers)
groups_resp = await self.fetch(req)
groups_resp = await self.httpfetch(
self.globus_groups_url, method='GET', headers=groups_headers
)
# Build set of Group IDs
for group in groups_resp:
user_group_ids.add(group['id'])
Expand Down Expand Up @@ -345,13 +345,12 @@ async def revoke_service_tokens(self, services):
all_tokens = [tok for tok in access_tokens + refresh_tokens if tok is not None]

for token in all_tokens:
req = HTTPRequest(
await self.httpfetch(
self.revocation_url,
method="POST",
headers=self.get_client_credential_headers(),
body=urllib.parse.urlencode({'token': token}),
)
await self.fetch(req)


class LocalGlobusOAuthenticator(LocalAuthenticator, GlobusOAuthenticator):
Expand Down
53 changes: 46 additions & 7 deletions oauthenticator/oauth2.py
Expand Up @@ -391,6 +391,20 @@ def _validate_server_cert_default(self):
else:
return True

http_request_kwargs = Dict(
{},
help="""Extra default kwargs passed to all HTTPRequests.

For example, to use a HTTP proxy for all requests:

c.OAuthenticator.http_request_kwargs = {
"proxy_host": "proxy.example.com",
"proxy_port": 8080,
}
""",
config=True,
)

http_client = Any()

@default("http_client")
Expand All @@ -406,10 +420,11 @@ async def fetch(self, req, label="fetching", parse_json=True, **kwargs):
req: tornado HTTPRequest
label (str): label describing what is happening,
used in log message when the request fails.
parse_json (bool): whether to parse the response as JSON
**kwargs: remaining keyword args
passed to underlying `client.fetch(req, **kwargs)`
Returns:
r: parsed JSON response
parsed JSON response if `parse_json=True`, else `tornado.HTTPResponse`
"""
try:
resp = await self.http_client.fetch(req, **kwargs)
Expand Down Expand Up @@ -444,6 +459,33 @@ async def fetch(self, req, label="fetching", parse_json=True, **kwargs):
else:
return resp

async def httpfetch(
self, url, label="fetching", parse_json=True, raise_error=True, **kwargs
):
"""Wrapper for creating and fetching http requests

Includes http_request_kwargs in request kwargs
logs error responses, parses successful JSON responses

Args:
url (str): url to fetch
label (str): label describing what is happening,
used in log message when the request fails.
parse_json (bool): whether to parse the response as JSON
raise_error (bool): whether to raise an exception on HTTP errors
**kwargs: remaining keyword args
passed to underlying `tornado.HTTPRequest`, overrides
`http_request_kwargs`
Returns:
parsed JSON response if `parse_json=True`, else `tornado.HTTPResponse`
"""
request_kwargs = self.http_request_kwargs.copy()
request_kwargs.update(kwargs)
req = HTTPRequest(url, **request_kwargs)
return await self.fetch(
req, label=label, parse_json=parse_json, raise_error=raise_error
)

def login_url(self, base_url):
return url_path_join(base_url, "oauth_login")

Expand Down Expand Up @@ -590,16 +632,14 @@ async def get_token_info(self, handler, params):
"""
url = url_concat(self.token_url, params)

req = HTTPRequest(
token_info = await self.httpfetch(
url,
method="POST",
headers=self.build_token_info_request_headers(),
body=json.dumps(params),
validate_cert=self.validate_server_cert,
)

token_info = await self.fetch(req)

if "error_description" in token_info:
raise web.HTTPError(
403,
Expand Down Expand Up @@ -635,15 +675,14 @@ async def token_to_user(self, token_info):
if self.userdata_token_method == "url":
url = url_concat(url, dict(access_token=access_token))

req = HTTPRequest(
return await self.httpfetch(
url,
"Fetching user info...",
method="GET",
headers=self.build_userdata_request_headers(access_token, token_type),
validate_cert=self.validate_server_cert,
)

return await self.fetch(req, "Fetching user info...")

def build_auth_state_dict(self, token_info, user_info):
"""
Builds the `auth_state` dict that will be returned by a succesfull `authenticate` method call.
Expand Down
4 changes: 2 additions & 2 deletions oauthenticator/tests/mocks.py
Expand Up @@ -33,7 +33,7 @@ def add_host(self, host, paths):

Args:
host (str): the host to mock (e.g. 'api.github.com')
paths (list(str|regex, callable)): a list of paths (or regexps for paths)
paths (list[(str|regex, callable)]): a list of paths (or regexps for paths)
and callables to be called for those paths.
The mock handlers will receive the request as their only argument.

Expand All @@ -47,7 +47,7 @@ def add_host(self, host, paths):
Example::

client.add_host('api.github.com', [
('/user': lambda request: {'login': 'name'})
('/user', lambda request: {'login': 'name'})
])
"""
self.hosts[host] = paths
Expand Down
24 changes: 24 additions & 0 deletions oauthenticator/tests/test_oauth2.py
@@ -1,3 +1,4 @@
import re
import uuid
from unittest.mock import Mock

Expand Down Expand Up @@ -41,3 +42,26 @@ async def test_custom_logout(monkeypatch):
await logout_handler.get()
assert logout_handler.clear_login_cookie.called
logout_handler.clear_cookie.assert_called_once_with(STATE_COOKIE_NAME)


async def test_httpfetch(client):
authenticator = OAuthenticator()
authenticator.http_request_kwargs = {
"proxy_host": "proxy.example.org",
"proxy_port": 8080,
}

# Return request fields as the response so we can examine it
client.add_host(
"example.org",
[
(
re.compile(".*"),
lambda req: [req.url, req.method, req.proxy_host, req.proxy_port],
),
],
)
authenticator.http_client = client

r = await authenticator.httpfetch("http://example.org/a")
assert r == ['http://example.org/a', 'GET', "proxy.example.org", 8080]