Skip to content

Commit

Permalink
[BACKPORT] Make Proxima work with latest Mars (#2599) (#2605)
Browse files Browse the repository at this point in the history
Co-authored-by: yuyiming <36940796+yuyiming@users.noreply.github.com>
  • Loading branch information
wjsi and yuyiming committed Dec 8, 2021
1 parent d53df46 commit b31a3ad
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
8 changes: 4 additions & 4 deletions mars/learn/proxima/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ def rechunk_tensor(tensor, chunk_size):
if start_chunk_index == end_chunk_index:
t = tensor.chunks[start_chunk_index]
slice_op = TensorSlice(
(
[
slice(
offset - tensor_cumnrows[start_chunk_index],
split + offset - tensor_cumnrows[end_chunk_index],
),
slice(None),
),
],
dtype=t.dtype,
)
out_groups.append(
Expand All @@ -93,7 +93,7 @@ def rechunk_tensor(tensor, chunk_size):
start_chunk = tensor.chunks[start_chunk_index]
start_slice = int(offset - tensor_cumnrows[start_chunk_index])
slice_op = TensorSlice(
(slice(start_slice, None), slice(None)), dtype=start_chunk.dtype
[slice(start_slice, None), slice(None)], dtype=start_chunk.dtype
)
chunks.append(
slice_op.new_chunk(
Expand All @@ -107,7 +107,7 @@ def rechunk_tensor(tensor, chunk_size):
end_chunk = tensor.chunks[end_chunk_index]
end_slice = int(split + offset - tensor_cumnrows[end_chunk_index])
slice_op_end = TensorSlice(
(slice(None, end_slice), slice(None)), dtype=start_chunk.dtype
[slice(None, end_slice), slice(None)], dtype=start_chunk.dtype
)
chunks.append(
slice_op_end.new_chunk(
Expand Down
6 changes: 6 additions & 0 deletions mars/services/task/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,12 @@ def gen_subtask_graph(self) -> SubtaskGraph:
"Assigned %s start chunks for task %s", len(start_ops), self._task.task_id
)

# assign expect workers for those specified with `expect_worker`
# skip `start_ops`, which have been assigned before
for chunk in self._chunk_graph:
if chunk not in start_ops and chunk.op.expect_worker is not None:
chunk_to_bands[chunk] = self._to_band(chunk.op.expect_worker)

# fuse node
if self._fuse_enabled:
logger.debug("Start to fuse chunks for task %s", self._task.task_id)
Expand Down

0 comments on commit b31a3ad

Please sign in to comment.