Skip to content

Commit

Permalink
feat: Mutual-exclusion of manager/agent embedded plugins (#453)
Browse files Browse the repository at this point in the history
* Properly define and use the `blocklist` class attribute of plugin
  context objects
* test: Update mocks with blocklist
  • Loading branch information
achimnol committed Jun 8, 2022
1 parent 8e9e059 commit d38ae45
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 22 deletions.
1 change: 1 addition & 0 deletions changes/453.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement plugin blocklist and utilize it to mutually exclude self-embedded plugins in the manager and agent for when they are executed under a unified virtualenv
9 changes: 9 additions & 0 deletions src/ai/backend/agent/monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ai.backend.common.plugin.monitor import ErrorPluginContext, StatsPluginContext


class AgentErrorPluginContext(ErrorPluginContext):
blocklist = {"ai.backend.manager"}


class AgentStatsPluginContext(StatsPluginContext):
blocklist = {"ai.backend.manager"}
3 changes: 1 addition & 2 deletions src/ai/backend/agent/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from typing import (
Any,
Collection,
Container,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -359,7 +358,7 @@ class ComputePluginContext(BasePluginContext[AbstractComputePlugin]):
def discover_plugins(
cls,
plugin_group: str,
blocklist: Container[str] = None,
blocklist: set[str] = None,
) -> Iterator[Tuple[str, Type[AbstractComputePlugin]]]:
scanned_plugins = [*super().discover_plugins(plugin_group, blocklist)]

Expand Down
6 changes: 3 additions & 3 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from ai.backend.common import config, utils, identity, msgpack
from ai.backend.common.etcd import AsyncEtcd, ConfigScopes
from ai.backend.common.logging import Logger, BraceStyleAdapter
from ai.backend.common.plugin.monitor import ErrorPluginContext, StatsPluginContext
from ai.backend.common.types import (
HardwareMetadata, aobject,
ClusterInfo,
Expand All @@ -61,6 +60,7 @@
container_etcd_config_iv,
)
from .exception import ResourceError
from .monitor import AgentErrorPluginContext, AgentStatsPluginContext
from .types import AgentBackend, VolumeInfo, LifecycleEvent
from .utils import get_subnet_ip

Expand Down Expand Up @@ -187,8 +187,8 @@ async def __ainit__(self) -> None:
await self.read_agent_config()
await self.read_agent_config_container()

self.stats_monitor = StatsPluginContext(self.etcd, self.local_config)
self.error_monitor = ErrorPluginContext(self.etcd, self.local_config)
self.stats_monitor = AgentStatsPluginContext(self.etcd, self.local_config)
self.error_monitor = AgentErrorPluginContext(self.etcd, self.local_config)
await self.stats_monitor.init()
await self.error_monitor.init()

Expand Down
13 changes: 6 additions & 7 deletions src/ai/backend/common/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from typing import (
Any,
ClassVar,
Container,
Dict,
Generic,
Iterator,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Expand Down Expand Up @@ -106,6 +106,7 @@ class BasePluginContext(Generic[P]):
local_config: Mapping[str, Any]
plugins: Dict[str, P]
plugin_group: ClassVar[str] = 'backendai_XXX_v10'
blocklist: ClassVar[Optional[set[str]]] = None

_config_watchers: WeakSet[asyncio.Task]

Expand All @@ -126,13 +127,11 @@ def __init__(self, etcd: AsyncEtcd, local_config: Mapping[str, Any]) -> None:
def discover_plugins(
cls,
plugin_group: str,
blocklist: Container[str] = None,
blocklist: set[str] = None,
) -> Iterator[Tuple[str, Type[P]]]:
if blocklist is None:
blocklist = set()
for entrypoint in scan_entrypoints(plugin_group):
if entrypoint.name in blocklist:
continue
cls_blocklist = set() if cls.blocklist is None else cls.blocklist
arg_blocklist = set() if blocklist is None else blocklist
for entrypoint in scan_entrypoints(plugin_group, cls_blocklist | arg_blocklist):
log.info('loading plugin (group:{}): {}', plugin_group, entrypoint.name)
yield entrypoint.name, entrypoint.load()

Expand Down
12 changes: 12 additions & 0 deletions src/ai/backend/manager/plugin/monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from ai.backend.common.plugin.monitor import (
ErrorPluginContext,
StatsPluginContext,
)


class ManagerErrorPluginContext(ErrorPluginContext):
blocklist = {"ai.backend.agent"}


class ManagerStatsPluginContext(StatsPluginContext):
blocklist = {"ai.backend.agent"}
10 changes: 6 additions & 4 deletions src/ai/backend/manager/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@
from ai.backend.common.logging import Logger, BraceStyleAdapter
from ai.backend.common.plugin.hook import HookPluginContext, ALL_COMPLETED, PASSED
from ai.backend.common.plugin.monitor import (
ErrorPluginContext,
StatsPluginContext,
INCREMENT,
)

Expand Down Expand Up @@ -71,6 +69,10 @@
from .idle import init_idle_checkers
from .models.storage import StorageSessionManager
from .models.utils import connect_database
from .plugin.monitor import (
ManagerErrorPluginContext,
ManagerStatsPluginContext,
)
from .plugin.webapp import WebappPluginContext
from .registry import AgentRegistry
from .scheduler.dispatcher import SchedulerDispatcher
Expand Down Expand Up @@ -411,8 +413,8 @@ async def sched_dispatcher_ctx(root_ctx: RootContext) -> AsyncIterator[None]:

@actxmgr
async def monitoring_ctx(root_ctx: RootContext) -> AsyncIterator[None]:
ectx = ErrorPluginContext(root_ctx.shared_config.etcd, root_ctx.local_config)
sctx = StatsPluginContext(root_ctx.shared_config.etcd, root_ctx.local_config)
ectx = ManagerErrorPluginContext(root_ctx.shared_config.etcd, root_ctx.local_config)
sctx = ManagerStatsPluginContext(root_ctx.shared_config.etcd, root_ctx.local_config)
await ectx.init(context={'_root.context': root_ctx})
await sctx.init()
root_ctx.error_monitor = ectx
Expand Down
19 changes: 16 additions & 3 deletions src/ai/backend/plugin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import logging
from importlib.metadata import EntryPoint, entry_points
from pathlib import Path
from typing import Container, Iterator, Optional
from typing import Iterator, Optional

log = logging.getLogger(__name__)


def scan_entrypoints(
group_name: str,
blocklist: Container[str] = None,
blocklist: Optional[set[str]] = None,
) -> Iterator[EntryPoint]:
if blocklist is None:
blocklist = set()
Expand All @@ -21,7 +21,7 @@ def scan_entrypoints(
scan_entrypoint_from_plugin_checkouts(group_name),
scan_entrypoint_from_package_metadata(group_name),
):
if entrypoint.name in blocklist:
if match_blocklist(entrypoint.value, blocklist):
continue
if existing_entrypoint := existing_names.get(entrypoint.name, None):
if existing_entrypoint.value == entrypoint.value:
Expand All @@ -43,6 +43,19 @@ def scan_entrypoints(
yield entrypoint


def match_blocklist(entry_path: str, blocklist: set[str]) -> bool:
"""
Checks if the given module attribute reference is in the blocklist.
The blocklist items are assumeed to be prefixes of package import paths
or the package namespaces.
"""
mod_path = entry_path.partition(":")[0]
for block_pattern in blocklist:
if mod_path.startswith(block_pattern + ".") or mod_path == block_pattern:
return True
return False


def scan_entrypoint_from_package_metadata(group_name: str) -> Iterator[EntryPoint]:
yield from entry_points().select(group=group_name)

Expand Down
10 changes: 9 additions & 1 deletion tests/common/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ def load(self):
return self.load_result


def mock_entrypoints_with_instance(plugin_group_name: str, *, mocked_plugin):
def mock_entrypoints_with_instance(
plugin_group_name: str,
blocklist: set[str] = None,
*,
mocked_plugin,
):
# Since mocked_plugin is already an instance constructed via AsyncMock,
# we emulate the original constructor using a lambda fucntion.
yield DummyEntrypoint(
Expand All @@ -69,6 +74,7 @@ def mock_entrypoints_with_instance(plugin_group_name: str, *, mocked_plugin):
@overload
def mock_entrypoints_with_class(
plugin_group_name: str,
blocklist: set[str] = None,
*,
plugin_cls: list[Type[AbstractPlugin]],
) -> Iterator[DummyEntrypoint]:
Expand All @@ -78,6 +84,7 @@ def mock_entrypoints_with_class(
@overload
def mock_entrypoints_with_class(
plugin_group_name: str,
blocklist: set[str] = None,
*,
plugin_cls: Type[AbstractPlugin],
) -> DummyEntrypoint:
Expand All @@ -86,6 +93,7 @@ def mock_entrypoints_with_class(

def mock_entrypoints_with_class(
plugin_group_name: str,
blocklist: set[str] = None,
*,
plugin_cls,
):
Expand Down
2 changes: 1 addition & 1 deletion tests/manager/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def test_handle_heartbeat(
mock_redis_wrapper.execute = AsyncMock()
mocker.patch('ai.backend.manager.registry.redis', mock_redis_wrapper)

def mocked_entrypoints(entry_point_group: str):
def mocked_entrypoints(entry_point_group: str, blocklist: set[str] = None):
return []

mocker.patch('ai.backend.common.plugin.scan_entrypoints', mocked_entrypoints)
Expand Down
9 changes: 8 additions & 1 deletion tests/plugin/test_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import textwrap as tw
from pathlib import Path

from ai.backend.plugin.entrypoint import extract_entrypoints_from_buildscript
from ai.backend.plugin.entrypoint import extract_entrypoints_from_buildscript, match_blocklist


def test_parse_build():
Expand Down Expand Up @@ -61,3 +61,10 @@ def test_parse_build():
items = [*extract_entrypoints_from_buildscript("backendai_error_monitor_v20", p)]
assert (items[0].name, items[0].module, items[0].attr) == \
("intrinsic", "ai.backend.manager.plugin.error_monitor", "ErrorMonitor")


def test_match_blocklist():
assert match_blocklist("ai.backend.manager:abc", {"ai.backend.manager"})
assert not match_blocklist("ai.backend.manager:abc", {"ai.backend.agent"})
assert match_blocklist("ai.backend.manager.scheduler.fifo:FIFOScheduler", {"ai.backend.manager"})
assert not match_blocklist("ai.backend.common.monitor:ErrorMonitor", {"ai.backend.manager"})

0 comments on commit d38ae45

Please sign in to comment.