Skip to content

Commit

Permalink
Merge pull request #578 from manics/http-proxy-args
Browse files Browse the repository at this point in the history
[All] Add `http_request_kwargs` config option
  • Loading branch information
consideRatio committed Apr 4, 2023
2 parents 431dd06 + 71c8a9c commit 66c3a34
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 30 deletions.
4 changes: 1 addition & 3 deletions oauthenticator/bitbucket.py
Expand Up @@ -2,7 +2,6 @@
Custom Authenticator to use Bitbucket OAuth with JupyterHub
"""
from jupyterhub.auth import LocalAuthenticator
from tornado.httpclient import HTTPRequest
from tornado.httputil import url_concat
from traitlets import Set, default

Expand Down Expand Up @@ -68,8 +67,7 @@ async def _check_membership_allowed_teams(self, username, access_token, token_ty
"https://api.bitbucket.org/2.0/workspaces", {'role': 'member'}
)
while next_page:
req = HTTPRequest(next_page, method="GET", headers=headers)
resp_json = await self.fetch(req)
resp_json = await self.httpfetch(next_page, method="GET", headers=headers)
next_page = resp_json.get('next', None)

user_teams = {entry["name"] for entry in resp_json["values"]}
Expand Down
17 changes: 9 additions & 8 deletions oauthenticator/github.py
Expand Up @@ -7,7 +7,6 @@

from jupyterhub.auth import LocalAuthenticator
from requests.utils import parse_header_links
from tornado.httpclient import HTTPRequest
from traitlets import Bool, Set, Unicode, default

from .oauth2 import OAuthenticator
Expand Down Expand Up @@ -169,13 +168,13 @@ async def update_auth_model(self, auth_model):
if not auth_model["auth_state"]["github_user"]["email"] and (
"user" in granted_scopes or "user:email" in granted_scopes
):
req = HTTPRequest(
resp_json = await self.httpfetch(
self.github_api + "/user/emails",
"fetching user emails",
method="GET",
headers=self.build_userdata_request_headers(access_token, token_type),
validate_cert=self.validate_server_cert,
)
resp_json = await self.fetch(req, "fetching user emails")
for val in resp_json:
if val["primary"]:
auth_model["auth_state"]["github_user"]["email"] = val["email"]
Expand Down Expand Up @@ -210,13 +209,14 @@ async def _paginated_fetch(self, api_url, access_token, token_type):
url = api_url
content = []
while True:
req = HTTPRequest(
resp = await self.httpfetch(
url,
"fetching user teams",
parse_json=False,
method="GET",
headers=self.build_userdata_request_headers(access_token, token_type),
validate_cert=self.validate_server_cert,
)
resp = await self.fetch(req, "fetching user teams", parse_json=False)

resp_json = json.loads(resp.body.decode())
content += resp_json
Expand Down Expand Up @@ -258,14 +258,15 @@ async def _check_membership_allowed_organizations(

check_membership_url = self._build_check_membership_url(org, username)

req = HTTPRequest(
self.log.debug(f"Checking GitHub organization membership: {username} in {org}?")
resp = await self.httpfetch(
check_membership_url,
parse_json=False,
raise_error=False,
method="GET",
headers=headers,
validate_cert=self.validate_server_cert,
)
self.log.debug(f"Checking GitHub organization membership: {username} in {org}?")
resp = await self.fetch(req, raise_error=False, parse_json=False)
if resp.code == 204:
self.log.info(f"Allowing {username} as member of {org}")
return True
Expand Down
13 changes: 8 additions & 5 deletions oauthenticator/gitlab.py
Expand Up @@ -160,13 +160,12 @@ async def user_is_authorized(self, auth_model):

async def _get_gitlab_version(self, access_token):
url = f"{self.gitlab_api}/version"
req = HTTPRequest(
resp_json = await self.httpfetch(
url,
method="GET",
headers=_api_headers(access_token),
validate_cert=self.validate_server_cert,
)
resp_json = await self.fetch(req)
version_strings = resp_json['version'].split('-')[0].split('.')[:3]
version_ints = list(map(int, version_strings))
return version_ints
Expand All @@ -183,11 +182,15 @@ async def _check_membership_allowed_groups(self, user_id, access_token):
)
req = HTTPRequest(
url,
)
resp = await self.httpfetch(
url,
parse_json=False,
raise_error=False,
method="GET",
headers=headers,
validate_cert=self.validate_server_cert,
)
resp = await self.fetch(req, raise_error=False, parse_json=False)
if resp.code == 200:
return True # user _is_ in group
return False
Expand All @@ -202,13 +205,13 @@ async def _check_membership_allowed_project_ids(self, user_id, access_token):
self.member_api_variant,
user_id,
)
req = HTTPRequest(
resp_json = await self.httpfetch(
url,
raise_error=False,
method="GET",
headers=headers,
validate_cert=self.validate_server_cert,
)
resp_json = await self.fetch(req, raise_error=False)
if resp_json:
access_level = resp_json.get('access_level', 0)

Expand Down
9 changes: 4 additions & 5 deletions oauthenticator/globus.py
Expand Up @@ -7,7 +7,6 @@
import urllib

from jupyterhub.auth import LocalAuthenticator
from tornado.httpclient import HTTPRequest
from tornado.web import HTTPError
from traitlets import Bool, List, Set, Unicode, default

Expand Down Expand Up @@ -247,8 +246,9 @@ async def get_users_groups_ids(self, tokens):
# Get list of user's Groups
groups_headers = self.get_default_headers()
groups_headers['Authorization'] = f'Bearer {groups_token}'
req = HTTPRequest(self.globus_groups_url, method='GET', headers=groups_headers)
groups_resp = await self.fetch(req)
groups_resp = await self.httpfetch(
self.globus_groups_url, method='GET', headers=groups_headers
)
# Build set of Group IDs
for group in groups_resp:
user_group_ids.add(group['id'])
Expand Down Expand Up @@ -345,13 +345,12 @@ async def revoke_service_tokens(self, services):
all_tokens = [tok for tok in access_tokens + refresh_tokens if tok is not None]

for token in all_tokens:
req = HTTPRequest(
await self.httpfetch(
self.revocation_url,
method="POST",
headers=self.get_client_credential_headers(),
body=urllib.parse.urlencode({'token': token}),
)
await self.fetch(req)


class LocalGlobusOAuthenticator(LocalAuthenticator, GlobusOAuthenticator):
Expand Down
53 changes: 46 additions & 7 deletions oauthenticator/oauth2.py
Expand Up @@ -391,6 +391,20 @@ def _validate_server_cert_default(self):
else:
return True

http_request_kwargs = Dict(
{},
help="""Extra default kwargs passed to all HTTPRequests.
For example, to use a HTTP proxy for all requests:
c.OAuthenticator.http_request_kwargs = {
"proxy_host": "proxy.example.com",
"proxy_port": 8080,
}
""",
config=True,
)

http_client = Any()

@default("http_client")
Expand All @@ -406,10 +420,11 @@ async def fetch(self, req, label="fetching", parse_json=True, **kwargs):
req: tornado HTTPRequest
label (str): label describing what is happening,
used in log message when the request fails.
parse_json (bool): whether to parse the response as JSON
**kwargs: remaining keyword args
passed to underlying `client.fetch(req, **kwargs)`
Returns:
r: parsed JSON response
parsed JSON response if `parse_json=True`, else `tornado.HTTPResponse`
"""
try:
resp = await self.http_client.fetch(req, **kwargs)
Expand Down Expand Up @@ -444,6 +459,33 @@ async def fetch(self, req, label="fetching", parse_json=True, **kwargs):
else:
return resp

async def httpfetch(
self, url, label="fetching", parse_json=True, raise_error=True, **kwargs
):
"""Wrapper for creating and fetching http requests
Includes http_request_kwargs in request kwargs
logs error responses, parses successful JSON responses
Args:
url (str): url to fetch
label (str): label describing what is happening,
used in log message when the request fails.
parse_json (bool): whether to parse the response as JSON
raise_error (bool): whether to raise an exception on HTTP errors
**kwargs: remaining keyword args
passed to underlying `tornado.HTTPRequest`, overrides
`http_request_kwargs`
Returns:
parsed JSON response if `parse_json=True`, else `tornado.HTTPResponse`
"""
request_kwargs = self.http_request_kwargs.copy()
request_kwargs.update(kwargs)
req = HTTPRequest(url, **request_kwargs)
return await self.fetch(
req, label=label, parse_json=parse_json, raise_error=raise_error
)

def login_url(self, base_url):
return url_path_join(base_url, "oauth_login")

Expand Down Expand Up @@ -590,16 +632,14 @@ async def get_token_info(self, handler, params):
"""
url = url_concat(self.token_url, params)

req = HTTPRequest(
token_info = await self.httpfetch(
url,
method="POST",
headers=self.build_token_info_request_headers(),
body=json.dumps(params),
validate_cert=self.validate_server_cert,
)

token_info = await self.fetch(req)

if "error_description" in token_info:
raise web.HTTPError(
403,
Expand Down Expand Up @@ -635,15 +675,14 @@ async def token_to_user(self, token_info):
if self.userdata_token_method == "url":
url = url_concat(url, dict(access_token=access_token))

req = HTTPRequest(
return await self.httpfetch(
url,
"Fetching user info...",
method="GET",
headers=self.build_userdata_request_headers(access_token, token_type),
validate_cert=self.validate_server_cert,
)

return await self.fetch(req, "Fetching user info...")

def build_auth_state_dict(self, token_info, user_info):
"""
Builds the `auth_state` dict that will be returned by a succesfull `authenticate` method call.
Expand Down
4 changes: 2 additions & 2 deletions oauthenticator/tests/mocks.py
Expand Up @@ -33,7 +33,7 @@ def add_host(self, host, paths):
Args:
host (str): the host to mock (e.g. 'api.github.com')
paths (list(str|regex, callable)): a list of paths (or regexps for paths)
paths (list[(str|regex, callable)]): a list of paths (or regexps for paths)
and callables to be called for those paths.
The mock handlers will receive the request as their only argument.
Expand All @@ -47,7 +47,7 @@ def add_host(self, host, paths):
Example::
client.add_host('api.github.com', [
('/user': lambda request: {'login': 'name'})
('/user', lambda request: {'login': 'name'})
])
"""
self.hosts[host] = paths
Expand Down
24 changes: 24 additions & 0 deletions oauthenticator/tests/test_oauth2.py
@@ -1,3 +1,4 @@
import re
import uuid
from unittest.mock import Mock

Expand Down Expand Up @@ -41,3 +42,26 @@ async def test_custom_logout(monkeypatch):
await logout_handler.get()
assert logout_handler.clear_login_cookie.called
logout_handler.clear_cookie.assert_called_once_with(STATE_COOKIE_NAME)


async def test_httpfetch(client):
authenticator = OAuthenticator()
authenticator.http_request_kwargs = {
"proxy_host": "proxy.example.org",
"proxy_port": 8080,
}

# Return request fields as the response so we can examine it
client.add_host(
"example.org",
[
(
re.compile(".*"),
lambda req: [req.url, req.method, req.proxy_host, req.proxy_port],
),
],
)
authenticator.http_client = client

r = await authenticator.httpfetch("http://example.org/a")
assert r == ['http://example.org/a', 'GET', "proxy.example.org", 8080]

0 comments on commit 66c3a34

Please sign in to comment.