From 3fd01618123036d4f6e63ca2b7e32332f6905f92 Mon Sep 17 00:00:00 2001 From: Bart Feenstra Date: Sat, 6 Apr 2024 19:29:19 +0100 Subject: [PATCH] Add service containers --- betty/service/__init__.py | 111 +++++++++++++++++++ betty/service/_default.py | 165 ++++++++++++++++++++++++++++ betty/tests/service/__init__.py | 1 + betty/tests/service/test_default.py | 128 +++++++++++++++++++++ 4 files changed, 405 insertions(+) create mode 100644 betty/service/__init__.py create mode 100644 betty/service/_default.py create mode 100644 betty/tests/service/__init__.py create mode 100644 betty/tests/service/test_default.py diff --git a/betty/service/__init__.py b/betty/service/__init__.py new file mode 100644 index 000000000..459b6a14a --- /dev/null +++ b/betty/service/__init__.py @@ -0,0 +1,111 @@ +""" +Provide application service management. +""" +from __future__ import annotations + +from collections.abc import Callable, Sequence, Iterable +from types import TracebackType +from typing import TypeAlias, Any, Self, overload, AsyncContextManager, Generic, TypeVar + +ServiceContextT = TypeVar('ServiceContextT') +ServiceId: TypeAlias = str +Service: TypeAlias = Any + + +class CyclicDependency(RuntimeError): + def __init__(self, service_ids: Sequence[ServiceId]): + assert len(service_ids) > 1 + traceback = [] + for index, service_id in enumerate(service_ids): + traceback_line = f'- "{service_id}"' + if index == 0: + traceback_line += ' (requested service)' + if index == len(service_ids) - 1: + traceback_line += ' (cyclic dependency)' + traceback.append(traceback_line) + traceback_str = '\n'.join(traceback) + super().__init__(f''' +Cyclic service dependency detected for "{service_ids[0]}": +{traceback_str} +''') + + +class ServiceNotFound(RuntimeError): + def __init__(self, unknown_service_id: ServiceId, known_service_ids: Iterable[ServiceId]): + message = f'Unknown service "{unknown_service_id}".' + known_service_ids = sorted(known_service_ids) + if known_service_ids: + message += ' Did you mean one of:\n' + message += '\n'.join(( + f'- {known_service_id}' + for known_service_id + in sorted(known_service_ids) + )) + else: + message += ' There are no available services.' + super().__init__(message) + + +class ServiceContainerNotStarted(RuntimeError): + def __init__(self): + super().__init__('This service container has not yet started.') + + +class ServiceContainerStarted(RuntimeError): + def __init__(self): + super().__init__('This service container has already started.') + + +class ServiceContainer(Generic[ServiceContextT]): + """ + Define a service container. + + A service container allows access to whatever services are defined, and manages their resources. + + Implementations must be thread-safe. + """ + + @property + def context(self) -> ServiceContextT: + raise NotImplementedError + + async def get(self, service_id: ServiceId) -> Service: + raise NotImplementedError(type(self)) + + async def start(self) -> None: + raise NotImplementedError(type(self)) + + async def __aenter__(self) -> Self: + await self.start() + return self + + async def stop(self) -> None: + raise NotImplementedError(type(self)) + + async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None: + await self.stop() + + +ServiceFactory: TypeAlias = Callable[[ServiceContainer[ServiceContextT]], AsyncContextManager[Service]] + + +class ServiceContainerBuilder(Generic[ServiceContextT]): + """ + Define a service container builder. + + A service container builder allows you to define the services to build a service container with. + """ + + @overload + def define(self, service_id: ServiceId, *, service_factory: ServiceFactory[ServiceContextT]) -> None: + pass + + @overload + def define(self, service_id: ServiceId, *, service: Service) -> None: + pass + + def define(self, service_id: ServiceId, *, service: Service | None = None, service_factory: ServiceFactory[ServiceContextT] | None = None) -> None: + raise NotImplementedError(type(self)) + + def build(self) -> ServiceContainer[ServiceContextT]: + raise NotImplementedError(type(self)) diff --git a/betty/service/_default.py b/betty/service/_default.py new file mode 100644 index 000000000..b180e7dff --- /dev/null +++ b/betty/service/_default.py @@ -0,0 +1,165 @@ +from collections import defaultdict +from collections.abc import Mapping, MutableMapping, MutableSequence, AsyncIterator +from contextlib import asynccontextmanager +from typing import Any, AsyncContextManager, Generic +from typing_extensions import override + +from betty.concurrent import AsynchronizedLock, _Lock +from betty.service import ServiceContainer, ServiceId, ServiceFactory, Service, ServiceNotFound, \ + ServiceContainerBuilder, CyclicDependency, ServiceContainerNotStarted, ServiceContainerStarted, ServiceContextT + + +class _ServiceContainerBase(ServiceContainer[ServiceContextT], Generic[ServiceContextT]): + def __init__( + self, + service_factories: Mapping[ServiceId, ServiceFactory[ServiceContextT]], + entered_service_context_managers: MutableSequence[AsyncContextManager[Service]], + services: MutableMapping[ServiceId, Any], + locks: MutableMapping[ServiceId, _Lock], + locks_lock: _Lock, + *, + service_context: ServiceContextT | None = None, + ): + self._service_factories = service_factories + self._entered_service_context_managers = entered_service_context_managers + self._services = services + self._locks = locks + self._locks_lock = locks_lock + self._started = False + self._context = service_context + + async def _lock(self, service_id: ServiceId) -> _Lock: + async with self._locks_lock: + return self._locks[service_id] + + def _assert_started(self) -> None: + if not self._started: + raise ServiceContainerNotStarted() + + def _assert_not_started(self) -> None: + if self._started: + raise ServiceContainerStarted() + + @override + @property + def context(self) -> ServiceContextT: + return self._context # type: ignore[return-value] + + async def start(self) -> None: + self._assert_not_started() + assert not self._started + self._started = True + + async def stop(self) -> None: + self._assert_started() + + @override + async def get(self, service_id: ServiceId) -> Service: + self._assert_started() + async with await self._lock(service_id): + try: + return self._services[service_id] + except KeyError: + self._services[service_id] = await self._initialize(service_id) + return self._services[service_id] + + async def _initialize(self, service_id: ServiceId) -> Service: + raise NotImplementedError(type(self)) + + +class DefaultServiceContainer(_ServiceContainerBase[ServiceContextT], Generic[ServiceContextT]): + def __init__( + self, + service_factories: Mapping[ServiceId, ServiceFactory[ServiceContextT]], + *, + service_context: ServiceContextT | None = None, + ): + super().__init__( + service_factories, + [], + {}, + defaultdict(AsynchronizedLock.threading), + AsynchronizedLock.threading(), + service_context=service_context, + ) + + @override + async def _initialize(self, service_id: ServiceId) -> Service: + async with _ServiceInitializingServiceContainer( + self._service_factories, + self._entered_service_context_managers, + self._services, + self._locks, + self._locks_lock, + service_context=self._context, + ) as services: + return await services.initialize(service_id) + + async def stop(self) -> None: + await super().stop() + # @todo We should probably sort these topologically based on dependencies before exiting them + for entered_service_context_manager in self._entered_service_context_managers: + await entered_service_context_manager.__aexit__(None, None, None) + + +class _ServiceInitializingServiceContainer(_ServiceContainerBase[ServiceContextT], Generic[ServiceContextT]): + def __init__( + self, + service_factories: Mapping[ServiceId, ServiceFactory[ServiceContextT]], + entered_service_context_managers: MutableSequence[AsyncContextManager[Service]], + services: MutableMapping[ServiceId, Any], + locks: MutableMapping[ServiceId, _Lock], + locks_lock: _Lock, + *, + service_context: ServiceContextT | None = None, + ): + super().__init__( + service_factories, + entered_service_context_managers, + services, + locks, + locks_lock, + service_context=service_context, + ) + self._seen: MutableSequence[ServiceId] = [] + + @override + async def get(self, service_id: ServiceId) -> Service: + if service_id in self._seen: + raise CyclicDependency((*self._seen, service_id)) + self._seen.append(service_id) + return await super().get(service_id) + + @override + async def _initialize(self, service_id: ServiceId) -> Service: + try: + service_factory = self._service_factories[service_id] + except KeyError: + raise ServiceNotFound(service_id, self._service_factories.keys()) + service_context = service_factory(self) + service = await service_context.__aenter__() + self._entered_service_context_managers.append(service_context) + return service + + async def initialize(self, service_id: ServiceId) -> Service: + self._seen.append(service_id) + return await self._initialize(service_id) + + +class DefaultServiceContainerBuilder(ServiceContainerBuilder[ServiceContextT], Generic[ServiceContextT]): + def __init__(self, *, service_context: ServiceContextT | None = None): + self._service_factories: MutableMapping[ServiceId, ServiceFactory[ServiceContextT]] = {} + self._context = service_context + + @override + def define(self, service_id: ServiceId, *, service: Service | None = None, service_factory: ServiceFactory[ServiceContextT] | None = None) -> None: + if service_factory is None: + @asynccontextmanager + async def service_factory(_: ServiceContainer[ServiceContextT]) -> AsyncIterator[Service]: + yield service + assert service_factory is not None + self._service_factories[service_id] = service_factory + + @override + def build(self) -> ServiceContainer[ServiceContextT]: + return DefaultServiceContainer(self._service_factories, service_context=self._context) diff --git a/betty/tests/service/__init__.py b/betty/tests/service/__init__.py new file mode 100644 index 000000000..c8d34cebf --- /dev/null +++ b/betty/tests/service/__init__.py @@ -0,0 +1 @@ +"""Test the :py:mod:`betty.service` module.""" diff --git a/betty/tests/service/test_default.py b/betty/tests/service/test_default.py new file mode 100644 index 000000000..8f9751050 --- /dev/null +++ b/betty/tests/service/test_default.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +import pytest + +from betty.service import ServiceNotFound, ServiceContainer, CyclicDependency, ServiceContainerStarted, \ + ServiceContainerNotStarted +from betty.service._default import DefaultServiceContainerBuilder + + +class DummyService: + pass + + +class DummyDependentService: + def __init__(self, dependency: DummyService): + pass + + +class DummyLeftHandCyclicDependencyService: + def __init__(self, dependency: DummyRightHandCyclicDependencyService): + pass + + +class DummyRightHandCyclicDependencyService: + def __init__(self, dependency: DummyLeftHandCyclicDependencyService): + pass + + +class TestDefaultServiceContainerBuilderAndDefaultServiceContainer: + async def test_without_services(self) -> None: + builder = DefaultServiceContainerBuilder[None]() + builder.build() + + async def test_starting_a_started_container_should_error(self) -> None: + builder = DefaultServiceContainerBuilder[None]() + services = builder.build() + async with services: + with pytest.raises(ServiceContainerStarted): + await services.start() + + async def test_stopping_a_not_started_container_should_error(self) -> None: + builder = DefaultServiceContainerBuilder[None]() + services = builder.build() + with pytest.raises(ServiceContainerNotStarted): + await services.stop() + + async def test_with_unknown_service(self) -> None: + builder = DefaultServiceContainerBuilder[None]() + async with builder.build() as services: + with pytest.raises(ServiceNotFound): + await services.get('UnknownServiceId') + + async def test_with_as_is_service(self) -> None: + service_id = 'MyFirstService' + service = DummyService() + builder = DefaultServiceContainerBuilder[None]() + builder.define(service_id, service=service) + async with builder.build() as services: + assert await services.get(service_id) is service + + async def test_with_service_factory(self) -> None: + service_id = 'MyFirstService' + service = DummyService() + builder = DefaultServiceContainerBuilder[None]() + setup = False + teardown = False + + @asynccontextmanager + async def _service_factory(_: ServiceContainer[None]) -> AsyncIterator[DummyService]: + nonlocal setup + nonlocal teardown + setup = True + yield service + teardown = True + builder.define(service_id, service_factory=_service_factory) + async with builder.build() as services: + assert await services.get(service_id) is service + assert setup + assert not teardown + assert teardown + + async def test_with_context(self) -> None: + context = object() + service_id = 'MyFirstService' + builder = DefaultServiceContainerBuilder[object](service_context=context) + + @asynccontextmanager + async def _service_factory(services: ServiceContainer[object]) -> AsyncIterator[None]: + assert services.context is context + yield + builder.define(service_id, service_factory=_service_factory) + async with builder.build() as services: + await services.get(service_id) + + async def test_with_dependency(self) -> None: + dependency_service_id = 'MyFirstDependency' + dependent_service_id = 'MyFirstDependent' + dependency = DummyService() + + @asynccontextmanager + async def _new_dummy_dependent_service(services: ServiceContainer[None]) -> AsyncIterator[DummyDependentService]: + yield DummyDependentService(await services.get(dependency_service_id)) + builder = DefaultServiceContainerBuilder[None]() + builder.define(dependency_service_id, service=dependency) + builder.define(dependent_service_id, service_factory=_new_dummy_dependent_service) + async with builder.build() as services: + assert isinstance(await services.get(dependent_service_id), DummyDependentService) + + async def test_with_cyclic_dependency(self) -> None: + left_hand_dependency_service_id = 'MyFirstLeftHandDependency' + right_hand_dependency_service_id = 'MyFirstRightHandDependency' + + @asynccontextmanager + async def _new_dummy_left_hand_dependency_service(services: ServiceContainer[None]) -> AsyncIterator[DummyLeftHandCyclicDependencyService]: + yield DummyLeftHandCyclicDependencyService(await services.get(right_hand_dependency_service_id)) + + @asynccontextmanager + async def _new_dummy_right_hand_dependency_service(services: ServiceContainer[None]) -> AsyncIterator[DummyRightHandCyclicDependencyService]: + yield DummyRightHandCyclicDependencyService(await services.get(left_hand_dependency_service_id)) + builder = DefaultServiceContainerBuilder[None]() + builder.define(left_hand_dependency_service_id, service_factory=_new_dummy_left_hand_dependency_service) + builder.define(right_hand_dependency_service_id, service_factory=_new_dummy_right_hand_dependency_service) + async with builder.build() as services: + with pytest.raises(CyclicDependency): + await services.get(left_hand_dependency_service_id)