-
Notifications
You must be signed in to change notification settings - Fork 6.1k
8351412: Add AVX-512 intrinsics for ML-KEM #24953
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Welcome back ferakocz! A progress list of the required criteria for merging this PR into |
|
@ferakocz This change now passes all automated pre-integration checks. ℹ️ This project also has non-automated pre-integration requirements. Please see the file CONTRIBUTING.md for details. After integration, the commit message for the final commit will be: You can use pull request commands such as /summary, /contributor and /issue to adjust it as needed. At the time when this comment was updated there had been 355 new commits pushed to the
As there are no conflicts, your changes will automatically be rebased on top of these commits when integrating. If you prefer to avoid this automatic rebasing, please check the documentation for the /integrate command for further details. As you do not have Committer status in this project an existing Committer must agree to sponsor your change. Possible candidates are the reviewers of this PR (@lmesnik, @sviswa7) but any other Committer may sponsor as well. ➡️ To flag this PR as ready for integration with the above commit message, type |
Webrevs
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please explain in comments how this fix has been tested? I would like to understand which tests are relevant and which flags needs to be set to test this functionality.
| @@ -1,5 +1,5 @@ | |||
| /* | |||
| * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. | |||
| * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems the file contains only copyright changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only reviewed three intrinsics so far, more review to do.
| // a (short[256]) = c_rarg1 | ||
| // b (short[256]) = c_rarg2 | ||
| // c (short[256]) = c_rarg3 | ||
| // kyberConsts (short[40]) = c_rarg4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kyberConsts is not one of the arguments passed in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
| // result (short[256]) = c_rarg0 | ||
| // a (short[256]) = c_rarg1 | ||
| // b (short[256]) = c_rarg2 | ||
| // kyberConsts (short[40]) = c_rarg3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kyberConsts is not one of the arguments passed in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
| address generate_kyberAddPoly_2_avx512(StubGenerator *stubgen, | ||
| MacroAssembler *_masm) { | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Java code for "implKyberAddPoly(short[] result, short[] a, short[] b)" does BarrettReduction but the intrinsic code here does not. Is that intentional and how is the reduction handled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, the Java version is the one that is too cautious. There is Barrett reduction after at most 4 consecutive uses of mlKemAddPoly(), so doing the reduction in implKyberAddPoly() is not necessary. Thanks for discovering this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I have another question, is there a reason that the Java versions of AddPoly (both for 2 and 3 input) return 1, whereas the corresponding intrinsics return 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I use that for debugging. E.g. it is fairly easy to change the Java code to call both the intrinsic and Java version and compare the results. I don't see any harm in leaving that in the production version, since it is always ignored.
| address generate_kyber12To16_avx512(StubGenerator *stubgen, | ||
| MacroAssembler *_masm) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If AVX512_VBMI and AVX512_VBMI2 is available, it looks to me that the loop body of this algorithm can be implemented using more efficient instructions in simple 5 steps:
Step 1:
Load 0-47, 48-95, 96-143, 144-191 condensed bytes into xmm0, xmm1, xmm2, xmm3 respectively using masked load.
Step 2:
Use vpermb to arrange xmm0 such that bytes 1, 4, 7, ... are duplicated
xmm0 before b47, b46, ..., b0 where each b is a byte
xmm0 after b47 b46 b46 b45, ......., b5 b4 b4 b3 b2 b1 b1 b0
Repeat this for xmm1, xmm2, xmm3
Step 3:
Use vpshldvw to shift every word (16 bits) in the xmm0 appropriately with variable shift
Shift word 31 by 4, word 30 by 0, ... word 3 by 4, word 2 by 0, word 1 by 4, word 0 by 0
Repeat this for xmm1, xmm2, xmm3
Step 4:
Use vpand to "and" each word element in xmm0 by 0xfff.
Repeat this for xmm1, xmm2, xmm3
Step 5:
Store xmm0 into parsed
Store xmm1 into parsed + 64
Store xmm2 into parsed +128
Store xmm3 into parsed + 192
If you think there is not sufficient time, we could look into it after the merge of this PR as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that way we can speed this up a little (well, in itself it might be significant), but with the current intrinsics, the contribution of this function to the overall running time is about 1.5%, so it would not matter that much, while on the other hand not all AVX-512 capable processors have vbmi.
So I would rather not do it in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another minor comment. Rest of the PR looks good to me.
| // Kyber barrett reduce function. | ||
| // | ||
| // coeffs (short[256]) = c_rarg0 | ||
| // kyberConsts (short[40]) = c_rarg1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kyberConsts is not an input parameter to implKyberBarrettReduce.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
|
@sviswa7, thanks a lot for the review! If you agree with my changes to load the constants using broadcasting instructions instead of full AVX register loads, would you be so kind as to approve the PR and sponsor my integration? |
| static void montmul(int outputRegs[], int inputRegs1[], int inputRegs2[], | ||
| int scratchRegs1[], int scratchRegs2[], MacroAssembler *_masm) { | ||
| for (int i = 0; i < 4; i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the intrinsic for montMul we are treating as if MONT_R_BITS is 16 and MONT_Q_INV_MOD_R is 0xF301 whereas in the Java code MONT_R_BITS is 20 and MONT_Q_INT_MOD_R is 0x8F301. Are these equivalent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As used in this case, they are equivalent. For z = montmul(a,b), z will be between -q and q and congruent to a * b * R^-1 mod q, where R > 2 * q, R is a power of 2, -R/2 * q <= a * b < R/2 * q. For the Java code, we use R = 2^20 and for the intrinsic, R = 2^16. In our computations, b is always c * R mod q, so the montmul() really computes a * c mod q. In the Java code, we use 32-bit numbers for the computations, and we use R = 2^20 because that way the a * b numbers that occur during all computations stay in the required range (the inverse NTT computation is where they can grow the most), so we don't have to do Barrett reductions during that computation. For the intrinsics, we use R = 2^16, because this way we can do twice as much work in parallel, but we have to do Barrett reduction after levels 2 and 4 in the inverse NTT computation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the explanation. It would be good to add it as a comment in the stubGenerator_x86_64_kyber.cpp.
The broadcast instructions look good. I only have one query on montMul above that I have wondering about. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the comment.
|
/integrate |
|
/sponsor |
|
Going to push as commit 972f2eb.
Your commit was automatically rebased without conflicts. |
|
Please also write a release note as the performance improvement is significant. Thanks! |
|
I haven't find answer an my question about testing. How this fix is tested? |
The change in the file test/jdk/sun/security/provider/acvp/Launcher.java in PR https://github.com/openjdk/jdk/pull/23860/files covers this as well. |
|
Thanks for pointing to the test. |
Done. https://bugs.openjdk.org/browse/JDK-8357741 Release Note: ML-KEM Performance Improved |
By using the AVX-512 vector registers the speed of the computation of the ML-KEM algorithms (key generation, encapsulation, decapsulation) can be approximately doubled.
Progress
Issue
Reviewers
Reviewing
Using
gitCheckout this PR locally:
$ git fetch https://git.openjdk.org/jdk.git pull/24953/head:pull/24953$ git checkout pull/24953Update a local copy of the PR:
$ git checkout pull/24953$ git pull https://git.openjdk.org/jdk.git pull/24953/headUsing Skara CLI tools
Checkout this PR locally:
$ git pr checkout 24953View PR using the GUI difftool:
$ git pr show -t 24953Using diff file
Download this PR as a diff file:
https://git.openjdk.org/jdk/pull/24953.diff
Using Webrev
Link to Webrev Comment