Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a universal function to replace subgraph in OptimizationRule #3353

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions mars/core/entity/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __init__(self, *args, **kwargs):
def op(self):
return self._op

@property
def outputs(self):
return self._op.outputs

@property
def inputs(self):
return self.op.inputs
Expand Down
90 changes: 89 additions & 1 deletion mars/optimization/logical/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import itertools
import weakref
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Type, Set

from ...core import OperandType, EntityType, enter_mode
from ...core import OperandType, EntityType, enter_mode, Entity
from ...core.graph import EntityGraph
from ...utils import implements

Expand Down Expand Up @@ -130,6 +131,93 @@
for succ in successors:
self._graph.add_edge(new_node, succ)

def _replace_subgraph(
self,
graph: Optional[EntityGraph],
nodes_to_remove: Optional[Set[EntityType]],
new_results: Optional[List[Entity]],
results_to_remove: Optional[List[Entity]],
):
"""
Replace the subgraph from the self._graph represented by a list of nodes with input graph.
It will delete the nodes in removed_nodes with all linked edges first, and then add (or update if it's still
existed in self._graph) the nodes and edges of the input graph.

Parameters
----------
graph : EntityGraph, optional
The input graph. If it's none, no new node and edge will be added.
nodes_to_remove : Set[EntityType], optional
The nodes to be removed. All the edges connected with them are removed as well.
new_results : List[Entity], optional
The new results to be added to the graph.
results_to_remove : List[Entity], optional
The results to be removed from the graph. If a result is not in self._graph.results, it will be ignored.

Raises
------
ValueError
1. If the input key of the removed node's successor can't be found in the subgraph.
2. Or some of the nodes of the subgraph are in removed ones.
3. Or the added result is not a valid output of any node in the updated graph.
"""
affected_successors = set()

output_to_node = dict()
nodes_to_remove = nodes_to_remove or set()
results_to_remove = results_to_remove or list()
new_results = new_results or list()
final_results = set(
filter(lambda x: x not in results_to_remove, self._graph.results)
)

if graph is not None:
# Add the output key -> node of the subgraph
for node in graph.iter_nodes():
if node in nodes_to_remove:
raise ValueError(f"The node {node} is in the removed set")

Check warning on line 178 in mars/optimization/logical/core.py

View check run for this annotation

Codecov / codecov/patch

mars/optimization/logical/core.py#L178

Added line #L178 was not covered by tests
for output in node.outputs:
output_to_node[output.key] = node

# Add the output key -> node of the original graph
for node in self._graph.iter_nodes():
if node not in nodes_to_remove:
for output in node.outputs:
output_to_node[output.key] = node

for result in new_results:
if result.key not in output_to_node:
raise ValueError(f"Unknown result {result} to add")

Check warning on line 190 in mars/optimization/logical/core.py

View check run for this annotation

Codecov / codecov/patch

mars/optimization/logical/core.py#L190

Added line #L190 was not covered by tests
final_results.update(new_results)

for node in nodes_to_remove:
for affected_successor in self._graph.iter_successors(node):
if affected_successor not in nodes_to_remove:
affected_successors.add(affected_successor)
# Check whether affected successors' inputs are in subgraph
for affected_successor in affected_successors:
for inp in affected_successor.inputs:
if inp.key not in output_to_node:
raise ValueError(
f"The output {inp} of node {affected_successor} is missing in the subgraph"
)
# Here all the pre-check are passed, we start to replace the subgraph
for node in nodes_to_remove:
self._graph.remove_node(node)

if graph is None:
return

for node in graph.iter_nodes():
self._graph.add_node(node)

for node in itertools.chain(graph.iter_nodes(), affected_successors):
for inp in node.inputs:
pred_node = output_to_node[inp.key]
self._graph.add_edge(pred_node, node)

self._graph.results = list(final_results)

def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType):
pred_original = self._records.get_original_entity(predecessor, predecessor)
if predecessor not in self._preds_to_remove:
Expand Down
13 changes: 13 additions & 0 deletions mars/optimization/logical/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
235 changes: 235 additions & 0 deletions mars/optimization/logical/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import pytest


from ..core import OptimizationRule
from .... import tensor as mt
from .... import dataframe as md


class _MockRule(OptimizationRule):
def apply(self) -> bool:
pass

def replace_subgraph(self, graph, nodes_to_remove, new_results, results_to_remove):
self._replace_subgraph(graph, nodes_to_remove, new_results, results_to_remove)


def test_replace_tileable_subgraph():
"""
Original Graph:
s1 ---> c1 ---> v1 ---> v4 ----> v6(output) <--- v5 <--- c5 <--- s5
| ^
| |
V |
v3 ------|
^
|
s2 ---> c2 ---> v2

Target Graph:
s1 ---> c1 ---> v1 ---> v7 ----> v8(output) <--- v5 <--- c5 <--- s5
^
|
s2 ---> c2 ---> v2

The nodes [v3, v4, v6] will be removed.
Subgraph only contains [v7, v8]
"""
s1 = mt.random.randint(0, 100, size=(5, 4))
v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5)
s2 = mt.random.randint(0, 100, size=(5, 4))
v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5)
v3 = v1.add(v2)
v4 = v3.add(v1)
s5 = mt.random.randint(0, 100, size=(5, 4))
v5 = md.DataFrame(s5, columns=list("ABCD"), chunk_size=4)
v6 = v5.sub(v4)
g1 = v6.build_graph()
v7 = v1.sub(v2)
v8 = v7.add(v5)
g2 = v8.build_graph()

# Here we use a trick way to construct the subgraph for test only
key_to_node = dict()
for node in g2.iter_nodes():
key_to_node[node.key] = node
for key, node in key_to_node.items():
if key != v7.key and key != v8.key:
g2.remove_node(node)
r = _MockRule(g1, None, None)
for node in g1.iter_nodes():
key_to_node[node.key] = node

c1 = g1.successors(key_to_node[s1.key])[0]
c2 = g1.successors(key_to_node[s2.key])[0]
c5 = g1.successors(key_to_node[s5.key])[0]

new_results = [v8.outputs[0]]
removed_results = [
v6.outputs[0],
v8.outputs[0], # v8.outputs[0] is not in the original results, so we ignore it.
]
r.replace_subgraph(
g2, {key_to_node[op.key] for op in [v3, v4, v6]}, new_results, removed_results
)
assert g1.results == new_results

expected_nodes = {s1, c1, v1, s2, c2, v2, s5, c5, v5, v7, v8}
assert set(g1) == {key_to_node[n.key] for n in expected_nodes}

expected_edges = {
s1: [c1],
c1: [v1],
v1: [v7],
s2: [c2],
c2: [v2],
v2: [v7],
s5: [c5],
c5: [v5],
v5: [v8],
v7: [v8],
v8: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])


def test_replace_null_subgraph():
"""
Original Graph:
s1 ---> c1 ---> v1 ---> v3(out) <--- v2 <--- c2 <--- s2

Target Graph:
c1 ---> v1 ---> v3 <--- v2(out) <--- c2

The nodes [s1, s2] will be removed.
Subgraph is None
"""
s1 = mt.random.randint(0, 100, size=(10, 4))
v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5)
s2 = mt.random.randint(0, 100, size=(10, 4))
v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5)
v3 = v1.add(v2)
g1 = v3.build_graph()
key_to_node = {node.key: node for node in g1.iter_nodes()}
c1 = g1.successors(key_to_node[s1.key])[0]
c2 = g1.successors(key_to_node[s2.key])[0]
r = _MockRule(g1, None, None)
expected_results = [v3.outputs[0]]

# delete c5 s5 will fail
with pytest.raises(ValueError):
r.replace_subgraph(
None, {key_to_node[op.key] for op in [s1, s2]}, None, [v2.outputs[0]]
)

assert g1.results == expected_results
assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}}
expected_edges = {
s1: [c1],
c1: [v1],
v1: [v3],
s2: [c2],
c2: [v2],
v2: [v3],
v3: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])

c1.inputs.clear()
c2.inputs.clear()
r.replace_subgraph(
None,
{key_to_node[op.key] for op in [s1, s2]},
[v2.outputs[0]],
[v3.outputs[0]],
)
assert g1.results == [v2.outputs[0]]
assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}}
expected_edges = {
c1: [v1],
v1: [v3],
c2: [v2],
v2: [v3],
v3: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])


def test_replace_subgraph_without_removing_nodes():
"""
Original Graph:
s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2

Target Graph:
s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2
s3 ---> c3 ---> v3

Nothing will be removed.
Subgraph only contains [s3, c3, v3]
"""
s1 = mt.random.randint(0, 100, size=(10, 4))
v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5)
s2 = mt.random.randint(0, 100, size=(10, 4))
v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5)
v4 = v1.add(v2)
g1 = v4.build_graph()

s3 = mt.random.randint(0, 100, size=(10, 4))
v3 = md.DataFrame(s3, columns=list("ABCD"), chunk_size=5)
g2 = v3.build_graph()
key_to_node = {
node.key: node for node in itertools.chain(g1.iter_nodes(), g2.iter_nodes())
}
expected_results = [v3.outputs[0], v4.outputs[0]]
c1 = g1.successors(key_to_node[s1.key])[0]
c2 = g1.successors(key_to_node[s2.key])[0]
c3 = g2.successors(key_to_node[s3.key])[0]
r = _MockRule(g1, None, None)
r.replace_subgraph(g2, None, [v3.outputs[0]], None)
assert g1.results == expected_results
assert set(g1) == {
key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, s3, c3, v3, v4}
}
expected_edges = {
s1: [c1],
c1: [v1],
v1: [v4],
s2: [c2],
c2: [v2],
v2: [v4],
s3: [c3],
c3: [v3],
v3: [],
v4: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])