Skip to content

Commit

Permalink
[WIP]: OIDC test
Browse files Browse the repository at this point in the history
  • Loading branch information
holesch committed Feb 9, 2024
1 parent b556e95 commit 3012817
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 6 deletions.
183 changes: 183 additions & 0 deletions not_my_board/_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
#!/usr/bin/env python3

import not_my_board._util as util
import not_my_board._http as http
import secrets
import urllib.parse
import asyncio
import base64
import hashlib


class IdentityProvider:
@classmethod
async def from_url(cls, url):
separator = "" if url.endswith("/") else "/"
config_url = url + separator + ".well-known/openid-configuration"
config = await http.get_json(config_url)

self = cls()
self._authorization_endpoint = config["authorization_endpoint"]
self._token_endpoint = config["token_endpoint"]
return self

@property
def authorization_endpoint(self):
return self._authorization_endpoint

@property
def token_endpoint(self):
return self._token_endpoint


class OidcClient:
class RedirectFlow:
def __init__(self, client_id, identity_provider, redirect_uri):
self._identity_provider = identity_provider
self._redirect_uri = redirect_uri
self._state = secrets.token_urlsafe()
self._nonce = secrets.token_urlsafe()
self._code_verifier = secrets.token_urlsafe()
self._callback_event = asyncio.Event()
self._client_id = client_id

hashed = hashlib.sha256(self._code_verifier.encode()).digest()
code_challange = base64.urlsafe_b64encode(hashed).rstrip(b"=").decode('ascii')

auth_params = {
"scope": "openid profile offline_access",
"response_type": "code",
"client_id": client_id,
"redirect_uri": redirect_uri,
"state": self._state,
"nonce": self._nonce,
"prompt": "consent",
"code_challenge": code_challange,
"code_challenge_method": "S256",
}

url_parts = list(urllib.parse.urlparse(self._identity_provider.authorization_endpoint))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update(auth_params)

url_parts[4] = urllib.parse.urlencode(query)

self._auth_url = urllib.parse.urlunparse(url_parts)

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
pass

@property
def auth_url(self):
return self._auth_url

def callback(self, params):
if params["state"] != self._state:
raise ProtocolError("State in redirect URI doesn't match the state from the authorization request")

if self._callback_event.is_set():
raise ProtocolError("Callback already received")

self._code = params["code"]
self._callback_event.set()

async def request_token(self):
await self._callback_event.wait()
url = self._identity_provider.token_endpoint
params = {
"grant_type": "authorization_code",
"code": self._code,
"redirect_uri": self._redirect_uri,
"client_id": self._client_id,
"code_verifier": self._code_verifier,
}
response = await http.post_form(url, params)
import pprint
pprint.pprint(response)
# {'access_token': 'eyJhbGci',
# 'expires_in': 60,
# 'id_token': 'eyJhbGciOiJSU',
# 'not-before-policy': 0,
# 'refresh_expires_in': 0,
# 'refresh_token': 'eyJhbGci',
# 'scope': 'openid offline_access profile email',
# 'session_state': 'bbd34547-b12e-4f6a-8366-d41b6d6626d2',
# 'token_type': 'Bearer'}

def __init__(self, client_id, identity_provider):
self._client_id = client_id
self._identity_provider = identity_provider

def redirect_flow(self, redirect_uri):
return self.RedirectFlow(self._client_id, self._identity_provider, redirect_uri)

async def refresh_token(self, refresh_token):
url = self._identity_provider.token_endpoint
params = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": self._client_id,
}
response = await http.post_form(url, params)
# import pprint
# pprint.pprint(response)
# {'access_token': 'eyJhbGci',
# 'expires_in': 60,
# 'id_token': 'eyJhbG',
# 'not-before-policy': 0,
# 'refresh_expires_in': 0,
# 'refresh_token': 'eyJhbGc',
# 'scope': 'openid offline_access profile email',
# 'session_state': 'bbd34547-b12e-4f6a-8366-d41b6d6626d2',
# 'token_type': 'Bearer'}



class CallbackHandler:
def set_flow(self, flow):
self._flow = flow

async def callback(self, reader, writer):
req = await reader.read(4096)
url = req.split()[1].decode("utf-8")
params = dict(urllib.parse.parse_qsl(urllib.parse.urlparse(url)[4]))
# print(params)
# {'state': 'QX3FljMU_CaVQMEc8Utu8iO2Z250g7zUXlhXBXQidUY', 'session_state': 'bbd34547-b12e-4f6a-8366-d41b6d6626d2', 'code': 'f7001289-9a30-4f06-8e2c-9c2d0fee9244.bbd34547-b12e-4f6a-8366-d41b6d6626d2.e0502b64-60ff-40c7-85de-c155713bf34e'}
self._flow.callback(params)

writer.write(
b"HTTP/1.1 200 OK\r\n"
b"Content-Type: text/plain\r\n"
b"Connection: Close\r\n"
b"\r\n"
b"Success\n")
await writer.drain()


class ProtocolError(Exception):
pass


async def _main():
url = "http://localhost:8080/realms/master"
identity_provider = await IdentityProvider.from_url(url)
client = OidcClient("not-my-board-cli", identity_provider)
callback_handler = CallbackHandler()

async with util.Server(callback_handler.callback, port=9090) as server:
async with client.redirect_flow("http://localhost:9090/callback") as flow:
callback_handler.set_flow(flow)
print(flow.auth_url)
await flow.request_token()


# await client.refresh_token(refresh_token)




if __name__ == "__main__":
util.run(_main())
38 changes: 32 additions & 6 deletions not_my_board/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,37 @@ class ProtocolError(Exception):


async def get_json(url):
return await _request_json("GET", url)


async def post_form(url, params):
content_type="application/x-www-form-urlencoded"
body = urllib.parse.urlencode(params).encode()
return await _request_json("POST", url, content_type, body)


async def _request_json(method, url, content_type=None, body=None):
url = urllib.parse.urlsplit(url)
headers = [
("Host", url.netloc),
("User-Agent", h11.PRODUCT_ID),
("Accept", "application/json"),
("Connection", "close"),
]
if body is not None:
headers += [
("Content-Type", content_type),
("Content-Length", str(len(body))),
]

conn = h11.Connection(our_role=h11.CLIENT)

to_send = conn.send(
h11.Request(method="GET", target=url.path or "/", headers=headers)
h11.Request(method=method, target=url.path or "/", headers=headers)
)
if body is not None:
to_send += conn.send(h11.Data(body))
to_send += conn.send(h11.EndOfMessage())

if url.scheme == "https":
default_port = 443
Expand All @@ -40,24 +59,31 @@ async def get_json(url):

async with util.connect(url.hostname, port, ssl=ssl) as (reader, writer):
writer.write(to_send)
writer.write(conn.send(h11.EndOfMessage()))
await writer.drain()

async def receive_all():
error_status = None
while True:
event = conn.next_event()
if event is h11.NEED_DATA:
conn.receive_data(await reader.read(4096))
elif isinstance(event, h11.Response):
if event.status_code != 200:
raise ProtocolError(
f"Expected status code 200, got {event.status_code}"
)
error_status = event.status_code
error_data = b""
elif isinstance(event, h11.Data):
yield event.data
if error_status is None:
yield event.data
else:
error_data += event.data
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
break

if error_status is not None:
raise ProtocolError(
f"Expected status code 200, got {error_status}: {error_data}"
)

content = b"".join([data async for data in receive_all()])

return json.loads(content)
Expand Down

0 comments on commit 3012817

Please sign in to comment.