Skip to content

Commit

Permalink
Hub: Refactor to drop the Place class
Browse files Browse the repository at this point in the history
Move all the logic from the Place classmethods to the Hub singleton.
After that, there's no reason to keep the Place class at all.

Introduce contextvars for the reservation context and client IP address.
Now the _ReservationContext class is unnecessary.
  • Loading branch information
holesch committed Feb 2, 2024
1 parent 6aecc87 commit ca5e0c4
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 106 deletions.
179 changes: 74 additions & 105 deletions not_my_board/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@

import asyncio
import contextlib
import contextvars
import itertools
import logging
import random
import traceback

import asgineer

import not_my_board._jsonrpc as jsonrpc
import not_my_board._models as models
import not_my_board._util as util

logger = logging.getLogger(__name__)
client_ip_var = contextvars.ContextVar("client_ip")
reservation_context_var = contextvars.ContextVar("reservation_context")
valid_tokens = ("dummy-token-1", "dummy-token-2")


Expand Down Expand Up @@ -65,158 +70,122 @@ async def _authorize_ws(ws):


class Hub:
_places = {}
_exporters = {}
_available = set()
_wait_queue = []
_reservations = {}

def __init__(self):
self._id_generator = itertools.count(start=1)

async def get_places(self):
return {"places": [p.desc for p in Place.all()]}
return {"places": [p.dict() for p in self._places.values()]}

async def agent_communicate(self, client_ip, rpc):
async with Place.reservation_context(client_ip) as ctx:
api = AgentAPI(ctx)
rpc.set_api_object(api)
client_ip_var.set(client_ip)
async with self._register_agent():
rpc.set_api_object(self)
await rpc.serve_forever()

async def exporter_communicate(self, client_ip, rpc):
client_ip_var.set(client_ip)
async with util.background_task(rpc.io_loop()) as io_loop:
place = await rpc.get_place()
with Place.register(place, rpc, client_ip):
export_desc = await rpc.get_place()
with self._register_place(export_desc, rpc, client_ip):
await io_loop


_hub = Hub()


class AgentAPI:
def __init__(self, reservation_context):
self._reservation_context = reservation_context

async def reserve(self, candidate_ids):
place = await Place.reserve(candidate_ids, self._reservation_context)
return place.desc["id"]

async def return_reservation(self, place_id):
await Place.return_by_id(place_id, self._reservation_context)


class Place:
_all_places = {}
_next_id = 1
_available = set()
_wait_queue = []
_reservations = {}

@classmethod
def all(cls):
return cls._all_places.values()

@classmethod
def _new_id(cls):
id_ = cls._next_id
cls._next_id += 1
return id_

@classmethod
@contextlib.contextmanager
def register(cls, desc, exporter, client_ip):
self = cls()
self._desc = desc
self._exporter = exporter

self._id = cls._new_id()
self._desc["id"] = self._id
self._desc["host"] = client_ip
def _register_place(self, export_desc, rpc, client_ip):
id_ = next(self._id_generator)
place = models.Place(id=id_, host=client_ip, **export_desc)

try:
logger.info("New place registered: %d", self._id)
cls._all_places[self._id] = self
cls._available.add(self._id)
logger.info("New place registered: %d", id_)
self._places[id_] = place
self._exporters[id_] = rpc
self._available.add(id_)
yield self
finally:
logger.info("Place disappeared: %d", self._id)
del cls._all_places[self._id]
cls._available.discard(self._id)
for candidates, _, future in cls._wait_queue:
candidates.discard(self._id)
logger.info("Place disappeared: %d", id_)
del self._places[id_]
del self._exporters[id_]
self._available.discard(id_)
for candidates, _, future in self._wait_queue:
candidates.discard(id_)
if not candidates and not future.done():
future.set_exception(Exception("All candidate places are gone"))

@property
def desc(self):
return self._desc

@classmethod
@contextlib.asynccontextmanager
async def reservation_context(cls, client_ip):
ctx = _ReservationContext(client_ip)
async def _register_agent(self):
ctx = object()
reservation_context_var.set(ctx)

try:
cls._reservations[ctx] = set()
yield ctx
self._reservations[ctx] = set()
yield
finally:
for place in cls._reservations[ctx].copy():
await cls.return_by_id(place, ctx)
del cls._reservations[ctx]
coros = [self.return_reservation(id_) for id_ in self._reservations[ctx]]
results = await asyncio.gather(*coros, return_exceptions=True)
del self._reservations[ctx]
for result in results:
if isinstance(result, Exception):
logger.warning("Error while deregistering agent: %s", result)

@classmethod
async def reserve(cls, candidate_ids, ctx):
existing_candidates = {id_ for id_ in candidate_ids if id_ in cls._all_places}
async def reserve(self, candidate_ids):
ctx = reservation_context_var.get()
existing_candidates = {id_ for id_ in candidate_ids if id_ in self._places}
if not existing_candidates:
raise RuntimeError("None of the candidates exist anymore")

available_candidates = existing_candidates & cls._available
available_candidates = existing_candidates & self._available
if available_candidates:
# TODO do something smart to get the best candidate
reserved_id = random.choice(list(available_candidates))

cls._available.remove(reserved_id)
cls._reservations[ctx].add(reserved_id)
self._available.remove(reserved_id)
self._reservations[ctx].add(reserved_id)
logger.info("Place reserved: %d", reserved_id)
place = cls._all_places[reserved_id]
else:
logger.debug(
"No places available, adding request to queue: %s",
str(existing_candidates),
)
future = asyncio.get_running_loop().create_future()
entry = (existing_candidates, ctx, future)
cls._wait_queue.append(entry)
self._wait_queue.append(entry)
try:
place = await future
reserved_id = await future
finally:
cls._wait_queue.remove(entry)
self._wait_queue.remove(entry)

# TODO refactor Place class
# pylint: disable=protected-access
try:
await place._exporter.set_allowed_ips([ctx.client_ip])
except Exception:
await cls.return_by_id(place._id, ctx)
raise

return place

@classmethod
async def return_by_id(cls, place_id, ctx):
cls._reservations[ctx].remove(place_id)
if place_id in cls._all_places:
for candidates, new_ctx, future in cls._wait_queue:
client_ip = client_ip_var.get()
async with util.on_error(self.return_reservation, reserved_id):
rpc = self._exporters[reserved_id]
await rpc.set_allowed_ips([client_ip])

return reserved_id

async def return_reservation(self, place_id):
ctx = reservation_context_var.get()
self._reservations[ctx].remove(place_id)
if place_id in self._places:
for candidates, new_ctx, future in self._wait_queue:
if place_id in candidates and not future.done():
cls._reservations[new_ctx].add(place_id)
self._reservations[new_ctx].add(place_id)
logger.info("Place returned and reserved again: %d", place_id)
future.set_result(cls._all_places[place_id])
future.set_result(place_id)
break
else:
logger.info("Place returned: %d", place_id)
cls._available.add(place_id)
# pylint: disable=protected-access
await cls._all_places[place_id]._exporter.set_allowed_ips([])
self._available.add(place_id)
rpc = self._exporters[place_id]
await rpc.set_allowed_ips([])
else:
logger.info("Place returned, but it doesn't exist: %d", place_id)


class _ReservationContext:
def __init__(self, client_ip):
self._client_ip = client_ip

@property
def client_ip(self):
return self._client_ip
_hub = Hub()


class ProtocolError(Exception):
Expand Down
5 changes: 4 additions & 1 deletion not_my_board/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,7 @@ class ExportDesc(pydantic.BaseModel):

class Place(ExportDesc):
id: pydantic.PositiveInt
host: pydantic.IPvAnyAddress
# host: pydantic.IPvAnyAddress
# can't serialize IP address with json.dumps()
# TODO: maybe drop pydantic as a dependency
host: str

0 comments on commit ca5e0c4

Please sign in to comment.