From f1c2c624acbc30e81421de18e6cbe46fadbbb93a Mon Sep 17 00:00:00 2001 From: EricPai Date: Tue, 27 Jun 2023 17:12:48 +0800 Subject: [PATCH 1/8] Add replace_subgraph with tests --- mars/core/entity/core.py | 4 + mars/optimization/logical/core.py | 78 ++++++- mars/optimization/logical/tests/__init__.py | 13 ++ mars/optimization/logical/tests/test_core.py | 222 +++++++++++++++++++ 4 files changed, 316 insertions(+), 1 deletion(-) create mode 100644 mars/optimization/logical/tests/__init__.py create mode 100644 mars/optimization/logical/tests/test_core.py diff --git a/mars/core/entity/core.py b/mars/core/entity/core.py index 6a27ac65d2..d3234a59d5 100644 --- a/mars/core/entity/core.py +++ b/mars/core/entity/core.py @@ -42,6 +42,10 @@ def __init__(self, *args, **kwargs): def op(self): return self._op + @property + def outputs(self): + return self._op.outputs + @property def inputs(self): return self.op.inputs diff --git a/mars/optimization/logical/core.py b/mars/optimization/logical/core.py index ba49f825f0..b0e3104bc6 100644 --- a/mars/optimization/logical/core.py +++ b/mars/optimization/logical/core.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import itertools import weakref from abc import ABC, abstractmethod from collections import defaultdict @@ -19,7 +20,7 @@ from enum import Enum from typing import Dict, List, Optional, Type, Set -from ...core import OperandType, EntityType, enter_mode +from ...core import OperandType, EntityType, enter_mode, Entity from ...core.graph import EntityGraph from ...utils import implements @@ -130,6 +131,77 @@ def _replace_node(self, original_node: EntityType, new_node: EntityType): for succ in successors: self._graph.add_edge(new_node, succ) + def _replace_subgraph( + self, + graph: Optional[EntityGraph], + removed_nodes: Optional[Set[EntityType]], + new_results: Optional[List[Entity]] = None, + ): + """ + Replace the subgraph from the self._graph represented by a list of nodes with input graph. + It will delete the nodes in removed_nodes with all linked edges first, and then add (or update if it's still + existed in self._graph) the nodes and edges of the input graph. + + Parameters + ---------- + graph : EntityGraph, optional + The input graph. If it's none, no new node and edge will be added. + removed_nodes : Set[EntityType], optional + The nodes to be removed. All the edges connected with them are removed as well. + new_results : List[EntityType], optional, default None + The updated results of the graph. If it's None, then the results will not be updated. + + Raises + ------ + ReplaceSubgraphError + If the input key of the removed node's successor can't be found in the subgraph. + Or some of the nodes of the subgraph are in removed ones. + """ + infected_successors = set() + + output_to_node = dict() + removed_nodes = removed_nodes or set() + if graph is not None: + # Add the output key -> node of the subgraph + for node in graph.iter_nodes(): + if node in removed_nodes: + raise ReplaceSubgraphError(f"The node {node} is in the removed set") + for output in node.outputs: + output_to_node[output.key] = node + + for node in removed_nodes: + for infected_successor in self._graph.iter_successors(node): + if infected_successor not in removed_nodes: + infected_successors.add(infected_successor) + # Check whether infected successors' inputs are in subgraph + for infected_successor in infected_successors: + for inp in infected_successor.inputs: + if inp.key not in output_to_node: + raise ReplaceSubgraphError( + f"The output {inp} of node {infected_successor} is missing in the subgraph" + ) + for node in removed_nodes: + self._graph.remove_node(node) + + if graph is None: + return + + # Add the output key -> node of the original graph + for node in self._graph.iter_nodes(): + for output in node.outputs: + output_to_node[output.key] = node + + for node in graph.iter_nodes(): + self._graph.add_node(node) + + for node in itertools.chain(graph.iter_nodes(), infected_successors): + for inp in node.inputs: + pred_node = output_to_node[inp.key] + self._graph.add_edge(pred_node, node) + + if new_results is not None: + self._graph.results = new_results.copy() + def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType): pred_original = self._records.get_original_entity(predecessor, predecessor) if predecessor not in self._preds_to_remove: @@ -283,3 +355,7 @@ def optimize(cls, graph: EntityGraph) -> OptimizationRecords: graph.results = new_results return records + + +class ReplaceSubgraphError(Exception): + pass diff --git a/mars/optimization/logical/tests/__init__.py b/mars/optimization/logical/tests/__init__.py new file mode 100644 index 0000000000..c71e83c08e --- /dev/null +++ b/mars/optimization/logical/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 1999-2021 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mars/optimization/logical/tests/test_core.py b/mars/optimization/logical/tests/test_core.py new file mode 100644 index 0000000000..c7da1e0192 --- /dev/null +++ b/mars/optimization/logical/tests/test_core.py @@ -0,0 +1,222 @@ +# Copyright 1999-2021 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import pytest + + +from ..core import OptimizationRule, ReplaceSubgraphError +from .... import tensor as mt +from .... import dataframe as md + + +class _MockRule(OptimizationRule): + def apply(self) -> bool: + pass + + def replace_subgraph(self, graph, removed_nodes, new_results=None): + self._replace_subgraph(graph, removed_nodes, new_results) + + +def test_replace_tileable_subgraph(): + """ + Original Graph: + s1 ---> c1 ---> v1 ---> v4 ----> v6(output) <--- v5 <--- c5 <--- s5 + | ^ + | | + V | + v3 ------| + ^ + | + s2 ---> c2 ---> v2 + + Target Graph: + s1 ---> c1 ---> v1 ---> v7 ----> v8(output) <--- v5 <--- c5 <--- s5 + ^ + | + s2 ---> c2 ---> v2 + + The nodes [v3, v4, v6] will be removed. + Subgraph only contains [v7, v8] + """ + s1 = mt.random.randint(0, 100, size=(5, 4)) + v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5) + s2 = mt.random.randint(0, 100, size=(5, 4)) + v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5) + v3 = v1.add(v2) + v4 = v3.add(v1) + s5 = mt.random.randint(0, 100, size=(5, 4)) + v5 = md.DataFrame(s5, columns=list("ABCD"), chunk_size=4) + v6 = v5.sub(v4) + g1 = v6.build_graph() + v7 = v1.sub(v2) + v8 = v7.add(v5) + g2 = v8.build_graph() + + # Here we use a trick way to construct the subgraph for test only + key_to_node = dict() + for node in g2.iter_nodes(): + key_to_node[node.key] = node + for key, node in key_to_node.items(): + if key != v7.key and key != v8.key: + g2.remove_node(node) + r = _MockRule(g1, None, None) + for node in g1.iter_nodes(): + key_to_node[node.key] = node + + c1 = g1.successors(key_to_node[s1.key])[0] + c2 = g1.successors(key_to_node[s2.key])[0] + c5 = g1.successors(key_to_node[s5.key])[0] + + expected_results = [v8.outputs[0]] + r.replace_subgraph( + g2, {key_to_node[op.key] for op in [v3, v4, v6]}, expected_results + ) + assert g1.results == expected_results + + expected_nodes = {s1, c1, v1, s2, c2, v2, s5, c5, v5, v7, v8} + assert set(g1) == {key_to_node[n.key] for n in expected_nodes} + + expected_edges = { + s1: [c1], + c1: [v1], + v1: [v7], + s2: [c2], + c2: [v2], + v2: [v7], + s5: [c5], + c5: [v5], + v5: [v8], + v7: [v8], + v8: [], + } + for pred, successors in expected_edges.items(): + pred_node = key_to_node[pred.key] + assert g1.count_successors(pred_node) == len(successors) + for successor in successors: + assert g1.has_successor(pred_node, key_to_node[successor.key]) + + +def test_replace_null_subgraph(): + """ + Original Graph: + s1 ---> c1 ---> v1 ---> v3 <--- v2 <--- c2 <--- s2 + + Target Graph: + c1 ---> v1 ---> v3 <--- v2 <--- c2 + + The nodes [s1, s2] will be removed. + Subgraph is None + """ + s1 = mt.random.randint(0, 100, size=(10, 4)) + v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5) + s2 = mt.random.randint(0, 100, size=(10, 4)) + v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5) + v3 = v1.add(v2) + g1 = v3.build_graph() + key_to_node = {node.key: node for node in g1.iter_nodes()} + c1 = g1.successors(key_to_node[s1.key])[0] + c2 = g1.successors(key_to_node[s2.key])[0] + r = _MockRule(g1, None, None) + expected_results = [v3.outputs[0]] + # delete c5 s5 will fail + with pytest.raises(ReplaceSubgraphError) as e: + r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]}) + assert g1.results == expected_results + assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}} + expected_edges = { + s1: [c1], + c1: [v1], + v1: [v3], + s2: [c2], + c2: [v2], + v2: [v3], + v3: [], + } + for pred, successors in expected_edges.items(): + pred_node = key_to_node[pred.key] + assert g1.count_successors(pred_node) == len(successors) + for successor in successors: + assert g1.has_successor(pred_node, key_to_node[successor.key]) + + c1.inputs.clear() + c2.inputs.clear() + r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]}) + assert g1.results == expected_results + assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}} + expected_edges = { + c1: [v1], + v1: [v3], + c2: [v2], + v2: [v3], + v3: [], + } + for pred, successors in expected_edges.items(): + pred_node = key_to_node[pred.key] + assert g1.count_successors(pred_node) == len(successors) + for successor in successors: + assert g1.has_successor(pred_node, key_to_node[successor.key]) + + +def test_replace_subgraph_without_removing_nodes(): + """ + Original Graph: + s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2 + + Target Graph: + s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2 + s3 ---> c3 ---> v3 + + Nothing will be removed. + Subgraph only contains [s3, c3, v3] + """ + s1 = mt.random.randint(0, 100, size=(10, 4)) + v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5) + s2 = mt.random.randint(0, 100, size=(10, 4)) + v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5) + v4 = v1.add(v2) + g1 = v4.build_graph() + + s3 = mt.random.randint(0, 100, size=(10, 4)) + v3 = md.DataFrame(s3, columns=list("ABCD"), chunk_size=5) + g2 = v3.build_graph() + key_to_node = { + node.key: node for node in itertools.chain(g1.iter_nodes(), g2.iter_nodes()) + } + expected_results = [v3.outputs[0], v4.outputs[0]] + c1 = g1.successors(key_to_node[s1.key])[0] + c2 = g1.successors(key_to_node[s2.key])[0] + c3 = g2.successors(key_to_node[s3.key])[0] + r = _MockRule(g1, None, None) + r.replace_subgraph(g2, None, expected_results) + assert g1.results == expected_results + assert set(g1) == { + key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, s3, c3, v3, v4} + } + expected_edges = { + s1: [c1], + c1: [v1], + v1: [v4], + s2: [c2], + c2: [v2], + v2: [v4], + s3: [c3], + c3: [v3], + v3: [], + v4: [], + } + for pred, successors in expected_edges.items(): + pred_node = key_to_node[pred.key] + assert g1.count_successors(pred_node) == len(successors) + for successor in successors: + assert g1.has_successor(pred_node, key_to_node[successor.key]) From 32ea951a9990ddce55d37378f37186dc3240f3bd Mon Sep 17 00:00:00 2001 From: EricPai Date: Tue, 27 Jun 2023 17:40:29 +0800 Subject: [PATCH 2/8] Use ValueError instead --- mars/optimization/logical/core.py | 8 ++------ mars/optimization/logical/tests/test_core.py | 4 ++-- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/mars/optimization/logical/core.py b/mars/optimization/logical/core.py index b0e3104bc6..33bd61c97d 100644 --- a/mars/optimization/logical/core.py +++ b/mars/optimization/logical/core.py @@ -165,7 +165,7 @@ def _replace_subgraph( # Add the output key -> node of the subgraph for node in graph.iter_nodes(): if node in removed_nodes: - raise ReplaceSubgraphError(f"The node {node} is in the removed set") + raise ValueError(f"The node {node} is in the removed set") for output in node.outputs: output_to_node[output.key] = node @@ -177,7 +177,7 @@ def _replace_subgraph( for infected_successor in infected_successors: for inp in infected_successor.inputs: if inp.key not in output_to_node: - raise ReplaceSubgraphError( + raise ValueError( f"The output {inp} of node {infected_successor} is missing in the subgraph" ) for node in removed_nodes: @@ -355,7 +355,3 @@ def optimize(cls, graph: EntityGraph) -> OptimizationRecords: graph.results = new_results return records - - -class ReplaceSubgraphError(Exception): - pass diff --git a/mars/optimization/logical/tests/test_core.py b/mars/optimization/logical/tests/test_core.py index c7da1e0192..b4c0124b48 100644 --- a/mars/optimization/logical/tests/test_core.py +++ b/mars/optimization/logical/tests/test_core.py @@ -15,7 +15,7 @@ import pytest -from ..core import OptimizationRule, ReplaceSubgraphError +from ..core import OptimizationRule from .... import tensor as mt from .... import dataframe as md @@ -130,7 +130,7 @@ def test_replace_null_subgraph(): r = _MockRule(g1, None, None) expected_results = [v3.outputs[0]] # delete c5 s5 will fail - with pytest.raises(ReplaceSubgraphError) as e: + with pytest.raises(ValueError) as e: r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]}) assert g1.results == expected_results assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}} From e6b53f65e4fd83d92b0228cc5ad77a9bfd1e5897 Mon Sep 17 00:00:00 2001 From: EricPai Date: Tue, 27 Jun 2023 17:45:54 +0800 Subject: [PATCH 3/8] Fix typo: afftected --- mars/optimization/logical/core.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mars/optimization/logical/core.py b/mars/optimization/logical/core.py index 33bd61c97d..af9b8f1e91 100644 --- a/mars/optimization/logical/core.py +++ b/mars/optimization/logical/core.py @@ -157,7 +157,7 @@ def _replace_subgraph( If the input key of the removed node's successor can't be found in the subgraph. Or some of the nodes of the subgraph are in removed ones. """ - infected_successors = set() + affected_successors = set() output_to_node = dict() removed_nodes = removed_nodes or set() @@ -170,15 +170,15 @@ def _replace_subgraph( output_to_node[output.key] = node for node in removed_nodes: - for infected_successor in self._graph.iter_successors(node): - if infected_successor not in removed_nodes: - infected_successors.add(infected_successor) - # Check whether infected successors' inputs are in subgraph - for infected_successor in infected_successors: - for inp in infected_successor.inputs: + for affected_successor in self._graph.iter_successors(node): + if affected_successor not in removed_nodes: + affected_successors.add(affected_successor) + # Check whether affected successors' inputs are in subgraph + for affected_successor in affected_successors: + for inp in affected_successor.inputs: if inp.key not in output_to_node: raise ValueError( - f"The output {inp} of node {infected_successor} is missing in the subgraph" + f"The output {inp} of node {affected_successor} is missing in the subgraph" ) for node in removed_nodes: self._graph.remove_node(node) @@ -194,7 +194,7 @@ def _replace_subgraph( for node in graph.iter_nodes(): self._graph.add_node(node) - for node in itertools.chain(graph.iter_nodes(), infected_successors): + for node in itertools.chain(graph.iter_nodes(), affected_successors): for inp in node.inputs: pred_node = output_to_node[inp.key] self._graph.add_edge(pred_node, node) From c53e623ebc1bd3b8567dcb947b5390549042b546 Mon Sep 17 00:00:00 2001 From: EricPai Date: Tue, 27 Jun 2023 17:54:59 +0800 Subject: [PATCH 4/8] Refine variable names --- mars/optimization/logical/core.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mars/optimization/logical/core.py b/mars/optimization/logical/core.py index af9b8f1e91..9f99def5d6 100644 --- a/mars/optimization/logical/core.py +++ b/mars/optimization/logical/core.py @@ -134,7 +134,7 @@ def _replace_node(self, original_node: EntityType, new_node: EntityType): def _replace_subgraph( self, graph: Optional[EntityGraph], - removed_nodes: Optional[Set[EntityType]], + nodes_to_remove: Optional[Set[EntityType]], new_results: Optional[List[Entity]] = None, ): """ @@ -146,7 +146,7 @@ def _replace_subgraph( ---------- graph : EntityGraph, optional The input graph. If it's none, no new node and edge will be added. - removed_nodes : Set[EntityType], optional + nodes_to_remove : Set[EntityType], optional The nodes to be removed. All the edges connected with them are removed as well. new_results : List[EntityType], optional, default None The updated results of the graph. If it's None, then the results will not be updated. @@ -160,18 +160,18 @@ def _replace_subgraph( affected_successors = set() output_to_node = dict() - removed_nodes = removed_nodes or set() + nodes_to_remove = nodes_to_remove or set() if graph is not None: # Add the output key -> node of the subgraph for node in graph.iter_nodes(): - if node in removed_nodes: + if node in nodes_to_remove: raise ValueError(f"The node {node} is in the removed set") for output in node.outputs: output_to_node[output.key] = node - for node in removed_nodes: + for node in nodes_to_remove: for affected_successor in self._graph.iter_successors(node): - if affected_successor not in removed_nodes: + if affected_successor not in nodes_to_remove: affected_successors.add(affected_successor) # Check whether affected successors' inputs are in subgraph for affected_successor in affected_successors: @@ -180,7 +180,7 @@ def _replace_subgraph( raise ValueError( f"The output {inp} of node {affected_successor} is missing in the subgraph" ) - for node in removed_nodes: + for node in nodes_to_remove: self._graph.remove_node(node) if graph is None: @@ -200,7 +200,7 @@ def _replace_subgraph( self._graph.add_edge(pred_node, node) if new_results is not None: - self._graph.results = new_results.copy() + self._graph.results = list(new_results) def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType): pred_original = self._records.get_original_entity(predecessor, predecessor) From 9cb8257ff683e6f0f22f4346e9918d6c2a1e2375 Mon Sep 17 00:00:00 2001 From: EricPai Date: Wed, 28 Jun 2023 11:12:52 +0800 Subject: [PATCH 5/8] Using merge instead of replace as the result updating strategy --- mars/optimization/logical/core.py | 42 +++++++++----- mars/optimization/logical/tests/test_core.py | 61 ++++++++++++-------- 2 files changed, 66 insertions(+), 37 deletions(-) diff --git a/mars/optimization/logical/core.py b/mars/optimization/logical/core.py index 9f99def5d6..3693d24f3e 100644 --- a/mars/optimization/logical/core.py +++ b/mars/optimization/logical/core.py @@ -135,7 +135,8 @@ def _replace_subgraph( self, graph: Optional[EntityGraph], nodes_to_remove: Optional[Set[EntityType]], - new_results: Optional[List[Entity]] = None, + new_results: Optional[List[Entity]], + results_to_remove: Optional[List[Entity]], ): """ Replace the subgraph from the self._graph represented by a list of nodes with input graph. @@ -148,19 +149,28 @@ def _replace_subgraph( The input graph. If it's none, no new node and edge will be added. nodes_to_remove : Set[EntityType], optional The nodes to be removed. All the edges connected with them are removed as well. - new_results : List[EntityType], optional, default None - The updated results of the graph. If it's None, then the results will not be updated. + new_results : List[Entity], optional + The new results to be added to the graph. + results_to_remove : List[Entity], optional + The results to be removed from the graph. If a result is not in self._graph.results, it will be ignored. Raises ------ - ReplaceSubgraphError - If the input key of the removed node's successor can't be found in the subgraph. - Or some of the nodes of the subgraph are in removed ones. + ValueError + 1. If the input key of the removed node's successor can't be found in the subgraph. + 2. Or some of the nodes of the subgraph are in removed ones. + 3. Or the added result is not a valid output of any node in the updated graph. """ affected_successors = set() output_to_node = dict() nodes_to_remove = nodes_to_remove or set() + results_to_remove = results_to_remove or list() + new_results = new_results or list() + final_results = set( + filter(lambda x: x not in results_to_remove, self._graph.results) + ) + if graph is not None: # Add the output key -> node of the subgraph for node in graph.iter_nodes(): @@ -169,6 +179,17 @@ def _replace_subgraph( for output in node.outputs: output_to_node[output.key] = node + # Add the output key -> node of the original graph + for node in self._graph.iter_nodes(): + if node not in nodes_to_remove: + for output in node.outputs: + output_to_node[output.key] = node + + for result in new_results: + if result.key not in output_to_node: + raise ValueError(f"Unknown result {result} to add") + final_results.update(new_results) + for node in nodes_to_remove: for affected_successor in self._graph.iter_successors(node): if affected_successor not in nodes_to_remove: @@ -180,17 +201,13 @@ def _replace_subgraph( raise ValueError( f"The output {inp} of node {affected_successor} is missing in the subgraph" ) + # Here all the pre-check are passed, we start to replace the subgraph for node in nodes_to_remove: self._graph.remove_node(node) if graph is None: return - # Add the output key -> node of the original graph - for node in self._graph.iter_nodes(): - for output in node.outputs: - output_to_node[output.key] = node - for node in graph.iter_nodes(): self._graph.add_node(node) @@ -199,8 +216,7 @@ def _replace_subgraph( pred_node = output_to_node[inp.key] self._graph.add_edge(pred_node, node) - if new_results is not None: - self._graph.results = list(new_results) + self._graph.results = list(final_results) def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType): pred_original = self._records.get_original_entity(predecessor, predecessor) diff --git a/mars/optimization/logical/tests/test_core.py b/mars/optimization/logical/tests/test_core.py index b4c0124b48..622a341014 100644 --- a/mars/optimization/logical/tests/test_core.py +++ b/mars/optimization/logical/tests/test_core.py @@ -24,8 +24,8 @@ class _MockRule(OptimizationRule): def apply(self) -> bool: pass - def replace_subgraph(self, graph, removed_nodes, new_results=None): - self._replace_subgraph(graph, removed_nodes, new_results) + def replace_subgraph(self, graph, nodes_to_remove, new_results, results_to_remove): + self._replace_subgraph(graph, nodes_to_remove, new_results, results_to_remove) def test_replace_tileable_subgraph(): @@ -78,11 +78,15 @@ def test_replace_tileable_subgraph(): c2 = g1.successors(key_to_node[s2.key])[0] c5 = g1.successors(key_to_node[s5.key])[0] - expected_results = [v8.outputs[0]] + new_results = [v8.outputs[0]] + removed_results = [ + v6.outputs[0], + v8.outputs[0], # v8.outputs[0] is not in the original results, so we ignore it. + ] r.replace_subgraph( - g2, {key_to_node[op.key] for op in [v3, v4, v6]}, expected_results + g2, {key_to_node[op.key] for op in [v3, v4, v6]}, new_results, removed_results ) - assert g1.results == expected_results + assert g1.results == new_results expected_nodes = {s1, c1, v1, s2, c2, v2, s5, c5, v5, v7, v8} assert set(g1) == {key_to_node[n.key] for n in expected_nodes} @@ -110,10 +114,10 @@ def test_replace_tileable_subgraph(): def test_replace_null_subgraph(): """ Original Graph: - s1 ---> c1 ---> v1 ---> v3 <--- v2 <--- c2 <--- s2 + s1 ---> c1 ---> v1 ---> v3(out) <--- v2 <--- c2 <--- s2 Target Graph: - c1 ---> v1 ---> v3 <--- v2 <--- c2 + c1 ---> v1 ---> v3 <--- v2(out) <--- c2 The nodes [s1, s2] will be removed. Subgraph is None @@ -129,30 +133,39 @@ def test_replace_null_subgraph(): c2 = g1.successors(key_to_node[s2.key])[0] r = _MockRule(g1, None, None) expected_results = [v3.outputs[0]] + # delete c5 s5 will fail with pytest.raises(ValueError) as e: - r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]}) - assert g1.results == expected_results - assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}} - expected_edges = { - s1: [c1], - c1: [v1], - v1: [v3], - s2: [c2], - c2: [v2], - v2: [v3], - v3: [], - } - for pred, successors in expected_edges.items(): - pred_node = key_to_node[pred.key] + r.replace_subgraph( + None, {key_to_node[op.key] for op in [s1, s2]}, None, [v2.outputs[0]] + ) + + assert g1.results == expected_results + assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}} + expected_edges = { + s1: [c1], + c1: [v1], + v1: [v3], + s2: [c2], + c2: [v2], + v2: [v3], + v3: [], + } + for pred, successors in expected_edges.items(): + pred_node = key_to_node[pred.key] assert g1.count_successors(pred_node) == len(successors) for successor in successors: assert g1.has_successor(pred_node, key_to_node[successor.key]) c1.inputs.clear() c2.inputs.clear() - r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]}) - assert g1.results == expected_results + r.replace_subgraph( + None, + {key_to_node[op.key] for op in [s1, s2]}, + [v2.outputs[0]], + [v3.outputs[0]], + ) + assert g1.results == [v2.outputs[0]] assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}} expected_edges = { c1: [v1], @@ -198,7 +211,7 @@ def test_replace_subgraph_without_removing_nodes(): c2 = g1.successors(key_to_node[s2.key])[0] c3 = g2.successors(key_to_node[s3.key])[0] r = _MockRule(g1, None, None) - r.replace_subgraph(g2, None, expected_results) + r.replace_subgraph(g2, None, [v3.outputs[0]], None) assert g1.results == expected_results assert set(g1) == { key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, s3, c3, v3, v4} From e2eaefc69c9b24a5a06908ff44d02d072bd28787 Mon Sep 17 00:00:00 2001 From: EricPai Date: Wed, 28 Jun 2023 11:23:07 +0800 Subject: [PATCH 6/8] Fix flake8 error --- mars/optimization/logical/tests/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mars/optimization/logical/tests/test_core.py b/mars/optimization/logical/tests/test_core.py index 622a341014..166b25fceb 100644 --- a/mars/optimization/logical/tests/test_core.py +++ b/mars/optimization/logical/tests/test_core.py @@ -135,7 +135,7 @@ def test_replace_null_subgraph(): expected_results = [v3.outputs[0]] # delete c5 s5 will fail - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError): r.replace_subgraph( None, {key_to_node[op.key] for op in [s1, s2]}, None, [v2.outputs[0]] ) From dd02ba9aa9ca13a35a411ff8bde000e3d4908a97 Mon Sep 17 00:00:00 2001 From: EricPai Date: Mon, 3 Jul 2023 11:54:02 +0800 Subject: [PATCH 7/8] Change results updating algorithm --- mars/optimization/logical/core.py | 29 ++++++++--------- mars/optimization/logical/tests/test_core.py | 34 +++++++++----------- 2 files changed, 29 insertions(+), 34 deletions(-) diff --git a/mars/optimization/logical/core.py b/mars/optimization/logical/core.py index 3693d24f3e..1bf2d1e2ca 100644 --- a/mars/optimization/logical/core.py +++ b/mars/optimization/logical/core.py @@ -135,8 +135,7 @@ def _replace_subgraph( self, graph: Optional[EntityGraph], nodes_to_remove: Optional[Set[EntityType]], - new_results: Optional[List[Entity]], - results_to_remove: Optional[List[Entity]], + new_results: Optional[List[Entity]] = None, ): """ Replace the subgraph from the self._graph represented by a list of nodes with input graph. @@ -149,27 +148,24 @@ def _replace_subgraph( The input graph. If it's none, no new node and edge will be added. nodes_to_remove : Set[EntityType], optional The nodes to be removed. All the edges connected with them are removed as well. - new_results : List[Entity], optional - The new results to be added to the graph. - results_to_remove : List[Entity], optional - The results to be removed from the graph. If a result is not in self._graph.results, it will be ignored. + new_results : List[Entity], optional, default None + The new results to be replaced to the original by their keys. Raises ------ ValueError 1. If the input key of the removed node's successor can't be found in the subgraph. 2. Or some of the nodes of the subgraph are in removed ones. - 3. Or the added result is not a valid output of any node in the updated graph. + 3. Or some of the removed nodes are also in the results. + 4. Or the key of the new result can't be found in the original results. """ affected_successors = set() - output_to_node = dict() nodes_to_remove = nodes_to_remove or set() - results_to_remove = results_to_remove or list() new_results = new_results or list() - final_results = set( - filter(lambda x: x not in results_to_remove, self._graph.results) - ) + result_indices = { + result.key: idx for idx, result in enumerate(self._graph.results) + } if graph is not None: # Add the output key -> node of the subgraph @@ -185,10 +181,12 @@ def _replace_subgraph( for output in node.outputs: output_to_node[output.key] = node + # Check if the updated result is valid for result in new_results: + if result.key not in result_indices: + raise ValueError(f"Unknown result {result} to replace") if result.key not in output_to_node: - raise ValueError(f"Unknown result {result} to add") - final_results.update(new_results) + raise ValueError(f"The result {result} is missing in the updated graph") for node in nodes_to_remove: for affected_successor in self._graph.iter_successors(node): @@ -216,7 +214,8 @@ def _replace_subgraph( pred_node = output_to_node[inp.key] self._graph.add_edge(pred_node, node) - self._graph.results = list(final_results) + for result in new_results: + self._graph.results[result_indices[result.key]] = result def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType): pred_original = self._records.get_original_entity(predecessor, predecessor) diff --git a/mars/optimization/logical/tests/test_core.py b/mars/optimization/logical/tests/test_core.py index 166b25fceb..d06b249bd9 100644 --- a/mars/optimization/logical/tests/test_core.py +++ b/mars/optimization/logical/tests/test_core.py @@ -24,8 +24,8 @@ class _MockRule(OptimizationRule): def apply(self) -> bool: pass - def replace_subgraph(self, graph, nodes_to_remove, new_results, results_to_remove): - self._replace_subgraph(graph, nodes_to_remove, new_results, results_to_remove) + def replace_subgraph(self, graph, nodes_to_remove, new_results=None): + self._replace_subgraph(graph, nodes_to_remove, new_results) def test_replace_tileable_subgraph(): @@ -61,8 +61,9 @@ def test_replace_tileable_subgraph(): g1 = v6.build_graph() v7 = v1.sub(v2) v8 = v7.add(v5) + v8._key = v6.key + v8.outputs[0]._key = v6.key g2 = v8.build_graph() - # Here we use a trick way to construct the subgraph for test only key_to_node = dict() for node in g2.iter_nodes(): @@ -79,15 +80,12 @@ def test_replace_tileable_subgraph(): c5 = g1.successors(key_to_node[s5.key])[0] new_results = [v8.outputs[0]] - removed_results = [ - v6.outputs[0], - v8.outputs[0], # v8.outputs[0] is not in the original results, so we ignore it. - ] - r.replace_subgraph( - g2, {key_to_node[op.key] for op in [v3, v4, v6]}, new_results, removed_results - ) + r.replace_subgraph(g2, {key_to_node[op.key] for op in [v3, v4, v6]}, new_results) assert g1.results == new_results - + for node in g1.iter_nodes(): + if node.key == v8.key: + key_to_node[v8.key] = node + break expected_nodes = {s1, c1, v1, s2, c2, v2, s5, c5, v5, v7, v8} assert set(g1) == {key_to_node[n.key] for n in expected_nodes} @@ -117,7 +115,7 @@ def test_replace_null_subgraph(): s1 ---> c1 ---> v1 ---> v3(out) <--- v2 <--- c2 <--- s2 Target Graph: - c1 ---> v1 ---> v3 <--- v2(out) <--- c2 + c1 ---> v1 ---> v3(out) <--- v2 <--- c2 The nodes [s1, s2] will be removed. Subgraph is None @@ -137,7 +135,7 @@ def test_replace_null_subgraph(): # delete c5 s5 will fail with pytest.raises(ValueError): r.replace_subgraph( - None, {key_to_node[op.key] for op in [s1, s2]}, None, [v2.outputs[0]] + None, {key_to_node[op.key] for op in [s1, s2]}, [v2.outputs[0]] ) assert g1.results == expected_results @@ -161,11 +159,9 @@ def test_replace_null_subgraph(): c2.inputs.clear() r.replace_subgraph( None, - {key_to_node[op.key] for op in [s1, s2]}, - [v2.outputs[0]], - [v3.outputs[0]], + {key_to_node[op.key] for op in [s1, s2]} ) - assert g1.results == [v2.outputs[0]] + assert g1.results == expected_results assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}} expected_edges = { c1: [v1], @@ -206,12 +202,12 @@ def test_replace_subgraph_without_removing_nodes(): key_to_node = { node.key: node for node in itertools.chain(g1.iter_nodes(), g2.iter_nodes()) } - expected_results = [v3.outputs[0], v4.outputs[0]] + expected_results = [v4.outputs[0]] c1 = g1.successors(key_to_node[s1.key])[0] c2 = g1.successors(key_to_node[s2.key])[0] c3 = g2.successors(key_to_node[s3.key])[0] r = _MockRule(g1, None, None) - r.replace_subgraph(g2, None, [v3.outputs[0]], None) + r.replace_subgraph(g2, None) assert g1.results == expected_results assert set(g1) == { key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, s3, c3, v3, v4} From 85058060ac0874d09ee836ec41b11b7622284e41 Mon Sep 17 00:00:00 2001 From: EricPai Date: Fri, 30 Jun 2023 11:35:03 +0800 Subject: [PATCH 8/8] Refine ArithmeticToEval related rules --- mars/optimization/logical/core.py | 32 ----- mars/optimization/logical/tests/test_core.py | 5 +- .../logical/tileable/arithmetic_query.py | 125 ++++++++++++------ 3 files changed, 84 insertions(+), 78 deletions(-) diff --git a/mars/optimization/logical/core.py b/mars/optimization/logical/core.py index 1bf2d1e2ca..7809ade083 100644 --- a/mars/optimization/logical/core.py +++ b/mars/optimization/logical/core.py @@ -13,7 +13,6 @@ # limitations under the License. import functools import itertools -import weakref from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass @@ -92,8 +91,6 @@ def get_original_entity( class OptimizationRule(ABC): - _preds_to_remove = weakref.WeakKeyDictionary() - def __init__( self, graph: EntityGraph, @@ -217,35 +214,6 @@ def _replace_subgraph( for result in new_results: self._graph.results[result_indices[result.key]] = result - def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType): - pred_original = self._records.get_original_entity(predecessor, predecessor) - if predecessor not in self._preds_to_remove: - self._preds_to_remove[pred_original] = {node} - else: - self._preds_to_remove[pred_original].add(node) - - def _remove_collapsable_predecessors(self, node: EntityType): - node = self._records.get_optimization_result(node) or node - preds_opt_to_remove = [] - for pred in self._graph.predecessors(node): - pred_original = self._records.get_original_entity(pred, pred) - pred_opt = self._records.get_optimization_result(pred, pred) - - if pred_opt in self._graph.results or pred_original in self._graph.results: - continue - affect_succ = self._preds_to_remove.get(pred_original) or [] - affect_succ_opt = [ - self._records.get_optimization_result(s, s) for s in affect_succ - ] - if all(s in affect_succ_opt for s in self._graph.successors(pred)): - preds_opt_to_remove.append((pred_original, pred_opt)) - - for pred_original, pred_opt in preds_opt_to_remove: - self._graph.remove_node(pred_opt) - self._records.append_record( - OptimizationRecord(pred_original, None, OptimizationRecordType.delete) - ) - class OperandBasedOptimizationRule(OptimizationRule): """ diff --git a/mars/optimization/logical/tests/test_core.py b/mars/optimization/logical/tests/test_core.py index d06b249bd9..6cfef989d7 100644 --- a/mars/optimization/logical/tests/test_core.py +++ b/mars/optimization/logical/tests/test_core.py @@ -157,10 +157,7 @@ def test_replace_null_subgraph(): c1.inputs.clear() c2.inputs.clear() - r.replace_subgraph( - None, - {key_to_node[op.key] for op in [s1, s2]} - ) + r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]}) assert g1.results == expected_results assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}} expected_edges = { diff --git a/mars/optimization/logical/tileable/arithmetic_query.py b/mars/optimization/logical/tileable/arithmetic_query.py index 5ecf4a2945..604d9b7ac7 100644 --- a/mars/optimization/logical/tileable/arithmetic_query.py +++ b/mars/optimization/logical/tileable/arithmetic_query.py @@ -13,20 +13,27 @@ # limitations under the License. import weakref -from typing import NamedTuple, Optional +from abc import ABC +from typing import NamedTuple, Optional, Type, Set import numpy as np from pandas.api.types import is_scalar from .... import dataframe as md -from ....core import Tileable, get_output_types, ENTITY_TYPE +from ....core import Tileable, get_output_types, ENTITY_TYPE, TileableGraph +from ....core.graph import EntityGraph from ....dataframe.arithmetic.core import DataFrameUnaryUfunc, DataFrameBinopUfunc from ....dataframe.base.eval import DataFrameEval from ....dataframe.indexing.getitem import DataFrameIndex from ....dataframe.indexing.setitem import DataFrameSetitem -from ....typing import OperandType +from ....typing import OperandType, EntityType from ....utils import implements -from ..core import OptimizationRecord, OptimizationRecordType +from ..core import ( + OptimizationRecord, + OptimizationRecordType, + OptimizationRecords, + Optimizer, +) from ..tileable.core import register_operand_based_optimization_rule from .core import OperandBasedOptimizationRule @@ -66,8 +73,70 @@ def builder(lhs: str, rhs: str): _extract_result_cache = weakref.WeakKeyDictionary() +class _EvalRewriteOptimizationRule(OperandBasedOptimizationRule, ABC): + def __init__( + self, + graph: EntityGraph, + records: OptimizationRecords, + optimizer_cls: Type[Optimizer], + ): + super().__init__(graph, records, optimizer_cls) + self._marked_predecessors = dict() + + def _mark_predecessor(self, node: EntityType, predecessor: EntityType): + pred_original = self._records.get_original_entity(predecessor, predecessor) + if predecessor not in self._marked_predecessors: + self._marked_predecessors[pred_original] = {node} + else: + self._marked_predecessors[pred_original].add(node) + + def _find_nodes_to_remove(self, node: EntityType) -> Set[EntityType]: + node = self._records.get_optimization_result(node) or node + removed_nodes = {node} + results_set = set(self._graph.results) + removed_pairs = [] + for pred in self._graph.iter_predecessors(node): + pred_original = self._records.get_original_entity(pred, pred) + pred_opt = self._records.get_optimization_result(pred, pred) + + if pred_opt in results_set or pred_original in results_set: + continue + + affect_succ = self._marked_predecessors.get(pred_original) or [] + affect_succ_opt = [ + self._records.get_optimization_result(s, s) for s in affect_succ + ] + if all(s in affect_succ_opt for s in self._graph.iter_successors(pred)): + removed_pairs.append((pred_original, pred_opt)) + + for pred_original, pred_opt in removed_pairs: + removed_nodes.add(pred_opt) + self._records.append_record( + OptimizationRecord(pred_original, None, OptimizationRecordType.delete) + ) + return removed_nodes + + def _replace_with_new_node(self, original_node: EntityType, new_node: EntityType): + # Find all the nodes to remove + nodes_to_remove = self._find_nodes_to_remove(original_node) + + # Build the replaced subgraph + subgraph = TileableGraph() + subgraph.add_node(new_node) + + new_results = [new_node] if new_node in self._graph.results else None + self._replace_subgraph(subgraph, nodes_to_remove, new_results) + self._records.append_record( + OptimizationRecord( + self._records.get_original_entity(original_node, original_node), + new_node, + OptimizationRecordType.replace, + ) + ) + + @register_operand_based_optimization_rule([DataFrameUnaryUfunc, DataFrameBinopUfunc]) -class SeriesArithmeticToEval(OperandBasedOptimizationRule): +class SeriesArithmeticToEval(_EvalRewriteOptimizationRule): _var_counter = 0 @classmethod @@ -151,7 +220,7 @@ def _extract_unary(self, tileable) -> EvalExtractRecord: if in_tileable is None: return EvalExtractRecord() - self._add_collapsable_predecessor(tileable, op.inputs[0]) + self._mark_predecessor(tileable, op.inputs[0]) return EvalExtractRecord( in_tileable, _func_name_to_builder[func_name](expr), variables ) @@ -164,10 +233,10 @@ def _extract_binary(self, tileable) -> EvalExtractRecord: lhs_tileable, lhs_expr, lhs_vars = self._extract_eval_expression(op.lhs) if lhs_tileable is not None: - self._add_collapsable_predecessor(tileable, op.lhs) + self._mark_predecessor(tileable, op.lhs) rhs_tileable, rhs_expr, rhs_vars = self._extract_eval_expression(op.rhs) if rhs_tileable is not None: - self._add_collapsable_predecessor(tileable, op.rhs) + self._mark_predecessor(tileable, op.rhs) if lhs_expr is None or rhs_expr is None: return EvalExtractRecord() @@ -204,24 +273,10 @@ def apply_to_operand(self, op: OperandType): new_node = new_op.new_tileable( [opt_in_tileable], _key=node.key, _id=node.id, **node.params ).data + self._replace_with_new_node(node, new_node) - self._remove_collapsable_predecessors(node) - self._replace_node(node, new_node) - self._graph.add_edge(opt_in_tileable, new_node) - self._records.append_record( - OptimizationRecord(node, new_node, OptimizationRecordType.replace) - ) - - # check node if it's in result - try: - i = self._graph.results.index(node) - self._graph.results[i] = new_node - except ValueError: - pass - - -class _DataFrameEvalRewriteRule(OperandBasedOptimizationRule): +class _DataFrameEvalRewriteRule(_EvalRewriteOptimizationRule): @implements(OperandBasedOptimizationRule.match_operand) def match_operand(self, op: OperandType) -> bool: optimized_eval_op = self._get_optimized_eval_op(op) @@ -245,16 +300,6 @@ def _get_optimized_eval_op(self, op: OperandType) -> OperandType: def _get_input_columnar_node(self, op: OperandType) -> ENTITY_TYPE: raise NotImplementedError - def _update_op_node(self, old_node: ENTITY_TYPE, new_node: ENTITY_TYPE): - self._replace_node(old_node, new_node) - for in_tileable in new_node.inputs: - self._graph.add_edge(in_tileable, new_node) - - original_node = self._records.get_original_entity(old_node, old_node) - self._records.append_record( - OptimizationRecord(original_node, new_node, OptimizationRecordType.replace) - ) - @implements(OperandBasedOptimizationRule.apply_to_operand) def apply_to_operand(self, op: DataFrameIndex): node = op.outputs[0] @@ -268,10 +313,8 @@ def apply_to_operand(self, op: DataFrameIndex): new_node = new_op.new_tileable( [opt_in_tileable], _key=node.key, _id=node.id, **node.params ).data - - self._add_collapsable_predecessor(node, in_columnar_node) - self._remove_collapsable_predecessors(node) - self._update_op_node(node, new_node) + self._mark_predecessor(node, in_columnar_node) + self._replace_with_new_node(node, new_node) @register_operand_based_optimization_rule([DataFrameIndex]) @@ -360,7 +403,5 @@ def apply_to_operand(self, op: DataFrameIndex): new_node = new_op.new_tileable( pred_opt_node.inputs, _key=node.key, _id=node.id, **node.params ).data - - self._add_collapsable_predecessor(opt_node, pred_opt_node) - self._remove_collapsable_predecessors(opt_node) - self._update_op_node(opt_node, new_node) + self._mark_predecessor(opt_node, pred_opt_node) + self._replace_with_new_node(opt_node, new_node)