New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR] eager constant folding in operator overloading #1789
Conversation
@yzhliu @MarisaKirisame @junrushao1994 can you help do a quick round of review? |
src/lang/ir_operator.cc
Outdated
const IntImm* pa = a.as<IntImm>(); \ | ||
const IntImm* pb = b.as<IntImm>(); \ | ||
const Type& ta = a.type(); \ | ||
const Type& tb = a.type(); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tb = b.type()?
* \tparam ValueType The constant value type | ||
*/ | ||
template<typename ValueType, | ||
typename = typename std::enable_if<std::is_pod<ValueType>::value>::type> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is std::is_arithmetic enough here? Seems std::is_pod will be deprecated.
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) | ||
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&); | ||
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|); | ||
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, do we need ~
op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this macro defines operator for binary ops, ~ is already defined as single operand op
src/arithmetic/compute_expr.h
Outdated
|
||
template<> | ||
inline bool GetConst<int64_t>(Expr e, int64_t *out) { | ||
inline bool GetConst(Expr e, int64_t *out) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: int64_t* out
src/lang/ir_operator.cc
Outdated
/*! | ||
* \brief Check whether type is used to represent index. | ||
* | ||
* Index type are frequently used in shape computation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/are/is
src/lang/ir_operator.cc
Outdated
} | ||
|
||
// The public function with a quick checking path. | ||
void BinaryOpMatchTypes(Expr &lhs, Expr &rhs) { // NOLINT(*) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Expr&
Thanks @zhiics for very helpful comments, i have fixed them, can you please take another round of look? |
include/tvm/ir_operator.h
Outdated
|
||
/*! | ||
* \brief Check whether x is a constant power of two | ||
* \note This only return true for integer types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is shift? I assume it is 'if it is const pow of 2, write to the pointer'
if so please add document
TVM_CONST_PROPAGATION({ | ||
Type rtype = ta.bits() >= tb.bits() ? ta : tb; | ||
if (pa && pb) return IntImm::make(rtype, pa->value - pb->value); | ||
if (pb && pb->value == 0) return SimpleCast(rtype, a); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if pa is 0, use unary neg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no unary neg in the expression
if (pb) { | ||
if (pb->value == 1) return SimpleCast(rtype, a); | ||
} | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check for x/0?
src/lang/ir_operator.cc
Outdated
Expr operator/(Expr a, Expr b) { | ||
TVM_CONST_PROPAGATION({ | ||
Type rtype = ta.bits() >= tb.bits() ? ta : tb; | ||
// due to division and mode can have different modes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
division and mod
src/lang/ir_operator.cc
Outdated
Expr operator%(Expr a, Expr b) { | ||
TVM_CONST_PROPAGATION({ | ||
Type rtype = ta.bits() >= tb.bits() ? ta : tb; | ||
// due to division and mode can have different modes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
division and mod
if (pa->value == 0) return SimpleCast(rtype, a); | ||
} | ||
if (pb) { | ||
if (pb->value == 1) return make_zero(rtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check for 0?
src/lang/ir_operator.cc
Outdated
return false_value; | ||
} | ||
} | ||
CHECK(cond.type().is_bool()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move check before if (so early return still get checked)?
src/lang/ir_operator.cc
Outdated
} else if (x.type().is_uint()) { | ||
return x; | ||
} else { | ||
LOG(WARNING) << "Warning: Data type " << x.type() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why warning? what other case is there?
src/pass/split_pipeline.cc
Outdated
} | ||
alloc_size = ir::Simplify(alloc_size); | ||
alloc_size = alloc_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redundant line
Thanks @MarisaKirisame for the comments, I have fixed accordingly |
@tqchen LGTM. I have the same question again, will/should we have reach-def and/or liveness analysis in Relay? We might want them to help some data flow related opts. |
@zhiics it might make sense to do so, but I would like to see such pass driven by a real end to end usecase, e.g. memory allocation pass |
Thanks @MarisaKirisame @junrushao1994 @zhiics this is now merged |
*/ | ||
inline bool is_zero(const Expr& x) { | ||
return is_const_int(x, 0); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is confusing. HalideIR contains a function with the same name that works for floats too. Moreover, it used to be reexported into the tvm namespace, so this change might have broken something in a very subtle way, like performance regression. (Hopefully, it didn't, because in tvm is_zero
seems to be used only for ints indeed, but anyway). I would consider renaming it or throwing an exception if the argument is not integer.
Make sure Expr now eargerly folds constant when it is integer expression, removes the need to call simplify in all concrete shape inference case so that we can reuse symbolic integer for concrete inference.
Example