Skip to content

Commit

Permalink
Add replace_subgraph with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ericpai committed Jun 27, 2023
1 parent 3418861 commit b9e6430
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 1 deletion.
4 changes: 4 additions & 0 deletions mars/core/entity/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __init__(self, *args, **kwargs):
def op(self):
return self._op

@property
def outputs(self):
return self._op.outputs

@property
def inputs(self):
return self.op.inputs
Expand Down
78 changes: 77 additions & 1 deletion mars/optimization/logical/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import itertools
import weakref
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Type, Set

from ...core import OperandType, EntityType, enter_mode
from ...core import OperandType, EntityType, enter_mode, Entity
from ...core.graph import EntityGraph
from ...utils import implements

Expand Down Expand Up @@ -130,6 +131,77 @@ def _replace_node(self, original_node: EntityType, new_node: EntityType):
for succ in successors:
self._graph.add_edge(new_node, succ)

def _replace_subgraph(
self,
graph: Optional[EntityGraph],
removed_nodes: Optional[Set[EntityType]],
new_results: Optional[List[Entity]] = None,
):
"""
Replace the subgraph from the self._graph represented by a list of nodes with input graph.
It will delete the nodes in removed_nodes with all linked edges first, and then add (or update if it's still
existed in self._graph) the nodes and edges of the input graph.
Parameters
----------
graph: EntityGraph, optional.
The input graph. If it's none, no new node and edge will be added.
removed_nodes: Set[EntityType], optional.
The nodes to be removed. All the edges connected with them are removed as well.
new_results: List[EntityType], optional, default is None
The updated results of the graph. If it's None, then the results will not be updated.
Raises
------
ReplaceSubgraphError:
If the input key of the removed node's successor can't be found in the subgraph.
Or some of the nodes of the subgraph are in removed ones.
"""
infected_successors = set()

output_to_node = dict()
removed_nodes = removed_nodes or set()
if graph is not None:
# Add the output key -> node of the subgraph
for node in graph.iter_nodes():
if node in removed_nodes:
raise ReplaceSubgraphError(f"The node {node} is in the removed set")
for output in node.outputs:
output_to_node[output.key] = node

for node in removed_nodes:
for infected_successor in self._graph.iter_successors(node):
if infected_successor not in removed_nodes:
infected_successors.add(infected_successor)
# Check whether infected successors' inputs are in subgraph
for infected_successor in infected_successors:
for inp in infected_successor.inputs:
if inp.key not in output_to_node:
raise ReplaceSubgraphError(
f"The output {inp} of node {infected_successor} is missing in the subgraph"
)
for node in removed_nodes:
self._graph.remove_node(node)

if graph is None:
return

# Add the output key -> node of the original graph
for node in self._graph.iter_nodes():
for output in node.outputs:
output_to_node[output.key] = node

for node in graph.iter_nodes():
self._graph.add_node(node)

for node in itertools.chain(graph.iter_nodes(), infected_successors):
for inp in node.inputs:
pred_node = output_to_node[inp.key]
self._graph.add_edge(pred_node, node)

if new_results is not None:
self._graph.results = new_results.copy()

def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType):
pred_original = self._records.get_original_entity(predecessor, predecessor)
if predecessor not in self._preds_to_remove:
Expand Down Expand Up @@ -283,3 +355,7 @@ def optimize(cls, graph: EntityGraph) -> OptimizationRecords:
graph.results = new_results

return records


class ReplaceSubgraphError(Exception):
pass
13 changes: 13 additions & 0 deletions mars/optimization/logical/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
221 changes: 221 additions & 0 deletions mars/optimization/logical/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import pytest

import mars.tensor as mt
from ..core import OptimizationRule, ReplaceSubgraphError
from .... import dataframe as md


class _MockRule(OptimizationRule):
def apply(self) -> bool:
pass

def replace_subgraph(self, graph, removed_nodes, new_results=None):
self._replace_subgraph(graph, removed_nodes, new_results)


def test_replace_tileable_subgraph():
"""
Original Graph:
s1 ---> c1 ---> v1 ---> v4 ----> v6(output) <--- v5 <--- c5 <--- s5
| ^
| |
V |
v3 ------|
^
|
s2 ---> c2 ---> v2
Target Graph:
s1 ---> c1 ---> v1 ---> v7 ----> v8(output) <--- v5 <--- c5 <--- s5
^
|
s2 ---> c2 ---> v2
The nodes [v3, v4, v6] will be removed.
Subgraph only contains [v7, v8]
"""
s1 = mt.random.randint(0, 100, size=(5, 4))
v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5)
s2 = mt.random.randint(0, 100, size=(5, 4))
v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5)
v3 = v1.add(v2)
v4 = v3.add(v1)
s5 = mt.random.randint(0, 100, size=(5, 4))
v5 = md.DataFrame(s5, columns=list("ABCD"), chunk_size=4)
v6 = v5.sub(v4)
g1 = v6.build_graph()
v7 = v1.sub(v2)
v8 = v7.add(v5)
g2 = v8.build_graph()

# Here we use a trick way to construct the subgraph for test only
key_to_node = dict()
for node in g2.iter_nodes():
key_to_node[node.key] = node
for key, node in key_to_node.items():
if key != v7.key and key != v8.key:
g2.remove_node(node)
r = _MockRule(g1, None, None)
for node in g1.iter_nodes():
key_to_node[node.key] = node

c1 = g1.successors(key_to_node[s1.key])[0]
c2 = g1.successors(key_to_node[s2.key])[0]
c5 = g1.successors(key_to_node[s5.key])[0]

expected_results = [v8.outputs[0]]
r.replace_subgraph(
g2, {key_to_node[op.key] for op in [v3, v4, v6]}, expected_results
)
assert g1.results == expected_results

expected_nodes = {s1, c1, v1, s2, c2, v2, s5, c5, v5, v7, v8}
assert set(g1) == {key_to_node[n.key] for n in expected_nodes}

expected_edges = {
s1: [c1],
c1: [v1],
v1: [v7],
s2: [c2],
c2: [v2],
v2: [v7],
s5: [c5],
c5: [v5],
v5: [v8],
v7: [v8],
v8: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])


def test_replace_null_subgraph():
"""
Original Graph:
s1 ---> c1 ---> v1 ---> v3 <--- v2 <--- c2 <--- s2
Target Graph:
c1 ---> v1 ---> v3 <--- v2 <--- c2
The nodes [s1, s2] will be removed.
Subgraph is None
"""
s1 = mt.random.randint(0, 100, size=(10, 4))
v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5)
s2 = mt.random.randint(0, 100, size=(10, 4))
v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5)
v3 = v1.add(v2)
g1 = v3.build_graph()
key_to_node = {node.key: node for node in g1.iter_nodes()}
c1 = g1.successors(key_to_node[s1.key])[0]
c2 = g1.successors(key_to_node[s2.key])[0]
r = _MockRule(g1, None, None)
expected_results = [v3.outputs[0]]
# delete c5 s5 will fail
with pytest.raises(ReplaceSubgraphError) as e:
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
assert g1.results == expected_results
assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}}
expected_edges = {
s1: [c1],
c1: [v1],
v1: [v3],
s2: [c2],
c2: [v2],
v2: [v3],
v3: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])

c1.inputs.clear()
c2.inputs.clear()
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
assert g1.results == expected_results
assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}}
expected_edges = {
c1: [v1],
v1: [v3],
c2: [v2],
v2: [v3],
v3: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])


def test_replace_subgraph_without_removing_nodes():
"""
Original Graph:
s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2
Target Graph:
s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2
s3 ---> c3 ---> v3
Nothing will be removed.
Subgraph only contains [s3, c3, v3]
"""
s1 = mt.random.randint(0, 100, size=(10, 4))
v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5)
s2 = mt.random.randint(0, 100, size=(10, 4))
v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5)
v4 = v1.add(v2)
g1 = v4.build_graph()

s3 = mt.random.randint(0, 100, size=(10, 4))
v3 = md.DataFrame(s3, columns=list("ABCD"), chunk_size=5)
g2 = v3.build_graph()
key_to_node = {
node.key: node for node in itertools.chain(g1.iter_nodes(), g2.iter_nodes())
}
expected_results = [v3.outputs[0], v4.outputs[0]]
c1 = g1.successors(key_to_node[s1.key])[0]
c2 = g1.successors(key_to_node[s2.key])[0]
c3 = g2.successors(key_to_node[s3.key])[0]
r = _MockRule(g1, None, None)
r.replace_subgraph(g2, None, expected_results)
assert g1.results == expected_results
assert set(g1) == {
key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, s3, c3, v3, v4}
}
expected_edges = {
s1: [c1],
c1: [v1],
v1: [v4],
s2: [c2],
c2: [v2],
v2: [v4],
s3: [c3],
c3: [v3],
v3: [],
v4: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])

0 comments on commit b9e6430

Please sign in to comment.