Skip to content

Commit

Permalink
Hub: Merge WebSocket endpoints again
Browse files Browse the repository at this point in the history
With the bidirectional jsonrpc channel it's now possible to use one
endpoint again. This makes it easier to integrate authentication,
because then there's only one connection handler.

Drop the dummy WebSocket authentication. Since ID tokens need to be
refreshed, a one time authentication is not enough. Without
authentication, all endpoints are the same, even for login.
  • Loading branch information
holesch committed Jun 15, 2024
1 parent f21441d commit d7b681c
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 149 deletions.
5 changes: 2 additions & 3 deletions not_my_board/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ def __init__(self, hub_url, http_client):

@contextlib.asynccontextmanager
async def hub_rpc(self):
auth = "Bearer dummy-token-1"
url = f"{self._hub_url}/ws-agent"
async with jsonrpc.WebsocketChannel(url, self._http, auth=auth) as rpc:
url = f"{self._hub_url}/ws"
async with jsonrpc.WebsocketChannel(url, self._http) as rpc:
yield rpc

@contextlib.asynccontextmanager
Expand Down
2 changes: 1 addition & 1 deletion not_my_board/_auth/_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def _context_stack(self, stack):
ready_event = asyncio.Event()
notification_api = _HubNotifications(ready_event)

channel_url = f"{self._hub_url}/ws-login"
channel_url = f"{self._hub_url}/ws"
hub = jsonrpc.WebsocketChannel(
channel_url, self._http, api_obj=notification_api
)
Expand Down
21 changes: 11 additions & 10 deletions not_my_board/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
async def export(hub_url, place, ca_files):
http_client = http.Client(ca_files)
async with Exporter(hub_url, place, http_client) as exporter:
await exporter.register_place()
await exporter.serve_forever()


Expand Down Expand Up @@ -54,22 +55,22 @@ async def _context_stack(self, stack):
util.Server(self._handle_client, port=self._place.port)
)

url = f"{self._hub_url}/ws-exporter"
auth = "Bearer dummy-token-1"
self._ws_server = await stack.enter_async_context(
url = f"{self._hub_url}/ws"
self._hub = await stack.enter_async_context(
jsonrpc.WebsocketChannel(
url, self._http, start=False, auth=auth, api_obj=self
url, self._http, api_obj=self
)
)

@jsonrpc.hidden
async def serve_forever(self):
await util.run_concurrently(
self._http_server.serve_forever(), self._ws_server.communicate_forever()
)
async def register_place(self):
place_id = await self._hub.register_place(self._place.dict())
logger.info("Place registered with ID %d", place_id)
return place_id

async def get_place(self):
return self._place.dict()
@jsonrpc.hidden
async def serve_forever(self):
await self._http_server.serve_forever()

async def set_allowed_ips(self, ips):
new_ips = set(map(ipaddress.ip_address, ips))
Expand Down
9 changes: 3 additions & 6 deletions not_my_board/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async def open_tunnel(
raise ProtocolError(f"Unexpected event: {event}")

@contextlib.asynccontextmanager
async def websocket(self, url, auth=None):
async def websocket(self, url):
url = self._parse_url(url)

ws_scheme = "ws" if url.scheme == "http" else "wss"
Expand All @@ -206,26 +206,23 @@ async def websocket(self, url, auth=None):
protocol = websockets.ClientProtocol(ws_uri)

async with self._connect(url) as (reader, writer):
async with _WebsocketConnection(protocol, reader, writer, auth) as con:
async with _WebsocketConnection(protocol, reader, writer) as con:
yield con


class _WebsocketConnection:
_close_timeout = 10

def __init__(self, protocol, reader, writer, auth=None):
def __init__(self, protocol, reader, writer):
self._protocol = protocol
self._reader = reader
self._writer = writer
self._chunks = []
self._decoder = None
self._send_lock = asyncio.Lock()
self._auth = auth

async def __aenter__(self):
request = self._protocol.connect()
if self._auth:
request.headers["Authorization"] = self._auth

# sending handshake request
self._protocol.send_request(request)
Expand Down
150 changes: 53 additions & 97 deletions not_my_board/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@

logger = logging.getLogger(__name__)
client_ip_var = contextvars.ContextVar("client_ip")
reservation_context_var = contextvars.ContextVar("reservation_context")
valid_tokens = ("dummy-token-1", "dummy-token-2")
connection_id_var = contextvars.ContextVar("connection_id")


def run_hub():
Expand Down Expand Up @@ -73,12 +72,8 @@ async def _handle_request(request):
response = (404, {}, "Page not found")

if isinstance(request, asgineer.WebsocketRequest):
if request.path == "/ws-agent":
await _handle_agent(hub, request)
elif request.path == "/ws-exporter":
await _handle_exporter(hub, request)
elif request.path == "/ws-login":
await _handle_login(hub, request)
if request.path == "/ws":
await _handle_websocket(hub, request)
else:
await request.close()
response = None
Expand All @@ -92,41 +87,11 @@ async def _handle_request(request):
return response


async def _handle_agent(hub, ws):
await _authorize_ws(ws)
client_ip = ws.scope["client"][0]
server = jsonrpc.Channel(ws.send, ws.receive_iter())
await hub.agent_communicate(client_ip, server)


async def _handle_exporter(hub, ws):
await _authorize_ws(ws)
client_ip = ws.scope["client"][0]
exporter = jsonrpc.Channel(ws.send, ws.receive_iter())
await hub.exporter_communicate(client_ip, exporter)


async def _handle_login(hub, ws):
async def _handle_websocket(hub, ws):
await ws.accept()
client_ip = ws.scope["client"][0]
channel = jsonrpc.Channel(ws.send, ws.receive_iter())
await hub.login_communicate(client_ip, channel)


async def _authorize_ws(ws):
try:
auth = ws.headers["authorization"]
scheme, token = auth.split(" ", 1)
if scheme != "Bearer":
raise ProtocolError(f"Invalid Authorization Scheme: {scheme}")
if token not in valid_tokens:
raise ProtocolError("Invalid token")
except Exception:
traceback.print_exc()
await ws.close()
return

await ws.accept()
await hub.communicate(client_ip, channel)


class Hub:
Expand Down Expand Up @@ -167,9 +132,11 @@ def __init__(self, config=None):

self._id_generator = itertools.count(start=1)

@jsonrpc.hidden
async def startup(self):
pass

@jsonrpc.hidden
async def shutdown(self):
pass

Expand All @@ -178,66 +145,55 @@ async def get_places(self):
return {"places": [p.dict() for p in self._places.values()]}

@jsonrpc.hidden
async def agent_communicate(self, client_ip, rpc):
async def communicate(self, client_ip, rpc):
client_ip_var.set(client_ip)
async with self._register_agent():
async with self._connection_context():
rpc.set_api_object(self)
await rpc.communicate_forever()

@jsonrpc.hidden
async def exporter_communicate(self, client_ip, rpc):
client_ip_var.set(client_ip)
async with util.background_task(rpc.communicate_forever()) as com_task:
export_desc = await rpc.get_place()
with self._register_place(export_desc, rpc, client_ip):
await com_task

@jsonrpc.hidden
async def login_communicate(self, client_ip, rpc):
client_ip_var.set(client_ip)
rpc.set_api_object(self)
await rpc.communicate_forever()

@contextlib.contextmanager
def _register_place(self, export_desc, rpc, client_ip):
id_ = next(self._id_generator)
place = models.Place(id=id_, host=_unmap_ip(client_ip), **export_desc)

try:
logger.info("New place registered: %d", id_)
self._places[id_] = place
self._exporters[id_] = rpc
self._available.add(id_)
yield self
finally:
logger.info("Place disappeared: %d", id_)
del self._places[id_]
del self._exporters[id_]
self._available.discard(id_)
for candidates, _, future in self._wait_queue:
candidates.discard(id_)
if not candidates and not future.done():
future.set_exception(Exception("All candidate places are gone"))

@contextlib.asynccontextmanager
async def _register_agent(self):
ctx = object()
reservation_context_var.set(ctx)
async def _connection_context(self):
id_ = next(self._id_generator)
connection_id_var.set(id_)
self._reservations[id_] = set()

try:
self._reservations[ctx] = set()
yield
finally:
coros = [self.return_reservation(id_) for id_ in self._reservations[ctx]]
if id_ in self._places:
logger.info("Place disappeared: %d", id_)
del self._places[id_]
del self._exporters[id_]
self._available.discard(id_)
for candidates, _, future in self._wait_queue:
candidates.discard(id_)
if not candidates and not future.done():
future.set_exception(Exception("All candidate places are gone"))

coros = [self.return_reservation(id_) for id_ in self._reservations[id_]]
results = await asyncio.gather(*coros, return_exceptions=True)
del self._reservations[ctx]
del self._reservations[id_]
for result in results:
if isinstance(result, Exception):
logger.warning("Error while deregistering agent: %s", result)

async def register_place(self, export_desc):
id_ = connection_id_var.get()
client_ip = client_ip_var.get()
place = models.Place(id=id_, host=_unmap_ip(client_ip), **export_desc)

if id_ in self._places:
raise RuntimeError("Place already registered")

self._places[id_] = place
self._exporters[id_] = jsonrpc.get_current_channel()
self._available.add(id_)
logger.info("New place registered: %d", id_)
return id_

async def reserve(self, candidate_ids):
ctx = reservation_context_var.get()
existing_candidates = {id_ for id_ in candidate_ids if id_ in self._places}
id_ = connection_id_var.get()
existing_candidates = {c_id for c_id in candidate_ids if c_id in self._places}
if not existing_candidates:
raise RuntimeError("None of the candidates exist anymore")

Expand All @@ -247,15 +203,15 @@ async def reserve(self, candidate_ids):
reserved_id = random.choice(list(available_candidates))

self._available.remove(reserved_id)
self._reservations[ctx].add(reserved_id)
logger.info("Place reserved: %d", reserved_id)
self._reservations[id_].add(reserved_id)
logger.info("Place %d reserved by %d", reserved_id, id_)
else:
logger.debug(
"No places available, adding request to queue: %s",
str(existing_candidates),
)
future = asyncio.get_running_loop().create_future()
entry = (existing_candidates, ctx, future)
entry = (existing_candidates, id_, future)
self._wait_queue.append(entry)
try:
reserved_id = await future
Expand All @@ -264,28 +220,28 @@ async def reserve(self, candidate_ids):

client_ip = client_ip_var.get()
async with util.on_error(self.return_reservation, reserved_id):
rpc = self._exporters[reserved_id]
await rpc.set_allowed_ips([_unmap_ip(client_ip)])
exporter = self._exporters[reserved_id]
await exporter.set_allowed_ips([_unmap_ip(client_ip)])

return reserved_id

async def return_reservation(self, place_id):
ctx = reservation_context_var.get()
self._reservations[ctx].remove(place_id)
id_ = connection_id_var.get()
self._reservations[id_].remove(place_id)
if place_id in self._places:
for candidates, new_ctx, future in self._wait_queue:
for candidates, agent_id, future in self._wait_queue:
if place_id in candidates and not future.done():
self._reservations[new_ctx].add(place_id)
logger.info("Place returned and reserved again: %d", place_id)
self._reservations[agent_id].add(place_id)
logger.info("Place %d returned by %d was reserved by %d", place_id, id_, agent_id)
future.set_result(place_id)
break
else:
logger.info("Place returned: %d", place_id)
logger.info("Place %d returned by %d", place_id, id_)
self._available.add(place_id)
rpc = self._exporters[place_id]
await rpc.set_allowed_ips([])
else:
logger.info("Place returned, but it doesn't exist: %d", place_id)
logger.info("Place %d returned, but it doesn't exist", place_id)

async def get_authentication_response(self, state):
future = asyncio.get_running_loop().create_future()
Expand Down
5 changes: 2 additions & 3 deletions not_my_board/_jsonrpc/_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,16 @@


class WebsocketChannel(Channel, util.ContextStack):
def __init__(self, url, http_client, start=True, auth=None, api_obj=None):
def __init__(self, url, http_client, start=True, api_obj=None):
self._url = url
self._http = http_client
self._ws = None
self._start = start
self._auth = auth

super().__init__(self._ws_send, self._ws_receive_iter(), api_obj)

async def _context_stack(self, stack):
ws = self._http.websocket(self._url, self._auth)
ws = self._http.websocket(self._url)
self._ws = await stack.enter_async_context(ws)

if self._start:
Expand Down
Loading

0 comments on commit d7b681c

Please sign in to comment.