Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Docs] Add sub-graph rewrite tutorial #29

Merged
merged 4 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ sphinx-copybutton
autodocsumm
sphinx-book-theme
matplotlib
sphinxcontrib-bibtex
git+https://github.com/sphinx-contrib/googleanalytics@master#egg=sphinxcontrib-googleanalytics
4 changes: 4 additions & 0 deletions docs/source/_static/custom.css
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,8 @@ code {

.heading-style, h1, h2, h3, h4, h5, h6 {
font-family: Ubuntu, system-ui;
}

dl.class, dl.function {
margin-bottom: 3em;
}
4 changes: 4 additions & 0 deletions docs/source/_static/img/subgraph-rewrite-example.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
'sphinxcontrib.googleanalytics',
'sphinx_copybutton',
'autodocsumm',
'sphinxcontrib.bibtex',
]

# Add any paths that contain templates here, relative to this directory.
Expand All @@ -88,6 +89,9 @@

autodoc_typehints = 'description'

bibtex_default_style = 'unsrt'
bibtex_bibfiles = ['references.bib']

intersphinx_mapping = {
'torch': ('https://pytorch.org/docs/stable', None),
'torchvision': ('https://pytorch.org/vision/stable', None),
Expand Down
1 change: 1 addition & 0 deletions docs/source/python_api/graph/transforms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ hidet.graph.transforms
.. toctree::
:caption: Transforms

subgraph_rewrite
resolve_variant


14 changes: 14 additions & 0 deletions docs/source/python_api/graph/transforms/subgraph_rewrite.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Sub-graph Rewrite Pass
----------------------


.. autoclass:: hidet.graph.transforms.subgraph_rewrite.TensorPattern
:members:

.. autoclass:: hidet.graph.transforms.subgraph_rewrite.OperatorPattern
:members:

.. autoclass:: hidet.graph.transforms.subgraph_rewrite.SubgraphRewriteRule
:members:

.. autofunction:: hidet.graph.transforms.subgraph_rewrite.register_rewrite_rule
7 changes: 7 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
@inproceedings{taso,
title={TASO: optimizing deep learning computation with automatic generation of graph substitutions},
author={Jia, Zhihao and Padon, Oded and Thomas, James and Warszawski, Todd and Zaharia, Matei and Aiken, Alex},
booktitle={Proceedings of the 27th ACM Symposium on Operating Systems Principles},
pages={47--62},
year={2019}
}
2 changes: 2 additions & 0 deletions gallery/how-to-guides/add-operator-resolve-rule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""
.. _add-operator-resolve-rule:

Add Operator Resolve Rule
=========================

Expand Down
176 changes: 173 additions & 3 deletions gallery/how-to-guides/add-subgraph-rewrite-rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,176 @@
Add Sub-Graph Rewrite Rule
==========================

.. todo::
Will come soon.
"""
This tutorial shows how to add a sub-graph rewrite rule in the graph optimization pipeline. Sub-graph rewriting is an
important technique in graph optimization. It is used to replace a sub-graph with another sub-graph, which is usually
more efficient than the original one. For example, we can replace a sub-graph with two matrix multiplications sharing
the same input and one addition with a concatenation and a single matrix multiplication:

.. figure:: /_static/img/subgraph-rewrite-example.svg
:align: center
:scale: 70%

The sub-graph rewrite rule that fuses two matrix multiplications.

.. seealso::
:class: margin

TASO :cite:`taso` systematically studies the sub-graph rewrite optimization for deep learning workloads.

After the rewrite, the graph becomes more efficient as we only need to run a single kernel and the `fused` matrix
multiplication usually exposes more parallelism to utilize the underlying hardware. We can also fuse multiple
convolutions into a single one, or do other sub-graph rewrites.

Sub-graph rewrite in Hidet
--------------------------

In Hidet, we use a *sub-graph rewrite rule* to describe the rewrite. A sub-graph rewrite rule contains two parts:

- **Sub-graph pattern**: a sub-graph pattern that we use to match the sub-graph in the graph. The pattern is a directed
acyclic graph (DAG) where each node is an operator pattern and each edge is a tensor pattern. We only specify the
operator type for each node and whether the (input) tensors are symbolic or concrete.

- **Target sub-graph constructor**: when we find a sub-graph that matches the pattern, we use the constructor to
construct the target sub-graph that replaces the matched sub-graph. When constructing the target sub-graph, we can
use the matched tensors and nodes to further determine whether the rewrite is applicable. If applicable, the
constructor returns the target sub-graph, otherwise it returns ``None``.

In the above example, the sub-graph pattern contains three input tensors, where x1 is a symbolic tensor and w1, w2 are
two concrete tensors (i.e., we know the concrete values of w1 and w2). There are three operators in the pattern, where
the first two are matrix multiplications and the last one is an addition. The output tensor of the addition is the
output tensor of the pattern. When we find a sub-graph that matches the pattern, we use the constructor to construct
the target sub-graph and replace the matched sub-graph with the target sub-graph.

.. note::

**Difference between sub-graph rewrite and operator resolving**. Although
:ref:`operator resolving <add-operator-resolve-rule>` can be conceptually considered as a special case of
sub-graph rewrite, we use a different mechanism to implement them and their execution order is different. The operator
resolving can be performed efficiently thus we can add arbitrary number of operator resolve rules. But the sub-graph
rewrite is usually more expensive. Second, we run the sub-graph rewrite pass before the operator resolving pass, so
that we can use the generic operators in the sub-graph patterns without worrying about the operator resolving.


Add a sub-graph rewrite rule
----------------------------

Let's implement the sub-graph rewrite rule shown in the above example. Before we start, we first create a new model
that contains the sub-graph we want to rewrite:

"""
from typing import Optional, List

import hidet
from hidet import Tensor, FlowGraph, Operator
from hidet import ops
from hidet.graph.transforms.graph_patterns import MatchDict


def example_model(x: Tensor, w0: Tensor, w1: Tensor, w2: Tensor):
x = ops.matmul(x, w0)
y1 = ops.matmul(x, w1)
y2 = ops.matmul(x, w2)
y = ops.relu(ops.concat([y1, y2], axis=1))
return y


x = hidet.symbol([3, 3])
w0, w1, w2 = hidet.randn([3, 3]), hidet.randn([3, 3]), hidet.randn([3, 3])
y = example_model(x, w0, w1, w2)
graph: FlowGraph = hidet.trace_from(y, inputs=[x])
print(graph)

# %%
# Then, we define and register the sub-graph rewrite rule.
#
from hidet.graph.ops.definitions import MatmulOp, ConcatOp
from hidet.graph.transforms import TensorPattern, SubgraphRewriteRule
from hidet.graph.transforms import op_pattern, register_rewrite_rule
from hidet.utils import same_list


# register the rewrite rule, only registered rewrite rules will be applied
@register_rewrite_rule
class FuseTwoMatmulRewriteRule(SubgraphRewriteRule):
def __init__(self):
super().__init__(name="new: [matmul(x, c1), matmul(x,c2)] => matmul(x, [c1, c2])")
self.x = TensorPattern() # x can match either a symbolic or concrete tensor
self.c1 = TensorPattern(is_const=True) # c1 can only match a concrete tensor
self.c2 = TensorPattern(is_const=True) # c2 can only match a concrete tensor
self.y1 = op_pattern(MatmulOp, [self.x, self.c1]) # pattern: y1 = matmul(x, c1)
self.y2 = op_pattern(MatmulOp, [self.x, self.c2]) # pattern: y2 = matmul(x, c2)
self.y = op_pattern(ConcatOp, [self.y1, self.y2]) # pattern: y = concat(y1, y2)

def source(self) -> List[TensorPattern]:
# Return the output tensors of the source sub-graph pattern.
return [self.y]

def target(self, matched: MatchDict) -> Optional[List[Tensor]]:
# The target sub-graph constructor
# The matched dictionary has type Dict[TensorPattern, Tensor]
# that maps the patterns to the matched tensors.
x, c1, c2, y = [matched[t] for t in [self.x, self.c1, self.c2, self.y]]
concat: Operator = y.op

# We can use the matched tensors to determine whether the rewrite is applicable.
# For example, in this case, we check whether the two weight matrices share the
# same shape except the last dimension.
if (
2 <= len(c1.shape) == len(c2.shape) and
same_list(c1.shape[:-1], c2.shape[:-1]) and
concat.attrs["axis"] == len(y.shape) - 1
):
# If applicable, we construct the target sub-graph and return the output tensors.
c = ops.concat([c1, c2], axis=-1)
y = ops.matmul(x, c)
return [y]
else:
# If not, we return None to indicate that the rewrite is not applicable.
return None


# %%
# We can check that the rewrite rule has been registered:
from hidet.graph.transforms import registered_rewrite_rules

print('Registered rewrite rules:')
for rule in registered_rewrite_rules:
assert isinstance(rule, SubgraphRewriteRule)
print(rule.name)

# %%
# Apply the rewrite rule
# ----------------------
# Besides the predefined rewrite rules, we can see that the rewrite rule we just registered is also included at the
# last line. In this tutorial, to prevent the default rewrite rules from being applied, we first clear the registered
# rewrite rules and then register the rewrite rule we just defined:
registered_rewrite_rules.clear()
register_rewrite_rule(FuseTwoMatmulRewriteRule()) # a second way to register the rewrite rule

# %%
# The rewrite process is done in a graph optimization pass called `subgraph_rewrite_pass`.
from hidet.graph.transforms import subgraph_rewrite_pass

rewrite_pass = subgraph_rewrite_pass()
rewritten_graph: FlowGraph = rewrite_pass(graph)
print(rewritten_graph)

# %%
# We can see that the rewritten graph contains 2 matmul operators instead of 3. There is no concat operator in the
# rewritten graph, because the inputs of concat operator are constant tensors and thus have been folded.
#
# We do not need to call the rewrite pass explicitly. It will be called automatically when we call
# :func:`hidet.graph.optimize`, together with other graph optimization passes.
graph_opt: FlowGraph = hidet.graph.optimize(graph)
print(graph_opt)

# %%
# Summary
# -------
# In this tutorial, we have learned how to define and register a sub-graph rewrite rule. It is an important
# component of the graph optimization framework. Hidet uses it to implement some horizontal fusion rules.

# %%
# References
# ----------
# .. bibliography::
24 changes: 10 additions & 14 deletions python/hidet/graph/ops/definitions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
# pylint: disable=redefined-builtin
from .conv2d import conv2d, conv2d_winograd, conv2d_gemm
from .conv2d import conv2d_gemm_image_transform, conv2d_gemm_filter_transform, conv2d_gemm_inverse_transform
from .conv2d import conv2d_winograd_image_transform, conv2d_winograd_filter_transform, conv2d_winograd_inverse_transform

from .conv2d_transpose import conv2d_transpose, conv2d_transpose_gemm

from .matmul import batch_matmul, matmul
from .pool import max_pool2d, avg_pool2d
from .softmax import softmax
from .activation import relu, sigmoid, relu6, clip, prelu
from .norm import batch_norm_infer, instance_norm
from .image import resize2d
from .arithmetic import add, sub, multiply, divide, neg, sqrt, rsqrt, where, max, min, reciprocal, exp, log, abs
from .arithmetic import bitwise_and, bitwise_not, bitwise_or, bitwise_xor, ceil, rightshift, leftshift
from .compare import equal, less_than, greater_than, less_or_equal, greater_or_equal, cond_not, cond_and
from .reduce import reduce_mean, reduce_min, reduce_max, reduce_sum, reduce_var, argmin, argmax
from .transform import squeeze, unsqueeze, flatten, concat, cast, take, rearrange, strided_slice, split, pad, conv_pad
from .pool import max_pool2d, avg_pool2d
from .softmax import softmax
from .activation import relu, sigmoid, relu6, clip, prelu
from .norm import batch_norm_infer, instance_norm
from .image import resize2d
from .cumulative import cumsum
from .special import barrier
from .conv2d import conv2d
from .conv2d_transpose import conv2d_transpose
from .matmul import batch_matmul, matmul

from .matmul import BatchMatmulOp
from .matmul import BatchMatmulOp, MatmulOp
from .conv2d import Conv2dOp
from .arithmetic import ErfOp, PowOp, AddOp, SubOp, MultiplyOp, DivideOp, WhereOp
from .compare import EqualOp
from .reduce import ReduceSumOp, ReduceMeanOp
from .transform import PadOp
from .transform import PadOp, ConcatOp

from . import utils
from . import arithmetic_resolve
2 changes: 2 additions & 0 deletions python/hidet/graph/ops/definitions/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,8 @@ def transpose(x: Tensor, axes: Optional[List[int]] = None) -> Tensor:


def concat(tensors: List[Tensor], axis: int) -> Tensor:
if not isinstance(tensors, (list, tuple)) or any(not isinstance(t, Tensor) for t in tensors):
raise ValueError('concat: expect a sequence of tensors, but got: {}'.format(type(tensors)))
return ConcatOp(*tensors, axis=axis).get_output(0)


Expand Down
6 changes: 4 additions & 2 deletions python/hidet/graph/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from .base import GraphPass, PassContext, logger
from .instruments import GraphPassInstrument, SaveGraphInstrument, ProfileInstrument
from .fold_const import fold_const_pass
from .pattern_transform import pattern_transform_pass
from .subgraph_rewrite import subgraph_rewrite_pass
from .automatic_mix_precision import automatic_mix_precision_pass
from .resolve_variant import resolve_variant_pass
from .fuse_operator import fuse_operator_pass
from .eliminate_barrier import eliminate_barrier_pass

from .resolve_variant import ResolveRule, register_resolve_rule, get_resolve_chain
from .graph_patterns import TensorPattern, OperatorPattern, SubgraphRewriteRule, register_rewrite_rule, op_pattern
from .graph_patterns import registered_rewrite_rules


def optimize(graph: FlowGraph) -> FlowGraph:
Expand All @@ -35,7 +37,7 @@ def optimize(graph: FlowGraph) -> FlowGraph:
"""
passes = [
fold_const_pass(),
pattern_transform_pass(),
subgraph_rewrite_pass(),
automatic_mix_precision_pass(),
resolve_variant_pass(),
fuse_operator_pass(),
Expand Down
8 changes: 2 additions & 6 deletions python/hidet/graph/transforms/graph_patterns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from typing import List
from .base import TensorPattern, OperatorPattern, GraphPattern, MatchDict, Usage, graph_pattern_match
from .base import TensorPattern, OperatorPattern, SubgraphRewriteRule, MatchDict, Usage, graph_pattern_match
from .base import register_rewrite_rule, op_pattern, registered_rewrite_rules
from .arithmetic_patterns import arithmetic_patterns
from .transform_patterns import transform_patterns
from .conv2d_patterns import conv2d_patterns
from .matmul_patterns import matmul_patterns


def all_graph_patterns() -> List[GraphPattern]:
return arithmetic_patterns() + transform_patterns() + conv2d_patterns() + matmul_patterns()
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List, Optional
from .base import GraphPattern, TensorPattern, MatchDict
from hidet.utils.py import initialize
from .base import SubgraphRewriteRule, TensorPattern, MatchDict, register_rewrite_rule


class arithmeticGraphPattern(GraphPattern):
class ArithmeticSubgraphRewriteRule(SubgraphRewriteRule):
def __init__(self, name, fsrc, fdst):
super().__init__(name)
x, y = TensorPattern.tensors(2, is_symbolic=True) # can not be const
Expand All @@ -24,7 +25,8 @@ def target(self, matched: MatchDict) -> Optional[List[TensorPattern]]:
# return [constructor.visit(self.tgt)]


def arithmetic_patterns() -> List[GraphPattern]:
@initialize()
def arithmetic_patterns():
# # tensors can be used as pattern inputs
# x, y, z = TensorPattern.tensors(3, is_symbolic=True) # can not be const
# a, b, c = TensorPattern.tensors(3, is_const=True) # can not be symbolic
Expand All @@ -41,4 +43,6 @@ def arithmetic_patterns() -> List[GraphPattern]:
lambda x, y, a, b: (x + y) + (a + b),
],
]
return [arithmeticGraphPattern(name, src, tgt) for name, src, tgt in pairs]
rules = [ArithmeticSubgraphRewriteRule(name, src, tgt) for name, src, tgt in pairs]
for rule in rules:
register_rewrite_rule(rule)