Skip to content

Commit

Permalink
Improve numexpr fusion (#3177)
Browse files Browse the repository at this point in the history
  • Loading branch information
fyrestone committed Jul 11, 2022
1 parent 3484e28 commit dbbbcaa
Show file tree
Hide file tree
Showing 6 changed files with 626 additions and 101 deletions.
93 changes: 93 additions & 0 deletions benchmarks/asv_bench/benchmarks/execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses

import mars.tensor as mt
from mars import new_session
from mars.core.graph import (
TileableGraph,
TileableGraphBuilder,
ChunkGraphBuilder,
ChunkGraph,
)
from mars.serialization import serialize
from mars.services.task import new_task_id
from mars.services.task.execution.ray.executor import execute_subtask


def _gen_subtask_chunk_graph(t):
graph = TileableGraph([t.data])
next(TileableGraphBuilder(graph).build())
return next(ChunkGraphBuilder(graph, fuse_enabled=False).build())


@dataclasses.dataclass
class _ASVSubtaskInfo:
subtask_id: str
serialized_subtask_chunk_graph: ChunkGraph


class NumExprExecutionSuite:
"""
Benchmark that times performance of numexpr execution.
"""

def setup(self):
self.session = new_session(default=True)
self.asv_subtasks = []
for _ in range(100):
a = mt.arange(1e6)
b = mt.arange(1e6) * 0.1
c = mt.sin(a) + mt.arcsinh(a / b)
subtask_id = new_task_id()
subtask_chunk_graph = _gen_subtask_chunk_graph(c)
self.asv_subtasks.append(
_ASVSubtaskInfo(
subtask_id=subtask_id,
serialized_subtask_chunk_graph=serialize(subtask_chunk_graph),
)
)

c = a * b - 4.1 * a > 2.5 * b
subtask_id = new_task_id()
subtask_chunk_graph = _gen_subtask_chunk_graph(c)
self.asv_subtasks.append(
_ASVSubtaskInfo(
subtask_id=subtask_id,
serialized_subtask_chunk_graph=serialize(subtask_chunk_graph),
)
)

def teardown(self):
self.session.stop_server()

def time_numexpr_execution(self):
for _ in range(100):
a = mt.arange(1e6)
b = mt.arange(1e6) * 0.1
c = mt.sin(a) + mt.arcsinh(a / b)
c.execute(show_progress=False)
c = a * b - 4.1 * a > 2.5 * b
c.execute(show_progress=False)

def time_numexpr_subtask_execution(self):
for asv_subtask_info in self.asv_subtasks:
execute_subtask(
"",
asv_subtask_info.subtask_id,
asv_subtask_info.serialized_subtask_chunk_graph,
set(),
False,
)
192 changes: 156 additions & 36 deletions mars/optimization/physical/numexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ...core import ChunkType
import dataclasses
import functools
import logging
from typing import List, Set

import numpy as np

from ...core import ChunkType, ChunkGraph
from ...tensor import arithmetic
from ...tensor import reduction
from ...tensor.fuse import TensorNeFuseChunk
from ...tensor.fuse.numexpr import NUMEXPR_INSTALLED
from .core import RuntimeOptimizer, register_optimizer


logger = logging.getLogger(__name__)


REDUCTION = object()
REDUCTION_OP = {
reduction.TensorSum,
reduction.TensorProd,
Expand Down Expand Up @@ -64,29 +75,99 @@
arithmetic.TensorRshift,
arithmetic.TensorTreeAdd,
arithmetic.TensorTreeMultiply,
arithmetic.TensorFloor,
arithmetic.TensorCeil,
arithmetic.TensorAnd,
arithmetic.TensorOr,
arithmetic.TensorNot,
reduction.TensorSum,
reduction.TensorProd,
reduction.TensorMax,
reduction.TensorMin,
}


def _check_reduction_axis(node: ChunkType):
return len(node.op.axis) == 1 or len(node.op.axis) == node.ndim
@dataclasses.dataclass
class _Fuse:
graph: ChunkGraph
heads: List[ChunkType]
tails: List[ChunkType]


def _support(node: ChunkType):
op_type = type(node.op)
def _can_fuse(node: ChunkType):
op = node.op
op_type = type(op)
if op_type in REDUCTION_OP:
return _check_reduction_axis(node)
return op_type in SUPPORT_OP

if len(op.axis) == 1 or len(op.axis) == node.ndim:
return REDUCTION
else:
return False
# return op_type in SUPPORT_OP
if op_type not in SUPPORT_OP:
return False
if op_type in (arithmetic.TensorOr, arithmetic.TensorAnd):
# numexpr only support logical and or:
# https://numexpr.readthedocs.io/projects/NumExpr3/en/latest/user_guide.html#supported-operators
if np.isscalar(op.lhs) or np.isscalar(op.rhs):
return False
return True


def _collect_fuse(
graph: ChunkGraph,
node: ChunkType,
graph_results: Set[ChunkType],
cached_can_fuse,
):
fuse_graph = ChunkGraph()
fuse_graph.add_node(node)
fuse_heads = []
fuse_tails = []
tail_reduction_node = None

stack = [node]
# Do a full search of sub graph even the fuse tails > 1
while len(stack) != 0:
node = stack.pop()
is_head = graph.count_predecessors(node) == 0
for n in graph.iter_predecessors(node):
can_fuse = cached_can_fuse(n)
if can_fuse is False or can_fuse is REDUCTION:
is_head = True
elif not fuse_graph.contains(n):
stack.append(n)
fuse_graph.add_node(n)
else:
fuse_graph.add_edge(n, node)
if is_head:
fuse_heads.append(node)
# Skip the successors of tail reduction node.
if node is tail_reduction_node:
continue
is_tail = graph.count_successors(node) == 0 or node in graph_results
for n in graph.iter_successors(node):
can_fuse = cached_can_fuse(n)
if can_fuse is False:
is_tail = True
elif can_fuse is REDUCTION:
if tail_reduction_node is None:
tail_reduction_node = n
fuse_tails.append(n)
stack.append(n)
fuse_graph.add_node(n)
elif n is tail_reduction_node:
fuse_graph.add_edge(node, n)
else:
is_tail = True
elif not fuse_graph.contains(n):
stack.append(n)
fuse_graph.add_node(n)
else:
fuse_graph.add_edge(node, n)
if is_tail:
fuse_tails.append(node)

def _transfer_op(node: ChunkType):
op = node.op
if type(op) in REDUCTION_OP and not _check_reduction_axis(node):
return op
return op
return _Fuse(fuse_graph, fuse_heads, fuse_tails)


@register_optimizer
Expand All @@ -100,35 +181,74 @@ def is_available(cls) -> bool:
def optimize(self):
fuses = []
explored = set()
cached_can_fuse = functools.lru_cache(maxsize=None)(_can_fuse)

graph = self._graph
graph_results = set(graph.results)
for node in graph.topological_iter():
if node.op.gpu or node.op.sparse:
# break
return [], []
if type(node.op) not in SUPPORT_OP or node in graph.results:
continue
if node in explored or type(node.op) in REDUCTION_OP:
# TODO: check logic here
continue
if graph.count_successors(node) != 1:
if node in explored or node in graph_results:
continue

selected = [node]
# add successors
cur_node = graph.successors(node)[0]
while graph.count_predecessors(cur_node) == 1 and _support(cur_node):
selected.append(cur_node)
if (
graph.count_successors(cur_node) != 1
or type(cur_node.op) in REDUCTION_OP
or cur_node in graph.results
):
break
else:
cur_node = graph.successors(cur_node)[0]
if len(selected) > 1:
explored.update(selected)
fuses.append(list(selected))
can_fuse = cached_can_fuse(node)
if can_fuse is True:
fuse = _collect_fuse(graph, node, graph_results, cached_can_fuse)
if len(fuse.graph) > 1:
explored.update(fuse.graph)
if len(fuse.tails) == 1:
fuses.append(fuse)
else:
logger.info(
"Refused fusing for numexpr because the tail node count > 1."
)

return self._fuse_nodes(fuses, TensorNeFuseChunk)

def _fuse_nodes(self, fuses: List[_Fuse], fuse_cls):
graph = self._graph
fused_nodes = []

for fuse in fuses:
fuse_graph = fuse.graph
tail_nodes = fuse.tails
head_nodes = fuse.heads
inputs = [
inp for n in head_nodes for inp in n.inputs if inp not in fuse_graph
]

tail_chunk = tail_nodes[0]
tail_chunk_op = tail_chunk.op
fuse_op = fuse_cls(
sparse=tail_chunk_op.sparse,
gpu=tail_chunk_op.gpu,
_key=tail_chunk_op.key,
fuse_graph=fuse_graph,
dtype=tail_chunk.dtype,
)
fused_chunk = fuse_op.new_chunk(
inputs,
kws=[tail_chunk.params],
_key=tail_chunk.key,
_chunk=tail_chunk,
).data

graph.add_node(fused_chunk)
for node in graph.iter_successors(tail_chunk):
graph.add_edge(fused_chunk, node)
for head_chunk in head_nodes:
for node in graph.iter_predecessors(head_chunk):
if not fuse_graph.contains(node):
graph.add_edge(node, fused_chunk)
for node in fuse_graph:
graph.remove_node(node)
fused_nodes.append(fused_chunk)

try:
# check tail node if it's in results
i = graph.results.index(tail_chunk)
graph.results[i] = fused_chunk
except ValueError:
pass

return fuses, fused_nodes

0 comments on commit dbbbcaa

Please sign in to comment.