Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a pass for FastExp conversion #3839

Merged
merged 4 commits into from Nov 18, 2020
Merged

Add a pass for FastExp conversion #3839

merged 4 commits into from Nov 18, 2020

Conversation

asaadaldien
Copy link
Contributor

@asaadaldien asaadaldien commented Nov 14, 2020

Convert llvm.intr.exp into a sequence of ops that computes an approximation to exp(x).
This is using the fact that:

exp(x) = exp(x - floor(x \ ln(2) * ln(2)) 2^(floor(x \ ln(2))
            = exp(x - k * ln(2)) * 2^k

exp(x - k * ln(2)) range is [0, ln(2)] which is approximated with 4degree polynomial.
The real number 2^k is computed with integer bitwise arithmetic.

MobileBert benchmarks:

Before:

------------------------------------------------------------------------------------
Benchmark                                          Time             CPU   Iterations
------------------------------------------------------------------------------------
BM_serving_default/process_time/real_time        907 ms          905 ms            1

After:

------------------------------------------------------------------------------------
Benchmark                                          Time             CPU   Iterations
------------------------------------------------------------------------------------
BM_serving_default/process_time/real_time        819 ms          815 ms            1


LogicalResult matchAndRewrite(LLVM::ExpOp op,
PatternRewriter &rewriter) const override {
constexpr float ln2Const = 0.69314718055994529;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float literals only need to have 9 decimal digits to be fully accurate, afaik.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

PatternRewriter &rewriter) const override {
constexpr float ln2Const = 0.69314718055994529;
constexpr float ln2InvConst = 1.4426950408889634;
constexpr float cVaues[5] = {0.05924867, 0.15514645, 0.50308552, 0.99968939,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you mean cValues ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines 36 to 37
constexpr float cVaues[5] = {0.05924867, 0.15514645, 0.50308552, 0.99968939,
1.0000072153251447};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain in a comment how these cValues are computed?

(Side note: i checked if current compilers generate good code for static const float k = std::log(2.0f), that is, not only evaluating k at compile time but not even having a lock around it. GCC does... but Clang does not! So you are doing it right (sadly) to use literals here). https://godbolt.org/z/dhd3Yb )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are least square fit computed with numpy I added a comment above

expY = rewriter.create<LLVM::FMulOp>(loc, floatType, expY, y);
expY = rewriter.create<LLVM::FAddOp>(loc, floatType, expY, PConst[4]);

// Compute 2^k with integer bitshift, 2^k = (127 + k) << 23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't understand what is meant in this equation, 2^k = (127 + k) << 23. The letter k appears on both sides of it, i just need help parsing what that means. A longer comment maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its exp2(k) computed with l-shift. aded a comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, thanks! i see what you're doing now - 23 is the number of mantissa bits, you're punning the float as an integer... so IIUC this relies on k being exactly an integral value. That's fairly reasonable as it was computed as the result of a floor, but I wonder what would happen for very large values which are not exactly representable as integers? i.e. absolute values larger than (1 / epsilon) ~= 1e+7 ? Are we safe from this case thanks to clamping of arguments somewhere? Or if you want to leave this as a TODO, that's fine too; there is a TODO below about float underflow/overflow, but here we are talking about much more 'tame' values (absolute values 1e+7 vs 1e+38). With clamping this is a non-issue because the exponential function is trivial so far away from 0, but without clamping it looks like there is a risk of this implementation of exponential producing surprising values for large absolute-value argument? E.g. exp(-1e+8f) should return 0, but might return some weird value here ? maybe add a testcase in the unit test, possibly commented out for now if you want to handle it later?

expY = rewriter.create<LLVM::FAddOp>(loc, floatType, expY, PConst[4]);

// Compute 2^k with integer bitshift, 2^k = (127 + k) << 23
Value fPBias = rewriter.create<LLVM::ConstantOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the mnemonic fPBias refer to and why does it start with a lowercase f before the uppercase P ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Floating point bias (1 in this case)

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a whole pass to apply one pattern seems wasteful. Can adding the pattern to "ConvertToLLVM" pass work? Maybe adjust the pattern benefit to make it apply preferentially?

@asaadaldien
Copy link
Contributor Author

Adding a whole pass to apply one pattern seems wasteful. Can adding the pattern to "ConvertToLLVM" pass work? Maybe adjust the pattern benefit to make it apply preferentially?

ConvertToLLVM is a dialect conversion (Std, Vector, IREE, HAL) -> LLVMIR this pattern is an LLVMIR -> LLVMIR it will has to be applied after the dialect conversion anyway.

@MaheshRavishankar
Copy link
Contributor

Adding a whole pass to apply one pattern seems wasteful. Can adding the pattern to "ConvertToLLVM" pass work? Maybe adjust the pattern benefit to make it apply preferentially?

ConvertToLLVM is a dialect conversion (Std, Vector, IREE, HAL) -> LLVMIR this pattern is an LLVMIR -> LLVMIR it will has to be applied after the dialect conversion anyway.

Oops! You are obviously right. My bad!

Copy link
Collaborator

@benvanik benvanik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome @asaadaldien and a really good example of how we can specialize things when we need to (the handwritten approach) vs. emitting runtime library calls that are impossible for LLVM to optimize against - here we can expect all those ops to at least get somewhat optimized based on context vs. a blackbox call.

@bjacob this seems like a good approach to start for toying with things like bringing over the things you know are good from xnnpack/etc :)

The make this better I think we'll want to move them into their own dedicated path so that people looking for them find them - I am not familiar with what the most relevant reusable llvm terminology for this kind of stuff is - it's effectively like an intrinsic expansion (like how llvm.memcpy can become just some loads/stores if the memcpy is small vs. a call out to the CRT/etc).

Maybe iree/compiler/Conversion/LLVM/Optimizations or ../Patterns or something. Then we have a nice place where new ones can get added - they don't have anything to do with linalg so they feel a bit weird here (and if I was a user wanting to plumb a custom op down from the high level into LLVM IR, linalg is not relevant - I just want a place to put my op->llvm or llvm->llvm pattern).

Fine with that as a followup, but definitely worth doing before we end up with a dozen of these all tangled together :)

(we can also have nice tests for such a folder)

@benvanik benvanik self-requested a review November 17, 2020 21:18
Copy link
Collaborator

@benvanik benvanik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, just saw mahesh's response which indicates I'm not the only one confused that this was in linalg land next to a conversion pipeline :) let's move it to Conversion/LLVMToLLVM to match the other dirs in there or something now

@asaadaldien
Copy link
Contributor Author

et's move it to Conve

Yeah having an LLVMIR -> LLVMIR pass is confusing move it to LLVMToLLVM directory

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after the move to LLVM -> LLVM

Copy link
Collaborator

@benvanik benvanik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks awesome! is there any kind of lit test you could add to Conversion/LLVMToLLVM/test/ that would let us see the IR and verify it without needing to rely on the execution tests? or is it too messy?

@asaadaldien
Copy link
Contributor Author

hat would let us see the IR and verify it without needing to rely on the execution tests? or is it too messy?

I think I can add a test to verify llvm.intr.exp -> correct_sequance_of_llvm_ops but its very verbose so the only test that is added in this PR is a numerical check test.

@asaadaldien asaadaldien merged commit 45b70db into main Nov 18, 2020
@asaadaldien asaadaldien deleted the ataei-fast_approx_exp branch November 18, 2020 21:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants