New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a pass for FastExp conversion #3839
Conversation
0612a09
to
11895f1
Compare
|
||
LogicalResult matchAndRewrite(LLVM::ExpOp op, | ||
PatternRewriter &rewriter) const override { | ||
constexpr float ln2Const = 0.69314718055994529; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float literals only need to have 9 decimal digits to be fully accurate, afaik.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
PatternRewriter &rewriter) const override { | ||
constexpr float ln2Const = 0.69314718055994529; | ||
constexpr float ln2InvConst = 1.4426950408889634; | ||
constexpr float cVaues[5] = {0.05924867, 0.15514645, 0.50308552, 0.99968939, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you mean cValues
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
constexpr float cVaues[5] = {0.05924867, 0.15514645, 0.50308552, 0.99968939, | ||
1.0000072153251447}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain in a comment how these cValues are computed?
(Side note: i checked if current compilers generate good code for static const float k = std::log(2.0f)
, that is, not only evaluating k at compile time but not even having a lock around it. GCC does... but Clang does not! So you are doing it right (sadly) to use literals here). https://godbolt.org/z/dhd3Yb )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are least square fit computed with numpy I added a comment above
expY = rewriter.create<LLVM::FMulOp>(loc, floatType, expY, y); | ||
expY = rewriter.create<LLVM::FAddOp>(loc, floatType, expY, PConst[4]); | ||
|
||
// Compute 2^k with integer bitshift, 2^k = (127 + k) << 23 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I don't understand what is meant in this equation, 2^k = (127 + k) << 23
. The letter k
appears on both sides of it, i just need help parsing what that means. A longer comment maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its exp2(k) computed with l-shift. aded a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, thanks! i see what you're doing now - 23 is the number of mantissa bits, you're punning the float as an integer... so IIUC this relies on k
being exactly an integral value. That's fairly reasonable as it was computed as the result of a floor
, but I wonder what would happen for very large values which are not exactly representable as integers? i.e. absolute values larger than (1 / epsilon) ~= 1e+7 ? Are we safe from this case thanks to clamping of arguments somewhere? Or if you want to leave this as a TODO, that's fine too; there is a TODO below about float underflow/overflow, but here we are talking about much more 'tame' values (absolute values 1e+7 vs 1e+38). With clamping this is a non-issue because the exponential function is trivial so far away from 0, but without clamping it looks like there is a risk of this implementation of exponential producing surprising values for large absolute-value argument? E.g. exp(-1e+8f)
should return 0, but might return some weird value here ? maybe add a testcase in the unit test, possibly commented out for now if you want to handle it later?
expY = rewriter.create<LLVM::FAddOp>(loc, floatType, expY, PConst[4]); | ||
|
||
// Compute 2^k with integer bitshift, 2^k = (127 + k) << 23 | ||
Value fPBias = rewriter.create<LLVM::ConstantOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does the mnemonic fPBias
refer to and why does it start with a lowercase f
before the uppercase P
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Floating point bias (1 in this case)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a whole pass to apply one pattern seems wasteful. Can adding the pattern to "ConvertToLLVM" pass work? Maybe adjust the pattern benefit to make it apply preferentially?
ConvertToLLVM is a dialect conversion (Std, Vector, IREE, HAL) -> LLVMIR this pattern is an LLVMIR -> LLVMIR it will has to be applied after the dialect conversion anyway. |
11895f1
to
6e8f011
Compare
Oops! You are obviously right. My bad! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome @asaadaldien and a really good example of how we can specialize things when we need to (the handwritten approach) vs. emitting runtime library calls that are impossible for LLVM to optimize against - here we can expect all those ops to at least get somewhat optimized based on context vs. a blackbox call.
@bjacob this seems like a good approach to start for toying with things like bringing over the things you know are good from xnnpack/etc :)
The make this better I think we'll want to move them into their own dedicated path so that people looking for them find them - I am not familiar with what the most relevant reusable llvm terminology for this kind of stuff is - it's effectively like an intrinsic expansion (like how llvm.memcpy can become just some loads/stores if the memcpy is small vs. a call out to the CRT/etc).
Maybe iree/compiler/Conversion/LLVM/Optimizations
or ../Patterns
or something. Then we have a nice place where new ones can get added - they don't have anything to do with linalg so they feel a bit weird here (and if I was a user wanting to plumb a custom op down from the high level into LLVM IR, linalg is not relevant - I just want a place to put my op->llvm or llvm->llvm pattern).
Fine with that as a followup, but definitely worth doing before we end up with a dozen of these all tangled together :)
(we can also have nice tests for such a folder)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, just saw mahesh's response which indicates I'm not the only one confused that this was in linalg land next to a conversion pipeline :) let's move it to Conversion/LLVMToLLVM to match the other dirs in there or something now
Yeah having an LLVMIR -> LLVMIR pass is confusing move it to LLVMToLLVM directory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after the move to LLVM -> LLVM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks awesome! is there any kind of lit test you could add to Conversion/LLVMToLLVM/test/ that would let us see the IR and verify it without needing to rely on the execution tests? or is it too messy?
I think I can add a test to verify |
Convert
llvm.intr.exp
into a sequence of ops that computes an approximation to exp(x).This is using the fact that:
exp(x - k * ln(2)) range is [0, ln(2)] which is approximated with 4degree polynomial.
The real number 2^k is computed with integer bitwise arithmetic.
MobileBert benchmarks:
Before:
After: