Skip to content

Commit

Permalink
[Graph] Cast optimizations (#135)
Browse files Browse the repository at this point in the history
* Add 3 new graph patterns:
 * operator x with 2 or 3 outputs that both has a cast -> combine the cast
 * two redundant casts (a->b->a) -> noop

Since the casts are likely to occur after automatic mixed precision, the
pass is run twice

* lint

---------

Co-authored-by: Xin Li <xin@centml.ai>
  • Loading branch information
xinli-git and xinli-centml committed Mar 21, 2023
1 parent 40bd792 commit 0cea056
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/hidet/graph/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def optimize(graph: FlowGraph) -> FlowGraph:
fold_const_pass(),
subgraph_rewrite_pass(),
automatic_mix_precision_pass(),
subgraph_rewrite_pass(),
resolve_variant_pass(),
fuse_operator_pass(),
eliminate_barrier_pass(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from hidet.graph import ops
from hidet.graph.ir.flow_graph import Tensor
from hidet.graph.ops.definitions.transform import ReshapeOp, SqueezeOp
from hidet.graph.ops.definitions.transform import ReshapeOp, SqueezeOp, CastOp
from hidet.utils import prod, initialize
from .base import SubgraphRewriteRule, TensorPattern, MatchDict, op_pattern, register_rewrite_rule

Expand Down Expand Up @@ -114,8 +114,65 @@ def target(self, matched: MatchDict) -> Optional[List[Tensor]]:
return [ops.squeeze(x * c, dims=dims)]


class FanoutTwoCast(SubgraphRewriteRule):
def __init__(self):
super().__init__('y1 = cast(x), y2 = cast(x) => y1 = y2 = z = cast(x)')
self.x = TensorPattern.tensor(is_symbolic=True)
self.c1 = op_pattern(CastOp, [self.x])
self.c2 = op_pattern(CastOp, [self.x])

def source(self) -> List[TensorPattern]:
return [self.c1, self.c2]

def target(self, matched: MatchDict) -> Optional[List[Tensor]]:
x, c1, c2 = [matched[v] for v in [self.x, self.c1, self.c2]]
if c1.dtype != c2.dtype:
return None
z = ops.cast(x, dtype=c1.dtype)
return [z, z]


class FanoutThreeCast(SubgraphRewriteRule):
def __init__(self):
super().__init__('y1,2,3 = cast(x) => y1 = y2 = y3 = z = cast(x)')
self.x = TensorPattern.tensor(is_symbolic=True)
self.c1 = op_pattern(CastOp, [self.x])
self.c2 = op_pattern(CastOp, [self.x])
self.c3 = op_pattern(CastOp, [self.x])

def source(self) -> List[TensorPattern]:
return [self.c1, self.c2, self.c3]

def target(self, matched: MatchDict) -> Optional[List[Tensor]]:
x, c1, c2, c3 = [matched[v] for v in [self.x, self.c1, self.c2, self.c3]]
if not (c1.dtype == c2.dtype and c2.dtype == c3.dtype):
return None
z = ops.cast(x, dtype=c1.dtype)
return [z, z, z]


class DoubleCast(SubgraphRewriteRule):
def __init__(self):
super().__init__('cast(cast(x)) => x')
self.x = TensorPattern.tensor(is_symbolic=True)
self.c1 = op_pattern(CastOp, [self.x])
self.c2 = op_pattern(CastOp, [self.c1])

def source(self) -> List[TensorPattern]:
return [self.c2]

def target(self, matched: MatchDict) -> Optional[List[Tensor]]:
x, _, c2 = [matched[v] for v in [self.x, self.c1, self.c2]]
if not c2.dtype == x.dtype:
return None
return [x]


@initialize()
def transform_patterns():
register_rewrite_rule(ReshapeScalePattern())
register_rewrite_rule(ReshapeBiasPattern())
register_rewrite_rule(SqueezeMultiplyPattern())
register_rewrite_rule(FanoutTwoCast())
register_rewrite_rule(FanoutThreeCast())
register_rewrite_rule(DoubleCast())

0 comments on commit 0cea056

Please sign in to comment.