Skip to content

Commit

Permalink
Merge pull request from GHSA-487p-qx68-5vjw
Browse files Browse the repository at this point in the history
  • Loading branch information
danking committed Dec 29, 2023
1 parent 08909ac commit 0dcc17f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 13 deletions.
9 changes: 5 additions & 4 deletions auth/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
uvloop.install()

CLOUD = get_global_config()['cloud']
ORGANIZATION_DOMAIN = os.environ['HAIL_ORGANIZATION_DOMAIN']
DEFAULT_NAMESPACE = os.environ['HAIL_DEFAULT_NAMESPACE']

is_test_deployment = DEFAULT_NAMESPACE != 'default'
Expand Down Expand Up @@ -333,7 +332,8 @@ async def callback(request) -> web.Response:
cleanup_session(session)

try:
flow_result = request.app[AppKeys.FLOW_CLIENT].receive_callback(request, flow_dict)
flow_client = request.app[AppKeys.FLOW_CLIENT]
flow_result = flow_client.receive_callback(request, flow_dict)
login_id = flow_result.login_id
except asyncio.CancelledError:
raise
Expand All @@ -352,10 +352,11 @@ async def callback(request) -> web.Response:

assert caller == 'signup'

username, domain = flow_result.email.split('@')
username, _ = flow_result.unverified_email.split('@')
username = ''.join(c for c in username if c.isalnum())

if domain != ORGANIZATION_DOMAIN:
assert flow_client.organization_id() is not None
if flow_result.organization_id != flow_client.organization_id():
raise web.HTTPUnauthorized()

try:
Expand Down
49 changes: 40 additions & 9 deletions hail/python/hailtop/auth/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from cryptography.hazmat.primitives import serialization
import json
import logging
import os
import urllib.parse
from typing import Any, Dict, List, Mapping, Optional, TypedDict, ClassVar

Expand All @@ -21,13 +22,26 @@


class FlowResult:
def __init__(self, login_id: str, email: str, token: Mapping[Any, Any]):
def __init__(self,
login_id: str,
unverified_email: str,
organization_id: Optional[str],
token: Mapping[Any, Any]):
self.login_id = login_id
self.email = email
self.unverified_email = unverified_email
self.organization_id = organization_id # In Azure, a Tenant ID. In Google, a domain name.
self.token = token


class Flow(abc.ABC):
@abc.abstractmethod
async def organization_id(self) -> str:
"""
The unique identifier of the organization (e.g. Azure Tenant, Google Organization) in
which this Hail Batch instance lives.
"""
raise NotImplementedError

@abc.abstractmethod
def initiate_flow(self, redirect_uri: str) -> dict:
"""
Expand Down Expand Up @@ -64,7 +78,6 @@ async def get_identity_uid_from_access_token(session: httpx.ClientSession, acces
"""
raise NotImplementedError


class GoogleFlow(Flow):
scopes: ClassVar[List[str]] = [
'https://www.googleapis.com/auth/userinfo.profile',
Expand All @@ -75,6 +88,11 @@ class GoogleFlow(Flow):
def __init__(self, credentials_file: str):
self._credentials_file = credentials_file

def organization_id(self) -> str:
if organization_id := os.environ.get('HAIL_ORGANIZATION_DOMAIN'):
return organization_id
raise ValueError('Only available in the auth pod')

def initiate_flow(self, redirect_uri: str) -> dict:
flow = google_auth_oauthlib.flow.Flow.from_client_secrets_file(
self._credentials_file, scopes=GoogleFlow.scopes, state=None
Expand All @@ -98,7 +116,7 @@ def receive_callback(self, request: aiohttp.web.Request, flow_dict: dict) -> Flo
flow.credentials.id_token, google.auth.transport.requests.Request() # type: ignore
)
email = token['email']
return FlowResult(email, email, token)
return FlowResult(email, email, token.get('hd'), token)

@staticmethod
def perform_installed_app_login_flow(oauth2_client: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -134,12 +152,12 @@ async def get_identity_uid_from_access_token(session: httpx.ClientSession, acces
if not (is_human_with_hail_audience or is_service_account):
return None

email = userinfo['email']
if email.endswith('iam.gserviceaccount.com'):
domain = userinfo.get('hd')
if domain == 'iam.gserviceaccount.com':
return userinfo['sub']
# We don't currently track user's unique GCP IAM ID (sub) in the database, just their email,
# but we should eventually use the sub as that is guaranteed to be unique to the user.
return email
return userinfo['email']
except httpx.ClientResponseError as e:
if e.status in (400, 401):
return None
Expand All @@ -163,8 +181,16 @@ def __init__(self, credentials_file: str):
self._client = msal.ConfidentialClientApplication(data['appId'], data['password'], authority)
self._tenant_id = tenant_id

def organization_id(self) -> str:
return self._tenant_id

def initiate_flow(self, redirect_uri: str) -> dict:
flow = self._client.initiate_auth_code_flow(scopes=[], redirect_uri=redirect_uri)
flow = self._client.initiate_auth_code_flow(
scopes=[], # confusingly, scopes=[] is the only way to get the openid, profile, and
# offline_access scopes
# https://github.com/AzureAD/microsoft-authentication-library-for-python/blob/dev/msal/application.py#L568-L580
redirect_uri=redirect_uri
)
return {
'flow': flow,
'authorization_url': flow['auth_uri'],
Expand All @@ -184,7 +210,12 @@ def receive_callback(self, request: aiohttp.web.Request, flow_dict: dict) -> Flo
if tid != self._tenant_id:
raise ValueError('invalid tenant id')

return FlowResult(token['id_token_claims']['oid'], token['id_token_claims']['preferred_username'], token)
return FlowResult(
token['id_token_claims']['oid'],
token['id_token_claims']['preferred_username'],
token['id_token_claims']['tid'],
token
)

@staticmethod
def perform_installed_app_login_flow(oauth2_client: Dict[str, Any]) -> Dict[str, Any]:
Expand Down

0 comments on commit 0dcc17f

Please sign in to comment.