-
Notifications
You must be signed in to change notification settings - Fork 6.1k
8351034: Add AVX-512 intrinsics for ML-DSA #23860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Welcome back ferakocz! A progress list of the required criteria for merging this PR into |
|
@ferakocz This change now passes all automated pre-integration checks. ℹ️ This project also has non-automated pre-integration requirements. Please see the file CONTRIBUTING.md for details. After integration, the commit message for the final commit will be: You can use pull request commands such as /summary, /contributor and /issue to adjust it as needed. At the time when this comment was updated there had been 535 new commits pushed to the
As there are no conflicts, your changes will automatically be rebased on top of these commits when integrating. If you prefer to avoid this automatic rebasing, please check the documentation for the /integrate command for further details. As you do not have Committer status in this project an existing Committer must agree to sponsor your change. Possible candidates are the reviewers of this PR (@lmesnik, @jatin-bhateja, @sviswa7) but any other Committer may sponsor as well. ➡️ To flag this PR as ready for integration with the above commit message, type |
Webrevs
|
|
ML-DSA benchmark results for this PR ML-DSA no intrinsics |
|
@ferakocz this pull request can not be integrated into git checkout mldsa-avx512-intrinsics
git fetch https://git.openjdk.org/jdk.git master
git merge FETCH_HEAD
# resolve conflicts and follow the instructions given by git merge
git commit -m "Merge master"
git push |
|
|
||
| __ movl(iterations, 2); | ||
|
|
||
| __ BIND(L_loop); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ferakocz , Kindly align loop entry address using __align64() here and at all the places before __BIND(LOOP)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, @jatin-bhateja, thanks for the suggestion. I have added __ align(OptoLoopAlignment); before all loop entries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ferakocz ,
Thanks!, for efficient utilization of Decode ICache (please refer to Intel SDM section 3.4.2.5), code blocks should be aligned to 32-byte boundaries; a 64-byte aligned code is a superset of both 16 and 32 byte aligned addresses and also matches with the cacheline size. However, I can noticed that we have been using OptoLoopAlignment at places in AES-GCM also.
I introduced some errors in generate_dilithiumAlmostInverseNtt_avx512 implementation in anticipation of catching it through existing ML_DSA_Tests under
test/jdk/sun/security/provider/acvp
But all the tests passed for me.
java -jar /home/jatinbha/sandboxes/jtreg/build/images/jtreg/lib/jtreg.jar -jdk:$JAVA_HOME -Djdk.test.lib.artifacts.ACVP-Server=/home/jatinbha/softwares/v1.1.0.38.zip -va -timeout:4 Launcher.java
Can you please point out a test I need to use for validation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the easiest is to put a for (int i = 0; i < 1000; i++) loop around the switch statement in the run() method of the ML_DSA_Test class (test/jdk/sun/security/provider/acvp/ML_DSA_Test.java). (This is because the intrinsics kick in after a few thousand calls of the method.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ferakocz , Yes, we should modify the test or lower the compilation threshold with -Xbatch -XX:TieredCompileThreshold=0.1.
Alternatively, since the tests has a depedency on Automatic Cryptographic Validation Test server I have created a simplified test which cover all the security levels.
Kindly include test/hotspot/jtreg/compiler/intrinsics/signature/TestModuleLatticeDSA.java
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added a new command to the test test/jdk/sun/security/provider/acvp/Launcher.java. The line with the -Xcomp will invoke the intrinsics on the first call, so they will be tested.
| void StubGenerator::generate_sha3_stubs() { | ||
| if (UseSHA3Intrinsics) { | ||
| StubRoutines::_sha3_implCompress = generate_sha3_implCompress(StubGenStubId::sha3_implCompress_id); | ||
| StubRoutines::_double_keccak = generate_double_keccak(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should UseDilithiumIntrinsics guard double_keccak generation ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, that is more of a SHA3 thing, other algorithms can take advantage of it, too (e.g. ML-KEM).
| @@ -0,0 +1,1404 @@ | |||
| /* | |||
| * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update copyright year
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, fixed.
| StubRoutines::_dilithiumNttMult = generate_dilithiumNttMult_avx512(); | ||
| StubRoutines::_dilithiumMontMulByConstant = generate_dilithiumMontMulByConstant_avx512(); | ||
| StubRoutines::_dilithiumDecomposePoly = generate_dilithiumDecomposePoly_avx512(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indentation fix needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, fixed.
| const Register constant2use = r10; | ||
| const Register roundsLeft = r11; | ||
|
|
||
| __ align(OptoLoopAlignment); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant alignment before label should be before it's bind
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, fixed.
|
There are no any new tests in the PR. How fix has been tested by openjdk tests? |
I have just added one. |
| * @library /test/lib | ||
| * @modules java.base/sun.security.provider | ||
| * @run main Launcher | ||
| * @run main/othervm -Xcomp Launcher |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test changes looks good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Partial review, just didnt want to sit on comments for this long.
(Spent quite a bit of time catching up on papers and math required)
The biggest roadblock I have following the code are raw register numbers. (And more comments? perhaps I need more math knowledge, but comments would help too).
Also, 'hidden variables' (xmm30). Can't complain, because this is exactly what Vladimir Ivanov told me to do on my first PR #10582 (comment) Perhaps that discussion applies here too.
| void generate_sha3_stubs(); | ||
| address generate_sha3_implCompress(StubGenStubId stub_id); | ||
|
|
||
| address generate_double_keccak(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can hide internal helper functions (i.e. montmulEven(*)) if you wish.
The trick is to add MacroAssembler* _masm as a parameter to the static (local) function. Its a trick I use to keep header clean, but still have plenty of helpers
| __ evmovdquq(xmm17, Address(permsAndRots, 0), Assembler::AVX_512bit); | ||
| __ evmovdquq(xmm18, Address(permsAndRots, 64), Assembler::AVX_512bit); | ||
| __ evmovdquq(xmm19, Address(permsAndRots, 128), Assembler::AVX_512bit); | ||
| __ evmovdquq(xmm20, Address(permsAndRots, 192), Assembler::AVX_512bit); | ||
| __ evmovdquq(xmm21, Address(permsAndRots, 256), Assembler::AVX_512bit); | ||
| __ evmovdquq(xmm22, Address(permsAndRots, 320), Assembler::AVX_512bit); | ||
| __ evmovdquq(xmm23, Address(permsAndRots, 384), Assembler::AVX_512bit); | ||
| __ evmovdquq(xmm24, Address(permsAndRots, 448), Assembler::AVX_512bit); | ||
| __ evmovdquq(xmm25, Address(permsAndRots, 512), Assembler::AVX_512bit); | ||
| __ evmovdquq(xmm26, Address(permsAndRots, 576), Assembler::AVX_512bit); | ||
| __ evmovdquq(xmm27, Address(permsAndRots, 640), Assembler::AVX_512bit); | ||
| __ evmovdquq(xmm28, Address(permsAndRots, 704), Assembler::AVX_512bit); | ||
| __ evmovdquq(xmm29, Address(permsAndRots, 768), Assembler::AVX_512bit); | ||
| __ evmovdquq(xmm30, Address(permsAndRots, 832), Assembler::AVX_512bit); | ||
| __ evmovdquq(xmm31, Address(permsAndRots, 896), Assembler::AVX_512bit); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Matter of taste, but I liked the compactness of montmulEven; i.e.
for (i=0; i<15; i++)
__ evmovdquq(xmm(17+i), Address(permsAndRots, 64*i), Assembler::AVX_512bit);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed.
| __ BIND(rounds24_loop); | ||
| __ subl( roundsLeft, 1); | ||
|
|
||
| __ evmovdquw(xmm5, xmm0, Assembler::AVX_512bit); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a pattern here; that can be 'compacted' into a loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, no. This loop body is imported from generate_sha3_implCompress() and doubled, as explained in the comment about 15 lines above.
| __ vpmuldq(xmm(scratchReg1), xmm(inputReg11), xmm(inputReg2), Assembler::AVX_512bit); | ||
| __ vpmuldq(xmm(scratchReg1 + 1), xmm(inputReg12), xmm(inputReg2 + 1), Assembler::AVX_512bit); | ||
| __ vpmuldq(xmm(scratchReg1 + 2), xmm(inputReg13), xmm(inputReg2 + 2), Assembler::AVX_512bit); | ||
| __ vpmuldq(xmm(scratchReg1 + 3), xmm(inputReg14), xmm(inputReg2 + 3), Assembler::AVX_512bit); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option for these four lines, to keep the style of rest of function
int inputReg1[] = {inputReg11, inputReg12, inputReg13, inputReg14};
for (int i = 0; i < parCnt; i++) {
__ vpmuldq(xmm(scratchReg1 + i), inputReg1[i], xmm(inputReg2 + i), Assembler::AVX_512bit);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed the whole structure instead.
| // levels are the same for all the montmuls that we can do in parallel | ||
|
|
||
| // level 0 | ||
| montmulEven(20, 8, 29, 20, 16, 4); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would improve readability to know which parameter is a register, and which is a count.. i.e.
montmulEven(xmm20, xmm8, xmm29, xmm20, xmm16, 4);
(its not that bad, once I remember that its always the last parameter.. but it does add to the 'mental load' one has to carry, and this code is already interesting enough)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed the structure, now it is clear(er) which parameter is what.
| } | ||
|
|
||
| ATTRIBUTE_ALIGNED(64) static const uint32_t dilithiumAvx512Perms[] = { | ||
| // collect montmul results into the destination register |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as dilithiumAvx512Consts(), 'magic offsets'; except here they are harder to count (eg. not clear visually what is the offset of ntt inverse).
Could be split into three constant arrays to make the compiler count for us
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, it is 64 bytes per line (16 4-byte uint32_ts), not that hard :-) ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ha! I didn't realize it was 16 per line.. ran out of fingers while counting!!! :)
'works for me, as long as its a "premeditated" decision'
| for (int i = 0; i < parCnt; i++) { | ||
| __ vpmuldq(xmm(i + scratchReg1), xmm(i + inputReg1), xmm((inputReg2 == 29) ? 29 : inputReg2 + i), Assembler::AVX_512bit); | ||
| } | ||
| for (int i = 0; i < parCnt; i++) { | ||
| __ vpmulld(xmm(i + scratchReg2), xmm(i + scratchReg1), xmm30, Assembler::AVX_512bit); | ||
| } | ||
| for (int i = 0; i < parCnt; i++) { | ||
| __ vpmuldq(xmm(i + scratchReg2), xmm(i + scratchReg2), xmm31, Assembler::AVX_512bit); | ||
| } | ||
| for (int i = 0; i < parCnt; i++) { | ||
| __ evpsubd(xmm(i + outputReg), k0, xmm(i + scratchReg1), xmm(i + scratchReg2), false, Assembler::AVX_512bit); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is such a deceptively brilliant function!!! Took me a while to understand (and map to Java montMul function). Perhaps needs more comments.
The comment on line 99 does provide good hints, but I still had some trouble. I ended up annotating a copy quite a bit. I do think all 'clever code' needs comments. Here is my annotated version, if you want to copy out anything:
static void montmulEven2(XMMRegister outputReg, XMMRegister inputReg1, XMMRegister inputReg2, XMMRegister scratchReg1,
XMMRegister scratchReg2, XMMRegister montQInvModR, XMMRegister dilithium_q, int parCnt, MacroAssembler* _masm) {
int output = outputReg->encoding();
int input1 = inputReg1->encoding();
int input2 = inputReg2->encoding();
int scratch1 = scratchReg1->encoding();
int scratch2 = scratchReg2->encoding();
for (int i = 0; i < parCnt; i++) {
// scratch1 = (int64)input1_even*input2_even
// Java: long a = (long) b * (long) c;
__ vpmuldq(xmm(i + scratch1), xmm(i + input1), xmm((input2 == 29) ? 29 : input2 + i), Assembler::AVX_512bit);
}
for (int i = 0; i < parCnt; i++) {
// scratch2 = int32(montQInvModR*(int32)scratch1)
// Java: int aLow = (int) a;
// Java: int m = MONT_Q_INV_MOD_R * aLow; // signed low product
__ vpmulld(xmm(i + scratch2), xmm(i + scratch1), montQInvModR, Assembler::AVX_512bit);
}
for (int i = 0; i < parCnt; i++) {
// scratch2 = (int64)scratch2_even*dilithium_q_even
// Java: ((long)m * MONT_Q)
__ vpmuldq(xmm(i + scratch2), xmm(i + scratch2), dilithium_q, Assembler::AVX_512bit);
}
for (int i = 0; i < parCnt; i++) {
// output_odd = scratch1_odd - scratch2_odd
// Java: (aHigh - (int) (("scratch2") >> MONT_R_BITS))
__ evpsubd(xmm(i + output), k0, xmm(i + scratch1), xmm(i + scratch2), false, Assembler::AVX_512bit);
}
}
- add comment that input2 can be xmm29, treated as constants, not consecutive (i.e. zetas)
- Candidate for ascii art, even/odd columns, implicit int/long casts (or more 'math' comments on what happens)
- use XMMRegisters instead of numbers (improve callsite readability)
- can use either
inputReg1 = inputReg1->successor() - or get
encoding()and keep current style
- can use either
- could be static (local) function (hide from header), then pass _masm
- pass all registers used (helps seeing register allocation, confirm no overlaps)
False trails (i.e. nothing to do, but I thought about it already, so other reviewer doesnt have to?)
- (ignore: worse performance) squash into a single for loop, let cpu do out-of-order (and improve readability)
- xmm30/xmm31 (montQInvModR/dilithium_q) are constant. At a glance, it looks like they should be combined into one precomputed one. And paper 039.pdf suggests merging constants precompute the product; but.. different constants and looking at Java, there are several implicit casts
For reductions of products inside the NTT this is not a problem because one has to multiply by the roots of unity
which are compile-time constants. So one can just precompute them with an additional
factor of β mod q so that the results after Montgomery reduction are in fact congruent to the desired value a
| for (int i = 0; i < 4; i++) { | ||
| __ evmovdqul(xmm(i), Address(poly1, i * 64), Assembler::AVX_512bit); | ||
| __ evmovdqul(xmm(i + 4), Address(poly2, i * 64), Assembler::AVX_512bit); | ||
| } | ||
|
|
||
| montmulEven(8, 4, 29, 12, 16, 4); | ||
| for (int i = 0; i < 4; i++) { | ||
| __ vpshufd(xmm(i + 8), xmm(i + 8), 0xB1, Assembler::AVX_512bit); | ||
| } | ||
| montmulEven(8, 0, 8, 12, 16, 4); | ||
| for (int i = 0; i < 4; i++) { | ||
| __ vpshufd(xmm(i), xmm(i), 0xB1, Assembler::AVX_512bit); | ||
| __ vpshufd(xmm(i + 4), xmm(i + 4), 0xB1, Assembler::AVX_512bit); | ||
| } | ||
| montmulEven(4, 4, 29, 12, 16, 4); | ||
| for (int i = 0; i < 4; i++) { | ||
| __ vpshufd(xmm(i + 4), xmm(i + 4), 0xB1, Assembler::AVX_512bit); | ||
| } | ||
| montmulEven(0, 0, 4, 12, 16, 4); | ||
| for (int i = 0; i < 4; i++) { | ||
| __ evpermt2d(xmm(i), xmm28, xmm(i + 8), Assembler::AVX_512bit); | ||
| } | ||
| for (int i = 0; i < 4; i++) { | ||
| __ evmovdqul(Address(result, i * 64), xmm(i), Assembler::AVX_512bit); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nice, compact and clean. The biggest issue I have with following this code is really with all the 'raw' registers. I would much rather prefer symbolic names, but up to you to decide style.
I ended up 'annotating' this snippet, so I could understand it and confirm everything.. as with montmulEven, hope some of it can be useful to you to copy out.
XMMRegister POLY1[] = {xmm0, xmm1, xmm2, xmm3};
XMMRegister POLY2[] = {xmm4, xmm5, xmm6, xmm7};
XMMRegister SCRATCH1[] = {xmm12, xmm13, xmm14, xmm15};
XMMRegister SCRATCH2[] = {xmm16, xmm17, xmm18, xmm19};
XMMRegister SCRATCH3[] = {xmm8, xmm9, xmm10, xmm11};
for (int i = 0; i < 4; i++) {
__ evmovdqul(POLY1[i], Address(poly1, i * 64), Assembler::AVX_512bit);
__ evmovdqul(POLY2[i], Address(poly2, i * 64), Assembler::AVX_512bit);
}
// montmulEven: inputs are in even columns and output is in odd columns
// scratch3_even = poly2_even*montRSquareModQ // poly2 to montgomery domain
montmulEven2(SCRATCH3[0], POLY2[0], montRSquareModQ, SCRATCH1[0], SCRATCH2[0], montQInvModR, dilithium_q, 4, _masm);
for (int i = 0; i < 4; i++) {
// swap even/odd; 0xB1 == 2-3-0-1
__ vpshufd(SCRATCH3[i], SCRATCH3[i], 0xB1, Assembler::AVX_512bit);
}
// scratch3_odd = poly1_even*scratch3_even = poly1_even*poly2_even*montRSquareModQ
montmulEven2(SCRATCH3[0], POLY1[0], SCRATCH3[0], SCRATCH1[0], SCRATCH2[0], 4, montQInvModR, dilithium_q, 4, _masm);
for (int i = 0; i < 4; i++) {
__ vpshufd(POLY1[i], POLY1[i], 0xB1, Assembler::AVX_512bit);
__ vpshufd(POLY2[i], POLY2[i], 0xB1, Assembler::AVX_512bit);
}
// poly2_even = poly2_odd*montRSquareModQ // poly2 to montgomery domain
montmulEven2(POLY2[0], POLY2[0], montRSquareModQ, SCRATCH1[0], SCRATCH2[0], 4, montQInvModR, dilithium_q, 4, _masm);
for (int i = 0; i < 4; i++) {
__ vpshufd(POLY2[i], POLY2[i], 0xB1, Assembler::AVX_512bit);
}
// poly1_odd = poly1_even*poly2_even
montmulEven2(POLY1[0], POLY1[0], POLY2[0], SCRATCH1[0], SCRATCH2[0], 4, montQInvModR, dilithium_q, 4, _masm);
for (int i = 0; i < 4; i++) {
// result is scrambled between scratch3_odd and poly1_odd; unscramble
__ evpermt2d(POLY1[i], perms, SCRATCH3[i], Assembler::AVX_512bit);
}
for (int i = 0; i < 4; i++) {
__ evmovdqul(Address(result, i * 64), POLY1[i], Assembler::AVX_512bit);
}
With symbolic variable names, code was much easier to follow conceptually. Also has the side benefit of making it obvious which XMM registers are used and that there is no conflicts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have rewritten it to use full montmuls (a new function) her and everywhere else. It is much easier to follow the code that way.
|
|
||
| __ evpbroadcastd(xmm29, constant, Assembler::AVX_512bit); // constant multiplier | ||
|
|
||
| __ movl(len, 2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here as the generate_dilithiumNttMult_avx512
- constants can be loaded directly into XMM
- len can be removed by unrolling at compile time
- symbolic names could be used for registers
- comments could be added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
|
||
| // Dilithium multiply polynomials in the NTT domain. | ||
| // Implements | ||
| // static int implDilithiumNttMult( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose no java changes in this PR, but I notice that the inputs are all assumed to have fixed size.
Most/all intrinsics I worked with had some sort of guard (eg Objects.checkFromIndexSize) right before the intrinsic java call. (It usually looks like it can be optimized away). But I notice no such guard here on the java side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These functions will not be used anywhere else and in ML_DSA.java all of the arrays passed to inrinsics are of the correct size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works for me; just thought I would point it out, so its a 'premeditated' decision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I ended up putting some asserts in the java code, just in case...
| const int montMulPermsIdx = 0; | ||
| const int nttL4PermsIdx = 64; | ||
| const int nttL5PermsIdx = 192; | ||
| const int nttL6PermsIdx = 320; | ||
| const int nttL7PermsIdx = 448; | ||
| const int nttInvL0PermsIdx = 704; | ||
| const int nttInvL1PermsIdx = 832; | ||
| const int nttInvL2PermsIdx = 960; | ||
| const int nttInvL3PermsIdx = 1088; | ||
| const int nttInvL4PermsIdx = 1216; | ||
|
|
||
| static address dilithiumAvx512PermsAddr() { | ||
| return (address) dilithiumAvx512Perms; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hear me out.. ...
enums!!
enum nttPermOffset {
montMulPermsIdx = 0,
nttL4PermsIdx = 64,
nttL5PermsIdx = 192,
nttL6PermsIdx = 320,
nttL7PermsIdx = 448,
nttInvL0PermsIdx = 704,
nttInvL1PermsIdx = 832,
nttInvL2PermsIdx = 960,
nttInvL3PermsIdx = 1088,
nttInvL4PermsIdx = 1216,
};
static address dilithiumAvx512PermsAddr(nttPermOffset offset) {
return (address) dilithiumAvx512Perms + offset;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
belay that comment.. now that I looked at generate_dilithiumAlmostInverseNtt_avx512, I see why thats not the 'entire picture'..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I leave it as it is now.
| // Dilithium Intrinsics | ||
| // Currently we only have them for AVX512 | ||
| #ifdef _LP64 | ||
| if (supports_evex() && supports_avx512bw()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
supports_evex check looks redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are checks for two different feature bits: CPU_AVX512F and CPU_AVX512BW. Are you saying that the latter implies the former in every implementation of the spec?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AVX512BW is built on top of AVX512F spec. In assembler and other places we only check BW in assertions which implies EVEX.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still need to have a look at the sha3 changes, but I think I am done with the most complex part of the review. This was a really interesting bit of code to review!
| } | ||
| } | ||
|
|
||
| static void loadPerm(int destinationRegs[], Register perms, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replXmm? i.e. this function is replicating (any) Xmm register, not just perm?..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since I am only using it for permutation describers, I thought this way it is easier to follow what is happening.
| const int montMulPermsIdx = 0; | ||
| const int nttL4PermsIdx = 64; | ||
| const int nttL5PermsIdx = 192; | ||
| const int nttL6PermsIdx = 320; | ||
| const int nttL7PermsIdx = 448; | ||
| const int nttInvL0PermsIdx = 704; | ||
| const int nttInvL1PermsIdx = 832; | ||
| const int nttInvL2PermsIdx = 960; | ||
| const int nttInvL3PermsIdx = 1088; | ||
| const int nttInvL4PermsIdx = 1216; | ||
|
|
||
| static address dilithiumAvx512PermsAddr() { | ||
| return (address) dilithiumAvx512Perms; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
belay that comment.. now that I looked at generate_dilithiumAlmostInverseNtt_avx512, I see why thats not the 'entire picture'..
| // qinvmodR and q are repeated in all slots of Zmm30 and Zmm31, resp. | ||
| // Zmm8-Zmm23 used as scratch registers | ||
| // result goes to Zmm0-Zmm7 | ||
| static void montMulByConst128(MacroAssembler *_masm) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this function some more.. I think you could remove this function and replace it with two calls to montMul64?
montMul64(xmm0_3, xmm0_3, xmm29_29, Scratch*, _masm);
montMul64(xmm4_7, xmm4_7, xmm29_29, Scratch*, _masm);
Scratch would have to be defined..
| // zetas (int[256]) = c_rarg1 | ||
| // | ||
| // | ||
| static address generate_dilithiumAlmostNtt_avx512(StubGenerator *stubgen, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comments as to generate_dilithiumAlmostInverseNtt_avx512
- similar comment about the 'pair-wise' operation, updating
[j]and[j+l]at a time.. - somehow had less trouble following the flow through registers here, perhaps I am getting used to it. FYI, ended renaming some as:
// xmm16_27 = Temp1
// xmm0_3 = Coeffs1
// xmm4_7 = Coeffs2
// xmm8_11 = Coeffs3
// xmm12_15 = Coeffs4 = Temp2
// xmm16_27 = Scratch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me, it was easier to follow what goes where using the xmm... names (with the symbolic names you always have to remember which one overlaps with another and how much).
| // | ||
| // coeffs (int[256]) = c_rarg0 | ||
| // zetas (int[256]) = c_rarg1 | ||
| static address generate_dilithiumAlmostInverseNtt_avx512(StubGenerator *stubgen, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done with this function; Perhaps the 'permute table' is a common vector-algorithm pattern, but this is really clever!
Some general comments first, rest inline.
- The array names for registers helped a lot. And so did the new helper functions!
- The java version of this code is quite intimidating to vectorize.. 3D loop, with geometric iteration variables.. and the literature is even more intimidating (discrete convolutions which I havent touched in two decades, ffts, ntts, etc.) Here is my attempt at a comment to 'un-scare' the next reader, though feel free to reword however you like.
The core of the (Java) loop is this 'pair-wise' operation:
int a = coeffs[j];
int b = coeffs[j + offset];
coeffs[j] = (a + b);
coeffs[j + offset] = montMul(a - b, -MONT_ZETAS_FOR_NTT[m]);
There are 8 'levels' (0-7); ('levels' are equivalent to (unrolling) the outer (Java) loop)
At each level, the 'pair-wise-offset' doubles (2^l: 1, 2, 4, 8, 16, 32, 64, 128).
To vectorize this Java code, observe that at each level, REGARDLESS the offset, half the operations are the SUM, and the other half is the
montgomery MULTIPLICATION (of the pair-difference with a constant). At each level, one 'just' has to shuffle
the coefficients, so that SUMs and MULTIPLICATIONs line up accordingly.
Otherwise, this pattern is 'lightly similar' to a discrete convolution (compute integral/summation of two functions at every offset)
- I still would prefer (more) symbolic register names.. I wouldn't hold my approval over it so won't object if nobody else does, but register numbers are harder to 'see' through the flow. I ended up search/replacing/'annotating' to make it easier on myself to follow the flow of data:
// xmm8_11 = Perms1
// xmm12_15 = Perms2
// xmm16_27 = Scratch
// xmm0_3 = CoeffsPlus
// xmm4_7 = CoeffsMul
// xmm24_27 = CoeffsMinus (overlaps with Scratch)
(I made a similar comment, but I think it is now hidden after the last refactor)
- would prefer to see the helper functions to get ALL the registers passed explicitly (i.e. currently
montMulPerm,montQInvModR,dilithium_q,xmm29, are implicit.). As a general rule, I've tried to set up all the registers up at the 'entry' function (generate_dilithium*in this case) and from there on, use symbolic names. Not always reasonable, but what I've grown used to see?
Done with this function; Perhaps the 'permute table' is a common vector-algorithm pattern, but this is really clever!
Some general comments first, rest inline.
- The array names for registers helped a lot. And so did the new helper functions!
- The java version of this code is quite intimidating to vectorize.. 3D loop, with geometric iteration variables.. and the literature is even more intimidating (discrete convolutions which I havent touched in two decades, ffts, ntts, etc.) Here is my attempt at a comment to 'un-scare' the next reader, though feel free to reword however you like.
The core of the (Java) loop is this 'pair-wise' operation:
int a = coeffs[j];
int b = coeffs[j + offset];
coeffs[j] = (a + b);
coeffs[j + offset] = montMul(a - b, -MONT_ZETAS_FOR_NTT[m]);
There are 8 'levels' (0-7); ('levels' are equivalent to (unrolling) the outer (Java) loop)
At each level, the 'pair-wise-offset' doubles (2^l: 1, 2, 4, 8, 16, 32, 64, 128).
To vectorize this Java code, observe that at each level, REGARDLESS the offset, half the operations are the SUM, and the other half is the
montgomery MULTIPLICATION (of the pair-difference with a constant). At each level, one 'just' has to shuffle
the coefficients, so that SUMs and MULTIPLICATIONs line up accordingly.
Otherwise, this pattern is 'lightly similar' to a discrete convolution (compute integral/summation of two functions at every offset)
- I still would prefer (more) symbolic register names.. I wouldn't hold my approval over it so won't object if nobody else does, but register numbers are harder to 'see' through the flow. I ended up search/replacing/'annotating' to make it easier on myself to follow the flow of data:
// xmm8_11 = Perms1
// xmm12_15 = Perms2
// xmm16_27 = Scratch
// xmm0_3 = CoeffsPlus
// xmm4_7 = CoeffsMul
// xmm24_27 = CoeffsMinus (overlaps with Scratch)
(I made a similar comment, but I think it is now hidden after the last refactor)
- would prefer to see the helper functions to get ALL the registers passed explicitly (i.e. currently
montMulPerm,montQInvModR,dilithium_q,xmm29, are implicit.). As a general rule, I've tried to set up all the registers up at the 'entry' function (generate_dilithium*in this case) and from there on, use symbolic names. Not always reasonable, but what I've grown used to see?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some more comments, but I kept the xmm... names for the registers, just like with the ntt function.
| for (int i = 0; i < 8; i++) { | ||
| __ evpaddd(xmm(i + 16), k0, xmm(i), xmm(i + 8), false, Assembler::AVX_512bit); | ||
| } | ||
|
|
||
| for (int i = 0; i < 8; i++) { | ||
| __ evpsubd(xmm(i), k0, xmm(i + 8), xmm(i), false, Assembler::AVX_512bit); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fairly clean as is, but could also be two sub_add calls, I think (you have to swap order of add/sub in the helper, to be able to clobber xmm(i).. or swap register usage downstream, so perhaps not.. but would be cleaner)
sub_add(CoeffsPlus, Scratch, Perms1, CoeffsPlus, _masm);
sub_add(CoeffsMul, &Scratch[4], Perms2, CoeffsMul, _masm);
If nothing else, would had prefered to see the use of the register array variables
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather leave this alone, too. I was considering the same, but decided that this is fairly easy to follow, it would be more complicated to either add a new helper function or follow where there are overlaps in the symbolically named register sets.
|
|
||
| store4Xmms(coeffs, 0, xmm16_19, _masm); | ||
| store4Xmms(coeffs, 4 * XMMBYTES, xmm20_23, _masm); | ||
| montMulByConst128(_masm); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would prefer explicit parameters here. But I think this could also be two montMul64 calls?
montMul64(xmm0_3, xmm0_3, xmm29_29, Scratch, _masm);
montMul64(xmm4_7, xmm4_7, xmm29_29, Scratch, _masm);
(I think there is one other use of montMulByConst128 where same applies; then you could delete both montMulByConst128 and montmulEven
| for (int i = 0; i < 8; i += 2) { | ||
| __ evpermi2d(xmm(i/2 + 16), xmm(i), xmm(i + 1), Assembler::AVX_512bit); | ||
| } | ||
| for (int i = 0; i < 8; i += 2) { | ||
| __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), Assembler::AVX_512bit); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wish there was a more 'abstract' way to arrange this, so its obvious from the shape of the code what registers are input/outputs (i.e. and use the register arrays). Even though its just 'elementary index operations' i/2 + 16 is still 'clever'. Couldnt think of anything myself though (same elsewhere in this function for the table permutes).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, this is how it is when we have three inputs, one of which also plays as output... At least the output is always the first one (so that one gets clobbered). This is why you have to replicate the permutation describer when you need both permutands later.
| __ evpaddd(xmm4, k0, xmm0, barrettAddend, false, Assembler::AVX_512bit); | ||
| __ evpaddd(xmm5, k0, xmm1, barrettAddend, false, Assembler::AVX_512bit); | ||
| __ evpaddd(xmm6, k0, xmm2, barrettAddend, false, Assembler::AVX_512bit); | ||
| __ evpaddd(xmm7, k0, xmm3, barrettAddend, false, Assembler::AVX_512bit); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fairly 'straightforward' transcription of the java code.. no comments from me.
At first glance using xmm0_3, xmm4_7, etc. might had been a good idea, but you only save one line per 4x group. (Unless you have one big loop, but I suspect that give you worse performance? Is that something you tried already? Might be worth it otherwise..)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have considered this but decided to leave it alone (for the reason that you mentioned).
| public static void mlDsaDecomposePoly(int[] input, int[] lowPart, int[] highPart, | ||
| int twoGamma2, int multiplier) { | ||
| assert (input.length == ML_DSA_N) && (lowPart.length == ML_DSA_N) | ||
| && (highPart.length == ML_DSA_N); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote this test to test java-to-intrinsic correspondence. Might be good to include it (and add the other 4 intrinsics). This is very similar to all my other Fuzz tests I've been adding for my own intrinsics (and you made this test FAR easier to write by breaking out the java implementation; need to 'copy' that pattern myself)
import java.util.Arrays;
import java.util.Random;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Constructor;
public class ML_DSA_Intrinsic_Test {
public static void main(String[] args) throws Exception {
MethodHandles.Lookup lookup = MethodHandles.lookup();
Class<?> kClazz = Class.forName("sun.security.provider.ML_DSA");
Constructor<?> constructor = kClazz.getDeclaredConstructor(
int.class);
constructor.setAccessible(true);
Method m = kClazz.getDeclaredMethod("mlDsaNttMultiply",
int[].class, int[].class, int[].class);
m.setAccessible(true);
MethodHandle mult = lookup.unreflect(m);
m = kClazz.getDeclaredMethod("implDilithiumNttMultJava",
int[].class, int[].class, int[].class);
m.setAccessible(true);
MethodHandle multJava = lookup.unreflect(m);
Random rnd = new Random();
long seed = rnd.nextLong();
rnd.setSeed(seed);
//Note: it might be useful to increase this number during development of new intrinsics
final int repeat = 1000000;
int[] coeffs1 = new int[ML_DSA_N];
int[] coeffs2 = new int[ML_DSA_N];
int[] prod1 = new int[ML_DSA_N];
int[] prod2 = new int[ML_DSA_N];
try {
for (int i = 0; i < repeat; i++) {
run(prod1, prod2, coeffs1, coeffs2, mult, multJava, rnd, seed, i);
}
System.out.println("Fuzz Success");
} catch (Throwable e) {
System.out.println("Fuzz Failed: " + e);
}
}
private static final int ML_DSA_N = 256;
public static void run(int[] prod1, int[] prod2, int[] coeffs1, int[] coeffs2,
MethodHandle mult, MethodHandle multJava, Random rnd,
long seed, int i) throws Exception, Throwable {
for (int j = 0; j<ML_DSA_N; j++) {
coeffs1[j] = rnd.nextInt();
coeffs2[j] = rnd.nextInt();
}
mult.invoke(prod1, coeffs1, coeffs2);
multJava.invoke(prod2, coeffs1, coeffs2);
if (!Arrays.equals(prod1, prod2)) {
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result mismatch: " + Arrays.toString(prod1) + " != " + Arrays.toString(prod2));
}
}
}
// java --add-opens java.base/sun.security.provider=ALL-UNNAMED -XX:+UseDilithiumIntrinsics test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will consider it for a follow-up PR.
|
@vpaprotsk , thanks a lot for the very thorough review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No further comments from me. (I did leave two questions, but nothing that requires code changes)
Thanks for addressing all my many (lengthy) comments and questions. And the refactor!
| // | ||
| // Performs two keccak() computations in parallel. The steps of the | ||
| // two computations are executed interleaved. | ||
| static address generate_double_keccak(StubGenerator *stubgen, MacroAssembler *_masm) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function seems ok. I didnt do as line-by-line 'exact' review as for the NTT intrinsics, but just put the new version into a diff next to the original function. Seems like a reasonable clean 'refactor' (hardcode the blocksize, add new input registers 10-14. Makes it really easy to spot vs 0-4 original registers..)
I didnt realize before that the 'top 3 limbs' are wasted. I guess it doesnt matter, there are registers to spare aplenty and it makes the entire algorithm cleaner and easier to follow.
I did also stare at the algorithm with the 'What about AVX2' question.. This function would pretty much need to be rewritten it looks like :/
Last two questions..
- how much performance is gained from doubling this function up?
- If thats worth it.. what if instead it was quadrupled the input? (I scanned the java code, it looked like NR was parametrized already to 2..). It looks like there are almost enough registers here to go to 4 (I think 3 would need to be freed up somehow.. alternatively, the upper 3 limbs are empty in all operations, perhaps it could be used instead.. at the expense of readability)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, the algorithm (keccak()) is doing the same things on 5 array elements (It works on essentially a 5x5 matrix doing row and column operations, so putting 5 array entries in a vector register was the "natural" thing to do).
This function can only be used under very special circumstances, which occur during the generation of tha "A matrix" in ML-KEM and ML-DSA, the speed of that matrix generation has almost doubled (I don't have exact numbers).
We are using 7 registers per state and 15 for the constants, so we have only 3 to spare. We could perhaps juggle with the constants keeping just the ones that will be needed next in registers and reloading them "just in time", but that might slow things down a bit - more load instructions executed + maybe some load delay. On the other hand, more parallelism. I might try it out.
| __ mov64(rax,1); | ||
| __ kmovbl(k1, rax); | ||
| __ addl(rax,2); | ||
| __ kmovbl(k2, rax); | ||
| __ addl(rax, 4); | ||
| __ kmovbl(k3, rax); | ||
| __ addl(rax, 8); | ||
| __ kmovbl(k4, rax); | ||
| __ addl(rax, 16); | ||
| __ kmovbl(k5, rax); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could use the sequence from generate_sha3_implCompress to setup the K registers, that has less dependency:
__ movl(rax, 0x1F);
__ kmovbl(k5, rax);
__ kshiftrbl(k4, k5, 1);
__ kshiftrbl(k3, k5, 2);
__ kshiftrbl(k2, k5, 3);
__ kshiftrbl(k1, k5, 4);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! (I had copied/doubled this function from the single state version before you made me do this change on that one and I forgot to update the copy :-) ) Changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ferakocz , I verified the new version of the patch on Linux and Windows, and it works fine.
Thanks for addressing my comments. Your passion is contagious :-)
| // levels 2 to 7 are done in 2 batches, by first saving half of the coefficients | ||
| // from level 1 into memory, doing all the level 2 to level 7 computations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In line number 344 - 347, we seem to be storing all the coefficients from level 1 into memory.
| store4Xmms(coeffs, 0, xmm0_3, _masm); | ||
| store4Xmms(coeffs, 4 * XMMBYTES, xmm4_7, _masm); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be unnecessary store.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching that. Changed.
| // level 4 | ||
| loadPerm(xmm16_19, perms, nttL4PermsIdx, _masm); | ||
| loadPerm(xmm12_15, perms, nttL4PermsIdx + 64, _masm); | ||
| load4Xmms(xmm24_27, zetas, 4 * 512, _masm); // for level 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment // for level3 is not relevant here and could be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooops. Deleted the comment.
| __ subl(len, 4 * XMMBYTES); | ||
| __ addptr(highPart, 4 * XMMBYTES); | ||
| __ addptr(lowPart, 4 * XMMBYTES); | ||
| __ cmpl(len, 0); | ||
| __ jcc(Assembler::notEqual, L_loop); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks to me that subl and cmpl could be merged:
__ addptr(highPart, 4 * XMMBYTES);
__ addptr(lowPart, 4 * XMMBYTES);
__ subl(len, 4 * XMMBYTES);
__ jcc(Assembler::notEqual, L_loop);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed.
| __ evporq(xmm18, k0, xmm18, xmm22, false, Assembler::AVX_512bit); | ||
| __ evporq(xmm19, k0, xmm19, xmm23, false, Assembler::AVX_512bit); | ||
|
|
||
| __ evpsubd(xmm12, k0, zero, one, false, Assembler::AVX_512bit); // -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The -1 initialization could be done outside the loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really. All registers are used.
| __ xorl(scratch, scratch); | ||
| __ evpbroadcastd(zero, scratch, Assembler::AVX_512bit); // 0 | ||
| __ addl(scratch, 1); | ||
| __ evpbroadcastd(one, scratch, Assembler::AVX_512bit); // 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A better way to initialize (0, 1, -1) vectors is:
// load 0 into int vector
vpxor(zero, zero, zero, Assembler::AVX_512bit);
// load -1 into int vector
vpternlogd(minusOne, 0xff, minusOne, minusOne, Assembler::AVX_512bit);
// load 1 into int vector
vpsubd(one, zero, minusOne, Assembler::AVX_512bit);
Where minusOne could be xmm31.
A broadcast from r register to xmm register is more expensive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall very clean and nicely done PR. Thanks a lot for considering my inputs.
That is in no small part thanks to the reviewers, especially to Volodymyr! |
|
@ferakocz Once you do /integrate, I will be honored to sponsor your PR. |
|
/integrate |
Thanks! |
|
/sponsor |
|
Going to push as commit e87ff32.
Your commit was automatically rebased without conflicts. |
By using the AVX-512 vector registers the speed of the computation of the ML-DSA algorithms (key generation, document signing, signature verification) can be approximately doubled.
Progress
Issue
Reviewers
Reviewing
Using
gitCheckout this PR locally:
$ git fetch https://git.openjdk.org/jdk.git pull/23860/head:pull/23860$ git checkout pull/23860Update a local copy of the PR:
$ git checkout pull/23860$ git pull https://git.openjdk.org/jdk.git pull/23860/headUsing Skara CLI tools
Checkout this PR locally:
$ git pr checkout 23860View PR using the GUI difftool:
$ git pr show -t 23860Using diff file
Download this PR as a diff file:
https://git.openjdk.org/jdk/pull/23860.diff
Using Webrev
Link to Webrev Comment