Skip to content

Commit

Permalink
[BACKPORT] Assign bands given devices of subtasks (#2278)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi authored Aug 2, 2021
1 parent c864ce3 commit 5c1ecde
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 15 deletions.
4 changes: 4 additions & 0 deletions mars/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,7 @@ def serialize(self, obj: Serializable, context: Dict):


BaseSerializer.register(Base)


class MarsError(Exception):
pass
4 changes: 3 additions & 1 deletion mars/oscar/backends/communication/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ....core.base import MarsError

class ChannelClosed(Exception):

class ChannelClosed(MarsError):
pass
4 changes: 3 additions & 1 deletion mars/oscar/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from mars.core.base import MarsError

class ActorPoolNotStarted(Exception):

class ActorPoolNotStarted(MarsError):
pass


Expand Down
4 changes: 3 additions & 1 deletion mars/services/lifecycle/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ...core.base import MarsError

class TileableNotTracked(Exception):

class TileableNotTracked(MarsError):
pass
23 changes: 23 additions & 0 deletions mars/services/scheduling/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ...core.base import MarsError


class NoMatchingSlots(MarsError):
def __init__(self, slot_prefix):
self.slot_prefix = slot_prefix

def __str__(self):
return str(self.slot_prefix)
50 changes: 46 additions & 4 deletions mars/services/scheduling/supervisor/assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,22 @@
# limitations under the License.

import asyncio
import itertools
import random
from collections import defaultdict
from typing import List

from .... import oscar as mo
from ....core.operand import Fetch, FetchShuffle
from ....typing import BandType
from ...core import NodeRole
from ...subtask import Subtask
from ..errors import NoMatchingSlots


class AssignerActor(mo.Actor):
_bands: List[BandType]

@classmethod
def gen_uid(cls, session_id: str):
return f'{session_id}_assigner'
Expand All @@ -36,6 +41,8 @@ def __init__(self, session_id: str):
self._meta_api = None

self._bands = []
self._address_to_bands = dict()
self._device_type_to_bands = dict()
self._band_watch_task = None

async def __post_create__(self):
Expand All @@ -47,18 +54,42 @@ async def __post_create__(self):

async def watch_bands():
async for bands in self._cluster_api.watch_all_bands(NodeRole.WORKER):
self._bands = list(bands)
self._update_bands(list(bands))

self._band_watch_task = asyncio.create_task(watch_bands())

async def __pre_destroy__(self):
if self._band_watch_task is not None: # pragma: no branch
self._band_watch_task.cancel()

def _update_bands(self, bands: List[BandType]):
self._bands = bands

grouped_bands = itertools.groupby(sorted(self._bands), key=lambda b: b[0])
self._address_to_bands = {k: list(v) for k, v in grouped_bands}

grouped_bands = itertools.groupby(
sorted(('numa' if b[1].startswith('numa') else 'gpu', b) for b in bands),
key=lambda tp: tp[0]
)
self._device_type_to_bands = {k: [v[1] for v in tps] for k, tps in grouped_bands}

def _get_device_bands(self, is_gpu: bool):
band_prefix = 'numa' if not is_gpu else 'gpu'
filtered_bands = self._device_type_to_bands.get(band_prefix) or []
if not filtered_bands:
raise NoMatchingSlots('gpu' if is_gpu else 'cpu')
return filtered_bands

def _get_random_band(self, is_gpu: bool):
avail_bands = self._get_device_bands(is_gpu)
return random.choice(avail_bands)

async def assign_subtasks(self, subtasks: List[Subtask]):
inp_keys = set()
selected_bands = dict()
for subtask in subtasks:
is_gpu = any(c.op.gpu for c in subtask.chunk_graph.result_chunks)
if subtask.expect_bands:
selected_bands[subtask.subtask_id] = subtask.expect_bands
continue
Expand All @@ -67,9 +98,9 @@ async def assign_subtasks(self, subtasks: List[Subtask]):
inp_keys.add(indep_chunk.key)
elif isinstance(indep_chunk.op, FetchShuffle):
if not self._bands:
self._bands = list(await self._cluster_api.get_all_bands(
NodeRole.WORKER))
selected_bands[subtask.subtask_id] = [random.choice(self._bands)]
self._update_bands(list(await self._cluster_api.get_all_bands(
NodeRole.WORKER)))
selected_bands[subtask.subtask_id] = [self._get_random_band(is_gpu)]
break

fields = ['store_size', 'bands']
Expand All @@ -81,6 +112,10 @@ async def assign_subtasks(self, subtasks: List[Subtask]):
inp_metas = dict(zip(inp_keys, metas))
assigns = []
for subtask in subtasks:
is_gpu = any(c.op.gpu for c in subtask.chunk_graph.result_chunks)
band_prefix = 'numa' if not is_gpu else 'gpu'
filtered_bands = self._get_device_bands(is_gpu)

if subtask.subtask_id in selected_bands:
bands = selected_bands[subtask.subtask_id]
else:
Expand All @@ -90,6 +125,13 @@ async def assign_subtasks(self, subtasks: List[Subtask]):
continue
meta = inp_metas[inp.key]
for band in meta['bands']:
if not band[1].startswith(band_prefix):
sel_bands = [b for b in self._address_to_bands[band[0]]
if b[1].startswith(band_prefix)]
if sel_bands:
band = (band[0], random.choice(sel_bands))
if band not in filtered_bands:
band = self._get_random_band(is_gpu)
band_sizes[band] += meta['store_size']
bands = []
max_size = -1
Expand Down
104 changes: 100 additions & 4 deletions mars/services/scheduling/supervisor/tests/test_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,88 @@
# limitations under the License.

import numpy as np
import asyncio
import pytest

import mars.oscar as mo
from mars.core import ChunkGraph
from mars.services.cluster import MockClusterAPI
from mars.services.cluster import ClusterAPI
from mars.services.cluster.core import NodeRole, NodeStatus
from mars.services.cluster.uploader import NodeInfoUploaderActor
from mars.services.cluster.supervisor.locator import SupervisorPeerLocatorActor
from mars.services.cluster.supervisor.node_info import NodeInfoCollectorActor
from mars.services.meta import MockMetaAPI
from mars.services.session import MockSessionAPI
from mars.services.scheduling.supervisor import AssignerActor
from mars.services.scheduling.errors import NoMatchingSlots
from mars.services.subtask import Subtask
from mars.tensor.fetch import TensorFetch
from mars.tensor.arithmetic import TensorTreeAdd


class MockNodeInfoCollectorActor(NodeInfoCollectorActor):
def __init__(self, timeout=None, check_interval=None, with_gpu=False):
super().__init__(timeout=timeout, check_interval=check_interval)
self.ready_bands = {('address0', 'numa-0'): 2,
('address1', 'numa-0'): 2,
('address2', 'numa-0'): 2}
if with_gpu:
self.ready_bands[('address0', 'gpu-0')] = 1
self.all_bands = self.ready_bands.copy()

async def update_node_info(self, address, role, env=None,
resource=None, detail=None, status=None):
if 'address' in address and status == NodeStatus.STOPPING:
del self.ready_bands[(address, 'numa-0')]
await super().update_node_info(address, role, env,
resource, detail, status)

def get_all_bands(self, role=None, statuses=None):
if statuses == {NodeStatus.READY}:
return self.ready_bands
else:
return self.all_bands


class FakeClusterAPI(ClusterAPI):
@classmethod
async def create(cls, address: str, **kw):
dones, _ = await asyncio.wait([
mo.create_actor(SupervisorPeerLocatorActor, 'fixed', address,
uid=SupervisorPeerLocatorActor.default_uid(),
address=address),
mo.create_actor(MockNodeInfoCollectorActor,
with_gpu=kw.get('with_gpu', False),
uid=NodeInfoCollectorActor.default_uid(),
address=address),
mo.create_actor(NodeInfoUploaderActor, NodeRole.WORKER,
interval=kw.get('upload_interval'),
band_to_slots=kw.get('band_to_slots'),
use_gpu=kw.get('use_gpu', False),
uid=NodeInfoUploaderActor.default_uid(),
address=address),
])

for task in dones:
try:
task.result()
except mo.ActorAlreadyExist: # pragma: no cover
pass

api = await super().create(address=address)
await api.mark_node_ready()
return api


@pytest.fixture
async def actor_pool():
async def actor_pool(request):
pool = await mo.create_actor_pool('127.0.0.1', n_process=0)
with_gpu = request.param

async with pool:
session_id = 'test_session'
await MockClusterAPI.create(pool.external_address)
await FakeClusterAPI.create(
pool.external_address, with_gpu=with_gpu)
await MockSessionAPI.create(pool.external_address, session_id=session_id)
meta_api = await MockMetaAPI.create(session_id, pool.external_address)
assigner_ref = await mo.create_actor(
Expand All @@ -45,7 +107,8 @@ async def actor_pool():


@pytest.mark.asyncio
async def test_assigner(actor_pool):
@pytest.mark.parametrize('actor_pool', [False], indirect=True)
async def test_assign_cpu_tasks(actor_pool):
pool, session_id, assigner_ref, meta_api = actor_pool

input1 = TensorFetch(key='a', source_key='a', dtype=np.dtype(int)).new_chunk([])
Expand Down Expand Up @@ -73,3 +136,36 @@ async def test_assigner(actor_pool):
subtask = Subtask('test_task', session_id, chunk_graph=chunk_graph)
[result] = await assigner_ref.assign_subtasks([subtask])
assert result in (('address1', 'numa-0'), ('address2', 'numa-0'))

result_chunk.op.gpu = True
subtask = Subtask('test_task', session_id, chunk_graph=chunk_graph)
with pytest.raises(NoMatchingSlots) as err:
await assigner_ref.assign_subtasks([subtask])
assert 'gpu' in str(err.value)


@pytest.mark.asyncio
@pytest.mark.parametrize('actor_pool', [True], indirect=True)
async def test_assign_gpu_tasks(actor_pool):
pool, session_id, assigner_ref, meta_api = actor_pool

input1 = TensorFetch(key='a', source_key='a', dtype=np.dtype(int)).new_chunk([])
input2 = TensorFetch(key='b', source_key='b', dtype=np.dtype(int)).new_chunk([])
result_chunk = TensorTreeAdd(args=[input1, input2], gpu=True) \
.new_chunk([input1, input2])

chunk_graph = ChunkGraph([result_chunk])
chunk_graph.add_node(input1)
chunk_graph.add_node(input2)
chunk_graph.add_node(result_chunk)
chunk_graph.add_edge(input1, result_chunk)
chunk_graph.add_edge(input2, result_chunk)

await meta_api.set_chunk_meta(input1, memory_size=200, store_size=200,
bands=[('address0', 'numa-0')])
await meta_api.set_chunk_meta(input2, memory_size=200, store_size=200,
bands=[('address0', 'numa-0')])

subtask = Subtask('test_task', session_id, chunk_graph=chunk_graph)
[result] = await assigner_ref.assign_subtasks([subtask])
assert result[1].startswith('gpu')
5 changes: 3 additions & 2 deletions mars/services/storage/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ...core.base import MarsError
from ...storage.errors import DataNotExist


DataNotExist = DataNotExist


class NoDataToSpill(Exception):
class NoDataToSpill(MarsError):
pass


class StorageFull(Exception):
class StorageFull(MarsError):
pass
4 changes: 3 additions & 1 deletion mars/services/task/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ...core.base import MarsError

class TaskNotExist(Exception):

class TaskNotExist(MarsError):
pass
4 changes: 3 additions & 1 deletion mars/storage/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..core.base import MarsError

class DataNotExist(Exception):

class DataNotExist(MarsError):
pass

0 comments on commit 5c1ecde

Please sign in to comment.