Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partial support for setitem with fancy indexing #2453

Merged
merged 11 commits into from
Oct 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mars/services/scheduling/worker/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ async def _get_band_quota_ref(self, band: str) -> Union[mo.ActorRef, QuotaActor]

async def _prepare_input_data(self, subtask: Subtask, band_name: str):
queries = []
shuffle_queries = []
storage_api = await StorageAPI.create(
subtask.session_id, address=self.address, band_name=band_name
)
Expand All @@ -167,13 +168,17 @@ async def _prepare_input_data(self, subtask: Subtask, band_name: str):
)
elif isinstance(chunk.op, FetchShuffle):
for key in chunk_key_to_data_keys[chunk.key]:
queries.append(
shuffle_queries.append(
storage_api.fetch.delay(
key, band_name=to_fetch_band, error="ignore"
)
)
if queries:
await storage_api.fetch.batch(*queries)
if shuffle_queries:
qinxuye marked this conversation as resolved.
Show resolved Hide resolved
# TODO(hks): The batch method doesn't accept different error arguments,
# combine them when it can.
await storage_api.fetch.batch(*shuffle_queries)

async def _collect_input_sizes(
self, subtask: Subtask, supervisor_address: str, band_name: str
Expand Down
22 changes: 16 additions & 6 deletions mars/services/subtask/worker/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ async def _load_input_data(self):
accept_nones.append(True)
elif isinstance(chunk.op, FetchShuffle):
for key in self._chunk_key_to_data_keys[chunk.key]:
keys.append(key)
gets.append(self._storage_api.get.delay(key, error="ignore"))
accept_nones.append(False)
if key not in keys:
keys.append(key)
gets.append(self._storage_api.get.delay(key, error="ignore"))
accept_nones.append(False)
if keys:
logger.debug(
"Start getting input data, keys: %s, " "subtask id: %s",
Expand Down Expand Up @@ -239,19 +240,28 @@ def cb(fut):
if ref_counts[inp.key] == 0:
# ref count reaches 0, remove it
for key in self._chunk_key_to_data_keys[inp.key]:
del self._datastore[key]
if key in self._datastore:
del self._datastore[key]

async def _unpin_data(self, data_keys):
# unpin input keys
unpins = []
shuffle_unpins = []
for key in data_keys:
if isinstance(key, tuple):
# a tuple key means it's a shuffle key,
# some shuffle data is None and not stored in storage
unpins.append(self._storage_api.unpin.delay(key, error="ignore"))
shuffle_unpins.append(
self._storage_api.unpin.delay(key, error="ignore")
)
else:
unpins.append(self._storage_api.unpin.delay(key))
await self._storage_api.unpin.batch(*unpins)
if unpins:
await self._storage_api.unpin.batch(*unpins)
if shuffle_unpins:
# TODO(hks): The batch method doesn't accept different error arguments,
# combine them when it can.
await self._storage_api.unpin.batch(*shuffle_unpins)

async def _store_data(self, chunk_graph: ChunkGraph):
# skip virtual operands for result chunks
Expand Down
Loading