Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have conversion take an "is nonzero" predicate #54179

Open
wrengr opened this issue Mar 3, 2022 · 1 comment
Open

Have conversion take an "is nonzero" predicate #54179

wrengr opened this issue Mar 3, 2022 · 1 comment
Labels
enhancement Improving things as opposed to bug fixing, e.g. new or missing feature mlir:sparse Sparse compiler in MLIR

Comments

@wrengr
Copy link
Contributor

wrengr commented Mar 3, 2022

This is primarily for sparse⇒sparse ConvertOp, though it also applies to dense⇒sparse ConvertOp and NewOp.

There is an unintended infelicity with the semantics of the current sparse⇒sparse conversion as implemented by the runtime library. In general, if the source tensor has any dense dimensions then it will be forced to store some zero elements. That's unavoidable, but the problem is that when enumerating the elements of the source tensor we currently do not perform any filtering of these zeros, which means the target tensor will unnecessarily store them since it can't tell them apart from nonzero values.

If all the element types were integral, then we could simply filter them out in the runtime library. However, since we have floating-point types a simple if (v != 0) isn't enough. There's a huge range of options for how to parameterize testing for "nonzero" floating-point numbers, so any choice we make is bound to be bad for some user. So a nice alternative is to hand the decision back to the user to define for us. That is, the conversion op should take an extra argument: a predicate for deciding whether a value counts as "nonzero" or not. Then the runtime library can simply check if (isNonzero(v)).

In addition to avoiding the problems around floating-point types, taking a user-defined predicate for what counts as a "nonzero" value has some other nice benefits. For example, we can implement relu as a convert op by using the predicate [](V v){ return v > 0; }. And similarly for any other function which filters/selects elements based only on their value.

The big question is how exactly to pass these predicate functions across the boundary from MLIR to C++, since it effectively requires the C++ runtime library to make a callback into MLIR land. A simple working proposal is to pass the predicate as a C function pointer: bool(*)(V), where the null pointer indicates the always-true function (i.e., the current behavior). This has the limitation that the user can't use closures or other callable objects, but it's not clear to me how we would pass those across the boundary (let alone define them on the MLIR side).

@wrengr wrengr added enhancement Improving things as opposed to bug fixing, e.g. new or missing feature mlir:sparse Sparse compiler in MLIR labels Mar 3, 2022
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 3, 2022

@llvm/issue-subscribers-mlir-sparse

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Improving things as opposed to bug fixing, e.g. new or missing feature mlir:sparse Sparse compiler in MLIR
Projects
None yet
Development

No branches or pull requests

2 participants