-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
edb1ae4
commit 3fd0161
Showing
4 changed files
with
405 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
""" | ||
Provide application service management. | ||
""" | ||
from __future__ import annotations | ||
|
||
from collections.abc import Callable, Sequence, Iterable | ||
from types import TracebackType | ||
from typing import TypeAlias, Any, Self, overload, AsyncContextManager, Generic, TypeVar | ||
|
||
ServiceContextT = TypeVar('ServiceContextT') | ||
ServiceId: TypeAlias = str | ||
Service: TypeAlias = Any | ||
|
||
|
||
class CyclicDependency(RuntimeError): | ||
def __init__(self, service_ids: Sequence[ServiceId]): | ||
assert len(service_ids) > 1 | ||
traceback = [] | ||
for index, service_id in enumerate(service_ids): | ||
traceback_line = f'- "{service_id}"' | ||
if index == 0: | ||
traceback_line += ' (requested service)' | ||
if index == len(service_ids) - 1: | ||
traceback_line += ' (cyclic dependency)' | ||
traceback.append(traceback_line) | ||
traceback_str = '\n'.join(traceback) | ||
super().__init__(f''' | ||
Cyclic service dependency detected for "{service_ids[0]}": | ||
{traceback_str} | ||
''') | ||
|
||
|
||
class ServiceNotFound(RuntimeError): | ||
def __init__(self, unknown_service_id: ServiceId, known_service_ids: Iterable[ServiceId]): | ||
message = f'Unknown service "{unknown_service_id}".' | ||
known_service_ids = sorted(known_service_ids) | ||
if known_service_ids: | ||
message += ' Did you mean one of:\n' | ||
message += '\n'.join(( | ||
f'- {known_service_id}' | ||
for known_service_id | ||
in sorted(known_service_ids) | ||
)) | ||
else: | ||
message += ' There are no available services.' | ||
super().__init__(message) | ||
|
||
|
||
class ServiceContainerNotStarted(RuntimeError): | ||
def __init__(self): | ||
super().__init__('This service container has not yet started.') | ||
|
||
|
||
class ServiceContainerStarted(RuntimeError): | ||
def __init__(self): | ||
super().__init__('This service container has already started.') | ||
|
||
|
||
class ServiceContainer(Generic[ServiceContextT]): | ||
""" | ||
Define a service container. | ||
A service container allows access to whatever services are defined, and manages their resources. | ||
Implementations must be thread-safe. | ||
""" | ||
|
||
@property | ||
def context(self) -> ServiceContextT: | ||
raise NotImplementedError | ||
|
||
async def get(self, service_id: ServiceId) -> Service: | ||
raise NotImplementedError(type(self)) | ||
|
||
async def start(self) -> None: | ||
raise NotImplementedError(type(self)) | ||
|
||
async def __aenter__(self) -> Self: | ||
await self.start() | ||
return self | ||
|
||
async def stop(self) -> None: | ||
raise NotImplementedError(type(self)) | ||
|
||
async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None: | ||
await self.stop() | ||
|
||
|
||
ServiceFactory: TypeAlias = Callable[[ServiceContainer[ServiceContextT]], AsyncContextManager[Service]] | ||
|
||
|
||
class ServiceContainerBuilder(Generic[ServiceContextT]): | ||
""" | ||
Define a service container builder. | ||
A service container builder allows you to define the services to build a service container with. | ||
""" | ||
|
||
@overload | ||
def define(self, service_id: ServiceId, *, service_factory: ServiceFactory[ServiceContextT]) -> None: | ||
pass | ||
|
||
@overload | ||
def define(self, service_id: ServiceId, *, service: Service) -> None: | ||
pass | ||
|
||
def define(self, service_id: ServiceId, *, service: Service | None = None, service_factory: ServiceFactory[ServiceContextT] | None = None) -> None: | ||
raise NotImplementedError(type(self)) | ||
|
||
def build(self) -> ServiceContainer[ServiceContextT]: | ||
raise NotImplementedError(type(self)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
from collections import defaultdict | ||
from collections.abc import Mapping, MutableMapping, MutableSequence, AsyncIterator | ||
from contextlib import asynccontextmanager | ||
from typing import Any, AsyncContextManager, Generic | ||
from typing_extensions import override | ||
|
||
from betty.concurrent import AsynchronizedLock, _Lock | ||
from betty.service import ServiceContainer, ServiceId, ServiceFactory, Service, ServiceNotFound, \ | ||
ServiceContainerBuilder, CyclicDependency, ServiceContainerNotStarted, ServiceContainerStarted, ServiceContextT | ||
|
||
|
||
class _ServiceContainerBase(ServiceContainer[ServiceContextT], Generic[ServiceContextT]): | ||
def __init__( | ||
self, | ||
service_factories: Mapping[ServiceId, ServiceFactory[ServiceContextT]], | ||
entered_service_context_managers: MutableSequence[AsyncContextManager[Service]], | ||
services: MutableMapping[ServiceId, Any], | ||
locks: MutableMapping[ServiceId, _Lock], | ||
locks_lock: _Lock, | ||
*, | ||
service_context: ServiceContextT | None = None, | ||
): | ||
self._service_factories = service_factories | ||
self._entered_service_context_managers = entered_service_context_managers | ||
self._services = services | ||
self._locks = locks | ||
self._locks_lock = locks_lock | ||
self._started = False | ||
self._context = service_context | ||
|
||
async def _lock(self, service_id: ServiceId) -> _Lock: | ||
async with self._locks_lock: | ||
return self._locks[service_id] | ||
|
||
def _assert_started(self) -> None: | ||
if not self._started: | ||
raise ServiceContainerNotStarted() | ||
|
||
def _assert_not_started(self) -> None: | ||
if self._started: | ||
raise ServiceContainerStarted() | ||
|
||
@override | ||
@property | ||
def context(self) -> ServiceContextT: | ||
return self._context # type: ignore[return-value] | ||
|
||
async def start(self) -> None: | ||
self._assert_not_started() | ||
assert not self._started | ||
self._started = True | ||
|
||
async def stop(self) -> None: | ||
self._assert_started() | ||
|
||
@override | ||
async def get(self, service_id: ServiceId) -> Service: | ||
self._assert_started() | ||
async with await self._lock(service_id): | ||
try: | ||
return self._services[service_id] | ||
except KeyError: | ||
self._services[service_id] = await self._initialize(service_id) | ||
return self._services[service_id] | ||
|
||
async def _initialize(self, service_id: ServiceId) -> Service: | ||
raise NotImplementedError(type(self)) | ||
|
||
|
||
class DefaultServiceContainer(_ServiceContainerBase[ServiceContextT], Generic[ServiceContextT]): | ||
def __init__( | ||
self, | ||
service_factories: Mapping[ServiceId, ServiceFactory[ServiceContextT]], | ||
*, | ||
service_context: ServiceContextT | None = None, | ||
): | ||
super().__init__( | ||
service_factories, | ||
[], | ||
{}, | ||
defaultdict(AsynchronizedLock.threading), | ||
AsynchronizedLock.threading(), | ||
service_context=service_context, | ||
) | ||
|
||
@override | ||
async def _initialize(self, service_id: ServiceId) -> Service: | ||
async with _ServiceInitializingServiceContainer( | ||
self._service_factories, | ||
self._entered_service_context_managers, | ||
self._services, | ||
self._locks, | ||
self._locks_lock, | ||
service_context=self._context, | ||
) as services: | ||
return await services.initialize(service_id) | ||
|
||
async def stop(self) -> None: | ||
await super().stop() | ||
# @todo We should probably sort these topologically based on dependencies before exiting them | ||
for entered_service_context_manager in self._entered_service_context_managers: | ||
await entered_service_context_manager.__aexit__(None, None, None) | ||
|
||
|
||
class _ServiceInitializingServiceContainer(_ServiceContainerBase[ServiceContextT], Generic[ServiceContextT]): | ||
def __init__( | ||
self, | ||
service_factories: Mapping[ServiceId, ServiceFactory[ServiceContextT]], | ||
entered_service_context_managers: MutableSequence[AsyncContextManager[Service]], | ||
services: MutableMapping[ServiceId, Any], | ||
locks: MutableMapping[ServiceId, _Lock], | ||
locks_lock: _Lock, | ||
*, | ||
service_context: ServiceContextT | None = None, | ||
): | ||
super().__init__( | ||
service_factories, | ||
entered_service_context_managers, | ||
services, | ||
locks, | ||
locks_lock, | ||
service_context=service_context, | ||
) | ||
self._seen: MutableSequence[ServiceId] = [] | ||
|
||
@override | ||
async def get(self, service_id: ServiceId) -> Service: | ||
if service_id in self._seen: | ||
raise CyclicDependency((*self._seen, service_id)) | ||
self._seen.append(service_id) | ||
return await super().get(service_id) | ||
|
||
@override | ||
async def _initialize(self, service_id: ServiceId) -> Service: | ||
try: | ||
service_factory = self._service_factories[service_id] | ||
except KeyError: | ||
raise ServiceNotFound(service_id, self._service_factories.keys()) | ||
service_context = service_factory(self) | ||
service = await service_context.__aenter__() | ||
self._entered_service_context_managers.append(service_context) | ||
return service | ||
|
||
async def initialize(self, service_id: ServiceId) -> Service: | ||
self._seen.append(service_id) | ||
return await self._initialize(service_id) | ||
|
||
|
||
class DefaultServiceContainerBuilder(ServiceContainerBuilder[ServiceContextT], Generic[ServiceContextT]): | ||
def __init__(self, *, service_context: ServiceContextT | None = None): | ||
self._service_factories: MutableMapping[ServiceId, ServiceFactory[ServiceContextT]] = {} | ||
self._context = service_context | ||
|
||
@override | ||
def define(self, service_id: ServiceId, *, service: Service | None = None, service_factory: ServiceFactory[ServiceContextT] | None = None) -> None: | ||
if service_factory is None: | ||
@asynccontextmanager | ||
async def service_factory(_: ServiceContainer[ServiceContextT]) -> AsyncIterator[Service]: | ||
yield service | ||
assert service_factory is not None | ||
self._service_factories[service_id] = service_factory | ||
|
||
@override | ||
def build(self) -> ServiceContainer[ServiceContextT]: | ||
return DefaultServiceContainer(self._service_factories, service_context=self._context) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Test the :py:mod:`betty.service` module.""" |
Oops, something went wrong.