Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a universal function to replace subgraph in OptimizationRule #3353

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mars/core/entity/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __init__(self, *args, **kwargs):
def op(self):
return self._op

@property
def outputs(self):
return self._op.outputs

@property
def inputs(self):
return self.op.inputs
Expand Down
119 changes: 87 additions & 32 deletions mars/optimization/logical/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import weakref
import itertools
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Type, Set

from ...core import OperandType, EntityType, enter_mode
from ...core import OperandType, EntityType, enter_mode, Entity
from ...core.graph import EntityGraph
from ...utils import implements

Expand Down Expand Up @@ -91,8 +91,6 @@ def get_original_entity(


class OptimizationRule(ABC):
_preds_to_remove = weakref.WeakKeyDictionary()

def __init__(
self,
graph: EntityGraph,
Expand Down Expand Up @@ -130,34 +128,91 @@ def _replace_node(self, original_node: EntityType, new_node: EntityType):
for succ in successors:
self._graph.add_edge(new_node, succ)

def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType):
pred_original = self._records.get_original_entity(predecessor, predecessor)
if predecessor not in self._preds_to_remove:
self._preds_to_remove[pred_original] = {node}
else:
self._preds_to_remove[pred_original].add(node)

def _remove_collapsable_predecessors(self, node: EntityType):
node = self._records.get_optimization_result(node) or node
preds_opt_to_remove = []
for pred in self._graph.predecessors(node):
pred_original = self._records.get_original_entity(pred, pred)
pred_opt = self._records.get_optimization_result(pred, pred)

if pred_opt in self._graph.results or pred_original in self._graph.results:
continue
affect_succ = self._preds_to_remove.get(pred_original) or []
affect_succ_opt = [
self._records.get_optimization_result(s, s) for s in affect_succ
]
if all(s in affect_succ_opt for s in self._graph.successors(pred)):
preds_opt_to_remove.append((pred_original, pred_opt))

for pred_original, pred_opt in preds_opt_to_remove:
self._graph.remove_node(pred_opt)
self._records.append_record(
OptimizationRecord(pred_original, None, OptimizationRecordType.delete)
)
def _replace_subgraph(
self,
graph: Optional[EntityGraph],
nodes_to_remove: Optional[Set[EntityType]],
new_results: Optional[List[Entity]] = None,
):
"""
Replace the subgraph from the self._graph represented by a list of nodes with input graph.
It will delete the nodes in removed_nodes with all linked edges first, and then add (or update if it's still
existed in self._graph) the nodes and edges of the input graph.

Parameters
----------
graph : EntityGraph, optional
The input graph. If it's none, no new node and edge will be added.
nodes_to_remove : Set[EntityType], optional
The nodes to be removed. All the edges connected with them are removed as well.
new_results : List[Entity], optional, default None
The new results to be replaced to the original by their keys.

Raises
------
ValueError
1. If the input key of the removed node's successor can't be found in the subgraph.
2. Or some of the nodes of the subgraph are in removed ones.
3. Or some of the removed nodes are also in the results.
4. Or the key of the new result can't be found in the original results.
"""
affected_successors = set()
output_to_node = dict()
nodes_to_remove = nodes_to_remove or set()
new_results = new_results or list()
result_indices = {
result.key: idx for idx, result in enumerate(self._graph.results)
}

if graph is not None:
# Add the output key -> node of the subgraph
for node in graph.iter_nodes():
if node in nodes_to_remove:
raise ValueError(f"The node {node} is in the removed set")
for output in node.outputs:
output_to_node[output.key] = node

# Add the output key -> node of the original graph
for node in self._graph.iter_nodes():
if node not in nodes_to_remove:
for output in node.outputs:
output_to_node[output.key] = node

# Check if the updated result is valid
for result in new_results:
if result.key not in result_indices:
raise ValueError(f"Unknown result {result} to replace")
if result.key not in output_to_node:
raise ValueError(f"The result {result} is missing in the updated graph")

for node in nodes_to_remove:
for affected_successor in self._graph.iter_successors(node):
if affected_successor not in nodes_to_remove:
affected_successors.add(affected_successor)
# Check whether affected successors' inputs are in subgraph
for affected_successor in affected_successors:
for inp in affected_successor.inputs:
if inp.key not in output_to_node:
raise ValueError(
f"The output {inp} of node {affected_successor} is missing in the subgraph"
)
# Here all the pre-check are passed, we start to replace the subgraph
for node in nodes_to_remove:
self._graph.remove_node(node)

if graph is None:
return

for node in graph.iter_nodes():
self._graph.add_node(node)

for node in itertools.chain(graph.iter_nodes(), affected_successors):
for inp in node.inputs:
pred_node = output_to_node[inp.key]
self._graph.add_edge(pred_node, node)

for result in new_results:
self._graph.results[result_indices[result.key]] = result


class OperandBasedOptimizationRule(OptimizationRule):
Expand Down
13 changes: 13 additions & 0 deletions mars/optimization/logical/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
228 changes: 228 additions & 0 deletions mars/optimization/logical/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import pytest


from ..core import OptimizationRule
from .... import tensor as mt
from .... import dataframe as md


class _MockRule(OptimizationRule):
def apply(self) -> bool:
pass

def replace_subgraph(self, graph, nodes_to_remove, new_results=None):
self._replace_subgraph(graph, nodes_to_remove, new_results)


def test_replace_tileable_subgraph():
"""
Original Graph:
s1 ---> c1 ---> v1 ---> v4 ----> v6(output) <--- v5 <--- c5 <--- s5
| ^
| |
V |
v3 ------|
^
|
s2 ---> c2 ---> v2

Target Graph:
s1 ---> c1 ---> v1 ---> v7 ----> v8(output) <--- v5 <--- c5 <--- s5
^
|
s2 ---> c2 ---> v2

The nodes [v3, v4, v6] will be removed.
Subgraph only contains [v7, v8]
"""
s1 = mt.random.randint(0, 100, size=(5, 4))
v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5)
s2 = mt.random.randint(0, 100, size=(5, 4))
v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5)
v3 = v1.add(v2)
v4 = v3.add(v1)
s5 = mt.random.randint(0, 100, size=(5, 4))
v5 = md.DataFrame(s5, columns=list("ABCD"), chunk_size=4)
v6 = v5.sub(v4)
g1 = v6.build_graph()
v7 = v1.sub(v2)
v8 = v7.add(v5)
v8._key = v6.key
v8.outputs[0]._key = v6.key
g2 = v8.build_graph()
# Here we use a trick way to construct the subgraph for test only
key_to_node = dict()
for node in g2.iter_nodes():
key_to_node[node.key] = node
for key, node in key_to_node.items():
if key != v7.key and key != v8.key:
g2.remove_node(node)
r = _MockRule(g1, None, None)
for node in g1.iter_nodes():
key_to_node[node.key] = node

c1 = g1.successors(key_to_node[s1.key])[0]
c2 = g1.successors(key_to_node[s2.key])[0]
c5 = g1.successors(key_to_node[s5.key])[0]

new_results = [v8.outputs[0]]
r.replace_subgraph(g2, {key_to_node[op.key] for op in [v3, v4, v6]}, new_results)
assert g1.results == new_results
for node in g1.iter_nodes():
if node.key == v8.key:
key_to_node[v8.key] = node
break
expected_nodes = {s1, c1, v1, s2, c2, v2, s5, c5, v5, v7, v8}
assert set(g1) == {key_to_node[n.key] for n in expected_nodes}

expected_edges = {
s1: [c1],
c1: [v1],
v1: [v7],
s2: [c2],
c2: [v2],
v2: [v7],
s5: [c5],
c5: [v5],
v5: [v8],
v7: [v8],
v8: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])


def test_replace_null_subgraph():
"""
Original Graph:
s1 ---> c1 ---> v1 ---> v3(out) <--- v2 <--- c2 <--- s2

Target Graph:
c1 ---> v1 ---> v3(out) <--- v2 <--- c2

The nodes [s1, s2] will be removed.
Subgraph is None
"""
s1 = mt.random.randint(0, 100, size=(10, 4))
v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5)
s2 = mt.random.randint(0, 100, size=(10, 4))
v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5)
v3 = v1.add(v2)
g1 = v3.build_graph()
key_to_node = {node.key: node for node in g1.iter_nodes()}
c1 = g1.successors(key_to_node[s1.key])[0]
c2 = g1.successors(key_to_node[s2.key])[0]
r = _MockRule(g1, None, None)
expected_results = [v3.outputs[0]]

# delete c5 s5 will fail
with pytest.raises(ValueError):
r.replace_subgraph(
None, {key_to_node[op.key] for op in [s1, s2]}, [v2.outputs[0]]
)

assert g1.results == expected_results
assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}}
expected_edges = {
s1: [c1],
c1: [v1],
v1: [v3],
s2: [c2],
c2: [v2],
v2: [v3],
v3: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])

c1.inputs.clear()
c2.inputs.clear()
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
assert g1.results == expected_results
assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}}
expected_edges = {
c1: [v1],
v1: [v3],
c2: [v2],
v2: [v3],
v3: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])


def test_replace_subgraph_without_removing_nodes():
"""
Original Graph:
s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2

Target Graph:
s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2
s3 ---> c3 ---> v3

Nothing will be removed.
Subgraph only contains [s3, c3, v3]
"""
s1 = mt.random.randint(0, 100, size=(10, 4))
v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5)
s2 = mt.random.randint(0, 100, size=(10, 4))
v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5)
v4 = v1.add(v2)
g1 = v4.build_graph()

s3 = mt.random.randint(0, 100, size=(10, 4))
v3 = md.DataFrame(s3, columns=list("ABCD"), chunk_size=5)
g2 = v3.build_graph()
key_to_node = {
node.key: node for node in itertools.chain(g1.iter_nodes(), g2.iter_nodes())
}
expected_results = [v4.outputs[0]]
c1 = g1.successors(key_to_node[s1.key])[0]
c2 = g1.successors(key_to_node[s2.key])[0]
c3 = g2.successors(key_to_node[s3.key])[0]
r = _MockRule(g1, None, None)
r.replace_subgraph(g2, None)
assert g1.results == expected_results
assert set(g1) == {
key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, s3, c3, v3, v4}
}
expected_edges = {
s1: [c1],
c1: [v1],
v1: [v4],
s2: [c2],
c2: [v2],
v2: [v4],
s3: [c3],
c3: [v3],
v3: [],
v4: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])
Loading