Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JupyterHub integration #36

Merged
merged 2 commits into from
Jun 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion continuous_integration/before_install.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
gimme 1.12 \
&& source ~/.gimme/envs/go1.12.env \
&& go get github.com/stretchr/testify/assert
&& go get github.com/stretchr/testify/assert \
&& nvm install 6 \
&& nvm use 6
4 changes: 4 additions & 0 deletions continuous_integration/install.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
set -xe

npm install -g configurable-http-proxy

pip install \
dask \
distributed \
Expand All @@ -10,6 +12,8 @@ pip install \
pytest \
pytest-asyncio \
trustme \
jupyterhub \
notebook \
black \
flake8

Expand Down
22 changes: 16 additions & 6 deletions dask-gateway-server/dask_gateway_server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,15 @@ class DaskGateway(Application):
)

@validate("public_url", "gateway_url", "private_url")
def _resolve_hostname(self, proposal):
def _normalize_url(self, proposal):
url = proposal.value
parsed = urlparse(url)
if parsed.hostname in {"", "0.0.0.0"}:
# Resolve hostname
host = socket.gethostname()
parsed = parsed._replace(netloc="%s:%i" % (host, parsed.port))
url = urlunparse(parsed)
# Ensure no trailing slash
url = urlunparse(parsed._replace(path=parsed.path.rstrip("/")))
return url

tls_key = Unicode(
Expand Down Expand Up @@ -361,6 +363,10 @@ def _cookie_secret_validate(self, proposal):
def api_url(self):
return self.public_url + "/gateway/api"

@property
def public_url_prefix(self):
return urlparse(self.public_url).path

def create_task(self, task):
out = asyncio.ensure_future(task)
self.pending_tasks.add(out)
Expand Down Expand Up @@ -554,7 +560,9 @@ async def start_tornado_application(self):
private_url.port, address=private_url.hostname
)
self.log.info("Gateway API listening on %s", self.private_url)
await self.web_proxy.add_route("/gateway/", self.private_url)
await self.web_proxy.add_route(
self.public_url_prefix + "/gateway/", self.private_url
)

async def start_or_exit(self):
try:
Expand Down Expand Up @@ -713,9 +721,11 @@ async def start_cluster(self, cluster):
return True

async def add_cluster_to_proxies(self, cluster):
await self.web_proxy.add_route(
"/gateway/clusters/" + cluster.name, cluster.dashboard_address
)
if cluster.dashboard_address:
await self.web_proxy.add_route(
self.public_url_prefix + "/gateway/clusters/" + cluster.name,
cluster.dashboard_address,
)
await self.scheduler_proxy.add_route(
"/" + cluster.name, cluster.scheduler_address
)
Expand Down
139 changes: 137 additions & 2 deletions dask-gateway-server/dask_gateway_server/auth.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import json
import os
import base64
from urllib.parse import quote

from tornado import web
from traitlets import Unicode
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
from traitlets import Unicode, default
from traitlets.config import LoggingConfigurable

__all__ = ("Authenticator", "KerberosAuthenticator", "DummyAuthenticator")
__all__ = (
"Authenticator",
"KerberosAuthenticator",
"DummyAuthenticator",
"JupyterHubAuthenticator",
)


class Authenticator(LoggingConfigurable):
Expand Down Expand Up @@ -146,3 +154,130 @@ def authenticate(self, handler):
self.raise_auth_required(handler)

return user


class JupyterHubAuthenticator(Authenticator):
"""An authenticator that uses JupyterHub to perform authentication"""

jupyterhub_api_token = Unicode(
help="""
Dask Gateway's JupyterHub API Token, used for authenticating the
gateway's API requests to JupyterHub.

By default this is determined from the ``JUPYTERHUB_API_TOKEN``
environment variable.
""",
config=True,
)

@default("jupyterhub_api_token")
def _default_jupyterhub_api_token(self):
out = os.environ.get("JUPYTERHUB_API_TOKEN")
if not out:
raise ValueError("JUPYTERHUB_API_TOKEN must be set")
return out

jupyterhub_api_url = Unicode(
help="""
The API URL for the JupyterHub server.

By default this is determined from the ``JUPYTERHUB_API_URL``
environment variable.
""",
config=True,
)

@default("jupyterhub_api_url")
def _default_jupyterhub_api_url(self):
out = os.environ.get("JUPYTERHUB_API_URL")
if not out:
raise ValueError("JUPYTERHUB_API_URL must be set")
return out

tls_key = Unicode(
"",
help="""
Path to TLS key file for making API requests to JupyterHub.

When setting this, you should also set tls_cert.
""",
config=True,
)

tls_cert = Unicode(
"",
help="""
Path to TLS certficate file for making API requests to JupyterHub.

When setting this, you should also set tls_cert.
""",
config=True,
)

tls_ca = Unicode(
"",
help="""
Path to TLS CA file for verifying API requests to JupyterHub.

When setting this, you should also set tls_key and tls_cert.
""",
config=True,
)

def raise_auth_required(self, handler):
handler.set_status(401)
handler.write("Authentication required")
handler.set_header("WWW-Authenticate", "jupyterhub")
raise web.Finish()

def get_token(self, handler):
auth_header = handler.request.headers.get("Authorization")
if auth_header:
auth_type, auth_key = auth_header.split(" ", 1)
if auth_type == "jupyterhub":
return auth_key
return None

async def authenticate(self, handler):
token = self.get_token(handler)
if token is None:
self.raise_auth_required(handler)

url = "%s/authorizations/token/%s" % (
self.jupyterhub_api_url,
quote(token, safe=""),
)

req = HTTPRequest(
url,
method="GET",
headers={"Authorization": "token %s" % self.jupyterhub_api_token},
)

kwargs = {}
if self.tls_cert and self.tls_key:
kwargs.update({"client_cert": self.tls_cert, "client_key": self.tls_key})
if self.tls_ca:
kwargs["ca_certs"] = self.tls_ca

client = AsyncHTTPClient()
resp = await client.fetch(req, raise_error=False, **kwargs)

if resp.code < 400:
return json.loads(resp.body)["name"]
elif resp.code == 404:
self.log.debug("Token for non-existant user requested")
self.raise_auth_required(handler)
else:
if resp.code == 403:
msg = "Permission failure verifying user's JupyterHub API token"
code = 500
elif resp.code >= 500:
msg = "Upstream failure verifying user's JupyterHub API token"
code = 502
else:
msg = "Failure verifying user's JupyterHub API token"
code = 500

self.log.error("%s - code: %s, reason: %s", msg, resp.code, resp.reason)
raise web.HTTPError(code, msg)
39 changes: 25 additions & 14 deletions dask-gateway-server/dask_gateway_server/handlers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import functools
import json
from inspect import isawaitable
from urllib.parse import urlparse, unquote

from tornado import web
Expand All @@ -16,11 +17,12 @@ def user_authenticated(method):
"""Ensure this method is authenticated via a user"""

@functools.wraps(method)
def inner(self, *args, **kwargs):
# Trigger authentication mechanism
if self.current_user is None:
async def inner(self, *args, **kwargs):
username = await self.current_user_from_login()
if username is None:
raise web.HTTPError(401)
return method(self, *args, **kwargs)
self.current_user = username
return await method(self, *args, **kwargs)

return inner

Expand All @@ -30,7 +32,7 @@ def token_authenticated(method):

@functools.wraps(method)
def inner(self, *args, **kwargs):
username = self.get_current_user_from_token()
username = self.current_user_from_token()
if username is None:
raise web.HTTPError(401)
self.current_user = username
Expand Down Expand Up @@ -66,7 +68,7 @@ def check_cluster(self, cluster_name):
if self.dask_cluster.name != cluster_name:
raise web.HTTPError(403)

def get_current_user_from_token(self):
def current_user_from_token(self):
auth_header = self.request.headers.get("Authorization")
if auth_header:
auth_type, auth_key = auth_header.split(" ", 1)
Expand All @@ -84,7 +86,7 @@ def get_current_user_from_token(self):
)
return None

def get_current_user(self):
async def current_user_from_login(self):
cookie = self.get_secure_cookie(
DASK_GATEWAY_COOKIE, max_age_days=self.cookie_max_age_days
)
Expand All @@ -101,6 +103,8 @@ def get_current_user(self):

# Finally, fall back to using the authenticator
username = self.authenticator.authenticate(self)
if isawaitable(username):
username = await username
user = self.gateway.db.get_or_create_user(username)
self.set_secure_cookie(
DASK_GATEWAY_COOKIE, user.cookie, expires_days=self.cookie_max_age_days
Expand All @@ -111,25 +115,32 @@ def get_current_user(self):


def cluster_model(gateway, cluster, full=True):
if cluster.scheduler_address:
if cluster.status == ClusterStatus.RUNNING:
scheduler = "gateway://%s/%s" % (
urlparse(gateway.gateway_url).netloc,
cluster.name,
)
dashboard = "%s/gateway/clusters/%s" % (gateway.public_url, cluster.name)
dashboard = (
"/gateway/clusters/%s" % cluster.name if cluster.dashboard_address else ""
)
else:
scheduler = dashboard = ""
scheduler = dashboard = None
out = {
"name": cluster.name,
"scheduler_address": scheduler or None,
"dashboard_address": dashboard or None,
"scheduler_address": scheduler,
"dashboard_route": dashboard,
"status": cluster.status.name,
"start_time": cluster.start_time,
"stop_time": cluster.stop_time,
}
if full:
out["tls_cert"] = cluster.tls_cert.decode() or None
out["tls_key"] = cluster.tls_key.decode() or None
if cluster.status == ClusterStatus.RUNNING:
tls_cert = cluster.tls_cert.decode()
tls_key = cluster.tls_key.decode()
else:
tls_cert = tls_key = None
out["tls_cert"] = tls_cert
out["tls_key"] = tls_key
return out


Expand Down
2 changes: 1 addition & 1 deletion dask-gateway-server/dask_gateway_server/proxy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ async def wait_until_up(self):
"the proxy process and the gateway"
% (self._subcommand, self.api_url)
)
else:
elif e.code != 599:
raise RuntimeError(
"Error while connecting to %s proxy api at %s: %s"
% (self._subcommand, self.api_url, e)
Expand Down
22 changes: 21 additions & 1 deletion dask-gateway/dask_gateway/auth.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import getpass
import os
import re
from base64 import b64encode
from urllib.parse import urlparse

import dask


__all__ = ("GatewayAuth", "BasicAuth", "KerberosAuth", "get_auth")
__all__ = ("GatewayAuth", "BasicAuth", "KerberosAuth", "JupyterHubAuth", "get_auth")


def _import_object(name):
Expand Down Expand Up @@ -39,6 +40,8 @@ def get_auth(auth=None):
auth = KerberosAuth
elif auth == "basic":
auth = BasicAuth
elif auth == "jupyterhub":
auth = JupyterHubAuth
else:
auth = _import_object(auth)
elif not callable(auth):
Expand Down Expand Up @@ -144,3 +147,20 @@ def post_response(self, req, resp, context):
if not token:
raise Exception("Kerberos negotiation failed")
kerberos.authGSSClientStep(context, token)


class JupyterHubAuth(GatewayAuth):
"""Uses JupyterHub API tokens to authenticate"""

def __init__(self, api_token=None):
if api_token is None:
api_token = os.environ.get("JUPYTERHUB_API_TOKEN")
if api_token is None:
raise ValueError(
"No JupyterHub API token provided, and JUPYTERHUB_API_TOKEN "
"environment variable not found"
)
self.api_token = api_token

def pre_request(self, req, resp):
req.headers["Authorization"] = "jupyterhub " + self.api_token
Loading