Skip to content

Commit

Permalink
meh
Browse files Browse the repository at this point in the history
  • Loading branch information
bartfeenstra committed Apr 9, 2024
1 parent c6c02ff commit 13332e8
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 40 deletions.
21 changes: 14 additions & 7 deletions betty/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from collections.abc import Callable, Sequence
from types import TracebackType
from typing import TypeAlias, Any, Self, overload, AsyncContextManager
from typing import TypeAlias, Any, Self, overload, AsyncContextManager, Generic, TypeVar

ServiceContextT = TypeVar('ServiceContextT')
ServiceId: TypeAlias = str
Service: TypeAlias = Any
ServiceFactory: TypeAlias = Callable[['ServiceContainer'], AsyncContextManager[Service]]


class CyclicDependency(RuntimeError):
Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(self):
super().__init__('This service container has already started.')


class ServiceContainer:
class ServiceContainer(Generic[ServiceContextT]):
"""
Define a service container.
Expand All @@ -54,6 +54,10 @@ class ServiceContainer:
Implementations must be thread-safe.
"""

@property
def context(self) -> ServiceContextT:
raise NotImplementedError

async def get(self, service_id: ServiceId) -> Service:
raise NotImplementedError(type(self))

Expand All @@ -71,23 +75,26 @@ async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseExc
await self.stop()


class ServiceContainerBuilder:
ServiceFactory: TypeAlias = Callable[[ServiceContainer[ServiceContextT]], AsyncContextManager[Service]]


class ServiceContainerBuilder(Generic[ServiceContextT]):
"""
Define a service container builder.
A service container builder allows you to define the services to build a service container with.
"""

@overload
def define(self, service_id: ServiceId, *, service_factory: ServiceFactory) -> None:
def define(self, service_id: ServiceId, *, service_factory: ServiceFactory[ServiceContextT]) -> None:
pass

@overload
def define(self, service_id: ServiceId, *, service: Service) -> None:
pass

def define(self, service_id: ServiceId, *, service: Service | None = None, service_factory: ServiceFactory | None = None) -> None:
def define(self, service_id: ServiceId, *, service: Service | None = None, service_factory: ServiceFactory[ServiceContextT] | None = None) -> None:
raise NotImplementedError(type(self))

def build(self) -> ServiceContainer:
def build(self) -> ServiceContainer[ServiceContextT]:
raise NotImplementedError(type(self))
63 changes: 40 additions & 23 deletions betty/service/_default.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
from collections import defaultdict
from collections.abc import Mapping, MutableMapping, MutableSequence, AsyncIterator
from contextlib import asynccontextmanager
from typing import Any, AsyncContextManager
from typing import Any, AsyncContextManager, Generic
from typing_extensions import override

from betty.concurrent import AsynchronizedLock, _Lock
from betty.service import ServiceContainer, ServiceId, ServiceFactory, Service, ServiceNotFound, \
ServiceContainerBuilder, CyclicDependency, ServiceContainerNotStarted, ServiceContainerStarted
ServiceContainerBuilder, CyclicDependency, ServiceContainerNotStarted, ServiceContainerStarted, ServiceContextT


class _ServiceContainerBase(ServiceContainer):
class _ServiceContainerBase(ServiceContainer[ServiceContextT], Generic[ServiceContextT]):
def __init__(
self,
service_factories: Mapping[ServiceId, ServiceFactory],
entered_service_contexts: MutableSequence[AsyncContextManager[Service]],
service_factories: Mapping[ServiceId, ServiceFactory[ServiceContextT]],
entered_service_context_managers: MutableSequence[AsyncContextManager[Service]],
services: MutableMapping[ServiceId, Any],
locks: MutableMapping[ServiceId, _Lock],
locks_lock: _Lock,
*,
service_context: ServiceContextT | None = None,
):
self._service_factories = service_factories
self._entered_service_contexts = entered_service_contexts
self._entered_service_context_managers = entered_service_context_managers
self._services = services
self._locks = locks
self._locks_lock = locks_lock
self._started = False
self._context = service_context

async def _lock(self, service_id: ServiceId) -> _Lock:
async with self._locks_lock:
Expand All @@ -37,6 +40,10 @@ def _assert_not_started(self) -> None:
if self._started:
raise ServiceContainerStarted()

@property
def context(self) -> ServiceContextT:
return self._context

async def start(self) -> None:
self._assert_not_started()
assert not self._started
Expand All @@ -59,46 +66,55 @@ async def _initialize(self, service_id: ServiceId) -> Service:
raise NotImplementedError(type(self))


class DefaultServiceContainer(_ServiceContainerBase):
def __init__(self, service_factories: Mapping[ServiceId, ServiceFactory]):
class DefaultServiceContainer(_ServiceContainerBase[ServiceContextT], Generic[ServiceContextT]):
def __init__(
self,
service_factories: Mapping[ServiceId, ServiceFactory[ServiceContextT]],
*,
service_context: ServiceContextT | None = None,
):
super().__init__(
service_factories,
[],
{},
defaultdict(AsynchronizedLock.threading),
AsynchronizedLock.threading(),
service_context=service_context,
)

@override
async def _initialize(self, service_id: ServiceId) -> Service:
async with _ServiceInitializingServiceContainer(
self._service_factories,
self._entered_service_contexts,
self._entered_service_context_managers,
self._services,
self._locks,
self._locks_lock,
service_context=self._context,
) as services:
return await services.initialize(service_id)

async def stop(self) -> None:
await super().stop()
# @todo We should probably sort these topologically based on dependencies before exiting them
for entered_service_context in self._entered_service_contexts:
await entered_service_context.__aexit__(None, None, None)
for entered_service_context_manager in self._entered_service_context_managers:
await entered_service_context_manager.__aexit__(None, None, None)


class _ServiceInitializingServiceContainer(_ServiceContainerBase):
class _ServiceInitializingServiceContainer(_ServiceContainerBase[ServiceContextT], Generic[ServiceContextT]):
def __init__(
self,
service_factories: Mapping[ServiceId, ServiceFactory],
entered_service_contexts: MutableSequence[AsyncContextManager[Service]],
service_factories: Mapping[ServiceId, ServiceFactory[ServiceContextT]],
entered_service_context_managers: MutableSequence[AsyncContextManager[Service]],
services: MutableMapping[ServiceId, Any],
locks: MutableMapping[ServiceId, _Lock],
locks_lock: _Lock,
*,
service_context: ServiceContextT | None = None,
):
super().__init__(
service_factories,
entered_service_contexts,
entered_service_context_managers,
services,
locks,
locks_lock,
Expand All @@ -120,27 +136,28 @@ async def _initialize(self, service_id: ServiceId) -> Service:
raise ServiceNotFound(f'Service "{service_id}" is unknown.')
service_context = service_factory(self)
service = await service_context.__aenter__()
self._entered_service_contexts.append(service_context)
self._entered_service_context_managers.append(service_context)
return service

async def initialize(self, service_id: ServiceId) -> Service:
self._seen.append(service_id)
return await self._initialize(service_id)


class DefaultServiceContainerBuilder(ServiceContainerBuilder):
def __init__(self):
self._service_factories: MutableMapping[ServiceId, ServiceFactory] = {}
class DefaultServiceContainerBuilder(ServiceContainerBuilder[ServiceContextT], Generic[ServiceContextT]):
def __init__(self, *, service_context: ServiceContextT | None = None):
self._service_factories: MutableMapping[ServiceId, ServiceFactory[ServiceContextT]] = {}
self._context = service_context

@override
def define(self, service_id: ServiceId, *, service: Service | None = None, service_factory: ServiceFactory | None = None) -> None:
def define(self, service_id: ServiceId, *, service: Service | None = None, service_factory: ServiceFactory[ServiceContextT] | None = None) -> None:
if service_factory is None:
@asynccontextmanager
async def service_factory(_: ServiceContainer) -> AsyncIterator[Service]:
async def service_factory(_: ServiceContainer[ServiceContextT]) -> AsyncIterator[Service]:
yield service
assert service_factory is not None
self._service_factories[service_id] = service_factory

@override
def build(self) -> ServiceContainer:
return DefaultServiceContainer(self._service_factories)
def build(self) -> ServiceContainer[ServiceContextT]:
return DefaultServiceContainer(self._service_factories, service_context=self._context)
28 changes: 18 additions & 10 deletions betty/tests/service/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,32 @@ def __init__(self, dependency: DummyLeftHandCyclicDependencyService):

class TestDefaultServiceContainerBuilderAndDefaultServiceContainer:
async def test_without_services(self) -> None:
builder = DefaultServiceContainerBuilder()
builder = DefaultServiceContainerBuilder[None]()
builder.build()

async def test_with_unknown_service(self) -> None:
builder = DefaultServiceContainerBuilder()
builder = DefaultServiceContainerBuilder[None]()
async with builder.build() as services:
with pytest.raises(ServiceNotFound):
await services.get('UnknownServiceId')

async def test_with_as_is_service(self) -> None:
service_id = 'MyFirstService'
service = DummyService()
builder = DefaultServiceContainerBuilder()
builder = DefaultServiceContainerBuilder[None]()
builder.define(service_id, service=service)
async with builder.build() as services:
assert await services.get(service_id) is service

async def test_with_service_factory(self) -> None:
service_id = 'MyFirstService'
service = DummyService()
builder = DefaultServiceContainerBuilder()
builder = DefaultServiceContainerBuilder[None]()
setup = False
teardown = False

@asynccontextmanager
async def _service_factory(_: ServiceContainer) -> AsyncIterator[DummyService]:
async def _service_factory(_: ServiceContainer[None]) -> AsyncIterator[DummyService]:
nonlocal setup
nonlocal teardown
setup = True
Expand All @@ -74,9 +74,9 @@ async def test_with_dependency(self) -> None:
dependency = DummyService()

@asynccontextmanager
async def _new_dummy_dependent_service(services: ServiceContainer) -> AsyncIterator[DummyDependentService]:
async def _new_dummy_dependent_service(services: ServiceContainer[None]) -> AsyncIterator[DummyDependentService]:
yield DummyDependentService(await services.get(dependency_service_id))
builder = DefaultServiceContainerBuilder()
builder = DefaultServiceContainerBuilder[None]()
builder.define(dependency_service_id, service=dependency)
builder.define(dependent_service_id, service_factory=_new_dummy_dependent_service)
async with builder.build() as services:
Expand All @@ -87,15 +87,23 @@ async def test_with_cyclic_dependency(self) -> None:
right_hand_dependency_service_id = 'MyFirstRightHandDependency'

@asynccontextmanager
async def _new_dummy_left_hand_dependency_service(services: ServiceContainer) -> AsyncIterator[DummyLeftHandCyclicDependencyService]:
async def _new_dummy_left_hand_dependency_service(services: ServiceContainer[None]) -> AsyncIterator[DummyLeftHandCyclicDependencyService]:
yield DummyLeftHandCyclicDependencyService(await services.get(right_hand_dependency_service_id))

@asynccontextmanager
async def _new_dummy_right_hand_dependency_service(services: ServiceContainer) -> AsyncIterator[DummyRightHandCyclicDependencyService]:
async def _new_dummy_right_hand_dependency_service(services: ServiceContainer[None]) -> AsyncIterator[DummyRightHandCyclicDependencyService]:
yield DummyRightHandCyclicDependencyService(await services.get(left_hand_dependency_service_id))
builder = DefaultServiceContainerBuilder()
builder = DefaultServiceContainerBuilder[None]()
builder.define(left_hand_dependency_service_id, service_factory=_new_dummy_left_hand_dependency_service)
builder.define(right_hand_dependency_service_id, service_factory=_new_dummy_right_hand_dependency_service)
async with builder.build() as services:
with pytest.raises(CyclicDependency):
await services.get(left_hand_dependency_service_id)

async def test_with_context(self) -> None:
class _Context:
pass
context = _Context()
builder = DefaultServiceContainerBuilder[_Context](service_context=context)
async with builder.build() as services:
assert services.context is context

0 comments on commit 13332e8

Please sign in to comment.