Skip to content

Commit

Permalink
[WIP]: Add OIDC Auth
Browse files Browse the repository at this point in the history
  • Loading branch information
holesch committed May 5, 2024
1 parent c971193 commit 614850c
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 9 deletions.
7 changes: 7 additions & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ py.install_sources(
subdir: 'not_my_board/_jsonrpc',
)

py.install_sources(
'not_my_board/_auth/__init__.py',
'not_my_board/_auth/_openid.py',
'not_my_board/_auth/_login.py',
subdir: 'not_my_board/_auth',
)

py.install_sources(
'not_my_board/cli/__init__.py',
subdir: 'not_my_board/cli',
Expand Down
1 change: 1 addition & 0 deletions not_my_board/_auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._login import LoginFlow, get_id_token
121 changes: 121 additions & 0 deletions not_my_board/_auth/_login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import asyncio
import json
import os
import pathlib

import not_my_board._jsonrpc as jsonrpc
import not_my_board._util as util

from ._openid import AuthRequest, ensure_fresh


class LoginFlow(util.ContextStack):
def __init__(self, hub_url, http_client):
self._hub_url = hub_url
self._http = http_client
self._show_claims = []
self._token_store = _TokenStore()

async def _context_stack(self, stack):
url = f"{self._hub_url}/api/v1/auth-info"
auth_info = await self._http.get_json(url)
redirect_uri = f"{self._hub_url}/oidc-callback"

self._request = await AuthRequest.create(
auth_info["issuer"], auth_info["client_id"], redirect_uri, self._http
)

ready_event = asyncio.Event()
notification_api = _HubNotifications(ready_event)

channel_url = f"{self._hub_url}/ws-login"
hub = jsonrpc.WebsocketChannel(
channel_url, self._http, api_obj=notification_api
)
self._hub = await stack.enter_async_context(hub)

coro = self._hub.get_authentication_response(self._request.state)
self._auth_response_task = await stack.enter_async_context(
util.background_task(coro)
)

await ready_event.wait()

self._show_claims = auth_info.get("show_claims")

async def finish(self):
auth_response = await self._auth_response_task
id_token, refresh_token, claims = await self._request.request_tokens(
auth_response, self._http
)

async with _TokenStore() as token_store:
token_store.save_tokens(self._hub_url, id_token, refresh_token)

if self._show_claims:
# filter claims to only show relevant ones
return {k: v for k, v in claims.items() if k in self._show_claims}
else:
return claims

@property
def login_url(self):
return self._request.login_url


class _HubNotifications:
def __init__(self, ready_event):
self._ready_event = ready_event

async def oidc_callback_registered(self):
self._ready_event.set()


async def get_id_token(hub_url, http_client):
async with _TokenStore() as token_store:
id_token, refresh_token = token_store.get_tokens(hub_url)
id_token, refresh_token = await ensure_fresh(
id_token, refresh_token, http_client
)
token_store.save_tokens(hub_url, id_token, refresh_token)

return id_token


class _TokenStore(util.ContextStack):
_path = pathlib.Path("/var/lib/not-my-board/auth_tokens.json")

def __init__(self):
if not self._path.exists():
self._path.parent.mkdir(parents=True, exist_ok=True)
self._path.touch(mode=0o600)

if not os.access(self._path, os.R_OK | os.W_OK):
raise RuntimeError(f"Not allowed to access {self._path}")

async def _context_stack(self, stack):
# pylint: disable-next=consider-using-with # false positive
self._f = stack.enter_context(self._path.open("r+"))
await stack.enter_async_context(util.flock(self._f))
content = self._f.read()
self._hub_tokens_map = json.loads(content) if content else {}

def get_tokens(self, hub_url):
if hub_url not in self._hub_tokens_map:
raise RuntimeError("Login required")

tokens = self._hub_tokens_map[hub_url]
return tokens["id"], tokens["refresh"]

def save_tokens(self, hub_url, id_token, refresh_token):
new_tokens = {
"id": id_token,
"refresh": refresh_token,
}
old_tokens = self._hub_tokens_map.get(hub_url)

if old_tokens != new_tokens:
self._hub_tokens_map[hub_url] = new_tokens
self._f.seek(0)
self._f.truncate()
self._f.write(json.dumps(self._hub_tokens_map))
170 changes: 170 additions & 0 deletions not_my_board/_auth/_openid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#!/usr/bin/env python3

import base64
import dataclasses
import hashlib
import secrets
import urllib.parse

import jwt


@dataclasses.dataclass
class IdentityProvider:
issuer: str
authorization_endpoint: str
token_endpoint: str
jwks_uri: str

@classmethod
async def from_url(cls, issuer_url, http_client):
config_url = urllib.parse.urljoin(
f"{issuer_url}/", ".well-known/openid-configuration"
)
config = await http_client.get_json(config_url)

init_args = {
field.name: config[field.name] for field in dataclasses.fields(cls)
}
return cls(**init_args)


@dataclasses.dataclass
class AuthRequest:
client_id: str
redirect_uri: str
state: str
nonce: str
code_verifier: str
identity_provider: IdentityProvider

@classmethod
async def create(cls, issuer_url, client_id, redirect_uri, http_client):
identity_provider = await IdentityProvider.from_url(issuer_url, http_client)
state = secrets.token_urlsafe()
nonce = secrets.token_urlsafe()
code_verifier = secrets.token_urlsafe()

return cls(
client_id, redirect_uri, state, nonce, code_verifier, identity_provider
)

@property
def login_url(self):
hashed = hashlib.sha256(self.code_verifier.encode()).digest()
code_challange = base64.urlsafe_b64encode(hashed).rstrip(b"=").decode("ascii")

auth_params = {
"scope": "openid profile offline_access",
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"state": self.state,
"nonce": self.nonce,
"prompt": "consent",
"code_challenge": code_challange,
"code_challenge_method": "S256",
}

url_parts = list(
urllib.parse.urlparse(self.identity_provider.authorization_endpoint)
)
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update(auth_params)

url_parts[4] = urllib.parse.urlencode(query)

return urllib.parse.urlunparse(url_parts)

async def request_tokens(self, auth_response, http_client):
if "error" in auth_response:
if "error_description" in auth_response:
msg = f'{auth_response["error_description"]} ({auth_response["error"]})'
else:
msg = auth_response["error"]

raise RuntimeError(f"Authentication error: {msg}")

url = self.identity_provider.token_endpoint
params = {
"grant_type": "authorization_code",
"code": auth_response["code"],
"redirect_uri": self.redirect_uri,
"client_id": self.client_id,
"code_verifier": self.code_verifier,
}
response = await http_client.post_form(url, params)

if response["token_type"].lower() != "bearer":
raise RuntimeError(
f'Expected token type "Bearer", got "{response["token_type"]}"'
)

claims = await verify(response["id_token"], self.client_id, http_client)
if claims["nonce"] != self.nonce:
raise RuntimeError(
"Nonce in the ID token doesn't match the one in the authorization request"
)

return response["id_token"], response["refresh_token"], claims


async def ensure_fresh(id_token, refresh_token, http_client):
if _needs_refresh(id_token):
claims = jwt.decode(id_token, options={"verify_signature": False})
issuer_url = claims["iss"]
client_id = claims["aud"]
identity_provider = await IdentityProvider.from_url(issuer_url, http_client)

params = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
}
response = await http_client.post_form(identity_provider.token_endpoint, params)
return response["id_token"], response["refresh_token"]
else:
return id_token, refresh_token


def _needs_refresh(id_token):
try:
jwt.decode(
id_token,
options={
"verify_signature": False,
"require": ["iss", "sub", "aud", "exp", "iat"],
"verify_exp": True,
"verify_iat": True,
"verify_nbf": True,
},
)
except Exception:
return True
return False


async def verify(token, client_id, http_client):
unverified_token = jwt.api_jwt.decode_complete(
token, options={"verify_signature": False}
)
kid = unverified_token["header"]["kid"]
issuer = unverified_token["payload"]["iss"]

identity_provider = await IdentityProvider.from_url(issuer, http_client)
jwk_set_raw = await http_client.get_json(identity_provider.jwks_uri)
jwk_set = jwt.PyJWKSet.from_dict(jwk_set_raw)

for key in jwk_set.keys:
if key.public_key_use in ["sig", None] and key.key_id == kid:
signing_key = key
break
else:
raise RuntimeError(f'Unable to find a signing key that matches "{kid}"')

return jwt.decode(
token,
key=signing_key.key,
algorithms="RS256",
audience=client_id,
)
Loading

0 comments on commit 614850c

Please sign in to comment.