Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 27 additions & 59 deletions src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,51 +120,15 @@ static address dilithiumAvx512PermsAddr() {

// We do Montgomery multiplications of two vectors of 16 ints each in 4 steps:
// 1. Do the multiplications of the corresponding even numbered slots into
// the odd numbered slots of a third register using montmulEven().
// the odd numbered slots of a third register.
// 2. Swap the even and odd numbered slots of the original input registers.
// 3. Similar to step 1, but into a different output register.
// 4. Combine the outputs of step 1 and step 3 into the output of the Montgomery
// multiplication.
// (For levels 0-6 in the Ntt and levels 1-7 of the inverse Ntt we only swap the
// odd-even slots of the first multiplicand as in the second (zetas) the
// odd slots contain the same number as the corresponding even one.)

// Montgomery multiplication of the *even* numbered slices of sequences
// of parCnt consecutive registers. Zmm_inputReg1 and Zmm_inputReg2 are the
// starts of the two sequences. When inputReg2 == 29, we use that register as
// the second multiplicand in each multiplication.
// The result goes to the *odd* numbered slices of Zmm_outputReg.
// The parCnt long cosecutive sequences of registers that start with
// Zmm_scratch1 and Zmm_scrath2 are used as scratch registers, so their
// contents will be clobbered.
// The output register sequence can overlap any of the input and scratch
// register sequences, however the two scratch register sequences should be
// non-overlapping.
// Zmm_31 should contain q and
// Zmm_30 should contain q^-1 mod 2^32 in all of their slices.
static void montmulEven(int outputReg, int inputReg1, int inputReg2,
int scratchReg1, int scratchReg2,
int parCnt, MacroAssembler *_masm) {
for (int i = 0; i < parCnt; i++) {
__ vpmuldq(xmm(i + scratchReg1), xmm(i + inputReg1),
xmm((inputReg2 == 29) ? 29 : inputReg2 + i), Assembler::AVX_512bit);
}
for (int i = 0; i < parCnt; i++) {
__ vpmulld(xmm(i + scratchReg2), xmm(i + scratchReg1), montQInvModR,
Assembler::AVX_512bit);
}
for (int i = 0; i < parCnt; i++) {
__ vpmuldq(xmm(i + scratchReg2), xmm(i + scratchReg2), dilithium_q,
Assembler::AVX_512bit);
}
for (int i = 0; i < parCnt; i++) {
__ evpsubd(xmm(i + outputReg), k0, xmm(i + scratchReg1),
xmm(i + scratchReg2), false, Assembler::AVX_512bit);
}
}

// Full Montgomery multiplication of the corresponding slices of two register
// sets of 4 registers each. The indexes of the registers to be multiplied
// The indexes of the registers to be multiplied
// are in inputRegs1[] and inputRegs[2].
// The results go to the registers whose indexes are in outputRegs.
// scratchRegs should contain 12 different register indexes.
Expand Down Expand Up @@ -235,24 +199,6 @@ static void montMul64(int outputRegs[], int inputRegs1[], int inputRegs2[],
montMul64(outputRegs, inputRegs1, inputRegs2, scratchRegs, false, _masm);
}

// input in Zmm0-Zmm7, the constant is repeated in all slots of Zmm29
// qinvmodR and q are repeated in all slots of Zmm30 and Zmm31, resp.
// Zmm8-Zmm23 used as scratch registers
// result goes to Zmm0-Zmm7
static void montMulByConst128(MacroAssembler *_masm) {
montmulEven(8, 0, 29, 8, 16, 8, _masm);

for (int i = 0; i < 8; i++) {
__ vpshufd(xmm(i),xmm(i), 0xB1, Assembler::AVX_512bit);
}

montmulEven(0, 0, 29, 0, 16, 8, _masm);

for (int i = 0; i < 8; i++) {
__ evpermt2d(xmm(i), montMulPerm, xmm(i + 8), Assembler::AVX_512bit);
}
}

static void sub_add(int subResult[], int addResult[],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big fan of all these helper functions! Makes reading the top level functions way easier, thanks for refactoring!

int input1[], int input2[], MacroAssembler *_masm) {

Expand Down Expand Up @@ -344,6 +290,16 @@ static address generate_dilithiumAlmostNtt_avx512(StubGenerator *stubgen,
__ lea(perms, ExternalAddress(dilithiumAvx512PermsAddr()));

__ evmovdqul(montMulPerm, Address(perms, montMulPermsIdx), Assembler::AVX_512bit);

// Each level represents one iteration of the outer for loop of the Java version
// In each of these iterations half of the coefficients are (Montgomery)
// multiplied by a zeta corresponding to the coefficient and then these
// products will be added to and subtracted from the other half of the
// coefficients. In each level we just collect the coefficients (using
// evpermi2d() instructions where necessary, i.e. in levels 4-7) that need to
// be multiplied by the zetas in one set, the rest to another set of vector
// registers, then redistribute the addition/substraction results.

// For levels 0 and 1 the zetas are not different within the 4 xmm registers
// that we would use for them, so we use only one, xmm29.
loadXmm29(zetas, 0, _masm);
Expand Down Expand Up @@ -534,6 +490,16 @@ static address generate_dilithiumAlmostInverseNtt_avx512(StubGenerator *stubgen,
ExternalAddress(dilithiumAvx512ConstsAddr(dilithium_qIdx)),
Assembler::AVX_512bit, scratch); // q

// Each level represents one iteration of the outer for loop of the
// Java version.
// In each of these iterations half of the coefficients are added to and
// subtracted from the other half of the coefficients then the result of
// the substartion is (Montgomery) multiplied by the corresponding zetas.
// In each level we just collect the coefficients (using evpermi2d()
// instructions where necessary, i.e. on levels 0-4) so that the results of
// the additions and subtractions go to the vector registers so that they
// align with each other and the zetas.

// We do levels 0-6 in two batches, each batch entirely in the vector registers
load4Xmms(xmm0_3, coeffs, 0, _masm);
load4Xmms(xmm4_7, coeffs, 4 * XMMBYTES, _masm);
Expand Down Expand Up @@ -657,7 +623,8 @@ static address generate_dilithiumAlmostInverseNtt_avx512(StubGenerator *stubgen,

store4Xmms(coeffs, 0, xmm16_19, _masm);
store4Xmms(coeffs, 4 * XMMBYTES, xmm20_23, _masm);
montMulByConst128(_masm);
montMul64(xmm0_3, xmm0_3, xmm29_29, xmm16_27, _masm);
montMul64(xmm4_7, xmm4_7, xmm29_29, xmm16_27, _masm);
store4Xmms(coeffs, 8 * XMMBYTES, xmm0_3, _masm);
store4Xmms(coeffs, 12 * XMMBYTES, xmm4_7, _masm);

Expand Down Expand Up @@ -761,7 +728,7 @@ static address generate_dilithiumMontMulByConstant_avx512(StubGenerator *stubgen

__ lea(perms, ExternalAddress(dilithiumAvx512PermsAddr()));

// the following four vector registers are used in montMulByConst128
// the following four vector registers are used in montMul64
__ vpbroadcastd(montQInvModR,
ExternalAddress(dilithiumAvx512ConstsAddr(montQInvModRIdx)),
Assembler::AVX_512bit, scratch); // q^-1 mod 2^32
Expand All @@ -778,7 +745,8 @@ static address generate_dilithiumMontMulByConstant_avx512(StubGenerator *stubgen

load4Xmms(xmm0_3, coeffs, 0, _masm);
load4Xmms(xmm4_7, coeffs, 4 * XMMBYTES, _masm);
montMulByConst128(_masm);
montMul64(xmm0_3, xmm0_3, xmm29_29, xmm16_27, _masm);
montMul64(xmm4_7, xmm4_7, xmm29_29, xmm16_27, _masm);
store4Xmms(coeffs, 0, xmm0_3, _masm);
store4Xmms(coeffs, 4 * XMMBYTES, xmm4_7, _masm);

Expand Down
89 changes: 27 additions & 62 deletions src/hotspot/cpu/x86/stubGenerator_x86_64_sha3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@

#define BIND(label) bind(label); BLOCK_COMMENT(#label ":")

#define xmm(i) as_XMMRegister(i)

// Constants
ATTRIBUTE_ALIGNED(64) static const uint64_t round_consts_arr[24] = {
0x0000000000000001L, 0x0000000000008082L, 0x800000000000808AL,
Expand Down Expand Up @@ -149,28 +151,14 @@ static address generate_sha3_implCompress(StubGenStubId stub_id,
__ kshiftrwl(k1, k5, 4);

// load the state
__ evmovdquq(xmm0, k5, Address(state, 0), false, Assembler::AVX_512bit);
__ evmovdquq(xmm1, k5, Address(state, 40), false, Assembler::AVX_512bit);
__ evmovdquq(xmm2, k5, Address(state, 80), false, Assembler::AVX_512bit);
__ evmovdquq(xmm3, k5, Address(state, 120), false, Assembler::AVX_512bit);
__ evmovdquq(xmm4, k5, Address(state, 160), false, Assembler::AVX_512bit);
for (int i = 0; i < 5; i++) {
__ evmovdquq(xmm(i), k5, Address(state, i * 40), false, Assembler::AVX_512bit);
}

// load the permutation and rotation constants
__ evmovdquq(xmm17, Address(permsAndRots, 0), Assembler::AVX_512bit);
__ evmovdquq(xmm18, Address(permsAndRots, 64), Assembler::AVX_512bit);
__ evmovdquq(xmm19, Address(permsAndRots, 128), Assembler::AVX_512bit);
__ evmovdquq(xmm20, Address(permsAndRots, 192), Assembler::AVX_512bit);
__ evmovdquq(xmm21, Address(permsAndRots, 256), Assembler::AVX_512bit);
__ evmovdquq(xmm22, Address(permsAndRots, 320), Assembler::AVX_512bit);
__ evmovdquq(xmm23, Address(permsAndRots, 384), Assembler::AVX_512bit);
__ evmovdquq(xmm24, Address(permsAndRots, 448), Assembler::AVX_512bit);
__ evmovdquq(xmm25, Address(permsAndRots, 512), Assembler::AVX_512bit);
__ evmovdquq(xmm26, Address(permsAndRots, 576), Assembler::AVX_512bit);
__ evmovdquq(xmm27, Address(permsAndRots, 640), Assembler::AVX_512bit);
__ evmovdquq(xmm28, Address(permsAndRots, 704), Assembler::AVX_512bit);
__ evmovdquq(xmm29, Address(permsAndRots, 768), Assembler::AVX_512bit);
__ evmovdquq(xmm30, Address(permsAndRots, 832), Assembler::AVX_512bit);
__ evmovdquq(xmm31, Address(permsAndRots, 896), Assembler::AVX_512bit);
for (int i = 0; i < 15; i++) {
__ evmovdquq(xmm(i + 17), Address(permsAndRots, i * 64), Assembler::AVX_512bit);
}

__ align(OptoLoopAlignment);
__ BIND(sha3_loop);
Expand Down Expand Up @@ -317,11 +305,9 @@ static address generate_sha3_implCompress(StubGenStubId stub_id,
}

// store the state
__ evmovdquq(Address(state, 0), k5, xmm0, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state, 40), k5, xmm1, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state, 80), k5, xmm2, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state, 120), k5, xmm3, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state, 160), k5, xmm4, true, Assembler::AVX_512bit);
for (int i = 0; i < 5; i++) {
__ evmovdquq(Address(state, i * 40), k5, xmm(i), true, Assembler::AVX_512bit);
}

__ pop(r14);
__ pop(r13);
Expand Down Expand Up @@ -373,34 +359,18 @@ static address generate_double_keccak(StubGenerator *stubgen, MacroAssembler *_m
__ kmovbl(k5, rax);

// load the states
__ evmovdquq(xmm0, k5, Address(state0, 0), false, Assembler::AVX_512bit);
__ evmovdquq(xmm1, k5, Address(state0, 40), false, Assembler::AVX_512bit);
__ evmovdquq(xmm2, k5, Address(state0, 80), false, Assembler::AVX_512bit);
__ evmovdquq(xmm3, k5, Address(state0, 120), false, Assembler::AVX_512bit);
__ evmovdquq(xmm4, k5, Address(state0, 160), false, Assembler::AVX_512bit);

__ evmovdquq(xmm10, k5, Address(state1, 0), false, Assembler::AVX_512bit);
__ evmovdquq(xmm11, k5, Address(state1, 40), false, Assembler::AVX_512bit);
__ evmovdquq(xmm12, k5, Address(state1, 80), false, Assembler::AVX_512bit);
__ evmovdquq(xmm13, k5, Address(state1, 120), false, Assembler::AVX_512bit);
__ evmovdquq(xmm14, k5, Address(state1, 160), false, Assembler::AVX_512bit);
for (int i = 0; i < 5; i++) {
__ evmovdquq(xmm(i), k5, Address(state0, i * 40), false, Assembler::AVX_512bit);
}
for (int i = 0; i < 5; i++) {
__ evmovdquq(xmm(10 + i), k5, Address(state1, i * 40), false, Assembler::AVX_512bit);
}

// load the permutation and rotation constants
__ evmovdquq(xmm17, Address(permsAndRots, 0), Assembler::AVX_512bit);
__ evmovdquq(xmm18, Address(permsAndRots, 64), Assembler::AVX_512bit);
__ evmovdquq(xmm19, Address(permsAndRots, 128), Assembler::AVX_512bit);
__ evmovdquq(xmm20, Address(permsAndRots, 192), Assembler::AVX_512bit);
__ evmovdquq(xmm21, Address(permsAndRots, 256), Assembler::AVX_512bit);
__ evmovdquq(xmm22, Address(permsAndRots, 320), Assembler::AVX_512bit);
__ evmovdquq(xmm23, Address(permsAndRots, 384), Assembler::AVX_512bit);
__ evmovdquq(xmm24, Address(permsAndRots, 448), Assembler::AVX_512bit);
__ evmovdquq(xmm25, Address(permsAndRots, 512), Assembler::AVX_512bit);
__ evmovdquq(xmm26, Address(permsAndRots, 576), Assembler::AVX_512bit);
__ evmovdquq(xmm27, Address(permsAndRots, 640), Assembler::AVX_512bit);
__ evmovdquq(xmm28, Address(permsAndRots, 704), Assembler::AVX_512bit);
__ evmovdquq(xmm29, Address(permsAndRots, 768), Assembler::AVX_512bit);
__ evmovdquq(xmm30, Address(permsAndRots, 832), Assembler::AVX_512bit);
__ evmovdquq(xmm31, Address(permsAndRots, 896), Assembler::AVX_512bit);

for (int i = 0; i < 15; i++) {
__ evmovdquq(xmm(17 + i), Address(permsAndRots, i * 64), Assembler::AVX_512bit);
}

// there will be 24 keccak rounds
// The same operations as the ones in generate_sha3_implCompress are
Expand Down Expand Up @@ -519,17 +489,12 @@ static address generate_double_keccak(StubGenerator *stubgen, MacroAssembler *_m
__ jcc(Assembler::notEqual, rounds24_loop);

// store the states
__ evmovdquq(Address(state0, 0), k5, xmm0, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state0, 40), k5, xmm1, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state0, 80), k5, xmm2, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state0, 120), k5, xmm3, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state0, 160), k5, xmm4, true, Assembler::AVX_512bit);

__ evmovdquq(Address(state1, 0), k5, xmm10, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state1, 40), k5, xmm11, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state1, 80), k5, xmm12, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state1, 120), k5, xmm13, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state1, 160), k5, xmm14, true, Assembler::AVX_512bit);
for (int i = 0; i < 5; i++) {
__ evmovdquq(Address(state0, i * 40), k5, xmm(i), true, Assembler::AVX_512bit);
}
for (int i = 0; i < 5; i++) {
__ evmovdquq(Address(state1, i * 40), k5, xmm(10 + i), true, Assembler::AVX_512bit);
}

__ leave(); // required for proper stackwalking of RuntimeStub frame
__ ret(0);
Expand Down