Skip to content

Conversation

@ferakocz
Copy link
Contributor

@ferakocz ferakocz commented Mar 3, 2025

By using the AVX-512 vector registers the speed of the computation of the ML-DSA algorithms (key generation, document signing, signature verification) can be approximately doubled.


Progress

  • Change must be properly reviewed (1 review required, with at least 1 Reviewer)
  • Change must not contain extraneous whitespace
  • Commit message must refer to an issue

Issue

  • JDK-8351034: Add AVX-512 intrinsics for ML-DSA (Enhancement - P3)

Reviewers

Reviewing

Using git

Checkout this PR locally:
$ git fetch https://git.openjdk.org/jdk.git pull/23860/head:pull/23860
$ git checkout pull/23860

Update a local copy of the PR:
$ git checkout pull/23860
$ git pull https://git.openjdk.org/jdk.git pull/23860/head

Using Skara CLI tools

Checkout this PR locally:
$ git pr checkout 23860

View PR using the GUI difftool:
$ git pr show -t 23860

Using diff file

Download this PR as a diff file:
https://git.openjdk.org/jdk/pull/23860.diff

Using Webrev

Link to Webrev Comment

@bridgekeeper
Copy link

bridgekeeper bot commented Mar 3, 2025

👋 Welcome back ferakocz! A progress list of the required criteria for merging this PR into master will be added to the body of your pull request. There are additional pull request commands available for use with this pull request.

@openjdk
Copy link

openjdk bot commented Mar 3, 2025

@ferakocz This change now passes all automated pre-integration checks.

ℹ️ This project also has non-automated pre-integration requirements. Please see the file CONTRIBUTING.md for details.

After integration, the commit message for the final commit will be:

8351034: Add AVX-512 intrinsics for ML-DSA

Reviewed-by: sviswanathan, lmesnik, vpaprotski, jbhateja

You can use pull request commands such as /summary, /contributor and /issue to adjust it as needed.

At the time when this comment was updated there had been 535 new commits pushed to the master branch:

As there are no conflicts, your changes will automatically be rebased on top of these commits when integrating. If you prefer to avoid this automatic rebasing, please check the documentation for the /integrate command for further details.

As you do not have Committer status in this project an existing Committer must agree to sponsor your change. Possible candidates are the reviewers of this PR (@lmesnik, @jatin-bhateja, @sviswa7) but any other Committer may sponsor as well.

➡️ To flag this PR as ready for integration with the above commit message, type /integrate in a new comment. (Afterwards, your sponsor types /sponsor in a new comment to perform the integration).

@openjdk openjdk bot changed the title JDK-8351034 Add AVX-512 intrinsics for ML-DSA 8351034: Add AVX-512 intrinsics for ML-DSA Mar 3, 2025
@openjdk openjdk bot added the rfr Pull request is ready for review label Mar 3, 2025
@openjdk
Copy link

openjdk bot commented Mar 3, 2025

@ferakocz The following labels will be automatically applied to this pull request:

  • graal
  • hotspot
  • security

When this pull request is ready to be reviewed, an "RFR" email will be sent to the corresponding mailing lists. If you would like to change these labels, use the /label pull request command.

@openjdk openjdk bot added graal graal-dev@openjdk.org security security-dev@openjdk.org hotspot hotspot-dev@openjdk.org labels Mar 3, 2025
@mlbridge
Copy link

mlbridge bot commented Mar 3, 2025

@mcpowers
Copy link
Contributor

mcpowers commented Mar 4, 2025

ML-DSA benchmark results for this PR

keygen    ML-DSA-44    96 us/op
keygen    ML-DSA-65   200 us/op
keygen    ML-DSA-87   272 us/op
siggen    ML-DSA-44   297 us/op
siggen    ML-DSA-65   452 us/op
siggen    ML-DSA-87   728 us/op
sigver    ML-DSA-44   115 us/op
sigver    ML-DSA-65   176 us/op
sigver    ML-DSA-87   290 us/op

ML-DSA no intrinsics

keygen    ML-DSA-44   169 us/op
keygen    ML-DSA-65   302 us/op
keygen    ML-DSA-87   444 us/op
siggen    ML-DSA-44   696 us/op
siggen    ML-DSA-65  1114 us/op
siggen    ML-DSA-87  1828 us/op
sigver    ML-DSA-44   187 us/op
sigver    ML-DSA-65   295 us/op
sigver    ML-DSA-87   473 us/op

@openjdk
Copy link

openjdk bot commented Mar 5, 2025

@ferakocz this pull request can not be integrated into master due to one or more merge conflicts. To resolve these merge conflicts and update this pull request you can run the following commands in the local repository for your personal fork:

git checkout mldsa-avx512-intrinsics
git fetch https://git.openjdk.org/jdk.git master
git merge FETCH_HEAD
# resolve conflicts and follow the instructions given by git merge
git commit -m "Merge master"
git push

@openjdk openjdk bot added the merge-conflict Pull request has merge conflict with target branch label Mar 5, 2025
@openjdk openjdk bot removed the merge-conflict Pull request has merge conflict with target branch label Mar 5, 2025

__ movl(iterations, 2);

__ BIND(L_loop);
Copy link
Member

@jatin-bhateja jatin-bhateja Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ferakocz , Kindly align loop entry address using __align64() here and at all the places before __BIND(LOOP)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @jatin-bhateja, thanks for the suggestion. I have added __ align(OptoLoopAlignment); before all loop entries.

Copy link
Member

@jatin-bhateja jatin-bhateja Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ferakocz ,

Thanks!, for efficient utilization of Decode ICache (please refer to Intel SDM section 3.4.2.5), code blocks should be aligned to 32-byte boundaries; a 64-byte aligned code is a superset of both 16 and 32 byte aligned addresses and also matches with the cacheline size. However, I can noticed that we have been using OptoLoopAlignment at places in AES-GCM also.

I introduced some errors in generate_dilithiumAlmostInverseNtt_avx512 implementation in anticipation of catching it through existing ML_DSA_Tests under
test/jdk/sun/security/provider/acvp

But all the tests passed for me.
java -jar /home/jatinbha/sandboxes/jtreg/build/images/jtreg/lib/jtreg.jar -jdk:$JAVA_HOME -Djdk.test.lib.artifacts.ACVP-Server=/home/jatinbha/softwares/v1.1.0.38.zip -va -timeout:4 Launcher.java

Can you please point out a test I need to use for validation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the easiest is to put a for (int i = 0; i < 1000; i++) loop around the switch statement in the run() method of the ML_DSA_Test class (test/jdk/sun/security/provider/acvp/ML_DSA_Test.java). (This is because the intrinsics kick in after a few thousand calls of the method.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ferakocz , Yes, we should modify the test or lower the compilation threshold with -Xbatch -XX:TieredCompileThreshold=0.1.

Alternatively, since the tests has a depedency on Automatic Cryptographic Validation Test server I have created a simplified test which cover all the security levels.

Kindly include test/hotspot/jtreg/compiler/intrinsics/signature/TestModuleLatticeDSA.java

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a new command to the test test/jdk/sun/security/provider/acvp/Launcher.java. The line with the -Xcomp will invoke the intrinsics on the first call, so they will be tested.

void StubGenerator::generate_sha3_stubs() {
if (UseSHA3Intrinsics) {
StubRoutines::_sha3_implCompress = generate_sha3_implCompress(StubGenStubId::sha3_implCompress_id);
StubRoutines::_double_keccak = generate_double_keccak();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should UseDilithiumIntrinsics guard double_keccak generation ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that is more of a SHA3 thing, other algorithms can take advantage of it, too (e.g. ML-KEM).

@@ -0,0 +1,1404 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update copyright year

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed.

StubRoutines::_dilithiumNttMult = generate_dilithiumNttMult_avx512();
StubRoutines::_dilithiumMontMulByConstant = generate_dilithiumMontMulByConstant_avx512();
StubRoutines::_dilithiumDecomposePoly = generate_dilithiumDecomposePoly_avx512();
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation fix needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed.

const Register constant2use = r10;
const Register roundsLeft = r11;

__ align(OptoLoopAlignment);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant alignment before label should be before it's bind

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed.

@lmesnik
Copy link
Member

lmesnik commented Mar 10, 2025

There are no any new tests in the PR. How fix has been tested by openjdk tests?

@ferakocz
Copy link
Contributor Author

There are no any new tests in the PR. How fix has been tested by openjdk tests?

I have just added one.

* @library /test/lib
* @modules java.base/sun.security.provider
* @run main Launcher
* @run main/othervm -Xcomp Launcher
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this case. Please add it as a separate testcase:
/*

  • @test
  • @summary Test verifies intrinsic implementation.
  • @library /test/lib
  • @modules java.base/sun.security.provider
  • @run main/othervm -Xcomp Launcher
    */

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Member

@lmesnik lmesnik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test changes looks good.

@openjdk openjdk bot added the ready Pull request is ready to be integrated label Mar 17, 2025
Copy link
Contributor

@vpaprotsk vpaprotsk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partial review, just didnt want to sit on comments for this long.

(Spent quite a bit of time catching up on papers and math required)

The biggest roadblock I have following the code are raw register numbers. (And more comments? perhaps I need more math knowledge, but comments would help too).

Also, 'hidden variables' (xmm30). Can't complain, because this is exactly what Vladimir Ivanov told me to do on my first PR #10582 (comment) Perhaps that discussion applies here too.

void generate_sha3_stubs();
address generate_sha3_implCompress(StubGenStubId stub_id);

address generate_double_keccak();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can hide internal helper functions (i.e. montmulEven(*)) if you wish.

The trick is to add MacroAssembler* _masm as a parameter to the static (local) function. Its a trick I use to keep header clean, but still have plenty of helpers

Comment on lines 395 to 409
__ evmovdquq(xmm17, Address(permsAndRots, 0), Assembler::AVX_512bit);
__ evmovdquq(xmm18, Address(permsAndRots, 64), Assembler::AVX_512bit);
__ evmovdquq(xmm19, Address(permsAndRots, 128), Assembler::AVX_512bit);
__ evmovdquq(xmm20, Address(permsAndRots, 192), Assembler::AVX_512bit);
__ evmovdquq(xmm21, Address(permsAndRots, 256), Assembler::AVX_512bit);
__ evmovdquq(xmm22, Address(permsAndRots, 320), Assembler::AVX_512bit);
__ evmovdquq(xmm23, Address(permsAndRots, 384), Assembler::AVX_512bit);
__ evmovdquq(xmm24, Address(permsAndRots, 448), Assembler::AVX_512bit);
__ evmovdquq(xmm25, Address(permsAndRots, 512), Assembler::AVX_512bit);
__ evmovdquq(xmm26, Address(permsAndRots, 576), Assembler::AVX_512bit);
__ evmovdquq(xmm27, Address(permsAndRots, 640), Assembler::AVX_512bit);
__ evmovdquq(xmm28, Address(permsAndRots, 704), Assembler::AVX_512bit);
__ evmovdquq(xmm29, Address(permsAndRots, 768), Assembler::AVX_512bit);
__ evmovdquq(xmm30, Address(permsAndRots, 832), Assembler::AVX_512bit);
__ evmovdquq(xmm31, Address(permsAndRots, 896), Assembler::AVX_512bit);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matter of taste, but I liked the compactness of montmulEven; i.e.

for (i=0; i<15; i++)
    __ evmovdquq(xmm(17+i), Address(permsAndRots, 64*i), Assembler::AVX_512bit);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

__ BIND(rounds24_loop);
__ subl( roundsLeft, 1);

__ evmovdquw(xmm5, xmm0, Assembler::AVX_512bit);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a pattern here; that can be 'compacted' into a loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, no. This loop body is imported from generate_sha3_implCompress() and doubled, as explained in the comment about 15 lines above.

Comment on lines 137 to 140
__ vpmuldq(xmm(scratchReg1), xmm(inputReg11), xmm(inputReg2), Assembler::AVX_512bit);
__ vpmuldq(xmm(scratchReg1 + 1), xmm(inputReg12), xmm(inputReg2 + 1), Assembler::AVX_512bit);
__ vpmuldq(xmm(scratchReg1 + 2), xmm(inputReg13), xmm(inputReg2 + 2), Assembler::AVX_512bit);
__ vpmuldq(xmm(scratchReg1 + 3), xmm(inputReg14), xmm(inputReg2 + 3), Assembler::AVX_512bit);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option for these four lines, to keep the style of rest of function

int inputReg1[] = {inputReg11, inputReg12, inputReg13, inputReg14};
  for (int i = 0; i < parCnt; i++) {
    __ vpmuldq(xmm(scratchReg1 + i), inputReg1[i], xmm(inputReg2 + i), Assembler::AVX_512bit);
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed the whole structure instead.

// levels are the same for all the montmuls that we can do in parallel

// level 0
montmulEven(20, 8, 29, 20, 16, 4);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would improve readability to know which parameter is a register, and which is a count.. i.e.

montmulEven(xmm20, xmm8, xmm29, xmm20, xmm16, 4);

(its not that bad, once I remember that its always the last parameter.. but it does add to the 'mental load' one has to carry, and this code is already interesting enough)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed the structure, now it is clear(er) which parameter is what.

}

ATTRIBUTE_ALIGNED(64) static const uint32_t dilithiumAvx512Perms[] = {
// collect montmul results into the destination register
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as dilithiumAvx512Consts(), 'magic offsets'; except here they are harder to count (eg. not clear visually what is the offset of ntt inverse).

Could be split into three constant arrays to make the compiler count for us

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it is 64 bytes per line (16 4-byte uint32_ts), not that hard :-) ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha! I didn't realize it was 16 per line.. ran out of fingers while counting!!! :)

'works for me, as long as its a "premeditated" decision'

Comment on lines 116 to 127
for (int i = 0; i < parCnt; i++) {
__ vpmuldq(xmm(i + scratchReg1), xmm(i + inputReg1), xmm((inputReg2 == 29) ? 29 : inputReg2 + i), Assembler::AVX_512bit);
}
for (int i = 0; i < parCnt; i++) {
__ vpmulld(xmm(i + scratchReg2), xmm(i + scratchReg1), xmm30, Assembler::AVX_512bit);
}
for (int i = 0; i < parCnt; i++) {
__ vpmuldq(xmm(i + scratchReg2), xmm(i + scratchReg2), xmm31, Assembler::AVX_512bit);
}
for (int i = 0; i < parCnt; i++) {
__ evpsubd(xmm(i + outputReg), k0, xmm(i + scratchReg1), xmm(i + scratchReg2), false, Assembler::AVX_512bit);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is such a deceptively brilliant function!!! Took me a while to understand (and map to Java montMul function). Perhaps needs more comments.

The comment on line 99 does provide good hints, but I still had some trouble. I ended up annotating a copy quite a bit. I do think all 'clever code' needs comments. Here is my annotated version, if you want to copy out anything:

static void montmulEven2(XMMRegister outputReg, XMMRegister inputReg1,  XMMRegister inputReg2, XMMRegister scratchReg1, 
    XMMRegister scratchReg2, XMMRegister montQInvModR, XMMRegister dilithium_q, int parCnt, MacroAssembler* _masm) {
  
  int output = outputReg->encoding();
  int input1 = inputReg1->encoding();
  int input2 = inputReg2->encoding();
  int scratch1 = scratchReg1->encoding();
  int scratch2 = scratchReg2->encoding();
  for (int i = 0; i < parCnt; i++) {
    // scratch1 = (int64)input1_even*input2_even
    //    Java: long a = (long) b * (long) c;
    __ vpmuldq(xmm(i + scratch1), xmm(i + input1), xmm((input2 == 29) ? 29 : input2 + i), Assembler::AVX_512bit);
  }
  for (int i = 0; i < parCnt; i++) {
    // scratch2 = int32(montQInvModR*(int32)scratch1)
    //    Java: int aLow = (int) a;
    //    Java: int m = MONT_Q_INV_MOD_R * aLow; // signed low product
    __ vpmulld(xmm(i + scratch2), xmm(i + scratch1), montQInvModR, Assembler::AVX_512bit);
  }
  for (int i = 0; i < parCnt; i++) {
    // scratch2 = (int64)scratch2_even*dilithium_q_even
    //    Java: ((long)m * MONT_Q)
    __ vpmuldq(xmm(i + scratch2), xmm(i + scratch2), dilithium_q, Assembler::AVX_512bit);
  }
  for (int i = 0; i < parCnt; i++) {
    // output_odd = scratch1_odd - scratch2_odd
    //    Java: (aHigh - (int) (("scratch2") >> MONT_R_BITS))
    __ evpsubd(xmm(i + output), k0, xmm(i + scratch1), xmm(i + scratch2), false, Assembler::AVX_512bit);
  }
}
  • add comment that input2 can be xmm29, treated as constants, not consecutive (i.e. zetas)
  • Candidate for ascii art, even/odd columns, implicit int/long casts (or more 'math' comments on what happens)
  • use XMMRegisters instead of numbers (improve callsite readability)
    • can use either inputReg1 = inputReg1->successor()
    • or get encoding() and keep current style
  • could be static (local) function (hide from header), then pass _masm
  • pass all registers used (helps seeing register allocation, confirm no overlaps)

False trails (i.e. nothing to do, but I thought about it already, so other reviewer doesnt have to?)

  • (ignore: worse performance) squash into a single for loop, let cpu do out-of-order (and improve readability)
  • xmm30/xmm31 (montQInvModR/dilithium_q) are constant. At a glance, it looks like they should be combined into one precomputed one. And paper 039.pdf suggests merging constants precompute the product; but.. different constants and looking at Java, there are several implicit casts
For reductions of products inside the NTT this is not a problem because one has to multiply by the roots of unity
which are compile-time constants. So one can just precompute them with an additional
factor of β mod q so that the results after Montgomery reduction are in fact congruent to the desired value a

Comment on lines 1017 to 1041
for (int i = 0; i < 4; i++) {
__ evmovdqul(xmm(i), Address(poly1, i * 64), Assembler::AVX_512bit);
__ evmovdqul(xmm(i + 4), Address(poly2, i * 64), Assembler::AVX_512bit);
}

montmulEven(8, 4, 29, 12, 16, 4);
for (int i = 0; i < 4; i++) {
__ vpshufd(xmm(i + 8), xmm(i + 8), 0xB1, Assembler::AVX_512bit);
}
montmulEven(8, 0, 8, 12, 16, 4);
for (int i = 0; i < 4; i++) {
__ vpshufd(xmm(i), xmm(i), 0xB1, Assembler::AVX_512bit);
__ vpshufd(xmm(i + 4), xmm(i + 4), 0xB1, Assembler::AVX_512bit);
}
montmulEven(4, 4, 29, 12, 16, 4);
for (int i = 0; i < 4; i++) {
__ vpshufd(xmm(i + 4), xmm(i + 4), 0xB1, Assembler::AVX_512bit);
}
montmulEven(0, 0, 4, 12, 16, 4);
for (int i = 0; i < 4; i++) {
__ evpermt2d(xmm(i), xmm28, xmm(i + 8), Assembler::AVX_512bit);
}
for (int i = 0; i < 4; i++) {
__ evmovdqul(Address(result, i * 64), xmm(i), Assembler::AVX_512bit);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice, compact and clean. The biggest issue I have with following this code is really with all the 'raw' registers. I would much rather prefer symbolic names, but up to you to decide style.

I ended up 'annotating' this snippet, so I could understand it and confirm everything.. as with montmulEven, hope some of it can be useful to you to copy out.

  XMMRegister POLY1[] = {xmm0, xmm1, xmm2, xmm3};
  XMMRegister POLY2[] = {xmm4, xmm5, xmm6, xmm7};
  XMMRegister SCRATCH1[] = {xmm12, xmm13, xmm14, xmm15};
  XMMRegister SCRATCH2[] = {xmm16, xmm17, xmm18, xmm19};
  XMMRegister SCRATCH3[] = {xmm8, xmm9, xmm10, xmm11};
  for (int i = 0; i < 4; i++) {
    __ evmovdqul(POLY1[i], Address(poly1, i * 64), Assembler::AVX_512bit);
    __ evmovdqul(POLY2[i], Address(poly2, i * 64), Assembler::AVX_512bit);
  }

  // montmulEven: inputs are in even columns and output is in odd columns
  // scratch3_even = poly2_even*montRSquareModQ // poly2 to montgomery domain
  montmulEven2(SCRATCH3[0], POLY2[0], montRSquareModQ, SCRATCH1[0], SCRATCH2[0], montQInvModR, dilithium_q, 4, _masm);
  for (int i = 0; i < 4; i++) {
    // swap even/odd; 0xB1 == 2-3-0-1
    __ vpshufd(SCRATCH3[i], SCRATCH3[i], 0xB1, Assembler::AVX_512bit);
  }

  // scratch3_odd = poly1_even*scratch3_even = poly1_even*poly2_even*montRSquareModQ
  montmulEven2(SCRATCH3[0], POLY1[0], SCRATCH3[0], SCRATCH1[0], SCRATCH2[0], 4, montQInvModR, dilithium_q, 4, _masm);
  for (int i = 0; i < 4; i++) {
    __ vpshufd(POLY1[i], POLY1[i], 0xB1, Assembler::AVX_512bit);
    __ vpshufd(POLY2[i], POLY2[i], 0xB1, Assembler::AVX_512bit);
  }

  // poly2_even = poly2_odd*montRSquareModQ // poly2 to montgomery domain
  montmulEven2(POLY2[0], POLY2[0], montRSquareModQ, SCRATCH1[0], SCRATCH2[0], 4, montQInvModR, dilithium_q, 4, _masm);
  for (int i = 0; i < 4; i++) {
    __ vpshufd(POLY2[i], POLY2[i], 0xB1, Assembler::AVX_512bit);
  }

  // poly1_odd = poly1_even*poly2_even
  montmulEven2(POLY1[0], POLY1[0], POLY2[0], SCRATCH1[0], SCRATCH2[0], 4, montQInvModR, dilithium_q, 4, _masm);
  for (int i = 0; i < 4; i++) {
    // result is scrambled between scratch3_odd and poly1_odd; unscramble
    __ evpermt2d(POLY1[i], perms, SCRATCH3[i], Assembler::AVX_512bit);
  }
  for (int i = 0; i < 4; i++) {
    __ evmovdqul(Address(result, i * 64), POLY1[i], Assembler::AVX_512bit);
  }

With symbolic variable names, code was much easier to follow conceptually. Also has the side benefit of making it obvious which XMM registers are used and that there is no conflicts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have rewritten it to use full montmuls (a new function) her and everywhere else. It is much easier to follow the code that way.


__ evpbroadcastd(xmm29, constant, Assembler::AVX_512bit); // constant multiplier

__ movl(len, 2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here as the generate_dilithiumNttMult_avx512

  • constants can be loaded directly into XMM
  • len can be removed by unrolling at compile time
  • symbolic names could be used for registers
  • comments could be added

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


// Dilithium multiply polynomials in the NTT domain.
// Implements
// static int implDilithiumNttMult(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose no java changes in this PR, but I notice that the inputs are all assumed to have fixed size.

Most/all intrinsics I worked with had some sort of guard (eg Objects.checkFromIndexSize) right before the intrinsic java call. (It usually looks like it can be optimized away). But I notice no such guard here on the java side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions will not be used anywhere else and in ML_DSA.java all of the arrays passed to inrinsics are of the correct size.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me; just thought I would point it out, so its a 'premeditated' decision.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I ended up putting some asserts in the java code, just in case...

Comment on lines +106 to +119
const int montMulPermsIdx = 0;
const int nttL4PermsIdx = 64;
const int nttL5PermsIdx = 192;
const int nttL6PermsIdx = 320;
const int nttL7PermsIdx = 448;
const int nttInvL0PermsIdx = 704;
const int nttInvL1PermsIdx = 832;
const int nttInvL2PermsIdx = 960;
const int nttInvL3PermsIdx = 1088;
const int nttInvL4PermsIdx = 1216;

static address dilithiumAvx512PermsAddr() {
return (address) dilithiumAvx512Perms;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hear me out.. ...
enums!!

enum nttPermOffset {
  montMulPermsIdx = 0,
  nttL4PermsIdx = 64,
  nttL5PermsIdx = 192,
  nttL6PermsIdx = 320,
  nttL7PermsIdx = 448,
  nttInvL0PermsIdx = 704,
  nttInvL1PermsIdx = 832,
  nttInvL2PermsIdx = 960,
  nttInvL3PermsIdx = 1088,
  nttInvL4PermsIdx = 1216,
};
static address dilithiumAvx512PermsAddr(nttPermOffset offset) {
  return (address) dilithiumAvx512Perms + offset;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

belay that comment.. now that I looked at generate_dilithiumAlmostInverseNtt_avx512, I see why thats not the 'entire picture'..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I leave it as it is now.

// Dilithium Intrinsics
// Currently we only have them for AVX512
#ifdef _LP64
if (supports_evex() && supports_avx512bw()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

supports_evex check looks redundant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are checks for two different feature bits: CPU_AVX512F and CPU_AVX512BW. Are you saying that the latter implies the former in every implementation of the spec?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AVX512BW is built on top of AVX512F spec. In assembler and other places we only check BW in assertions which implies EVEX.

Copy link
Contributor

@vpaprotsk vpaprotsk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still need to have a look at the sha3 changes, but I think I am done with the most complex part of the review. This was a really interesting bit of code to review!

}
}

static void loadPerm(int destinationRegs[], Register perms,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replXmm? i.e. this function is replicating (any) Xmm register, not just perm?..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I am only using it for permutation describers, I thought this way it is easier to follow what is happening.

Comment on lines +106 to +119
const int montMulPermsIdx = 0;
const int nttL4PermsIdx = 64;
const int nttL5PermsIdx = 192;
const int nttL6PermsIdx = 320;
const int nttL7PermsIdx = 448;
const int nttInvL0PermsIdx = 704;
const int nttInvL1PermsIdx = 832;
const int nttInvL2PermsIdx = 960;
const int nttInvL3PermsIdx = 1088;
const int nttInvL4PermsIdx = 1216;

static address dilithiumAvx512PermsAddr() {
return (address) dilithiumAvx512Perms;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

belay that comment.. now that I looked at generate_dilithiumAlmostInverseNtt_avx512, I see why thats not the 'entire picture'..

// qinvmodR and q are repeated in all slots of Zmm30 and Zmm31, resp.
// Zmm8-Zmm23 used as scratch registers
// result goes to Zmm0-Zmm7
static void montMulByConst128(MacroAssembler *_masm) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this function some more.. I think you could remove this function and replace it with two calls to montMul64?

  montMul64(xmm0_3, xmm0_3, xmm29_29, Scratch*, _masm);
  montMul64(xmm4_7, xmm4_7, xmm29_29, Scratch*, _masm);

Scratch would have to be defined..

// zetas (int[256]) = c_rarg1
//
//
static address generate_dilithiumAlmostNtt_avx512(StubGenerator *stubgen,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comments as to generate_dilithiumAlmostInverseNtt_avx512

  • similar comment about the 'pair-wise' operation, updating [j] and [j+l] at a time..
  • somehow had less trouble following the flow through registers here, perhaps I am getting used to it. FYI, ended renaming some as:
// xmm16_27 = Temp1
// xmm0_3 = Coeffs1
// xmm4_7 = Coeffs2
// xmm8_11 = Coeffs3
// xmm12_15 = Coeffs4 = Temp2
// xmm16_27 = Scratch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me, it was easier to follow what goes where using the xmm... names (with the symbolic names you always have to remember which one overlaps with another and how much).

//
// coeffs (int[256]) = c_rarg0
// zetas (int[256]) = c_rarg1
static address generate_dilithiumAlmostInverseNtt_avx512(StubGenerator *stubgen,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done with this function; Perhaps the 'permute table' is a common vector-algorithm pattern, but this is really clever!

Some general comments first, rest inline.

  • The array names for registers helped a lot. And so did the new helper functions!
  • The java version of this code is quite intimidating to vectorize.. 3D loop, with geometric iteration variables.. and the literature is even more intimidating (discrete convolutions which I havent touched in two decades, ffts, ntts, etc.) Here is my attempt at a comment to 'un-scare' the next reader, though feel free to reword however you like.
The core of the (Java) loop is this 'pair-wise' operation:
        int a = coeffs[j];
        int b = coeffs[j + offset];
        coeffs[j] = (a + b);
        coeffs[j + offset] = montMul(a - b, -MONT_ZETAS_FOR_NTT[m]);

There are 8 'levels' (0-7); ('levels' are equivalent to (unrolling) the outer (Java) loop)
At each level, the 'pair-wise-offset' doubles (2^l: 1, 2, 4, 8, 16, 32, 64, 128).

To vectorize this Java code, observe that at each level, REGARDLESS the offset, half the operations are the SUM, and the other half is the
montgomery MULTIPLICATION (of the pair-difference with a constant). At each level, one 'just' has to shuffle
the coefficients, so that SUMs and MULTIPLICATIONs line up accordingly.

Otherwise, this pattern is 'lightly similar' to a discrete convolution (compute integral/summation of two functions at every offset)
  • I still would prefer (more) symbolic register names.. I wouldn't hold my approval over it so won't object if nobody else does, but register numbers are harder to 'see' through the flow. I ended up search/replacing/'annotating' to make it easier on myself to follow the flow of data:
// xmm8_11  = Perms1
// xmm12_15 = Perms2
// xmm16_27 = Scratch
// xmm0_3 = CoeffsPlus
// xmm4_7 = CoeffsMul
// xmm24_27 = CoeffsMinus (overlaps with Scratch)

(I made a similar comment, but I think it is now hidden after the last refactor)

  • would prefer to see the helper functions to get ALL the registers passed explicitly (i.e. currently montMulPerm, montQInvModR, dilithium_q, xmm29, are implicit.). As a general rule, I've tried to set up all the registers up at the 'entry' function (generate_dilithium* in this case) and from there on, use symbolic names. Not always reasonable, but what I've grown used to see?
    Done with this function; Perhaps the 'permute table' is a common vector-algorithm pattern, but this is really clever!

Some general comments first, rest inline.

  • The array names for registers helped a lot. And so did the new helper functions!
  • The java version of this code is quite intimidating to vectorize.. 3D loop, with geometric iteration variables.. and the literature is even more intimidating (discrete convolutions which I havent touched in two decades, ffts, ntts, etc.) Here is my attempt at a comment to 'un-scare' the next reader, though feel free to reword however you like.
The core of the (Java) loop is this 'pair-wise' operation:
        int a = coeffs[j];
        int b = coeffs[j + offset];
        coeffs[j] = (a + b);
        coeffs[j + offset] = montMul(a - b, -MONT_ZETAS_FOR_NTT[m]);

There are 8 'levels' (0-7); ('levels' are equivalent to (unrolling) the outer (Java) loop)
At each level, the 'pair-wise-offset' doubles (2^l: 1, 2, 4, 8, 16, 32, 64, 128).

To vectorize this Java code, observe that at each level, REGARDLESS the offset, half the operations are the SUM, and the other half is the
montgomery MULTIPLICATION (of the pair-difference with a constant). At each level, one 'just' has to shuffle
the coefficients, so that SUMs and MULTIPLICATIONs line up accordingly.

Otherwise, this pattern is 'lightly similar' to a discrete convolution (compute integral/summation of two functions at every offset)
  • I still would prefer (more) symbolic register names.. I wouldn't hold my approval over it so won't object if nobody else does, but register numbers are harder to 'see' through the flow. I ended up search/replacing/'annotating' to make it easier on myself to follow the flow of data:
// xmm8_11  = Perms1
// xmm12_15 = Perms2
// xmm16_27 = Scratch
// xmm0_3 = CoeffsPlus
// xmm4_7 = CoeffsMul
// xmm24_27 = CoeffsMinus (overlaps with Scratch)

(I made a similar comment, but I think it is now hidden after the last refactor)

  • would prefer to see the helper functions to get ALL the registers passed explicitly (i.e. currently montMulPerm, montQInvModR, dilithium_q, xmm29, are implicit.). As a general rule, I've tried to set up all the registers up at the 'entry' function (generate_dilithium* in this case) and from there on, use symbolic names. Not always reasonable, but what I've grown used to see?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some more comments, but I kept the xmm... names for the registers, just like with the ntt function.

Comment on lines +650 to +656
for (int i = 0; i < 8; i++) {
__ evpaddd(xmm(i + 16), k0, xmm(i), xmm(i + 8), false, Assembler::AVX_512bit);
}

for (int i = 0; i < 8; i++) {
__ evpsubd(xmm(i), k0, xmm(i + 8), xmm(i), false, Assembler::AVX_512bit);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fairly clean as is, but could also be two sub_add calls, I think (you have to swap order of add/sub in the helper, to be able to clobber xmm(i).. or swap register usage downstream, so perhaps not.. but would be cleaner)

  sub_add(CoeffsPlus, Scratch, Perms1, CoeffsPlus, _masm);
  sub_add(CoeffsMul,  &Scratch[4], Perms2, CoeffsMul, _masm);

If nothing else, would had prefered to see the use of the register array variables

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather leave this alone, too. I was considering the same, but decided that this is fairly easy to follow, it would be more complicated to either add a new helper function or follow where there are overlaps in the symbolically named register sets.


store4Xmms(coeffs, 0, xmm16_19, _masm);
store4Xmms(coeffs, 4 * XMMBYTES, xmm20_23, _masm);
montMulByConst128(_masm);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would prefer explicit parameters here. But I think this could also be two montMul64 calls?

  montMul64(xmm0_3, xmm0_3, xmm29_29, Scratch, _masm);
  montMul64(xmm4_7, xmm4_7, xmm29_29, Scratch, _masm);

(I think there is one other use of montMulByConst128 where same applies; then you could delete both montMulByConst128 and montmulEven

Comment on lines +416 to +421
for (int i = 0; i < 8; i += 2) {
__ evpermi2d(xmm(i/2 + 16), xmm(i), xmm(i + 1), Assembler::AVX_512bit);
}
for (int i = 0; i < 8; i += 2) {
__ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), Assembler::AVX_512bit);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wish there was a more 'abstract' way to arrange this, so its obvious from the shape of the code what registers are input/outputs (i.e. and use the register arrays). Even though its just 'elementary index operations' i/2 + 16 is still 'clever'. Couldnt think of anything myself though (same elsewhere in this function for the table permutes).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this is how it is when we have three inputs, one of which also plays as output... At least the output is always the first one (so that one gets clobbered). This is why you have to replicate the permutation describer when you need both permutands later.

Comment on lines +868 to +871
__ evpaddd(xmm4, k0, xmm0, barrettAddend, false, Assembler::AVX_512bit);
__ evpaddd(xmm5, k0, xmm1, barrettAddend, false, Assembler::AVX_512bit);
__ evpaddd(xmm6, k0, xmm2, barrettAddend, false, Assembler::AVX_512bit);
__ evpaddd(xmm7, k0, xmm3, barrettAddend, false, Assembler::AVX_512bit);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fairly 'straightforward' transcription of the java code.. no comments from me.

At first glance using xmm0_3, xmm4_7, etc. might had been a good idea, but you only save one line per 4x group. (Unless you have one big loop, but I suspect that give you worse performance? Is that something you tried already? Might be worth it otherwise..)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have considered this but decided to leave it alone (for the reason that you mentioned).

public static void mlDsaDecomposePoly(int[] input, int[] lowPart, int[] highPart,
int twoGamma2, int multiplier) {
assert (input.length == ML_DSA_N) && (lowPart.length == ML_DSA_N)
&& (highPart.length == ML_DSA_N);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this test to test java-to-intrinsic correspondence. Might be good to include it (and add the other 4 intrinsics). This is very similar to all my other Fuzz tests I've been adding for my own intrinsics (and you made this test FAR easier to write by breaking out the java implementation; need to 'copy' that pattern myself)

import java.util.Arrays;
import java.util.Random;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Constructor;

public class ML_DSA_Intrinsic_Test {

    public static void main(String[] args) throws Exception {
        MethodHandles.Lookup lookup = MethodHandles.lookup();
        Class<?> kClazz = Class.forName("sun.security.provider.ML_DSA");
        Constructor<?> constructor = kClazz.getDeclaredConstructor(
                int.class);
        constructor.setAccessible(true);
        
        Method m = kClazz.getDeclaredMethod("mlDsaNttMultiply",
                int[].class, int[].class, int[].class);
        m.setAccessible(true);
        MethodHandle mult = lookup.unreflect(m);

        m = kClazz.getDeclaredMethod("implDilithiumNttMultJava",
                int[].class, int[].class, int[].class);
        m.setAccessible(true);
        MethodHandle multJava = lookup.unreflect(m);

        Random rnd = new Random();
        long seed = rnd.nextLong();
        rnd.setSeed(seed);
        //Note: it might be useful to increase this number during development of new intrinsics
        final int repeat = 1000000;
        int[] coeffs1 = new int[ML_DSA_N];
        int[] coeffs2 = new int[ML_DSA_N];
        int[] prod1 = new int[ML_DSA_N];
        int[] prod2 = new int[ML_DSA_N];
        try {
            for (int i = 0; i < repeat; i++) {
                run(prod1, prod2, coeffs1, coeffs2, mult, multJava, rnd, seed, i);
            }
            System.out.println("Fuzz Success");
        } catch (Throwable e) {
            System.out.println("Fuzz Failed: " + e);
        }
    }

    private static final int ML_DSA_N = 256;
    public static void run(int[] prod1, int[] prod2, int[] coeffs1, int[] coeffs2, 
        MethodHandle mult, MethodHandle multJava, Random rnd, 
        long seed, int i) throws Exception, Throwable {
        for (int j = 0; j<ML_DSA_N; j++) {
            coeffs1[j] = rnd.nextInt();
            coeffs2[j] = rnd.nextInt();
        }

        mult.invoke(prod1, coeffs1, coeffs2);
        multJava.invoke(prod2, coeffs1, coeffs2);

        if (!Arrays.equals(prod1, prod2)) {
                throw new RuntimeException("[Seed "+seed+"@"+i+"] Result mismatch: " + Arrays.toString(prod1) + " != " + Arrays.toString(prod2));
        }
    }
}
// java --add-opens java.base/sun.security.provider=ALL-UNNAMED  -XX:+UseDilithiumIntrinsics test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will consider it for a follow-up PR.

@ferakocz
Copy link
Contributor Author

@vpaprotsk , thanks a lot for the very thorough review!

Copy link
Contributor

@vpaprotsk vpaprotsk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No further comments from me. (I did leave two questions, but nothing that requires code changes)

Thanks for addressing all my many (lengthy) comments and questions. And the refactor!

//
// Performs two keccak() computations in parallel. The steps of the
// two computations are executed interleaved.
static address generate_double_keccak(StubGenerator *stubgen, MacroAssembler *_masm) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems ok. I didnt do as line-by-line 'exact' review as for the NTT intrinsics, but just put the new version into a diff next to the original function. Seems like a reasonable clean 'refactor' (hardcode the blocksize, add new input registers 10-14. Makes it really easy to spot vs 0-4 original registers..)

I didnt realize before that the 'top 3 limbs' are wasted. I guess it doesnt matter, there are registers to spare aplenty and it makes the entire algorithm cleaner and easier to follow.

I did also stare at the algorithm with the 'What about AVX2' question.. This function would pretty much need to be rewritten it looks like :/

Last two questions..

  • how much performance is gained from doubling this function up?
  • If thats worth it.. what if instead it was quadrupled the input? (I scanned the java code, it looked like NR was parametrized already to 2..). It looks like there are almost enough registers here to go to 4 (I think 3 would need to be freed up somehow.. alternatively, the upper 3 limbs are empty in all operations, perhaps it could be used instead.. at the expense of readability)

Copy link
Contributor Author

@ferakocz ferakocz Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the algorithm (keccak()) is doing the same things on 5 array elements (It works on essentially a 5x5 matrix doing row and column operations, so putting 5 array entries in a vector register was the "natural" thing to do).

This function can only be used under very special circumstances, which occur during the generation of tha "A matrix" in ML-KEM and ML-DSA, the speed of that matrix generation has almost doubled (I don't have exact numbers).

We are using 7 registers per state and 15 for the constants, so we have only 3 to spare. We could perhaps juggle with the constants keeping just the ones that will be needed next in registers and reloading them "just in time", but that might slow things down a bit - more load instructions executed + maybe some load delay. On the other hand, more parallelism. I might try it out.

Comment on lines 350 to 359
__ mov64(rax,1);
__ kmovbl(k1, rax);
__ addl(rax,2);
__ kmovbl(k2, rax);
__ addl(rax, 4);
__ kmovbl(k3, rax);
__ addl(rax, 8);
__ kmovbl(k4, rax);
__ addl(rax, 16);
__ kmovbl(k5, rax);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use the sequence from generate_sha3_implCompress to setup the K registers, that has less dependency:

__ movl(rax, 0x1F);
__ kmovbl(k5, rax);
__ kshiftrbl(k4, k5, 1);
__ kshiftrbl(k3, k5, 2);
__ kshiftrbl(k2, k5, 3);
__ kshiftrbl(k1, k5, 4);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! (I had copied/doubled this function from the single state version before you made me do this change on that one and I forgot to update the copy :-) ) Changed.

Copy link
Member

@jatin-bhateja jatin-bhateja left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ferakocz , I verified the new version of the patch on Linux and Windows, and it works fine.

Thanks for addressing my comments. Your passion is contagious :-)

@openjdk openjdk bot added the ready Pull request is ready to be integrated label Apr 2, 2025
Comment on lines +338 to +339
// levels 2 to 7 are done in 2 batches, by first saving half of the coefficients
// from level 1 into memory, doing all the level 2 to level 7 computations
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In line number 344 - 347, we seem to be storing all the coefficients from level 1 into memory.

Comment on lines 344 to 345
store4Xmms(coeffs, 0, xmm0_3, _masm);
store4Xmms(coeffs, 4 * XMMBYTES, xmm4_7, _masm);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be unnecessary store.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching that. Changed.

// level 4
loadPerm(xmm16_19, perms, nttL4PermsIdx, _masm);
loadPerm(xmm12_15, perms, nttL4PermsIdx + 64, _masm);
load4Xmms(xmm24_27, zetas, 4 * 512, _masm); // for level 3
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment // for level3 is not relevant here and could be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooops. Deleted the comment.

Comment on lines 1011 to 1015
__ subl(len, 4 * XMMBYTES);
__ addptr(highPart, 4 * XMMBYTES);
__ addptr(lowPart, 4 * XMMBYTES);
__ cmpl(len, 0);
__ jcc(Assembler::notEqual, L_loop);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me that subl and cmpl could be merged:
__ addptr(highPart, 4 * XMMBYTES);
__ addptr(lowPart, 4 * XMMBYTES);
__ subl(len, 4 * XMMBYTES);
__ jcc(Assembler::notEqual, L_loop);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

__ evporq(xmm18, k0, xmm18, xmm22, false, Assembler::AVX_512bit);
__ evporq(xmm19, k0, xmm19, xmm23, false, Assembler::AVX_512bit);

__ evpsubd(xmm12, k0, zero, one, false, Assembler::AVX_512bit); // -1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The -1 initialization could be done outside the loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. All registers are used.

Comment on lines 799 to 802
__ xorl(scratch, scratch);
__ evpbroadcastd(zero, scratch, Assembler::AVX_512bit); // 0
__ addl(scratch, 1);
__ evpbroadcastd(one, scratch, Assembler::AVX_512bit); // 1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better way to initialize (0, 1, -1) vectors is:
// load 0 into int vector
vpxor(zero, zero, zero, Assembler::AVX_512bit);
// load -1 into int vector
vpternlogd(minusOne, 0xff, minusOne, minusOne, Assembler::AVX_512bit);
// load 1 into int vector
vpsubd(one, zero, minusOne, Assembler::AVX_512bit);

Where minusOne could be xmm31.

A broadcast from r register to xmm register is more expensive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

@openjdk openjdk bot removed the ready Pull request is ready to be integrated label Apr 8, 2025
Copy link

@sviswa7 sviswa7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall very clean and nicely done PR. Thanks a lot for considering my inputs.

@openjdk openjdk bot added the ready Pull request is ready to be integrated label Apr 8, 2025
@ferakocz
Copy link
Contributor Author

ferakocz commented Apr 9, 2025

Overall very clean and nicely done PR. Thanks a lot for considering my inputs.

That is in no small part thanks to the reviewers, especially to Volodymyr!
@lmesnik, @jatin-bhateja, @sviswa7 would one of you /sponsor me with the integration?

@sviswa7
Copy link

sviswa7 commented Apr 9, 2025

@ferakocz Once you do /integrate, I will be honored to sponsor your PR.

@ferakocz
Copy link
Contributor Author

ferakocz commented Apr 9, 2025

/integrate

@ferakocz
Copy link
Contributor Author

ferakocz commented Apr 9, 2025

@ferakocz Once you do /integrate, I will be honored to sponsor your PR.

Thanks!

@openjdk openjdk bot added the sponsor Pull request is ready to be sponsored label Apr 9, 2025
@openjdk
Copy link

openjdk bot commented Apr 9, 2025

@ferakocz
Your change (at version 0b0d096) is now ready to be sponsored by a Committer.

@sviswa7
Copy link

sviswa7 commented Apr 9, 2025

/sponsor

@openjdk
Copy link

openjdk bot commented Apr 9, 2025

Going to push as commit e87ff32.
Since your change was applied there have been 539 commits pushed to the master branch:

Your commit was automatically rebased without conflicts.

@openjdk openjdk bot added the integrated Pull request has been integrated label Apr 9, 2025
@openjdk openjdk bot closed this Apr 9, 2025
@openjdk openjdk bot removed ready Pull request is ready to be integrated rfr Pull request is ready for review sponsor Pull request is ready to be sponsored labels Apr 9, 2025
@openjdk
Copy link

openjdk bot commented Apr 9, 2025

@sviswa7 @ferakocz Pushed as commit e87ff32.

💡 You may see a message that your pull request was closed with unmerged commits. This can be safely ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

graal graal-dev@openjdk.org hotspot hotspot-dev@openjdk.org integrated Pull request has been integrated security security-dev@openjdk.org

Development

Successfully merging this pull request may close these issues.

6 participants