Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/6498.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add resource isolation options for multi-agent setup
122 changes: 25 additions & 97 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,18 @@
from dataclasses import dataclass
from decimal import Decimal
from io import SEEK_END, BytesIO
from itertools import chain
from pathlib import Path
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Concatenate,
Final,
Generic,
Literal,
Optional,
ParamSpec,
TypeVar,
cast,
)
Expand Down Expand Up @@ -175,7 +178,6 @@
from ai.backend.common.types import (
MODEL_SERVICE_RUNTIME_PROFILES,
AbuseReportValue,
AcceleratorMetadata,
AgentId,
AutoPullBehavior,
BinarySize,
Expand Down Expand Up @@ -222,7 +224,6 @@
from ai.backend.logging.formatter import pretty

from . import __version__ as VERSION
from . import alloc_map as alloc_map_mod
from .affinity_map import AffinityMap
from .config.unified import AgentUnifiedConfig, ContainerSandboxType
from .exception import AgentError, ContainerCreationError, ResourceError
Expand All @@ -235,12 +236,10 @@
from .observer.heartbeat import HeartbeatObserver
from .observer.host_port import HostPortObserver
from .resources import (
AbstractComputeDevice,
AbstractComputePlugin,
ComputerContext,
KernelResourceSpec,
Mount,
align_memory,
allocate,
known_slot_types,
)
Expand Down Expand Up @@ -279,6 +278,7 @@
EVENT_DISPATCHER_CONSUMER_GROUP: Final = "agent"
STAT_COLLECTION_TIMEOUT: Final[float] = 10 * 60 # 10 minutes

P = ParamSpec("P")
KernelObjectType = TypeVar("KernelObjectType", bound=AbstractKernel)
KernelIdContainerPair = tuple[KernelId, Container]

Expand Down Expand Up @@ -341,7 +341,7 @@ def __init__(
kernel_config: KernelCreationConfig,
distro: str,
local_config: AgentUnifiedConfig,
computers: MutableMapping[DeviceName, ComputerContext],
computers: Mapping[DeviceName, ComputerContext],
restarting: bool = False,
) -> None:
self.image_labels = kernel_config["image"]["labels"]
Expand Down Expand Up @@ -729,18 +729,18 @@ class RestartTracker:
def _observe_stat_task(
stat_scope: StatScope,
) -> Callable[
[Callable[[AbstractAgent, float], Coroutine[Any, Any, None]]],
Callable[[AbstractAgent, float], Coroutine[Any, Any, None]],
[Callable[Concatenate[AbstractAgent, P], Coroutine[Any, Any, None]]],
Callable[Concatenate[AbstractAgent, P], Coroutine[Any, Any, None]],
]:
stat_task_observer = StatTaskObserver.instance()

def decorator(
func: Callable[[AbstractAgent, float], Coroutine[Any, Any, None]],
) -> Callable[[AbstractAgent, float], Coroutine[Any, Any, None]]:
async def wrapper(self: AbstractAgent, interval: float) -> None:
func: Callable[Concatenate[AbstractAgent, P], Coroutine[Any, Any, None]],
) -> Callable[Concatenate[AbstractAgent, P], Coroutine[Any, Any, None]]:
async def wrapper(self: AbstractAgent, *args: P.args, **kwargs: P.kwargs) -> None:
stat_task_observer.observe_stat_task_triggered(agent_id=self.id, stat_scope=stat_scope)
try:
await func(self, interval)
await func(self, *args, **kwargs)
except asyncio.CancelledError:
pass
except Exception as e:
Expand All @@ -766,7 +766,8 @@ class AbstractAgent(
etcd: AgentEtcdClientView
local_instance_id: str
kernel_registry: MutableMapping[KernelId, AbstractKernel]
computers: MutableMapping[DeviceName, ComputerContext]
computers: Mapping[DeviceName, ComputerContext]
slots: Mapping[SlotName, Decimal]
images: Mapping[ImageCanonical, ScannedImage]
port_pool: set[int]

Expand Down Expand Up @@ -838,6 +839,8 @@ def __init__(
skip_initial_scan: bool = False,
agent_public_key: Optional[PublicKey],
kernel_registry: KernelRegistry,
computers: Mapping[DeviceName, ComputerContext],
slots: Mapping[SlotName, Decimal],
) -> None:
self._skip_initial_scan = skip_initial_scan
self.loop = current_loop()
Expand All @@ -847,7 +850,8 @@ def __init__(
self.local_instance_id = generate_local_instance_id(__file__)
self.agent_public_key = agent_public_key
self.kernel_registry = kernel_registry.agent_mapping(self.id)
self.computers = {}
self.computers = computers
self.slots = slots
self.images = {}
self.restarting_kernels = {}
self.stat_ctx = StatContext(
Expand Down Expand Up @@ -931,31 +935,21 @@ async def __ainit__(self) -> None:
bgtask_observer=self._metric_registry.bgtask,
)

alloc_map_mod.log_alloc_map = self.local_config.debug.log_alloc_map
computers = await self.load_resources()

all_devices: list[AbstractComputeDevice] = []
metadatas: list[AcceleratorMetadata] = []
for name, computer in computers.items():
devices = await computer.list_devices()
all_devices.extend(devices)
alloc_map = await computer.create_alloc_map()
self.computers[name] = ComputerContext(computer, devices, alloc_map)
metadatas.append(computer.get_metadata())

self.slots = await self.update_slots()
log.info("Resource slots: {!r}", self.slots)
log.info("Slot types: {!r}", known_slot_types)
self.timer_tasks.append(aiotools.create_timer(self.update_slots_periodically, 30.0))

# Use ValkeyStatClient batch operations for better performance
metadatas = [computer.instance.get_metadata() for computer in self.computers.values()]
field_value_map = {}
for metadata in metadatas:
field_value_map[metadata["slot_name"]] = dump_json_str(metadata).encode()

if field_value_map:
await self.valkey_stat_client.store_computer_metadata(field_value_map)

all_devices = list(
chain.from_iterable((computer.devices for computer in self.computers.values()))
)
self.affinity_map = AffinityMap.build(all_devices)

if not self._skip_initial_scan:
Expand All @@ -965,9 +959,6 @@ async def __ainit__(self) -> None:
await self.scan_running_kernels()

# Prepare stat collector tasks.
self.timer_tasks.append(
aiotools.create_timer(self.collect_node_stat, UTILIZATION_METRIC_INTERVAL)
)
self.timer_tasks.append(
aiotools.create_timer(self.collect_container_stat, UTILIZATION_METRIC_INTERVAL)
)
Expand Down Expand Up @@ -1053,12 +1044,6 @@ async def shutdown(self, stop_signal: signal.Signals) -> None:
"""
await cancel_tasks(self._ongoing_exec_batch_tasks)

for _, computer in self.computers.items():
try:
await computer.instance.cleanup()
except Exception:
log.exception("Failed to clean up computer instance:")

async with self.registry_lock:
# Close all pending kernel runners.
for kernel_obj in self.kernel_registry.values():
Expand Down Expand Up @@ -1264,12 +1249,12 @@ async def collect_logs(
chunk_buffer.close()

@_observe_stat_task(stat_scope=StatScope.NODE)
async def collect_node_stat(self, interval: float):
async def collect_node_stat(self, resource_scaling_factors: Mapping[SlotName, Decimal]):
if self.local_config.debug.log_stats:
log.debug("collecting node statistics")
try:
async with asyncio.timeout(STAT_COLLECTION_TIMEOUT):
await self.stat_ctx.collect_node_stat()
await self.stat_ctx.collect_node_stat(resource_scaling_factors)
except Exception:
log.exception("unhandled exception while syncing node stats")
await self.produce_error_event()
Expand Down Expand Up @@ -1940,65 +1925,8 @@ def get_cgroup_path(self, controller: str, container_id: str) -> Path:
def get_cgroup_version(self) -> str:
raise NotImplementedError

@abstractmethod
async def load_resources(
self,
) -> Mapping[DeviceName, AbstractComputePlugin]:
"""
Detect available resources attached on the system and load corresponding device plugin.
"""

@abstractmethod
async def scan_available_resources(
self,
) -> Mapping[SlotName, Decimal]:
"""
Scan and define the amount of available resource slots in this node.
"""

async def update_slots(
self,
) -> Mapping[SlotName, Decimal]:
"""
Finalize the resource slots from the resource slots scanned by each device plugin,
excluding reserved capacities for the system and agent itself.
"""
scanned_slots = await self.scan_available_resources()
usable_slots: dict[SlotName, Decimal] = {}
reserved_slots = {
SlotName("cpu"): Decimal(self.local_config.resource.reserved_cpu),
SlotName("mem"): Decimal(self.local_config.resource.reserved_mem),
SlotName("disk"): Decimal(self.local_config.resource.reserved_disk),
}
for slot_name, slot_capacity in scanned_slots.items():
if slot_name == SlotName("mem"):
mem_reserved = int(reserved_slots.get(slot_name, 0))
mem_align = int(self.local_config.resource.memory_align_size)
mem_usable, mem_reserved = align_memory(
int(slot_capacity), mem_reserved, align=mem_align
)
usable_capacity = Decimal(mem_usable)
log.debug(
"usable-mem: {:m}, reserved-mem: {:m} after {:m} alignment",
BinarySize(mem_usable),
BinarySize(mem_reserved),
BinarySize(mem_align),
)
else:
usable_capacity = max(
Decimal(0), slot_capacity - reserved_slots.get(slot_name, Decimal(0))
)
usable_slots[slot_name] = usable_capacity
return usable_slots

async def update_slots_periodically(
self,
interval: float,
) -> None:
"""
A timer function to periodically scan and update the resource slots.
"""
self.slots = await self.update_slots()
def update_slots(self, updated_slots: Mapping[SlotName, Decimal]) -> None:
self.slots = updated_slots
log.debug("slots: {!r}", self.slots)

async def gather_hwinfo(self) -> Mapping[str, HardwareMetadata]:
Expand Down
11 changes: 11 additions & 0 deletions src/ai/backend/agent/alloc_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,17 @@ def update_affinity_hint(
hint_for_next_allocation.append(dev)
affinity_hint.devices = hint_for_next_allocation

@final
def update_device_slot_amounts(self, slot_amounts: Mapping[SlotName, Decimal]) -> None:
self.device_slots = {
device_id: DeviceSlotInfo(
slot_type=slot_info.slot_type,
slot_name=slot_info.slot_name,
amount=slot_amounts[slot_info.slot_name],
)
for device_id, slot_info in self.device_slots.items()
}

@abstractmethod
def allocate(
self,
Expand Down
16 changes: 5 additions & 11 deletions src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@
update_nested_dict,
)
from .kernel import DockerKernel
from .resources import load_resources, scan_available_resources
from .utils import PersistentServiceContainer

if TYPE_CHECKING:
Expand Down Expand Up @@ -294,7 +293,7 @@ def __init__(
kernel_config: KernelCreationConfig,
distro: str,
local_config: AgentUnifiedConfig,
computers: MutableMapping[DeviceName, ComputerContext],
computers: Mapping[DeviceName, ComputerContext],
port_pool: Set[int],
agent_sockpath: Path,
resource_lock: asyncio.Lock,
Expand Down Expand Up @@ -1356,6 +1355,8 @@ def __init__(
skip_initial_scan: bool = False,
agent_public_key: Optional[PublicKey],
kernel_registry: KernelRegistry,
computers: Mapping[DeviceName, ComputerContext],
slots: Mapping[SlotName, Decimal],
) -> None:
super().__init__(
etcd,
Expand All @@ -1365,6 +1366,8 @@ def __init__(
skip_initial_scan=skip_initial_scan,
agent_public_key=agent_public_key,
kernel_registry=kernel_registry,
computers=computers,
slots=slots,
)
self.checked_invalid_images = set()

Expand Down Expand Up @@ -1491,15 +1494,6 @@ def get_cgroup_path(self, controller: str, container_id: str) -> Path:
def get_cgroup_version(self) -> str:
return self.docker_info["CgroupVersion"]

async def load_resources(self) -> Mapping[DeviceName, AbstractComputePlugin]:
return await load_resources(self.etcd, self.local_config.model_dump(by_alias=True))

async def scan_available_resources(self) -> Mapping[SlotName, Decimal]:
return await scan_available_resources(
self.local_config.model_dump(by_alias=True),
{name: cctx.instance for name, cctx in self.computers.items()},
)

async def extract_image_command(self, image: str) -> Optional[str]:
async with closing_async(Docker()) as docker:
result = await docker.images.get(image)
Expand Down
1 change: 0 additions & 1 deletion src/ai/backend/agent/docker/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ async def load_resources(


async def scan_available_resources(
local_config: Mapping[str, Any],
compute_device_types: Mapping[DeviceName, AbstractComputePlugin],
) -> Mapping[SlotName, Decimal]:
"""
Expand Down
14 changes: 1 addition & 13 deletions src/ai/backend/agent/dummy/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
from ..types import Container, KernelOwnershipData, MountInfo
from .config import DEFAULT_CONFIG_PATH, dummy_local_config
from .kernel import DummyKernel
from .resources import load_resources, scan_available_resources


class DummyKernelCreationContext(AbstractKernelCreationContext[DummyKernel]):
Expand All @@ -73,7 +72,7 @@ def __init__(
kernel_config: KernelCreationConfig,
distro: str,
local_config: AgentUnifiedConfig,
computers: MutableMapping[DeviceName, ComputerContext],
computers: Mapping[DeviceName, ComputerContext],
restarting: bool = False,
*,
dummy_config: Mapping[str, Any],
Expand Down Expand Up @@ -284,17 +283,6 @@ def get_cgroup_version(self) -> str:
# Dummy agent does not use cgroups, so we return an empty string.
return ""

async def load_resources(self) -> Mapping[DeviceName, AbstractComputePlugin]:
return await load_resources(
self.etcd, self.local_config.model_dump(by_alias=True), self.dummy_config
)

async def scan_available_resources(self) -> Mapping[SlotName, Decimal]:
return await scan_available_resources(
self.local_config.model_dump(by_alias=True),
{name: cctx.instance for name, cctx in self.computers.items()},
)

async def extract_image_command(self, image: str) -> str | None:
delay = self.dummy_agent_cfg["delay"]["scan-image"]
await asyncio.sleep(delay)
Expand Down
1 change: 0 additions & 1 deletion src/ai/backend/agent/dummy/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ async def load_resources(


async def scan_available_resources(
local_config: Mapping[str, Any],
compute_device_types: Mapping[DeviceName, AbstractComputePlugin],
) -> Mapping[SlotName, Decimal]:
"""
Expand Down
Loading
Loading