Skip to content

Commit

Permalink
Support logical operators (apache#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jul 3, 2022
1 parent 0843e5d commit bcbfd3e
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 29 deletions.
103 changes: 81 additions & 22 deletions python/tvm/script/parse/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""AST Evaluation"""
import ast
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from . import doc

Expand Down Expand Up @@ -48,6 +48,26 @@ def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any:
return result.value
raise TypeError("Unexpected result type: %s" % type(result))

def _add_intermediate_result(self, value: Any) -> doc.Name:
name = f"__tvm_tmp_value_{self.new_value_count}"
self.new_value_count += 1
self.value_table[name] = value
lineno = 0
col_offset = 0
return doc.Name(
id=name,
ctx=doc.Load(
lineno=lineno,
col_offset=col_offset,
end_lineno=None,
end_col_offset=None,
),
lineno=lineno,
col_offset=col_offset,
end_lineno=None,
end_col_offset=None,
)

def _visit(self, node: doc.AST) -> Any:
if isinstance(node, list):
return [self._visit(n) for n in node]
Expand All @@ -71,34 +91,73 @@ def _visit(self, node: doc.AST) -> Any:
),
):
return node
new_fields = {}
fields = {}
for field in node.__class__._FIELDS: # pylint: disable=protected-access
attr = getattr(node, field)
if isinstance(attr, (doc.AST, tuple, list)):
new_fields[field] = self._visit(attr)
fields[field] = self._visit(attr)
else:
new_fields[field] = attr
fields[field] = attr
try:
new_value = _eval_expr(node.__class__(**new_fields), self.value_table)
if isinstance(node, doc.BoolOp) and isinstance(fields["op"], doc.And):
value = self._eval_binary(
fields["values"],
lhs_func_name="__tvm_logical_and__",
rhs_func_name="__tvm_r_logical_and__",
default_func=lambda lhs, rhs: lhs and rhs,
)
elif isinstance(node, doc.BoolOp) and isinstance(fields["op"], doc.Or):
value = self._eval_binary(
fields["values"],
lhs_func_name="__tvm_logical_or__",
rhs_func_name="__tvm_r_logical_or__",
default_func=lambda lhs, rhs: lhs or rhs,
)
elif isinstance(node, doc.UnaryOp) and isinstance(fields["op"], doc.Not):
value = self._eval_unary(
fields["operand"],
func_name="__tvm_logical_not__",
default_func=lambda v: not v,
)
else:
value = _eval_expr(node.__class__(**fields), self.value_table)
except Exception as e:
self.parser.report_error(node, str(e))
else:
name = f"__tvm_tmp_value_{self.new_value_count}"
self.new_value_count += 1
self.value_table[name] = new_value
return doc.Name(
id=name,
ctx=doc.Load(
lineno=node.lineno,
col_offset=node.col_offset,
end_lineno=node.end_lineno,
end_col_offset=node.end_col_offset,
),
lineno=node.lineno,
col_offset=node.col_offset,
end_lineno=node.end_lineno,
end_col_offset=node.end_col_offset,
)
return self._add_intermediate_result(value)

def _eval_unary(
self,
value: Any,
func_name: str,
default_func: Callable,
):
value = _eval_expr(value, self.value_table)
method = getattr(value, func_name, None)
if method is not None:
return method(value)
return default_func(value)

def _eval_binary(
self,
values: List[Any],
lhs_func_name: str,
rhs_func_name: str,
default_func: Callable,
):
assert len(values) > 0
values = [_eval_expr(v, self.value_table) for v in values if v is not None]
lhs = values[0]
for rhs in values[1:]:
method = getattr(lhs, lhs_func_name, None)
if method is not None:
lhs = method(rhs)
continue
method = getattr(rhs, rhs_func_name, None)
if method is not None:
lhs = method(lhs)
continue
lhs = default_func(lhs, rhs)
return lhs


def eval_expr(
Expand Down
28 changes: 21 additions & 7 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@
assert(y.a == x)
"""
from typing import Optional, Union
from tvm import ir

import tvm._ffi
import tvm.ir._ffi_api
from tvm import ir
from tvm.ir import Op, PrimExpr
from tvm.ir.base import Span
from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, const

from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const
from tvm.ir import PrimExpr, Op
import tvm.ir._ffi_api
from . import generic as _generic
from . import _ffi_api
from . import generic as _generic


def div_ambiguity_error():
Expand Down Expand Up @@ -66,8 +67,6 @@ def _dtype_is_float(value):
class ExprOp(object):
"""Operator overloading for Expr like expressions."""

# TODO(tkonolige): use inspect to add source information to these objects

def __add__(self, other):
return _generic.add(self, other)

Expand Down Expand Up @@ -184,6 +183,21 @@ def __nonzero__(self):
def __bool__(self):
return self.__nonzero__()

def __tvm_logical_not__(self):
return _ffi_api._OpNot(self, None) # type: ignore

def __tvm_logical_and__(self, rhs):
return _ffi_api._OpAnd(self, rhs, None) # type: ignore

def __tvm_logical_or__(self, rhs):
return _ffi_api._OpOr(self, rhs, None) # type: ignore

def __tvm_r_logical_and__(self, lhs):
return _ffi_api._OpAnd(lhs, self, None) # type: ignore

def __tvm_r_logical_or__(self, lhs):
return _ffi_api._OpOr(lhs, self, None) # type: ignore

def equal(self, other, span=None):
"""Build an equal check expression with other expr.
Expand Down
2 changes: 2 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,8 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, bitwise_xor);
REGISTER_MAKE_BIT_OP(left_shift, left_shift); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, right_shift);

TVM_REGISTER_GLOBAL("tir._OpNot").set_body_typed(logical_not);

TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
.set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) {
return if_then_else(cond, true_value, false_value, span);
Expand Down

0 comments on commit bcbfd3e

Please sign in to comment.