Skip to content

Commit

Permalink
Fix duplicate execution (#3301)
Browse files Browse the repository at this point in the history
* Fix duplicate execute

* Fix

Co-authored-by: 刘宝 <po.lb@antgroup.com>
  • Loading branch information
fyrestone and 刘宝 committed Dec 12, 2022
1 parent 4b15d0d commit 4b06c1c
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 13 deletions.
19 changes: 10 additions & 9 deletions mars/core/graph/builder/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
self,
tileable_graph: TileableGraph,
tile_context: TileContext,
processed_chunks: Set[ChunkType],
processed_chunks: Set[str],
chunk_to_fetch: Dict[ChunkType, ChunkType],
add_nodes: Callable,
):
Expand Down Expand Up @@ -301,11 +301,12 @@ def _iter(self):

if chunk_graph is not None:
# last tiled chunks, add them to processed
# so that fetch chunk can be generated
processed_chunks = [
c.chunk if isinstance(c, FUSE_CHUNK_TYPE) else c
# so that fetch chunk can be generated.
# Use chunk key as the key to make sure the copied chunk can be build to a fetch.
processed_chunks = (
c.chunk.key if isinstance(c, FUSE_CHUNK_TYPE) else c.key
for c in chunk_graph.result_chunks
]
)
self._processed_chunks.update(processed_chunks)

result_chunks = []
Expand Down Expand Up @@ -389,7 +390,7 @@ def __init__(
self.tile_context = TileContext() if tile_context is None else tile_context
self.tile_context.set_tileables(set(graph))

self._processed_chunks: Set[ChunkType] = set()
self._processed_chunks: Set[str] = set()
self._chunk_to_fetch: Dict[ChunkType, ChunkType] = dict()

tiler_cls = Tiler if tiler_cls is None else tiler_cls
Expand All @@ -402,7 +403,7 @@ def __init__(
)

def _process_node(self, entity: EntityType):
if entity in self._processed_chunks:
if entity.key in self._processed_chunks:
if entity not in self._chunk_to_fetch:
# gen fetch
fetch_chunk = build_fetch(entity).data
Expand All @@ -413,7 +414,7 @@ def _process_node(self, entity: EntityType):
def _select_inputs(self, inputs: List[ChunkType]):
new_inputs = []
for inp in inputs:
if inp in self._processed_chunks:
if inp.key in self._processed_chunks:
# gen fetch
if inp not in self._chunk_to_fetch:
fetch_chunk = build_fetch(inp).data
Expand All @@ -424,7 +425,7 @@ def _select_inputs(self, inputs: List[ChunkType]):
return new_inputs

def _if_add_node(self, node: EntityType, visited: Set):
return node not in visited and node not in self._processed_chunks
return node not in visited and node.key not in self._processed_chunks

def _build(self) -> Iterable[Union[TileableGraph, ChunkGraph]]:
tile_iterator = iter(self.tiler)
Expand Down
2 changes: 1 addition & 1 deletion mars/dataframe/base/rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def tile(cls, op: "DataFrameRechunk"):
params["dtypes"] = pd.concat([c.dtypes for c in inp_chunks_arr[0]])
if len(inp_slice_chunks) == 1:
c = inp_slice_chunks[0]
cc = c.op.copy().reset_key().new_chunk(c.op.inputs, kws=[params])
cc = c.op.copy().new_chunk(c.op.inputs, kws=[params])
out_chunks.append(cc)
else:
out_chunk = DataFrameConcat(
Expand Down
9 changes: 9 additions & 0 deletions mars/dataframe/base/tests/test_base_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ def test_to_cpu_execution(setup_gpu):


def test_rechunk_execution(setup):
ns = np.random.RandomState(0)
df = pd.DataFrame(ns.rand(100, 10), columns=["a" + str(i) for i in range(10)])

# test rechunk after sort
mdf = DataFrame(df, chunk_size=10)
result = mdf.sort_values("a0").rechunk(chunk_size=10).execute().fetch()
expected = df.sort_values("a0")
pd.testing.assert_frame_equal(result, expected)

data = pd.DataFrame(np.random.rand(8, 10))
df = from_pandas_df(pd.DataFrame(data), chunk_size=3)
df2 = df.rechunk((3, 4))
Expand Down
2 changes: 1 addition & 1 deletion mars/dataframe/sort/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _tile_head(cls, op: "DataFrameSortOperand"):
shape = tuple(shape)
concat_params["shape"] = shape
if len(to_combine_chunks) == 1:
c = to_combine_chunks[0]
c = to_combine_chunks[0].copy()
c._index = chunk_index
else:
c = DataFrameConcat(
Expand Down
2 changes: 1 addition & 1 deletion mars/dataframe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,7 @@ def _concat_chunks(merge_chunks: List[ChunkType], output_index: int):
# concat previous chunks
if len(to_merge_chunks) == 1:
# do not generate concat op for 1 input.
c = to_merge_chunks[0]
c = to_merge_chunks[0].copy()
c._index = (
(len(n_split),) if df_or_series.ndim == 1 else (len(n_split), 0)
)
Expand Down
2 changes: 1 addition & 1 deletion mars/services/task/supervisor/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self,
tileable_graph: TileableGraph,
tile_context: TileContext,
processed_chunks: Set[ChunkType],
processed_chunks: Set[str],
chunk_to_fetch: Dict[ChunkType, ChunkType],
add_nodes: Callable,
cancelled: asyncio.Event = None,
Expand Down

0 comments on commit 4b06c1c

Please sign in to comment.