Skip to content

Commit

Permalink
[IR] Creation-time constant fold for constant expressions (#209)
Browse files Browse the repository at this point in the history
* creation-time constant fold for constant expressions

* rename UnaryOp to UnaryExpr, BinaryOp to BinaryExpr, fold unary expr

* .

* .

* .
  • Loading branch information
yaoyaoding committed May 5, 2023
1 parent c0d597b commit daae22e
Show file tree
Hide file tree
Showing 49 changed files with 730 additions and 739 deletions.
10 changes: 5 additions & 5 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Optional, List, Tuple, Dict, Union
import os
import numpy as np
from hidet.ir.dialects.pattern import AnyExpr
from hidet.ir.dialects.pattern import PlaceholderExpr
from hidet.ir import dtypes
from hidet.ir.node import Node
from hidet.ir.type import DataType, PointerType, TensorPointerType, ReferenceType, TensorType, FuncType
Expand Down Expand Up @@ -251,13 +251,13 @@ def visit_BitwiseXor(self, e: BitwiseXor):
return '(' + self(e.a) + ' ^ ' + self(e.b) + ')'

def visit_BitwiseNot(self, e: BitwiseNot):
return '(~' + self(e.base) + ')'
return '(~' + self(e.a) + ')'

def visit_LeftShift(self, e: LeftShift):
return '(' + self(e.base) + ' << ' + self(e.cnt) + ')'
return '(' + self(e.a) + ' << ' + self(e.b) + ')'

def visit_RightShift(self, e: RightShift):
return '(' + self(e.base) + ' >> ' + self(e.cnt) + ')'
return '(' + self(e.a) + ' >> ' + self(e.b) + ')'

def visit_TensorElement(self, e: TensorElement):
if e.protected:
Expand Down Expand Up @@ -562,7 +562,7 @@ def visit_ScalarNode(self, e: ScalarNode):
def visit_TensorNode(self, e: TensorNode):
raise ValueError()

def visit_AnyExpr(self, e: AnyExpr):
def visit_AnyExpr(self, e: PlaceholderExpr):
raise ValueError()

def visit_NotDispatchedNode(self, n: Node):
Expand Down
16 changes: 8 additions & 8 deletions python/hidet/graph/ops/definitions/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

class EqualOp(BinaryElementwiseOp):
def __init__(self, x: Tensor, y: Tensor):
super().__init__(x, y, lambda a, b: expr.Equal(a, b), name='eq')
super().__init__(x, y, lambda a, b: a == b, name='eq')


class NotEqualOp(BinaryElementwiseOp):
def __init__(self, x: Tensor, y: Tensor):
super().__init__(x, y, lambda a, b: expr.NotEqual(a, b), name='ne')
super().__init__(x, y, lambda a, b: a != b, name='ne')


class LessOp(BinaryElementwiseOp):
Expand All @@ -46,25 +46,25 @@ def __init__(self, x: Tensor, y: Tensor):

class LogicalNotOp(UnaryElementwiseOp):
def __init__(self, x: Tensor):
super().__init__(x, lambda a: expr.LogicalNot(a), name='not')
super().__init__(x, lambda a: expr.logical_not(a), name='not')


class LogicalAndOp(BinaryElementwiseOp):
def __init__(self, x: Tensor, y: Tensor):
super().__init__(x, y, lambda a, b: expr.LogicalAnd(a, b), name='and')
super().__init__(x, y, lambda a, b: expr.logical_and(a, b), name='and')


class LogicalOrOp(BinaryElementwiseOp):
def __init__(self, x: Tensor, y: Tensor):
super().__init__(x, y, lambda a, b: expr.LogicalOr(a, b), name='or')
super().__init__(x, y, lambda a, b: expr.logical_or(a, b), name='or')


class LogicalXorOp(BinaryElementwiseOp):
def __init__(self, x: Tensor, y: Tensor):
def expr_logical_xor(a: Expr, b: Expr) -> Expr:
x = expr.LogicalAnd(a, expr.LogicalNot(b))
y = expr.LogicalAnd(expr.LogicalNot(a), b)
return expr.LogicalOr(x, y)
x = expr.logical_and(a, expr.logical_not(b))
y = expr.logical_and(expr.logical_not(a), b)
return expr.logical_or(x, y)

super().__init__(x, y, lambda a, b: expr_logical_xor(a, b), name='xor')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from hidet.ir.expr import if_then_else, LogicalAnd
from hidet.ir.expr import if_then_else, logical_and
from hidet.ir.compute import compute, reduce
from hidet.graph.ops.definitions.utils import Task, Operator, Tensor, TensorNode
from hidet.graph.ops.definitions.utils import input_like, normalize_stride, normalize_padding, normalize_kernel
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(
fcompute=lambda ni, ci, li: reduce(
shape=[og, kernel_size],
fcompute=lambda ogi, ki: if_then_else(
cond=LogicalAnd.join(li + p >= ki, li + p < length_in * s + ki, (li + p - ki) % s == 0),
cond=logical_and(li + p >= ki, li + p < length_in * s + ki, (li + p - ki) % s == 0),
then_expr=(
data[ni, (ci // wc) * og + ogi, (li + p - ki) // s] * weight[(ci // wc) * og + ogi, ci % wc, ki]
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence, Union, Tuple
from hidet.ir.expr import if_then_else, LogicalAnd
from hidet.ir.expr import if_then_else, logical_and
from hidet.ir.compute import compute, reduce
from hidet.graph.ops.definitions.utils import Task, Operator, Tensor, TensorNode
from hidet.graph.ops.definitions.utils import input_like, normalize_stride, normalize_padding
Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(
fcompute=lambda ni, ci, hi, wi: reduce(
shape=[og, kx, ky],
fcompute=lambda ogi, kxi, kyi: if_then_else(
cond=LogicalAnd.join(
cond=logical_and(
hi + px0 >= kxi,
hi + px0 < p * sx + kxi,
(hi + px0 - kxi) % sx == 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence, Union, Tuple
from hidet.ir.expr import if_then_else, LogicalAnd
from hidet.ir.expr import if_then_else, logical_and
from hidet.graph.ops.definitions.utils import Task, Operator, Tensor, TensorNode
from hidet.graph.ops.definitions.utils import input_like, normalize_stride, normalize_padding
from hidet.ir.compute import compute
Expand Down Expand Up @@ -42,7 +42,7 @@ def fcompute(b, i, k):
xx = hi + px0 - kxi
yy = wi + py0 - kyi
return if_then_else(
cond=LogicalAnd.join(xx >= 0, xx < p * sx, xx % sx == 0, yy >= 0, yy < q * sy, yy % sy == 0),
cond=logical_and(xx >= 0, xx < p * sx, xx % sx == 0, yy >= 0, yy < q * sy, yy % sy == 0),
then_expr=data[ni, gi * og + ogi, xx // sx, yy // sy],
else_expr=0.0,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence, Union, Tuple
from hidet.ir.expr import if_then_else, LogicalAnd
from hidet.ir.expr import if_then_else, logical_and
from hidet.ir.compute import compute, reduce
from hidet.graph.ops.definitions.utils import Task, Operator, Tensor, TensorNode
from hidet.graph.ops.definitions.utils import input_like, normalize_stride, normalize_padding
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(
fcompute=lambda ni, ci, zi, hi, wi: reduce(
shape=[og, kz, kx, ky],
fcompute=lambda ogi, kzi, kxi, kyi: if_then_else(
cond=LogicalAnd.join(
cond=logical_and(
zi + pz0 >= kzi,
zi + pz0 < r * sz + kzi,
(zi + pz0 - kzi) % sz == 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import hidet.option
from hidet.ir.compute import TensorNode, GridCompute, TensorInput
from hidet.ir.expr import Expr, Var, TensorElement, tensor_var
from hidet.ir.expr import Expr, Var, TensorElement, tensor_var, tensor_element
from hidet.ir.stmt import BufferStoreStmt
from hidet.ir.func import Function, IRModule
from hidet.ir.task import Task, InverseMap
Expand Down Expand Up @@ -119,15 +119,15 @@ def visit_TensorElement(self, e: TensorElement):
if e.base in self.anchor_inputs:
# access an input tensor in the anchor operator, replace it with the task input (i.e., InputTensor)
input_index = self.anchor_inputs.index(e.base)
return self.visit(TensorElement(self.anchor_task.inputs[input_index], e.indices))
return self.visit(tensor_element(self.anchor_task.inputs[input_index], e.indices))
elif isinstance(e.base, TensorNode):
# apply prologue
buf: TensorNode = e.base
indices = [self.visit(v) for v in e.indices]
indices = tuple(self.visit(v) for v in e.indices)
if isinstance(buf, TensorInput):
if buf in self.graph_input_to_var:
# buf is an input tensor of the fused graph
return TensorElement(self.graph_input_to_var[buf], indices)
return tensor_element(self.graph_input_to_var[buf], indices)
elif buf in self.consume:
# buf is an input tensor of an inner task of the fused graph,
# but not an input tensor of fused graph.
Expand Down Expand Up @@ -205,9 +205,9 @@ def visit_BufferStoreStmt(self, stmt: BufferStoreStmt):
assert len(tensor_elements) == 1, (
'Epilogue can only index one time of the input tensor ' 'with inverse map'
)
tensor_element: TensorElement = tensor_elements[0]
te: TensorElement = tensor_elements[0]
# in the context of above example, we replace 'out[i + 3, i + j]' by 'value'
self.memo[tensor_element] = self.visit(stmt.value)
self.memo[te] = self.visit(stmt.value)

# step 3
return self.visit(BufferStoreStmt(consumer_output, out_indices, value, stmt.protected))
Expand Down
8 changes: 4 additions & 4 deletions python/hidet/graph/ops/definitions/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Optional, List, Sequence, Union

from hidet.ir.dtypes import int32
from hidet.ir.expr import Expr, if_then_else, convert, cast, LogicalAnd, LogicalOr
from hidet.ir.expr import Expr, if_then_else, convert, cast, logical_or, logical_and
from hidet.ir import primitives as prim
from .utils import Task, Operator, Tensor, TensorNode, compute, input_like

Expand Down Expand Up @@ -164,10 +164,10 @@ def fmap(n, c, h, w):
if cubic_exclude:
for i in range(4):
weight_w[i] = if_then_else(
LogicalOr.join((w_int - 1 + i) < 0, (w_int + i) > image_width), 0.0, weight_w[i]
logical_or((w_int - 1 + i) < 0, (w_int + i) > image_width), 0.0, weight_w[i]
)
weight_h[i] = if_then_else(
LogicalOr.join((h_int - 1 + i) < 0, (h_int + i) > image_height), 0.0, weight_h[i]
logical_or((h_int - 1 + i) < 0, (h_int + i) > image_height), 0.0, weight_h[i]
)
sum_weight_w = sum(weight_w)
sum_weight_h = sum(weight_h)
Expand All @@ -185,7 +185,7 @@ def fmap(n, c, h, w):
)
if coordinate_transformation_mode == 'tf_half_pixel_for_nn':
value = if_then_else(
LogicalAnd.join(0 <= h, h < image_size[0], 0 <= w, w < image_size[1]), value, extrapolation_value
logical_and(0 <= h, h < image_size[0], 0 <= w, w < image_size[1]), value, extrapolation_value
)
return value

Expand Down
6 changes: 3 additions & 3 deletions python/hidet/graph/ops/definitions/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# limitations under the License.
from typing import Union, Sequence, List, Dict, Any

from hidet.ir.expr import Expr, LogicalAnd, convert, if_then_else
from hidet.ir.expr import Expr, convert, if_then_else, logical_and

from .utils import Task, Operator, Tensor, TensorNode, compute, reduce, input_like, normalize_stride, normalize_kernel
from .utils import normalize_padding, normalize_output
Expand All @@ -31,7 +31,7 @@ def __init__(self, x: TensorNode, kernel, strides, padding, reduce_type: str):
name='pad',
shape=[batch_size, channels, height + padding[0] + padding[2], width + padding[1] + padding[3]],
fcompute=lambda n, c, h, w: if_then_else(
LogicalAnd.join(padding[0] <= h, h < height + padding[0], padding[1] <= w, w < width + padding[1]),
logical_and(padding[0] <= h, h < height + padding[0], padding[1] <= w, w < width + padding[1]),
x[n, c, h - padding[0], w - padding[1]],
pad_value,
),
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(self, x: TensorNode, kernel, strides, padding, reduce_type: str):
],
fcompute=lambda n, c, d, h, w: (
if_then_else(
LogicalAnd.join(
logical_and(
padding[0] <= d,
d < depth + padding[0],
padding[1] <= h,
Expand Down
9 changes: 4 additions & 5 deletions python/hidet/graph/ops/definitions/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# limitations under the License.
from typing import List, Optional, Union, Sequence
from hidet.ir.type import DataType, data_type
from hidet.ir.expr import LogicalAnd, if_then_else, convert
from hidet.ir.expr import if_then_else, convert, cast as ir_cast, logical_and
from hidet.ir.layout import RowMajorLayout, ColumnMajorLayout
from hidet.ir.utils import index_deserialize, index_serialize
from hidet.utils import prod
Expand Down Expand Up @@ -273,11 +273,11 @@ def __init__(self, data: TensorNode, pads: List[int], value: float):
assert rank * 2 == len(pads)
out_shape = [a + b + c for a, b, c in zip(pads[:rank], shape, pads[rank:])]

value = convert(value, dtype=data.ttype.dtype.name)
value = convert(value, dtype=data.type.dtype.name)

def fmap(*indices):
indices = [idx - beg for idx, beg in zip(indices, pads[:rank])]
cond = LogicalAnd.join_list([LogicalAnd(0 <= idx, idx < shape[i]) for i, idx in enumerate(indices)])
cond = logical_and(*[logical_and(0 <= idx, idx < shape[i]) for i, idx in enumerate(indices)])
return if_then_else(cond, data[indices], value)

out = compute('out', shape=out_shape, fcompute=fmap)
Expand Down Expand Up @@ -430,13 +430,12 @@ def __init__(self, x: Tensor, axes: Optional[List[int]] = None):

class CastOp(Operator):
def __init__(self, x: Tensor, dtype: DataType):
from hidet.ir.expr import Cast
from .arithmetic import UnaryElementwiseTask

super().__init__(
inputs=[x],
attributes={'dtype': dtype},
task=UnaryElementwiseTask('cast', input_like(x, 'x'), op=lambda v: Cast(v, dtype)),
task=UnaryElementwiseTask('cast', input_like(x, 'x'), op=lambda v: ir_cast(v, dtype)),
)


Expand Down
4 changes: 2 additions & 2 deletions python/hidet/graph/ops/schedules/cpu/auto_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from hidet.ir.builders import FunctionBuilder
from hidet.ir.compute import TensorNode, GridCompute
from hidet.ir.expr import Call, Var, convert
from hidet.ir.expr import Var, convert, call
from hidet.ir.tools import rewrite
from hidet.ir.stmt import Stmt, BufferStoreStmt, EvaluateStmt
from ..auto_scheduler import AutoScheduler, ComputeExprLower
Expand Down Expand Up @@ -46,4 +46,4 @@ def schedule_grid_compute(
func_var = self.add_function(func)

# call the created function in the launch function
return EvaluateStmt(Call(func_var, args=call_args))
return EvaluateStmt(call(func_var, args=call_args))
6 changes: 3 additions & 3 deletions python/hidet/graph/ops/schedules/cuda/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from hidet.ir import IRModule
from hidet.ir.builders import FunctionBuilder
from hidet.ir.expr import scalar_var, convert, Expr, LogicalAnd, cast
from hidet.ir.expr import scalar_var, convert, Expr, cast, equal, logical_and
from hidet.ir.mapping import TaskMapping
from hidet.ir.primitives import block_idx, thread_idx
from hidet.ir.compute import ReduceOperation
Expand Down Expand Up @@ -91,7 +91,7 @@ def cuda_schedule_reduce_by_warp_reduce(task: ReduceTask) -> IRModule:
for (r,) in block_layout.worker2task(thread_idx()):
with fb.if_then(r < reduce_extent):
reduce_indices = index_deserialize(r, shape=reduce_shape)
with fb.if_then(LogicalAnd.join_list([reduce_index.equals(0) for reduce_index in reduce_indices])):
with fb.if_then(logical_and(*[equal(reduce_index, 0) for reduce_index in reduce_indices])):
reduce_indices = [convert(0) for _ in task.dims]
if task.keep_dim:
output_indices = merge_indices(grid_indices, reduce_indices, reduce_dims=task.dims)
Expand Down Expand Up @@ -140,7 +140,7 @@ def cuda_schedule_reduce_by_default(task: ReduceTask) -> IRModule:
# body
remain_indices = remain_layout.worker2task(thread_idx() + block_idx() * block_size)[0]
with fb.if_then(
LogicalAnd.join_list([remain_index < remain_shape[i] for i, remain_index in enumerate(remain_indices)])
logical_and(*[remain_index < remain_shape[i] for i, remain_index in enumerate(remain_indices)])
):
# get the reduced value along reduce dimensions
for reduce_indices in reduce_layout.worker2task(0):
Expand Down
3 changes: 2 additions & 1 deletion python/hidet/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
from .type import data_type, tensor_type, tensor_pointer_type

from .expr import Expr, Var, Constant
from .expr import BinaryOp, Condition, LessThan, LessEqual, Equal, NotEqual, Add, Sub, Multiply, Div, Mod, FloorDiv
from .expr import BinaryExpr, Condition, LessThan, LessEqual, Equal, NotEqual, Add, Sub, Multiply, Div, Mod, FloorDiv
from .expr import Let, Cast, LogicalAnd, LogicalOr, TensorElement, Call, TensorSlice, LogicalNot, Neg
from .expr import BitwiseXor, BitwiseAnd, BitwiseNot, BitwiseOr, Dereference
from .expr import var, scalar_var, tensor_var, is_one, is_zero, convert
from .expr import logical_and, logical_or, logical_not, equal, less_equal, less_than, not_equal

from .layout import DataLayout

Expand Down

0 comments on commit daae22e

Please sign in to comment.