diff --git a/src/hotspot/cpu/aarch64/register_aarch64.cpp b/src/hotspot/cpu/aarch64/register_aarch64.cpp index 349845154e2fe..82683daae4f08 100644 --- a/src/hotspot/cpu/aarch64/register_aarch64.cpp +++ b/src/hotspot/cpu/aarch64/register_aarch64.cpp @@ -58,23 +58,3 @@ const char* PRegister::PRegisterImpl::name() const { }; return is_valid() ? names[encoding()] : "pnoreg"; } - -// convenience methods for splitting 8-way vector register sequences -// in half -- needed because vector operations can normally only be -// benefit from 4-way instruction parallelism - -VSeq<4> vs_front(const VSeq<8>& v) { - return VSeq<4>(v.base(), v.delta()); -} - -VSeq<4> vs_back(const VSeq<8>& v) { - return VSeq<4>(v.base() + 4 * v.delta(), v.delta()); -} - -VSeq<4> vs_even(const VSeq<8>& v) { - return VSeq<4>(v.base(), v.delta() * 2); -} - -VSeq<4> vs_odd(const VSeq<8>& v) { - return VSeq<4>(v.base() + 1, v.delta() * 2); -} diff --git a/src/hotspot/cpu/aarch64/register_aarch64.hpp b/src/hotspot/cpu/aarch64/register_aarch64.hpp index 45578336cfeaa..feb189db93853 100644 --- a/src/hotspot/cpu/aarch64/register_aarch64.hpp +++ b/src/hotspot/cpu/aarch64/register_aarch64.hpp @@ -436,19 +436,20 @@ enum RC { rc_bad, rc_int, rc_float, rc_predicate, rc_stack }; // inputs into front and back halves or odd and even halves (see // convenience methods below). +// helper macro for computing register masks +#define VS_MASK_BIT(base, delta, i) (1 << (base + delta * i)) + template class VSeq { static_assert(N >= 2, "vector sequence length must be greater than 1"); - static_assert(N <= 8, "vector sequence length must not exceed 8"); - static_assert((N & (N - 1)) == 0, "vector sequence length must be power of two"); private: int _base; // index of first register in sequence int _delta; // increment to derive successive indices public: VSeq(FloatRegister base_reg, int delta = 1) : VSeq(base_reg->encoding(), delta) { } VSeq(int base, int delta = 1) : _base(base), _delta(delta) { - assert (_base >= 0, "invalid base register"); - assert (_delta >= 0, "invalid register delta"); - assert ((_base + (N - 1) * _delta) < 32, "range exceeded"); + assert (_base >= 0 && _base <= 31, "invalid base register"); + assert ((_base + (N - 1) * _delta) >= 0, "register range underflow"); + assert ((_base + (N - 1) * _delta) < 32, "register range overflow"); } // indexed access to sequence FloatRegister operator [](int i) const { @@ -457,27 +458,89 @@ template class VSeq { } int mask() const { int m = 0; - int bit = 1 << _base; for (int i = 0; i < N; i++) { - m |= bit << (i * _delta); + m |= VS_MASK_BIT(_base, _delta, i); } return m; } int base() const { return _base; } int delta() const { return _delta; } + bool is_constant() const { return _delta == 0; } }; -// declare convenience methods for splitting vector register sequences - -VSeq<4> vs_front(const VSeq<8>& v); -VSeq<4> vs_back(const VSeq<8>& v); -VSeq<4> vs_even(const VSeq<8>& v); -VSeq<4> vs_odd(const VSeq<8>& v); - -// methods for use in asserts to check VSeq inputs and oupts are +// methods for use in asserts to check VSeq inputs and outputs are // either disjoint or equal template bool vs_disjoint(const VSeq& n, const VSeq& m) { return (n.mask() & m.mask()) == 0; } template bool vs_same(const VSeq& n, const VSeq& m) { return n.mask() == m.mask(); } +// method for use in asserts to check whether registers appearing in +// an output sequence will be written before they are read from an +// input sequence. + +template bool vs_write_before_read(const VSeq& vout, const VSeq& vin) { + int b_in = vin.base(); + int d_in = vin.delta(); + int b_out = vout.base(); + int d_out = vout.delta(); + int bit_in = 1 << b_in; + int bit_out = 1 << b_out; + int mask_read = vin.mask(); // all pending reads + int mask_write = 0; // no writes as yet + + + for (int i = 0; i < N - 1; i++) { + // check whether a pending read clashes with a write + if ((mask_write & mask_read) != 0) { + return true; + } + // remove the pending input (so long as this is a constant + // sequence) + if (d_in != 0) { + mask_read ^= VS_MASK_BIT(b_in, d_in, i); + } + // record the next write + mask_write |= VS_MASK_BIT(b_out, d_out, i); + } + // no write before read + return false; +} + +// convenience methods for splitting 8-way of 4-way vector register +// sequences in half -- needed because vector operations can normally +// benefit from 4-way instruction parallelism or, occasionally, 2-way +// parallelism + +template +VSeq vs_front(const VSeq& v) { + static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even"); + return VSeq(v.base(), v.delta()); +} + +template +VSeq vs_back(const VSeq& v) { + static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even"); + return VSeq(v.base() + N / 2 * v.delta(), v.delta()); +} + +template +VSeq vs_even(const VSeq& v) { + static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even"); + return VSeq(v.base(), v.delta() * 2); +} + +template +VSeq vs_odd(const VSeq& v) { + static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even"); + return VSeq(v.base() + v.delta(), v.delta() * 2); +} + +// convenience method to construct a vector register sequence that +// indexes its elements in reverse order to the original + +template +VSeq vs_reverse(const VSeq& v) { + return VSeq(v.base() + (N - 1) * v.delta(), -v.delta()); +} + #endif // CPU_AARCH64_REGISTER_AARCH64_HPP diff --git a/src/hotspot/cpu/aarch64/stubDeclarations_aarch64.hpp b/src/hotspot/cpu/aarch64/stubDeclarations_aarch64.hpp index a893aacaaf2dd..1107ec0a8f82a 100644 --- a/src/hotspot/cpu/aarch64/stubDeclarations_aarch64.hpp +++ b/src/hotspot/cpu/aarch64/stubDeclarations_aarch64.hpp @@ -44,7 +44,7 @@ do_arch_blob, \ do_arch_entry, \ do_arch_entry_init) \ - do_arch_blob(compiler, 55000 ZGC_ONLY(+5000)) \ + do_arch_blob(compiler, 75000 ZGC_ONLY(+5000)) \ do_stub(compiler, vector_iota_indices) \ do_arch_entry(aarch64, compiler, vector_iota_indices, \ vector_iota_indices, vector_iota_indices) \ diff --git a/src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp b/src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp index f0f145e3d7612..b2756f2b2d002 100644 --- a/src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp +++ b/src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp @@ -4651,6 +4651,11 @@ class StubGenerator: public StubCodeGenerator { template void vs_addv(const VSeq& v, Assembler::SIMD_Arrangement T, const VSeq& v1, const VSeq& v2) { + // output must not be constant + assert(N == 1 || !v.is_constant(), "cannot output multiple values to a constant vector"); + // output cannot overwrite pending inputs + assert(!vs_write_before_read(v, v1), "output overwrites input"); + assert(!vs_write_before_read(v, v2), "output overwrites input"); for (int i = 0; i < N; i++) { __ addv(v[i], T, v1[i], v2[i]); } @@ -4659,6 +4664,11 @@ class StubGenerator: public StubCodeGenerator { template void vs_subv(const VSeq& v, Assembler::SIMD_Arrangement T, const VSeq& v1, const VSeq& v2) { + // output must not be constant + assert(N == 1 || !v.is_constant(), "cannot output multiple values to a constant vector"); + // output cannot overwrite pending inputs + assert(!vs_write_before_read(v, v1), "output overwrites input"); + assert(!vs_write_before_read(v, v2), "output overwrites input"); for (int i = 0; i < N; i++) { __ subv(v[i], T, v1[i], v2[i]); } @@ -4667,6 +4677,11 @@ class StubGenerator: public StubCodeGenerator { template void vs_mulv(const VSeq& v, Assembler::SIMD_Arrangement T, const VSeq& v1, const VSeq& v2) { + // output must not be constant + assert(N == 1 || !v.is_constant(), "cannot output multiple values to a constant vector"); + // output cannot overwrite pending inputs + assert(!vs_write_before_read(v, v1), "output overwrites input"); + assert(!vs_write_before_read(v, v2), "output overwrites input"); for (int i = 0; i < N; i++) { __ mulv(v[i], T, v1[i], v2[i]); } @@ -4674,6 +4689,10 @@ class StubGenerator: public StubCodeGenerator { template void vs_negr(const VSeq& v, Assembler::SIMD_Arrangement T, const VSeq& v1) { + // output must not be constant + assert(N == 1 || !v.is_constant(), "cannot output multiple values to a constant vector"); + // output cannot overwrite pending inputs + assert(!vs_write_before_read(v, v1), "output overwrites input"); for (int i = 0; i < N; i++) { __ negr(v[i], T, v1[i]); } @@ -4682,6 +4701,10 @@ class StubGenerator: public StubCodeGenerator { template void vs_sshr(const VSeq& v, Assembler::SIMD_Arrangement T, const VSeq& v1, int shift) { + // output must not be constant + assert(N == 1 || !v.is_constant(), "cannot output multiple values to a constant vector"); + // output cannot overwrite pending inputs + assert(!vs_write_before_read(v, v1), "output overwrites input"); for (int i = 0; i < N; i++) { __ sshr(v[i], T, v1[i], shift); } @@ -4689,6 +4712,11 @@ class StubGenerator: public StubCodeGenerator { template void vs_andr(const VSeq& v, const VSeq& v1, const VSeq& v2) { + // output must not be constant + assert(N == 1 || !v.is_constant(), "cannot output multiple values to a constant vector"); + // output cannot overwrite pending inputs + assert(!vs_write_before_read(v, v1), "output overwrites input"); + assert(!vs_write_before_read(v, v2), "output overwrites input"); for (int i = 0; i < N; i++) { __ andr(v[i], __ T16B, v1[i], v2[i]); } @@ -4696,18 +4724,51 @@ class StubGenerator: public StubCodeGenerator { template void vs_orr(const VSeq& v, const VSeq& v1, const VSeq& v2) { + // output must not be constant + assert(N == 1 || !v.is_constant(), "cannot output multiple values to a constant vector"); + // output cannot overwrite pending inputs + assert(!vs_write_before_read(v, v1), "output overwrites input"); + assert(!vs_write_before_read(v, v2), "output overwrites input"); for (int i = 0; i < N; i++) { __ orr(v[i], __ T16B, v1[i], v2[i]); } } template - void vs_notr(const VSeq& v, const VSeq& v1) { + void vs_notr(const VSeq& v, const VSeq& v1) { + // output must not be constant + assert(N == 1 || !v.is_constant(), "cannot output multiple values to a constant vector"); + // output cannot overwrite pending inputs + assert(!vs_write_before_read(v, v1), "output overwrites input"); for (int i = 0; i < N; i++) { __ notr(v[i], __ T16B, v1[i]); } } + template + void vs_sqdmulh(const VSeq& v, Assembler::SIMD_Arrangement T, const VSeq& v1, const VSeq& v2) { + // output must not be constant + assert(N == 1 || !v.is_constant(), "cannot output multiple values to a constant vector"); + // output cannot overwrite pending inputs + assert(!vs_write_before_read(v, v1), "output overwrites input"); + assert(!vs_write_before_read(v, v2), "output overwrites input"); + for (int i = 0; i < N; i++) { + __ sqdmulh(v[i], T, v1[i], v2[i]); + } + } + + template + void vs_mlsv(const VSeq& v, Assembler::SIMD_Arrangement T, const VSeq& v1, VSeq& v2) { + // output must not be constant + assert(N == 1 || !v.is_constant(), "cannot output multiple values to a constant vector"); + // output cannot overwrite pending inputs + assert(!vs_write_before_read(v, v1), "output overwrites input"); + assert(!vs_write_before_read(v, v2), "output overwrites input"); + for (int i = 0; i < N; i++) { + __ mlsv(v[i], T, v1[i], v2[i]); + } + } + // load N/2 successive pairs of quadword values from memory in order // into N successive vector registers of the sequence via the // address supplied in base. @@ -4723,6 +4784,7 @@ class StubGenerator: public StubCodeGenerator { // in base using post-increment addressing template void vs_ldpq_post(const VSeq& v, Register base) { + static_assert((N & (N - 1)) == 0, "sequence length must be even"); for (int i = 0; i < N; i += 2) { __ ldpq(v[i], v[i+1], __ post(base, 32)); } @@ -4733,11 +4795,55 @@ class StubGenerator: public StubCodeGenerator { // supplied in base using post-increment addressing template void vs_stpq_post(const VSeq& v, Register base) { + static_assert((N & (N - 1)) == 0, "sequence length must be even"); for (int i = 0; i < N; i += 2) { __ stpq(v[i], v[i+1], __ post(base, 32)); } } + // load N/2 pairs of quadword values from memory de-interleaved into + // N vector registers 2 at a time via the address supplied in base + // using post-increment addressing. + template + void vs_ld2_post(const VSeq& v, Assembler::SIMD_Arrangement T, Register base) { + static_assert((N & (N - 1)) == 0, "sequence length must be even"); + for (int i = 0; i < N; i += 2) { + __ ld2(v[i], v[i+1], T, __ post(base, 32)); + } + } + + // store N vector registers interleaved into N/2 pairs of quadword + // memory locations via the address supplied in base using + // post-increment addressing. + template + void vs_st2_post(const VSeq& v, Assembler::SIMD_Arrangement T, Register base) { + static_assert((N & (N - 1)) == 0, "sequence length must be even"); + for (int i = 0; i < N; i += 2) { + __ st2(v[i], v[i+1], T, __ post(base, 32)); + } + } + + // load N quadword values from memory de-interleaved into N vector + // registers 3 elements at a time via the address supplied in base. + template + void vs_ld3(const VSeq& v, Assembler::SIMD_Arrangement T, Register base) { + static_assert(N == ((N / 3) * 3), "sequence length must be multiple of 3"); + for (int i = 0; i < N; i += 3) { + __ ld3(v[i], v[i+1], v[i+2], T, base); + } + } + + // load N quadword values from memory de-interleaved into N vector + // registers 3 elements at a time via the address supplied in base + // using post-increment addressing. + template + void vs_ld3_post(const VSeq& v, Assembler::SIMD_Arrangement T, Register base) { + static_assert(N == ((N / 3) * 3), "sequence length must be multiple of 3"); + for (int i = 0; i < N; i += 3) { + __ ld3(v[i], v[i+1], v[i+2], T, __ post(base, 48)); + } + } + // load N/2 pairs of quadword values from memory into N vector // registers via the address supplied in base with each pair indexed // using the the start offset plus the corresponding entry in the @@ -4810,23 +4916,29 @@ class StubGenerator: public StubCodeGenerator { } } - // Helper routines for various flavours of dilithium montgomery - // multiply + // Helper routines for various flavours of montgomery multiply - // Perform 16 32-bit Montgomery multiplications in parallel - // See the montMul() method of the sun.security.provider.ML_DSA class. + // Perform 16 32-bit (4x4S) or 32 16-bit (4 x 8H) Montgomery + // multiplications in parallel + // + + // See the montMul() method of the sun.security.provider.ML_DSA + // class. // - // Computes 4x4S results - // a = b * c * 2^-32 mod MONT_Q - // Inputs: vb, vc - 4x4S vector register sequences - // vq - 2x4S constants - // Temps: vtmp - 4x4S vector sequence trashed after call - // Outputs: va - 4x4S vector register sequences + // Computes 4x4S results or 8x8H results + // a = b * c * 2^MONT_R_BITS mod MONT_Q + // Inputs: vb, vc - 4x4S or 4x8H vector register sequences + // vq - 2x4S or 2x8H constants + // Temps: vtmp - 4x4S or 4x8H vector sequence trashed after call + // Outputs: va - 4x4S or 4x8H vector register sequences // vb, vc, vtmp and vq must all be disjoint // va must be disjoint from all other inputs/temps or must equal vc - // n.b. MONT_R_BITS is 32, so the right shift by it is implicit. - void dilithium_montmul16(const VSeq<4>& va, const VSeq<4>& vb, const VSeq<4>& vc, - const VSeq<4>& vtmp, const VSeq<2>& vq) { + // va must have a non-zero delta i.e. it must not be a constant vseq. + // n.b. MONT_R_BITS is 16 or 32, so the right shift by it is implicit. + void vs_montmul4(const VSeq<4>& va, const VSeq<4>& vb, const VSeq<4>& vc, + Assembler::SIMD_Arrangement T, + const VSeq<4>& vtmp, const VSeq<2>& vq) { + assert (T == __ T4S || T == __ T8H, "invalid arrangement for montmul"); assert(vs_disjoint(vb, vc), "vb and vc overlap"); assert(vs_disjoint(vb, vq), "vb and vq overlap"); assert(vs_disjoint(vb, vtmp), "vb and vtmp overlap"); @@ -4840,40 +4952,106 @@ class StubGenerator: public StubCodeGenerator { assert(vs_disjoint(va, vb), "va and vb overlap"); assert(vs_disjoint(va, vq), "va and vq overlap"); assert(vs_disjoint(va, vtmp), "va and vtmp overlap"); + assert(!va.is_constant(), "output vector must identify 4 different registers"); // schedule 4 streams of instructions across the vector sequences for (int i = 0; i < 4; i++) { - __ sqdmulh(vtmp[i], __ T4S, vb[i], vc[i]); // aHigh = hi32(2 * b * c) - __ mulv(va[i], __ T4S, vb[i], vc[i]); // aLow = lo32(b * c) + __ sqdmulh(vtmp[i], T, vb[i], vc[i]); // aHigh = hi32(2 * b * c) + __ mulv(va[i], T, vb[i], vc[i]); // aLow = lo32(b * c) } for (int i = 0; i < 4; i++) { - __ mulv(va[i], __ T4S, va[i], vq[0]); // m = aLow * qinv + __ mulv(va[i], T, va[i], vq[0]); // m = aLow * qinv } for (int i = 0; i < 4; i++) { - __ sqdmulh(va[i], __ T4S, va[i], vq[1]); // n = hi32(2 * m * q) + __ sqdmulh(va[i], T, va[i], vq[1]); // n = hi32(2 * m * q) } for (int i = 0; i < 4; i++) { - __ shsubv(va[i], __ T4S, vtmp[i], va[i]); // a = (aHigh - n) / 2 + __ shsubv(va[i], T, vtmp[i], va[i]); // a = (aHigh - n) / 2 } } - // Perform 2x16 32-bit Montgomery multiplications in parallel - // See the montMul() method of the sun.security.provider.ML_DSA class. + // Perform 8 32-bit (4x4S) or 16 16-bit (2 x 8H) Montgomery + // multiplications in parallel // - // Computes 8x4S results - // a = b * c * 2^-32 mod MONT_Q - // Inputs: vb, vc - 8x4S vector register sequences - // vq - 2x4S constants - // Temps: vtmp - 4x4S vector sequence trashed after call - // Outputs: va - 8x4S vector register sequences + + // See the montMul() method of the sun.security.provider.ML_DSA + // class. + // + // Computes 4x4S results or 8x8H results + // a = b * c * 2^MONT_R_BITS mod MONT_Q + // Inputs: vb, vc - 4x4S or 4x8H vector register sequences + // vq - 2x4S or 2x8H constants + // Temps: vtmp - 4x4S or 4x8H vector sequence trashed after call + // Outputs: va - 4x4S or 4x8H vector register sequences // vb, vc, vtmp and vq must all be disjoint // va must be disjoint from all other inputs/temps or must equal vc - // n.b. MONT_R_BITS is 32, so the right shift by it is implicit. - void vs_montmul32(const VSeq<8>& va, const VSeq<8>& vb, const VSeq<8>& vc, - const VSeq<4>& vtmp, const VSeq<2>& vq) { + // va must have a non-zero delta i.e. it must not be a constant vseq. + // n.b. MONT_R_BITS is 16 or 32, so the right shift by it is implicit. + void vs_montmul2(const VSeq<2>& va, const VSeq<2>& vb, const VSeq<2>& vc, + Assembler::SIMD_Arrangement T, + const VSeq<2>& vtmp, const VSeq<2>& vq) { + assert (T == __ T4S || T == __ T8H, "invalid arrangement for montmul"); + assert(vs_disjoint(vb, vc), "vb and vc overlap"); + assert(vs_disjoint(vb, vq), "vb and vq overlap"); + assert(vs_disjoint(vb, vtmp), "vb and vtmp overlap"); + + assert(vs_disjoint(vc, vq), "vc and vq overlap"); + assert(vs_disjoint(vc, vtmp), "vc and vtmp overlap"); + + assert(vs_disjoint(vq, vtmp), "vq and vtmp overlap"); + + assert(vs_disjoint(va, vc) || vs_same(va, vc), "va and vc neither disjoint nor equal"); + assert(vs_disjoint(va, vb), "va and vb overlap"); + assert(vs_disjoint(va, vq), "va and vq overlap"); + assert(vs_disjoint(va, vtmp), "va and vtmp overlap"); + assert(!va.is_constant(), "output vector must identify 2 different registers"); + + // schedule 2 streams of i& va, const VSeq<4>& vb, const VSeq<4>& vc, + const VSeq<4>& vtmp, const VSeq<2>& vq) { + // Use the helper routine to schedule a 4x4S montgomery multiply. + // It will assert that the register use is valid + vs_montmul4(va, vb, vc, __ T4S, vtmp, vq); + } + + // Perform 2x16 32-bit Montgomery multiplications in parallel + + void dilithium_montmul32(const VSeq<8>& va, const VSeq<8>& vb, const VSeq<8>& vc, + const VSeq<4>& vtmp, const VSeq<2>& vq) { + // Schedule two successive 4x4S multiplies via the montmul helper + // on the front and back halves of va, vb and vc. The helper will + // assert that the register use has no overlap conflicts on each + // individual call but we also need to ensure that the necessary + // disjoint/equality constraints are met across both calls. + // vb, vc, vtmp and vq must be disjoint. va must either be // disjoint from all other registers or equal vc @@ -4891,8 +5069,8 @@ class StubGenerator: public StubCodeGenerator { assert(vs_disjoint(va, vq), "va and vq overlap"); assert(vs_disjoint(va, vtmp), "va and vtmp overlap"); - // we need to multiply the front and back halves of each sequence - // 4x4S at a time because + // we multiply the front and back halves of each sequence 4 at a + // time because // // 1) we are currently only able to get 4-way instruction // parallelism at best @@ -4901,8 +5079,8 @@ class StubGenerator: public StubCodeGenerator { // scratch registers to hold intermediate results so vtmp can only // be a VSeq<4> which means we only have 4 scratch slots - dilithium_montmul16(vs_front(va), vs_front(vb), vs_front(vc), vtmp, vq); - dilithium_montmul16(vs_back(va), vs_back(vb), vs_back(vc), vtmp, vq); + vs_montmul4(vs_front(va), vs_front(vb), vs_front(vc), __ T4S, vtmp, vq); + vs_montmul4(vs_back(va), vs_back(vb), vs_back(vc), __ T4S, vtmp, vq); } // perform combined montmul then add/sub on 4x4S vectors @@ -4975,7 +5153,7 @@ class StubGenerator: public StubCodeGenerator { // load next 8x4S inputs == b vs_ldpq_post(vs2, zetas); // compute a == c2 * b mod MONT_Q - vs_montmul32(vs2, vs1, vs2, vtmp, vq); + dilithium_montmul32(vs2, vs1, vs2, vtmp, vq); // load 8x4s coefficients via first start pos == c1 vs_ldpq_indexed(vs1, coeffs, c1Start, offsets); // compute a1 = c1 + a @@ -5029,8 +5207,8 @@ class StubGenerator: public StubCodeGenerator { VSeq<8> vs1(0), vs2(16), vs3(24); // 3 sets of 8x4s inputs/outputs VSeq<4> vtmp = vs_front(vs3); // n.b. tmp registers overlap vs3 VSeq<2> vq(30); // n.b. constants overlap vs3 - int offsets[4] = {0, 32, 64, 96}; - int offsets1[8] = {16, 48, 80, 112, 144, 176, 208, 240 }; + int offsets[4] = { 0, 32, 64, 96}; + int offsets1[8] = { 16, 48, 80, 112, 144, 176, 208, 240 }; int offsets2[8] = { 0, 32, 64, 96, 128, 160, 192, 224 }; __ add(result, coeffs, 0); __ lea(dilithiumConsts, ExternalAddress((address) StubRoutines::aarch64::_dilithiumConsts)); @@ -5056,7 +5234,7 @@ class StubGenerator: public StubCodeGenerator { // load next 32 (8x4S) inputs = b vs_ldpq_post(vs2, zetas); // a = b montul c1 - vs_montmul32(vs2, vs1, vs2, vtmp, vq); + dilithium_montmul32(vs2, vs1, vs2, vtmp, vq); // load 32 (8x4S) coefficients via second offsets = c2 vs_ldr_indexed(vs1, __ Q, coeffs, i, offsets2); // add/sub with result of multiply @@ -5188,7 +5366,7 @@ class StubGenerator: public StubCodeGenerator { // load b next 32 (8x4S) inputs vs_ldpq_post(vs2, zetas); // a = a1 montmul b - vs_montmul32(vs2, vs1, vs2, vtmp, vq); + dilithium_montmul32(vs2, vs1, vs2, vtmp, vq); // save a relative to second start index vs_stpq_indexed(vs2, coeffs, c2Start, offsets); @@ -5306,7 +5484,7 @@ class StubGenerator: public StubCodeGenerator { // reload constants q, qinv -- they were clobbered earlier vs_ldpq(vq, dilithiumConsts); // qInv, q // compute a1 = b montmul c - vs_montmul32(vs2, vs1, vs2, vtmp, vq); + dilithium_montmul32(vs2, vs1, vs2, vtmp, vq); // store a1 32 (8x4S) coefficients via second offsets vs_str_indexed(vs2, __ Q, coeffs, i, offsets2); } @@ -5370,9 +5548,9 @@ class StubGenerator: public StubCodeGenerator { // c load 32 (8x4S) next inputs from poly2 vs_ldpq_post(vs2, poly2); // compute a = b montmul c - vs_montmul32(vs2, vs1, vs2, vtmp, vq); + dilithium_montmul32(vs2, vs1, vs2, vtmp, vq); // compute a = rsquare montmul a - vs_montmul32(vs2, vrsquare, vs2, vtmp, vq); + dilithium_montmul32(vs2, vrsquare, vs2, vtmp, vq); // save a 32 (8x4S) results vs_stpq_post(vs2, result); @@ -5433,7 +5611,7 @@ class StubGenerator: public StubCodeGenerator { // load next 32 inputs vs_ldpq_post(vs2, coeffs); // mont mul by constant - vs_montmul32(vs2, vconst, vs2, vtmp, vq); + dilithium_montmul32(vs2, vconst, vs2, vtmp, vq); // write next 32 results vs_stpq_post(vs2, result); @@ -5605,6 +5783,1082 @@ class StubGenerator: public StubCodeGenerator { } + // Perform 16 16-bit Montgomery multiplications in parallel + + void kyber_montmul16(const VSeq<2>& va, const VSeq<2>& vb, const VSeq<2>& vc, + const VSeq<2>& vtmp, const VSeq<2>& vq) { + // Use the helper routine to schedule a 2x8H montgomery multiply. + // It will assert that the register use is valid + vs_montmul2(va, vb, vc, __ T8H, vtmp, vq); + } + + // Perform 32 16-bit Montgomery multiplications in parallel + + void kyber_montmul32(const VSeq<4>& va, const VSeq<4>& vb, const VSeq<4>& vc, + const VSeq<4>& vtmp, const VSeq<2>& vq) { + // Use the helper routine to schedule a 4x8H montgomery multiply. + // It will assert that the register use is valid + vs_montmul4(va, vb, vc, __ T8H, vtmp, vq); + } + + // Perform 64 16-bit Montgomery multiplications in parallel + + void kyber_montmul64(const VSeq<8>& va, const VSeq<8>& vb, const VSeq<8>& vc, + const VSeq<4>& vtmp, const VSeq<2>& vq) { + // Schedule two successive 4x8H multiplies via the montmul helper + // on the front and back halves of va, vb and vc. The helper will + // assert that the register use has no overlap conflicts on each + // individual call but we also need to ensure that the necessary + // disjoint/equality constraints are met across both calls. + + // vb, vc, vtmp and vq must be disjoint. va must either be + // disjoint from all other registers or equal vc + + assert(vs_disjoint(vb, vc), "vb and vc overlap"); + assert(vs_disjoint(vb, vq), "vb and vq overlap"); + assert(vs_disjoint(vb, vtmp), "vb and vtmp overlap"); + + assert(vs_disjoint(vc, vq), "vc and vq overlap"); + assert(vs_disjoint(vc, vtmp), "vc and vtmp overlap"); + + assert(vs_disjoint(vq, vtmp), "vq and vtmp overlap"); + + assert(vs_disjoint(va, vc) || vs_same(va, vc), "va and vc neither disjoint nor equal"); + assert(vs_disjoint(va, vb), "va and vb overlap"); + assert(vs_disjoint(va, vq), "va and vq overlap"); + assert(vs_disjoint(va, vtmp), "va and vtmp overlap"); + + // we multiply the front and back halves of each sequence 4 at a + // time because + // + // 1) we are currently only able to get 4-way instruction + // parallelism at best + // + // 2) we need registers for the constants in vq and temporary + // scratch registers to hold intermediate results so vtmp can only + // be a VSeq<4> which means we only have 4 scratch slots + + vs_montmul4(vs_front(va), vs_front(vb), vs_front(vc), __ T8H, vtmp, vq); + vs_montmul4(vs_back(va), vs_back(vb), vs_back(vc), __ T8H, vtmp, vq); + } + + void kyber_montmul32_sub_add(const VSeq<4>& va0, const VSeq<4>& va1, const VSeq<4>& vc, + const VSeq<4>& vtmp, const VSeq<2>& vq) { + // compute a = montmul(a1, c) + kyber_montmul32(vc, va1, vc, vtmp, vq); + // ouptut a1 = a0 - a + vs_subv(va1, __ T8H, va0, vc); + // and a0 = a0 + a + vs_addv(va0, __ T8H, va0, vc); + } + + void kyber_sub_add_montmul32(const VSeq<4>& va0, const VSeq<4>& va1, const VSeq<4>& vb, + const VSeq<4>& vtmp1, const VSeq<4>& vtmp2, const VSeq<2>& vq) { + // compute c = a0 - a1 + vs_subv(vtmp1, __ T8H, va0, va1); + // output a0 = a0 + a1 + vs_addv(va0, __ T8H, va0, va1); + // output a1 = b montmul c + kyber_montmul32(va1, vtmp1, vb, vtmp2, vq); + } + + void kyber_load64coeffs(const VSeq<8>& v, Register coeffs) { + vs_ldpq_post(v, coeffs); + } + + void kyber_load64zetas(const VSeq<8>& v, Register zetas) { + vs_ldpq_post(v, zetas); + } + + void kyber_load32zetas(const VSeq<4>& v, Register zetas) { + vs_ldpq_post(v, zetas); + } + + void kyber_store64coeffs(VSeq<8> v, Register tmpAddr) { + vs_stpq_post(v, tmpAddr); + } + + // Kyber NTT function. + // Implements + // static int implKyberNtt(short[] poly, short[] ntt_zetas) {} + // + // coeffs (short[256]) = c_rarg0 + // ntt_zetas (short[256]) = c_rarg1 + address generate_kyberNtt() { + + __ align(CodeEntryAlignment); + StubGenStubId stub_id = StubGenStubId::kyberNtt_id; + StubCodeMark mark(this, stub_id); + address start = __ pc(); + __ enter(); + + const Register coeffs = c_rarg0; + const Register zetas = c_rarg1; + + const Register kyberConsts = r10; + const Register tmpAddr = r11; + + VSeq<8> vs1(0), vs2(16), vs3(24); // 3 sets of 8x8H inputs/outputs + VSeq<4> vtmp = vs_front(vs3); // n.b. tmp registers overlap vs3 + VSeq<2> vq(30); // n.b. constants overlap vs3 + + + __ lea(kyberConsts, ExternalAddress((address) StubRoutines::aarch64::_kyberConsts)); + // load the montmul constants + vs_ldpq(vq, kyberConsts); + + // Each level corresponds to an iteration of the outermost loop of the + // Java method seilerNTT(int[] coeffs). There are some differences + // from what is done in the seilerNTT() method, though: + // 1. The computation is using 16-bit signed values, we do not convert them + // to ints here. + // 2. The zetas are delivered in a bigger array, 128 zetas are stored in + // this array for each level, it is easier that way to fill up the vector + // registers. + // 3. In the seilerNTT() method we use R = 2^20 for the Montgomery + // multiplications (this is because that way there should not be any + // overflow during the inverse NTT computation), here we usr R = 2^16 so + // that we can use the 16-bit arithmetic in the vector unit. + // + // On each level, we fill up the vector registers in such a way that the + // array elements that need to be multiplied by the zetas be in one + // set of vector registers while the corresponding ones that don't need to + // be multiplied, in another set. We can do 32 Montgomery multiplications + // in parallel, using 12 vector registers interleaving the steps of 4 + // identical computations, each done on 8 16-bit values per register. + // level 0 + __ add(tmpAddr, coeffs, 256); + kyber_load64coeffs(vs1, tmpAddr); + kyber_load64zetas(vs2, zetas); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + __ add(tmpAddr, coeffs, 0); + kyber_load64coeffs(vs1, tmpAddr); + vs_subv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_addv(vs1, __ T8H, vs1, vs2); + __ add(tmpAddr, coeffs, 0); + vs_stpq_post(vs1, tmpAddr); + __ add(tmpAddr, coeffs, 256); + vs_stpq_post(vs3, tmpAddr); + // restore montmul constants + vs_ldpq(vq, kyberConsts); + kyber_load64coeffs(vs1, tmpAddr); + kyber_load64zetas(vs2, zetas); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + __ add(tmpAddr, coeffs, 128); + kyber_load64coeffs(vs1, tmpAddr); + vs_subv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_addv(vs1, __ T8H, vs1, vs2); + __ add(tmpAddr, coeffs, 128); + kyber_store64coeffs(vs1, tmpAddr); + __ add(tmpAddr, coeffs, 384); + kyber_store64coeffs(vs3, tmpAddr); + + // level 1 + // restore montmul constants + vs_ldpq(vq, kyberConsts); + __ add(tmpAddr, coeffs, 128); + kyber_load64coeffs(vs1, tmpAddr); + kyber_load64zetas(vs2, zetas); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + __ add(tmpAddr, coeffs, 0); + kyber_load64coeffs(vs1, tmpAddr); + vs_subv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_addv(vs1, __ T8H, vs1, vs2); + __ add(tmpAddr, coeffs, 0); + kyber_store64coeffs(vs1, tmpAddr); + kyber_store64coeffs(vs3, tmpAddr); + vs_ldpq(vq, kyberConsts); + __ add(tmpAddr, coeffs, 384); + kyber_load64coeffs(vs1, tmpAddr); + kyber_load64zetas(vs2, zetas); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + __ add(tmpAddr, coeffs, 256); + kyber_load64coeffs(vs1, tmpAddr); + vs_subv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_addv(vs1, __ T8H, vs1, vs2); + __ add(tmpAddr, coeffs, 256); + kyber_store64coeffs(vs1, tmpAddr); + kyber_store64coeffs(vs3, tmpAddr); + + // level 2 + vs_ldpq(vq, kyberConsts); + int offsets1[4] = { 0, 32, 128, 160 }; + vs_ldpq_indexed(vs1, coeffs, 64, offsets1); + kyber_load64zetas(vs2, zetas); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + vs_ldpq_indexed(vs1, coeffs, 0, offsets1); + // kyber_subv_addv64(); + vs_subv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_addv(vs1, __ T8H, vs1, vs2); + __ add(tmpAddr, coeffs, 0); + vs_stpq_post(vs_front(vs1), tmpAddr); + vs_stpq_post(vs_front(vs3), tmpAddr); + vs_stpq_post(vs_back(vs1), tmpAddr); + vs_stpq_post(vs_back(vs3), tmpAddr); + vs_ldpq(vq, kyberConsts); + vs_ldpq_indexed(vs1, tmpAddr, 64, offsets1); + kyber_load64zetas(vs2, zetas); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + vs_ldpq_indexed(vs1, coeffs, 256, offsets1); + // kyber_subv_addv64(); + vs_subv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_addv(vs1, __ T8H, vs1, vs2); + __ add(tmpAddr, coeffs, 256); + vs_stpq_post(vs_front(vs1), tmpAddr); + vs_stpq_post(vs_front(vs3), tmpAddr); + vs_stpq_post(vs_back(vs1), tmpAddr); + vs_stpq_post(vs_back(vs3), tmpAddr); + + // level 3 + vs_ldpq(vq, kyberConsts); + int offsets2[4] = { 0, 64, 128, 192 }; + vs_ldpq_indexed(vs1, coeffs, 32, offsets2); + kyber_load64zetas(vs2, zetas); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + vs_ldpq_indexed(vs1, coeffs, 0, offsets2); + vs_subv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_addv(vs1, __ T8H, vs1, vs2); + vs_stpq_indexed(vs1, coeffs, 0, offsets2); + vs_stpq_indexed(vs3, coeffs, 32, offsets2); + + vs_ldpq(vq, kyberConsts); + vs_ldpq_indexed(vs1, coeffs, 256 + 32, offsets2); + kyber_load64zetas(vs2, zetas); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + vs_ldpq_indexed(vs1, coeffs, 256, offsets2); + vs_subv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_addv(vs1, __ T8H, vs1, vs2); + vs_stpq_indexed(vs1, coeffs, 256, offsets2); + vs_stpq_indexed(vs3, coeffs, 256 + 32, offsets2); + + // level 4 + vs_ldpq(vq, kyberConsts); + int offsets3[8] = { 0, 32, 64, 96, 128, 160, 192, 224 }; + vs_ldr_indexed(vs1, __ Q, coeffs, 16, offsets3); + kyber_load64zetas(vs2, zetas); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + vs_ldr_indexed(vs1, __ Q, coeffs, 0, offsets3); + vs_subv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_addv(vs1, __ T8H, vs1, vs2); + vs_str_indexed(vs1, __ Q, coeffs, 0, offsets3); + vs_str_indexed(vs3, __ Q, coeffs, 16, offsets3); + + vs_ldpq(vq, kyberConsts); + vs_ldr_indexed(vs1, __ Q, coeffs, 256 + 16, offsets3); + kyber_load64zetas(vs2, zetas); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + vs_ldr_indexed(vs1, __ Q, coeffs, 256, offsets3); + vs_subv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_addv(vs1, __ T8H, vs1, vs2); + vs_str_indexed(vs1, __ Q, coeffs, 256, offsets3); + vs_str_indexed(vs3, __ Q, coeffs, 256 + 16, offsets3); + + // level 5 + vs_ldpq(vq, kyberConsts); + int offsets4[4] = { 0, 32, 64, 96 }; + vs_ld2_indexed(vs1, __ T2D, coeffs, tmpAddr, 0, offsets4); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_montmul32_sub_add(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T2D, coeffs, tmpAddr, 0, offsets4); + vs_ld2_indexed(vs1, __ T2D, coeffs, tmpAddr, 128, offsets4); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_montmul32_sub_add(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T2D, coeffs, tmpAddr, 128, offsets4); + vs_ld2_indexed(vs1, __ T2D, coeffs, tmpAddr, 256, offsets4); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_montmul32_sub_add(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T2D, coeffs, tmpAddr, 256, offsets4); + + vs_ld2_indexed(vs1, __ T2D, coeffs, tmpAddr, 384, offsets4); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_montmul32_sub_add(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T2D, coeffs, tmpAddr, 384, offsets4); + + // level 6 + vs_ld2_indexed(vs1, __ T4S, coeffs, tmpAddr, 0, offsets4); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_montmul32_sub_add(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T4S, coeffs, tmpAddr, 0, offsets4); + vs_ld2_indexed(vs1, __ T4S, coeffs, tmpAddr, 128, offsets4); + // __ ldpq(v18, v19, __ post(zetas, 32)); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_montmul32_sub_add(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T4S, coeffs, tmpAddr, 128, offsets4); + + vs_ld2_indexed(vs1, __ T4S, coeffs, tmpAddr, 256, offsets4); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_montmul32_sub_add(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T4S, coeffs, tmpAddr, 256, offsets4); + + vs_ld2_indexed(vs1, __ T4S, coeffs, tmpAddr, 384, offsets4); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_montmul32_sub_add(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T4S, coeffs, tmpAddr, 384, offsets4); + + __ leave(); // required for proper stackwalking of RuntimeStub frame + __ mov(r0, zr); // return 0 + __ ret(lr); + + return start; + } + + // Kyber Inverse NTT function + // Implements + // static int implKyberInverseNtt(short[] poly, short[] zetas) {} + // + // coeffs (short[256]) = c_rarg0 + // ntt_zetas (short[256]) = c_rarg1 + address generate_kyberInverseNtt() { + + __ align(CodeEntryAlignment); + StubGenStubId stub_id = StubGenStubId::kyberInverseNtt_id; + StubCodeMark mark(this, stub_id); + address start = __ pc(); + __ enter(); + + const Register coeffs = c_rarg0; + const Register zetas = c_rarg1; + + const Register kyberConsts = r10; + const Register tmpAddr = r11; + const Register tmpAddr2 = c_rarg2; + + VSeq<8> vs1(0), vs2(16), vs3(24); // 3 sets of 8x8H inputs/outputs + VSeq<4> vtmp = vs_front(vs3); // n.b. tmp registers overlap vs3 + VSeq<2> vq(30); // n.b. constants overlap vs3 + + __ lea(kyberConsts, ExternalAddress((address) StubRoutines::aarch64::_kyberConsts)); + + // level 0 + vs_ldpq(vq, kyberConsts); + int offsets4[4] = { 0, 32, 64, 96 }; + vs_ld2_indexed(vs1, __ T4S, coeffs, tmpAddr, 0, offsets4); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_sub_add_montmul32(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vs_back(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T4S, coeffs, tmpAddr, 0, offsets4); + vs_ld2_indexed(vs1, __ T4S, coeffs, tmpAddr, 128, offsets4); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_sub_add_montmul32(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vs_back(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T4S, coeffs, tmpAddr, 128, offsets4); + vs_ld2_indexed(vs1, __ T4S, coeffs, tmpAddr, 256, offsets4); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_sub_add_montmul32(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vs_back(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T4S, coeffs, tmpAddr, 256, offsets4); + vs_ld2_indexed(vs1, __ T4S, coeffs, tmpAddr, 384, offsets4); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_sub_add_montmul32(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vs_back(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T4S, coeffs, tmpAddr, 384, offsets4); + + // level 1 + vs_ld2_indexed(vs1, __ T2D, coeffs, tmpAddr, 0, offsets4); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_sub_add_montmul32(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vs_back(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T2D, coeffs, tmpAddr, 0, offsets4); + vs_ld2_indexed(vs1, __ T2D, coeffs, tmpAddr, 128, offsets4); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_sub_add_montmul32(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vs_back(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T2D, coeffs, tmpAddr, 128, offsets4); + + vs_ld2_indexed(vs1, __ T2D, coeffs, tmpAddr, 256, offsets4); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_sub_add_montmul32(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vs_back(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T2D, coeffs, tmpAddr, 256, offsets4); + vs_ld2_indexed(vs1, __ T2D, coeffs, tmpAddr, 384, offsets4); + kyber_load32zetas(vs_front(vs2), zetas); + kyber_sub_add_montmul32(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vs_back(vs2), vtmp, vq); + vs_st2_indexed(vs1, __ T2D, coeffs, tmpAddr, 384, offsets4); + + // level 2 + int offsets3[8] = { 0, 32, 64, 96, 128, 160, 192, 224 }; + vs_ldr_indexed(vs1, __ Q, coeffs, 0, offsets3); + vs_ldr_indexed(vs2, __ Q, coeffs, 16, offsets3); + vs_addv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_subv(vs1, __ T8H, vs1, vs2); + vs_str_indexed(vs3, __ Q, coeffs, 0, offsets3); + kyber_load64zetas(vs2, zetas); + vs_ldpq(vq, kyberConsts); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + vs_str_indexed(vs2, __ Q, coeffs, 16, offsets3); + + vs_ldr_indexed(vs1, __ Q, coeffs, 256, offsets3); + vs_ldr_indexed(vs2, __ Q, coeffs, 256 + 16, offsets3); + vs_addv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_subv(vs1, __ T8H, vs1, vs2); + vs_str_indexed(vs3, __ Q, coeffs, 256, offsets3); + kyber_load64zetas(vs2, zetas); + vs_ldpq(vq, kyberConsts); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + vs_str_indexed(vs2, __ Q, coeffs, 256 + 16, offsets3); + // Barrett reduction at indexes where overflow may happen + __ add(tmpAddr, kyberConsts, 16); + vs_ldpq(vq, tmpAddr); // *** say what these values are *** + VSeq<8> vq1 = VSeq<8>(vq[0], 0); // 2 constant 8 sequences + VSeq<8> vq2 = VSeq<8>(vq[1], 0); // for above two kyber constants + VSeq<8> vq3 = VSeq<8>(v29, 0); // 3rd sequence for const montmul + vs_ldr_indexed(vs1, __ Q, coeffs, 0, offsets3); + vs_sqdmulh(vs2, __ T8H, vs1, vq2); + vs_sshr(vs2, __ T8H, vs2, 11); + vs_mlsv(vs1, __ T8H, vs2, vq1); + vs_str_indexed(vs1, __ Q, coeffs, 0, offsets3); + vs_ldr_indexed(vs1, __ Q, coeffs, 256, offsets3); + vs_sqdmulh(vs2, __ T8H, vs1, vq2); + vs_sshr(vs2, __ T8H, vs2, 11); + vs_mlsv(vs1, __ T8H, vs2, vq1); + vs_str_indexed(vs1, __ Q, coeffs, 256, offsets3); + + // level 3 + int offsets2[4] = { 0, 64, 128, 192 }; + vs_ldpq_indexed(vs1, coeffs, 0, offsets2); + vs_ldpq_indexed(vs2, coeffs, 32, offsets2); + vs_addv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_subv(vs1, __ T8H, vs1, vs2); + vs_stpq_indexed(vs3, coeffs, 0, offsets2); + kyber_load64zetas(vs2, zetas); + vs_ldpq(vq, kyberConsts); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + vs_stpq_indexed(vs2, coeffs, 32, offsets2); + + vs_ldpq_indexed(vs1, coeffs, 256, offsets2); + vs_ldpq_indexed(vs2, coeffs, 256 + 32, offsets2); + vs_addv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_subv(vs1, __ T8H, vs1, vs2); + vs_stpq_indexed(vs3, coeffs, 256, offsets2); + kyber_load64zetas(vs2, zetas); + vs_ldpq(vq, kyberConsts); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + vs_stpq_indexed(vs2, coeffs, 256 + 32, offsets2); + + // level 4 + int offsets1[4] = { 0, 32, 128, 160 }; + vs_ldpq_indexed(vs1, coeffs, 0, offsets1); + vs_ldpq_indexed(vs2, coeffs, 64, offsets1); + vs_addv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_subv(vs1, __ T8H, vs1, vs2); + vs_stpq_indexed(vs3, coeffs, 0, offsets1); + kyber_load64zetas(vs2, zetas); + vs_ldpq(vq, kyberConsts); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + vs_stpq_indexed(vs2, coeffs, 64, offsets1); + + vs_ldpq_indexed(vs1, coeffs, 256, offsets1); + vs_ldpq_indexed(vs2, coeffs, 256 + 64, offsets1); + vs_addv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_subv(vs1, __ T8H, vs1, vs2); + vs_stpq_indexed(vs3, coeffs, 256, offsets1); + kyber_load64zetas(vs2, zetas); + vs_ldpq(vq, kyberConsts); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + vs_stpq_indexed(vs2, coeffs, 256 + 64, offsets1); + + // level 5 + __ add(tmpAddr, coeffs, 0); + kyber_load64coeffs(vs1, tmpAddr); + __ add(tmpAddr, coeffs, 128); + kyber_load64coeffs(vs2, tmpAddr); + vs_addv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_subv(vs1, __ T8H, vs1, vs2); + __ add(tmpAddr, coeffs, 0); + kyber_store64coeffs(vs3, tmpAddr); + kyber_load64zetas(vs2, zetas); + vs_ldpq(vq, kyberConsts); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + __ add(tmpAddr, coeffs, 128); + kyber_store64coeffs(vs2, tmpAddr); + + kyber_load64coeffs(vs1, tmpAddr); + __ add(tmpAddr, coeffs, 384); + kyber_load64coeffs(vs2, tmpAddr); + vs_addv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_subv(vs1, __ T8H, vs1, vs2); + __ add(tmpAddr, coeffs, 256); + kyber_store64coeffs(vs3, tmpAddr); + kyber_load64zetas(vs2, zetas); + vs_ldpq(vq, kyberConsts); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + __ add(tmpAddr, coeffs, 384); + kyber_store64coeffs(vs2, tmpAddr); + // Barrett reduction at indexes where overflow may happen + __ add(tmpAddr, kyberConsts, 16); + vs_ldpq(vq, tmpAddr); // !!! what are these constants? + int offsets0[2] = { 0, 256 }; + vs_ldpq_indexed(vs_front(vs1), coeffs, 0, offsets0); + vs_sqdmulh(vs2, __ T8H, vs1, vq2); + vs_sshr(vs2, __ T8H, vs2, 11); + vs_mlsv(vs1, __ T8H, vs2, vq1); + vs_stpq_indexed(vs_front(vs1), coeffs, 0, offsets0); + // level 6 + __ add(tmpAddr, coeffs, 0); + kyber_load64coeffs(vs1, tmpAddr); + __ add(tmpAddr, coeffs, 256); + kyber_load64coeffs(vs2, tmpAddr); + vs_addv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_subv(vs1, __ T8H, vs1, vs2); + __ add(tmpAddr, coeffs, 0); + kyber_store64coeffs(vs3, tmpAddr); + kyber_load64zetas(vs2, zetas); + vs_ldpq(vq, kyberConsts); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + __ add(tmpAddr, coeffs, 256); + kyber_store64coeffs(vs2, tmpAddr); + + __ add(tmpAddr, coeffs, 128); + kyber_load64coeffs(vs1, tmpAddr); + __ add(tmpAddr, coeffs, 384); + kyber_load64coeffs(vs2, tmpAddr); + vs_addv(vs3, __ T8H, vs1, vs2); // n.b. trashes vq + vs_subv(vs1, __ T8H, vs1, vs2); + __ add(tmpAddr, coeffs, 128); + kyber_store64coeffs(vs3, tmpAddr); + kyber_load64zetas(vs2, zetas); + vs_ldpq(vq, kyberConsts); + kyber_montmul64(vs2, vs1, vs2, vtmp, vq); + __ add(tmpAddr, coeffs, 384); + kyber_store64coeffs(vs2, tmpAddr); + // multiply by 2^-n + __ add(tmpAddr, kyberConsts, 48); + __ ldr(v29, __ Q, tmpAddr); // loads ??? into constant sequence vq3 + vs_ldpq(vq, kyberConsts); + __ add(tmpAddr, coeffs, 0); + kyber_load64coeffs(vs1, tmpAddr); + kyber_montmul64(vs2, vs1, vq3, vtmp, vq); + __ add(tmpAddr, coeffs, 0); + kyber_store64coeffs(vs2, tmpAddr); + + kyber_load64coeffs(vs1, tmpAddr); + kyber_montmul64(vs2, vs1, vq3, vtmp, vq); + __ add(tmpAddr, coeffs, 128); + kyber_store64coeffs(vs2, tmpAddr); + + kyber_load64coeffs(vs1, tmpAddr); + kyber_montmul64(vs2, vs1, vq3, vtmp, vq); + __ add(tmpAddr, coeffs, 256); + kyber_store64coeffs(vs2, tmpAddr); + + kyber_load64coeffs(vs1, tmpAddr); + kyber_montmul64(vs2, vs1, vq3, vtmp, vq); + __ add(tmpAddr, coeffs, 384); + kyber_store64coeffs(vs2, tmpAddr); + + __ leave(); // required for proper stackwalking of RuntimeStub frame + __ mov(r0, zr); // return 0 + __ ret(lr); + + return start; + } + + // Kyber multiply polynomials in the NTT domain. + // Implements + // static int implKyberNttMult( + // short[] result, short[] ntta, short[] nttb, short[] zetas) {} + // + // result (short[256]) = c_rarg0 + // ntta (short[256]) = c_rarg1 + // nttb (short[256]) = c_rarg2 + // zetas (short[128]) = c_rarg3 + address generate_kyberNttMult() { + + __ align(CodeEntryAlignment); + StubGenStubId stub_id = StubGenStubId::kyberNttMult_id; + StubCodeMark mark(this, stub_id); + address start = __ pc(); + __ enter(); + + const Register result = c_rarg0; + const Register ntta = c_rarg1; + const Register nttb = c_rarg2; + const Register zetas = c_rarg3; + + const Register kyberConsts = r10; + const Register limit = r11; + + VSeq<4> vs1(0), vs2(4); // 4 sets of 8x8H inputs/outputs/tmps + VSeq<4> vs3(16), vs4(20); + VSeq<2> vq(30); // pair of constants for montmul + VSeq<2> vz(28); // pair of zetas + VSeq<4> vc(27, 0); // constant sequence for montmul + + __ lea(kyberConsts, ExternalAddress((address) StubRoutines::aarch64::_kyberConsts)); + + Label kyberNttMult_loop; + + __ add(limit, result, 512); + + // load q and qinv + vs_ldpq(vq, kyberConsts); + // load which kyber constant at offset 64? + __ add(kyberConsts, kyberConsts, 64); + __ ldr(v27, __ Q, kyberConsts); + + __ BIND(kyberNttMult_loop); + // load 16 zetas + vs_ldpq_post(vz, zetas); + // load 2 sets of 32 coefficients from the two input arrays + vs_ld2_post(vs_front(vs1), __ T8H, ntta); + vs_ld2_post(vs_back(vs1), __ T8H, nttb); + vs_ld2_post(vs_front(vs4), __ T8H, ntta); + vs_ld2_post(vs_back(vs4), __ T8H, nttb); + // montmul the first and second pair of values loaded into vs1 + // in order and then with one pair reversed storing the two + // results in vs3 + kyber_montmul16(vs_front(vs3), vs_front(vs1), vs_back(vs1), vs_front(vs2), vq); + kyber_montmul16(vs_back(vs3), vs_front(vs1), vs_reverse(vs_back(vs1)), vs_back(vs2), vq); + // montmul the first and second pair of values loaded into vs4 + // in order and then with one pair reversed storing the two + // results in vs1 + kyber_montmul16(vs_front(vs1), vs_front(vs4), vs_back(vs4), vs_front(vs2), vq); + kyber_montmul16(vs_back(vs1), vs_front(vs4), vs_reverse(vs_back(vs4)), vs_back(vs2), vq); + // for each pair of results pick the second value in the first + // pair to create a sequence that we montmul by the zetas + // i.e. we want sequence + int delta = vs1[1]->encoding() - vs3[1]->encoding(); + VSeq<2> vs5(vs3[1], delta); + kyber_montmul16(vs5, vz, vs5, vs_front(vs2), vq); + // add results in pairs storing in vs3 + vs_addv(vs_front(vs3), __ T8H, vs_even(vs3), vs_odd(vs3)); + vs_addv(vs_back(vs3), __ T8H, vs_even(vs1), vs_odd(vs1)); + // montmul result by constant vc and store result in vs1 + kyber_montmul32(vs1, vs3, vc, vs2, vq); + // store the four results as two interleaved pairs of + // quadwords + vs_st2_post(vs1, __ T8H, result); + + __ cmp(result, limit); + __ br(Assembler::NE, kyberNttMult_loop); + + __ leave(); // required for proper stackwalking of RuntimeStub frame + __ mov(r0, zr); // return 0 + __ ret(lr); + + return start; + } + + // Kyber add 2 polynomials. + // Implements + // static int implKyberAddPoly(short[] result, short[] a, short[] b) {} + // + // result (short[256]) = c_rarg0 + // a (short[256]) = c_rarg1 + // b (short[256]) = c_rarg2 + address generate_kyberAddPoly_2() { + + __ align(CodeEntryAlignment); + StubGenStubId stub_id = StubGenStubId::kyberAddPoly_2_id; + StubCodeMark mark(this, stub_id); + address start = __ pc(); + __ enter(); + + const Register result = c_rarg0; + const Register a = c_rarg1; + const Register b = c_rarg2; + + const Register kyberConsts = r11; + + // we sum 256 sets of values in total i.e. 32 x 8H quadwords. + // So, we can load, add and store the data in 3 groups of 11, + // 11 and 10 at a time i.e. we need to map sets of 10 or 11 + // registers. A further constraint is that the mapping needs + // to skip callee saves. So, we allocate the register + // sequences using two 8 sequences, two 2 sequences and two + // single registers. + VSeq<8> vs1_1(0); + VSeq<2> vs1_2(16); + FloatRegister vs1_3 = v28; + VSeq<8> vs2_1(18); + VSeq<2> vs2_2(26); + FloatRegister vs2_3 = v29; + // we also need corresponding constant sequences + VSeq<8> vc_1(31, 0); // two constant vector sequences + VSeq<2> vc_2(31, 0); + FloatRegister vc_3 = v31; + __ lea(kyberConsts, ExternalAddress((address) StubRoutines::aarch64::_kyberConsts)); + + __ ldr(vc_3, __ Q, Address(kyberConsts, 16)); + for (int i = 0; i < 3; i++) { + // load 80 or 88 values from a into vs1_1/2/3 + vs_ldpq_post(vs1_1, a); + vs_ldpq_post(vs1_2, a); + if (i < 2) { + __ ldr(vs1_3, __ Q, __ post(a, 16)); + } + // load 80 or 88 values from b into vs2_1/2/3 + vs_ldpq_post(vs2_1, b); + vs_ldpq_post(vs2_2, b); + if (i < 2) { + __ ldr(vs2_3, __ Q, __ post(b, 16)); + } + // sum 80 or 88 values across vs1 and vs2 into vs1 + vs_addv(vs1_1, __ T8H, vs1_1, vs2_1); + vs_addv(vs1_2, __ T8H, vs1_2, vs2_2); + if (i < 2) { + __ addv(vs1_3, __ T8H, vs1_3, vs2_3); + } + // add constant to all 80 or 88 results + vs_addv(vs1_1, __ T8H, vs1_1, vc_1); + vs_addv(vs1_2, __ T8H, vs1_2, vc_2); + if (i < 2) { + __ addv(vs1_3, __ T8H, vs1_3, vc_3); + } + // store 80 or 88 values + vs_stpq_post(vs1_1, result); + vs_stpq_post(vs1_2, result); + if (i < 2) { + __ str(vs1_3, __ Q, __ post(result, 16)); + } + } + + __ leave(); // required for proper stackwalking of RuntimeStub frame + __ mov(r0, zr); // return 0 + __ ret(lr); + + return start; + } + + // Kyber add 3 polynomials. + // Implements + // static int implKyberAddPoly(short[] result, short[] a, short[] b, short[] c) {} + // + // result (short[256]) = c_rarg0 + // a (short[256]) = c_rarg1 + // b (short[256]) = c_rarg2 + // c (short[256]) = c_rarg3 + address generate_kyberAddPoly_3() { + + __ align(CodeEntryAlignment); + StubGenStubId stub_id = StubGenStubId::kyberAddPoly_3_id; + StubCodeMark mark(this, stub_id); + address start = __ pc(); + __ enter(); + + const Register result = c_rarg0; + const Register a = c_rarg1; + const Register b = c_rarg2; + const Register c = c_rarg3; + + const Register kyberConsts = r11; + + // As above we sum 256 sets of values in total i.e. 32 x 8H + // quadwords. So, we can load, add and store the data in 3 + // groups of 11, 11 and 10 at a time i.e. we need to map sets + // of 10 or 11 registers. A further constraint is that the + // mapping needs to skip callee saves. So, we allocate the + // register sequences using two 8 sequences, two 2 sequences + // and two single registers. + VSeq<8> vs1_1(0); + VSeq<2> vs1_2(16); + FloatRegister vs1_3 = v28; + VSeq<8> vs2_1(18); + VSeq<2> vs2_2(26); + FloatRegister vs2_3 = v29; + // we also need corresponding constant sequences + VSeq<8> vc_1(31, 0); // two constant vector sequences + VSeq<2> vc_2(31, 0); + FloatRegister vc_3 = v31; + + __ lea(kyberConsts, ExternalAddress((address) StubRoutines::aarch64::_kyberConsts)); + + __ ldr(vc_3, __ Q, Address(kyberConsts, 16)); + for (int i = 0; i < 3; i++) { + // load 80 or 88 values from a into vs1_1/2/3 + vs_ldpq_post(vs1_1, a); + vs_ldpq_post(vs1_2, a); + if (i < 2) { + __ ldr(vs1_3, __ Q, __ post(a, 16)); + } + // load 80 or 88 values from b into vs2_1/2/3 + vs_ldpq_post(vs2_1, b); + vs_ldpq_post(vs2_2, b); + if (i < 2) { + __ ldr(vs2_3, __ Q, __ post(b, 16)); + } + // sum 80 or 88 values across vs1 and vs2 into vs1 + vs_addv(vs1_1, __ T8H, vs1_1, vs2_1); + vs_addv(vs1_2, __ T8H, vs1_2, vs2_2); + if (i < 2) { + __ addv(vs1_3, __ T8H, vs1_3, vs2_3); + } + // load 80 or 88 values from c into vs2_1/2/3 + vs_ldpq_post(vs2_1, c); + vs_ldpq_post(vs2_2, c); + if (i < 2) { + __ ldr(vs2_3, __ Q, __ post(c, 16)); + } + // sum 80 or 88 values across vs1 and vs2 into vs1 + vs_addv(vs1_1, __ T8H, vs1_1, vs2_1); + vs_addv(vs1_2, __ T8H, vs1_2, vs2_2); + if (i < 2) { + __ addv(vs1_3, __ T8H, vs1_3, vs2_3); + } + // add constant to all 80 or 88 results + vs_addv(vs1_1, __ T8H, vs1_1, vc_1); + vs_addv(vs1_2, __ T8H, vs1_2, vc_2); + if (i < 2) { + __ addv(vs1_3, __ T8H, vs1_3, vc_3); + } + // store 80 or 88 values + vs_stpq_post(vs1_1, result); + vs_stpq_post(vs1_2, result); + if (i < 2) { + __ str(vs1_3, __ Q, __ post(result, 16)); + } + } + + __ leave(); // required for proper stackwalking of RuntimeStub frame + __ mov(r0, zr); // return 0 + __ ret(lr); + + return start; + } + + // Kyber parse XOF output to polynomial coefficient candidates + // or decodePoly(12, ...). + // Implements + // static int implKyber12To16( + // byte[] condensed, int index, short[] parsed, int parsedLength) {} + // + // (parsedLength or (parsedLength - 48) must be divisible by 64.) + // + // condensed (byte[]) = c_rarg0 + // condensedIndex = c_rarg1 + // parsed (short[112 or 256]) = c_rarg2 + // parsedLength (112 or 256) = c_rarg3 + address generate_kyber12To16() { + Label L_F00, L_loop, L_end; + + __ BIND(L_F00); + __ emit_int64(0x0f000f000f000f00); + __ emit_int64(0x0f000f000f000f00); + + __ align(CodeEntryAlignment); + StubGenStubId stub_id = StubGenStubId::kyber12To16_id; + StubCodeMark mark(this, stub_id); + address start = __ pc(); + __ enter(); + + const Register condensed = c_rarg0; + const Register condensedOffs = c_rarg1; + const Register parsed = c_rarg2; + const Register parsedLength = c_rarg3; + + const Register tmpAddr = r11; + + // data is input 96 bytes at a time i.e. in groups of 6 x 16B + // quadwords so we need a 6 vector sequence for the inputs. + // Parsing produces 64 shorts, employing two 8 vector + // sequences to store and combine the intermediate data. + VSeq<6> vin(24); + VSeq<8> va(0), vb(16); + + __ adr(tmpAddr, L_F00); + __ ldr(v31, __ Q, tmpAddr); // 8H times 0x0f00 + __ add(condensed, condensed, condensedOffs); + + __ BIND(L_loop); + // load 96 (6 x 16B) byte values + vs_ld3_post(vin, __ T16B, condensed); + + // expand groups of input bytes in vin to shorts in va and vb + // n.b. target elements 2 and 3 duplicate elements 4 and 5 + __ ushll(va[0], __ T8H, vin[0], __ T8B, 0); + __ ushll2(va[1], __ T8H, vin[0], __ T16B, 0); + __ ushll(va[2], __ T8H, vin[1], __ T8B, 0); + __ ushll2(va[3], __ T8H, vin[1], __ T16B, 0); + __ ushll(va[4], __ T8H, vin[1], __ T8B, 0); + __ ushll2(va[5], __ T8H, vin[1], __ T16B, 0); + + __ ushll(vb[0], __ T8H, vin[3], __ T8B, 0); + __ ushll2(vb[1], __ T8H, vin[3], __ T16B, 0); + __ ushll(vb[2], __ T8H, vin[4], __ T8B, 0); + __ ushll2(vb[3], __ T8H, vin[4], __ T16B, 0); + __ ushll(vb[4], __ T8H, vin[4], __ T8B, 0); + __ ushll2(vb[5], __ T8H, vin[4], __ T16B, 0); + + // offset duplicated elements in va and vb by 8 + __ shl(va[2], __ T8H, va[2], 8); + __ shl(va[3], __ T8H, va[3], 8); + __ shl(vb[2], __ T8H, vb[2], 8); + __ shl(vb[3], __ T8H, vb[3], 8); + + // expand remaining input bytes in vin to shorts in va and vb + // but this time pre-shifted by 4 + __ ushll(va[6], __ T8H, vin[2], __ T8B, 4); + __ ushll2(va[7], __ T8H, vin[2], __ T16B, 4); + __ ushll(vb[6], __ T8H, vin[5], __ T8B, 4); + __ ushll2(vb[7], __ T8H, vin[5], __ T16B, 4); + + // split the duplicated 8 bit values into two distinct 4 bit + // upper and lower halves using a mask or a shift + __ andr(va[2], __ T16B, va[2], v31); + __ andr(va[3], __ T16B, va[3], v31); + __ ushr(va[4], __ T8H, va[4], 4); + __ ushr(va[5], __ T8H, va[5], 4); + __ andr(vb[2], __ T16B, vb[2], v31); + __ andr(vb[3], __ T16B, vb[3], v31); + __ ushr(vb[4], __ T8H, vb[4], 4); + __ ushr(vb[5], __ T8H, vb[5], 4); + + // sum resulting short values into the front halves of va and + // vb pairing registers offset by stride 2 + __ addv(va[0], __ T8H, va[0], va[2]); + __ addv(va[2], __ T8H, va[1], va[3]); + __ addv(va[1], __ T8H, va[4], va[6]); + __ addv(va[3], __ T8H, va[5], va[7]); + __ addv(vb[0], __ T8H, vb[0], vb[2]); + __ addv(vb[2], __ T8H, vb[1], vb[3]); + __ addv(vb[1], __ T8H, vb[4], vb[6]); + __ addv(vb[3], __ T8H, vb[5], vb[7]); + + // store results interleaved as shorts + vs_st2_post(vs_front(va), __ T8H, parsed); + vs_st2_post(vs_front(vb), __ T8H, parsed); + + __ sub(parsedLength, parsedLength, 64); + __ cmp(parsedLength, (u1)64); + __ br(Assembler::GE, L_loop); + __ cbz(parsedLength, L_end); + + // if anything is left it should be a final 72 bytes. so we + // load 48 bytes into both lanes of front(vin) and 24 bytes + // into the lower lane of back(vin) + vs_ld3_post(vs_front(vin), __ T16B, condensed); + vs_ld3(vs_back(vin), __ T8B, condensed); + + // expand groups of input bytes in vin to shorts in va and vb + // n.b. target elements 2 and 3 of va duplicate elements 4 and + // 5 and target element 2 of vb duplicates element 4. + __ ushll(va[0], __ T8H, vin[0], __ T8B, 0); + __ ushll2(va[1], __ T8H, vin[0], __ T16B, 0); + __ ushll(va[2], __ T8H, vin[1], __ T8B, 0); + __ ushll2(va[3], __ T8H, vin[1], __ T16B, 0); + __ ushll(va[4], __ T8H, vin[1], __ T8B, 0); + __ ushll2(va[5], __ T8H, vin[1], __ T16B, 0); + + __ ushll(vb[0], __ T8H, vin[3], __ T8B, 0); + __ ushll(vb[2], __ T8H, vin[4], __ T8B, 0); + __ ushll(vb[4], __ T8H, vin[4], __ T8B, 0); + + // offset duplicated elements in va and vb by 8 + __ shl(va[2], __ T8H, va[2], 8); + __ shl(va[3], __ T8H, va[3], 8); + __ shl(vb[2], __ T8H, vb[2], 8); + + // expand remaining input bytes in vin to shorts in va and vb + // but this time pre-shifted by 4 + __ ushll(va[6], __ T8H, vin[2], __ T8B, 4); + __ ushll2(va[7], __ T8H, vin[2], __ T16B, 4); + __ ushll(vb[6], __ T8H, vin[5], __ T8B, 4); + + // split the duplicated 8 bit values into two distinct 4 bit + // upper and lower halves using a mask or a shift + __ andr(va[2], __ T16B, va[2], v31); + __ andr(va[3], __ T16B, va[3], v31); + __ ushr(va[4], __ T8H, va[4], 4); + __ ushr(va[5], __ T8H, va[5], 4); + __ andr(vb[2], __ T16B, vb[2], v31); + __ ushr(vb[4], __ T8H, vb[4], 4); + + // sum resulting short values into the front halves of va and + // vb pairing registers offset by stride 2 + __ addv(va[0], __ T8H, va[0], va[2]); + __ addv(va[2], __ T8H, va[1], va[3]); + __ addv(va[1], __ T8H, va[4], va[6]); + __ addv(va[3], __ T8H, va[5], va[7]); + __ addv(vb[0], __ T8H, vb[0], vb[2]); + __ addv(vb[1], __ T8H, vb[4], vb[6]); + + // store results interleaved as shorts + vs_st2_post(vs_front(va), __ T8H, parsed); + vs_st2_post(vs_front(vs_front(vb)), __ T8H, parsed); + + __ BIND(L_end); + + __ leave(); // required for proper stackwalking of RuntimeStub frame + __ mov(r0, zr); // return 0 + __ ret(lr); + + return start; + } + + // Kyber barrett reduce function. + // Implements + // static int implKyberBarrettReduce(short[] coeffs) {} + // + // coeffs (short[256]) = c_rarg0 + address generate_kyberBarrettReduce() { + + __ align(CodeEntryAlignment); + StubGenStubId stub_id = StubGenStubId::kyberBarrettReduce_id; + StubCodeMark mark(this, stub_id); + address start = __ pc(); + __ enter(); + + const Register coeffs = c_rarg0; + + const Register kyberConsts = r10; + const Register result = r11; + + // As above we process 256 sets of values in total i.e. 32 x + // 8H quadwords. So, we can load, add and store the data in 3 + // groups of 11, 11 and 10 at a time i.e. we need to map sets + // of 10 or 11 registers. A further constraint is that the + // mapping needs to skip callee saves. So, we allocate the + // register sequences using two 8 sequences, two 2 sequences + // and two single registers. + VSeq<8> vs1_1(0); + VSeq<2> vs1_2(16); + FloatRegister vs1_3 = v28; + VSeq<8> vs2_1(18); + VSeq<2> vs2_2(26); + FloatRegister vs2_3 = v29; + // we also need a pair of corresponding constant sequences + VSeq<8> vc1_1(30, 0); + VSeq<2> vc1_2(30, 0); + FloatRegister vc1_3 = v30; + VSeq<8> vc2_1(31, 0); + VSeq<2> vc2_2(31, 0); + FloatRegister vc2_3 = v31; + + __ add(result, coeffs, 0); + __ lea(kyberConsts, ExternalAddress((address) StubRoutines::aarch64::_kyberConsts)); + + __ add(kyberConsts, kyberConsts, 16); + __ ldpq(vc1_3, vc2_3, kyberConsts); + + for (int i = 0; i < 3; i++) { + // load 80 or 88 coefficients + vs_ldpq_post(vs1_1, coeffs); + vs_ldpq_post(vs1_2, coeffs); + if (i < 2) { + __ ldr(vs1_3, __ Q, __ post(coeffs, 16)); + } + vs_sqdmulh(vs2_1, __ T8H, vs1_1, vc2_1); + vs_sqdmulh(vs2_2, __ T8H, vs1_2, vc2_2); + if (i < 2) { + __ sqdmulh(vs2_3, __ T8H, vs1_3, vc2_3); + } + vs_sshr(vs2_1, __ T8H, vs2_1, 11); + vs_sshr(vs2_2, __ T8H, vs2_2, 11); + if (i < 2) { + __ sshr(vs2_3, __ T8H, vs2_3, 11); + } + vs_mlsv(vs1_1, __ T8H, vs2_1, vc1_1); + vs_mlsv(vs1_2, __ T8H, vs2_2, vc1_2); + if (i < 2) { + __ mlsv(vs1_3, __ T8H, vs2_3, vc1_3); + } + vs_stpq_post(vs1_1, result); + vs_stpq_post(vs1_2, result); + if (i < 2) { + __ str(vs1_3, __ Q, __ post(result, 16)); + } + } + + __ leave(); // required for proper stackwalking of RuntimeStub frame + __ mov(r0, zr); // return 0 + __ ret(lr); + + return start; + } + /** * Arguments: * @@ -10011,6 +11265,16 @@ class StubGenerator: public StubCodeGenerator { StubRoutines::_dilithiumDecomposePoly = generate_dilithiumDecomposePoly(); } + if (UseKyberIntrinsics) { + StubRoutines::_kyberNtt = generate_kyberNtt(); + StubRoutines::_kyberInverseNtt = generate_kyberInverseNtt(); + StubRoutines::_kyberNttMult = generate_kyberNttMult(); + StubRoutines::_kyberAddPoly_2 = generate_kyberAddPoly_2(); + StubRoutines::_kyberAddPoly_3 = generate_kyberAddPoly_3(); + StubRoutines::_kyber12To16 = generate_kyber12To16(); + StubRoutines::_kyberBarrettReduce = generate_kyberBarrettReduce(); + } + if (UseBASE64Intrinsics) { StubRoutines::_base64_encodeBlock = generate_base64_encodeBlock(); StubRoutines::_base64_decodeBlock = generate_base64_decodeBlock(); diff --git a/src/hotspot/cpu/aarch64/stubRoutines_aarch64.cpp b/src/hotspot/cpu/aarch64/stubRoutines_aarch64.cpp index 536583ff40c0b..099b0a61b76c1 100644 --- a/src/hotspot/cpu/aarch64/stubRoutines_aarch64.cpp +++ b/src/hotspot/cpu/aarch64/stubRoutines_aarch64.cpp @@ -57,6 +57,16 @@ ATTRIBUTE_ALIGNED(64) uint32_t StubRoutines::aarch64::_dilithiumConsts[] = 5373807, 5373807, 5373807, 5373807 // addend for modular reduce }; +ATTRIBUTE_ALIGNED(64) uint16_t StubRoutines::aarch64::_kyberConsts[] = +{ + 0xF301, 0xF301, 0xF301, 0xF301, 0xF301, 0xF301, 0xF301, 0xF301, + 0x0D01, 0x0D01, 0x0D01, 0x0D01, 0x0D01, 0x0D01, 0x0D01, 0x0D01, + 0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, + 0x0200, 0x0200, 0x0200, 0x0200, 0x0200, 0x0200, 0x0200, 0x0200, + 0x0549, 0x0549, 0x0549, 0x0549, 0x0549, 0x0549, 0x0549, 0x0549, + 0x0F00, 0x0F00, 0x0F00, 0x0F00, 0x0F00, 0x0F00, 0x0F00, 0x0F00 +}; + /** * crc_table[] from jdk/src/share/native/java/util/zip/zlib-1.2.5/crc32.h */ diff --git a/src/hotspot/cpu/aarch64/stubRoutines_aarch64.hpp b/src/hotspot/cpu/aarch64/stubRoutines_aarch64.hpp index 857bb2ff10a91..7be27cd17f579 100644 --- a/src/hotspot/cpu/aarch64/stubRoutines_aarch64.hpp +++ b/src/hotspot/cpu/aarch64/stubRoutines_aarch64.hpp @@ -111,6 +111,7 @@ class aarch64 { private: static uint32_t _dilithiumConsts[]; + static uint16_t _kyberConsts[]; static juint _crc_table[]; static jubyte _adler_table[]; // begin trigonometric tables block. See comments in .cpp file diff --git a/src/hotspot/cpu/aarch64/vm_version_aarch64.cpp b/src/hotspot/cpu/aarch64/vm_version_aarch64.cpp index 0f04fee79220a..1a9dd77c64683 100644 --- a/src/hotspot/cpu/aarch64/vm_version_aarch64.cpp +++ b/src/hotspot/cpu/aarch64/vm_version_aarch64.cpp @@ -425,6 +425,17 @@ void VM_Version::initialize() { FLAG_SET_DEFAULT(UseDilithiumIntrinsics, false); } + if (_features & CPU_ASIMD) { + if (FLAG_IS_DEFAULT(UseKyberIntrinsics)) { + UseKyberIntrinsics = true; + } + } else if (UseKyberIntrinsics) { + if (!FLAG_IS_DEFAULT(UseKyberIntrinsics)) { + warning("Kyber intrinsic requires ASIMD instructions"); + } + FLAG_SET_DEFAULT(UseKyberIntrinsics, false); + } + if (FLAG_IS_DEFAULT(UseBASE64Intrinsics)) { UseBASE64Intrinsics = true; } @@ -703,6 +714,7 @@ void VM_Version::initialize_cpu_information(void) { get_compatible_board(_cpu_desc + desc_len, CPU_DETAILED_DESC_BUF_SIZE - desc_len); desc_len = (int)strlen(_cpu_desc); snprintf(_cpu_desc + desc_len, CPU_DETAILED_DESC_BUF_SIZE - desc_len, " %s", _features_string); + fprintf(stderr, "_features_string = \"%s\"", _features_string); _initialized = true; } diff --git a/src/hotspot/share/classfile/vmIntrinsics.cpp b/src/hotspot/share/classfile/vmIntrinsics.cpp index 8011b05969724..1ba11ffe2f119 100644 --- a/src/hotspot/share/classfile/vmIntrinsics.cpp +++ b/src/hotspot/share/classfile/vmIntrinsics.cpp @@ -495,6 +495,15 @@ bool vmIntrinsics::disabled_by_jvm_flags(vmIntrinsics::ID id) { case vmIntrinsics::_dilithiumDecomposePoly: if (!UseDilithiumIntrinsics) return true; break; + case vmIntrinsics::_kyberNtt: + case vmIntrinsics::_kyberInverseNtt: + case vmIntrinsics::_kyberNttMult: + case vmIntrinsics::_kyberAddPoly_2: + case vmIntrinsics::_kyberAddPoly_3: + case vmIntrinsics::_kyber12To16: + case vmIntrinsics::_kyberBarrettReduce: + if (!UseKyberIntrinsics) return true; + break; case vmIntrinsics::_base64_encodeBlock: case vmIntrinsics::_base64_decodeBlock: if (!UseBASE64Intrinsics) return true; diff --git a/src/hotspot/share/classfile/vmIntrinsics.hpp b/src/hotspot/share/classfile/vmIntrinsics.hpp index 93b67301b4bad..c6ac5c032160f 100644 --- a/src/hotspot/share/classfile/vmIntrinsics.hpp +++ b/src/hotspot/share/classfile/vmIntrinsics.hpp @@ -588,6 +588,26 @@ class methodHandle; do_intrinsic(_dilithiumDecomposePoly, sun_security_provider_ML_DSA, \ dilithiumDecomposePoly_name, IaIaIaIII_signature, F_S) \ do_name(dilithiumDecomposePoly_name, "implDilithiumDecomposePoly") \ + /* support for com.sun.crypto.provider.ML_KEM */ \ + do_class(com_sun_crypto_provider_ML_KEM, "com/sun/crypto/provider/ML_KEM") \ + do_signature(SaSaSaSaI_signature, "([S[S[S[S)I") \ + do_signature(BaISaII_signature, "([BI[SI)I") \ + do_signature(SaSaSaI_signature, "([S[S[S)I") \ + do_signature(SaSaI_signature, "([S[S)I") \ + do_signature(SaI_signature, "([S)I") \ + do_name(kyberAddPoly_name, "implKyberAddPoly") \ + do_intrinsic(_kyberNtt, com_sun_crypto_provider_ML_KEM, kyberNtt_name, SaSaI_signature, F_S) \ + do_name(kyberNtt_name, "implKyberNtt") \ + do_intrinsic(_kyberInverseNtt, com_sun_crypto_provider_ML_KEM, kyberInverseNtt_name, SaSaI_signature, F_S) \ + do_name(kyberInverseNtt_name, "implKyberInverseNtt") \ + do_intrinsic(_kyberNttMult, com_sun_crypto_provider_ML_KEM, kyberNttMult_name, SaSaSaSaI_signature, F_S) \ + do_name(kyberNttMult_name, "implKyberNttMult") \ + do_intrinsic(_kyberAddPoly_2, com_sun_crypto_provider_ML_KEM, kyberAddPoly_name, SaSaSaI_signature, F_S) \ + do_intrinsic(_kyberAddPoly_3, com_sun_crypto_provider_ML_KEM, kyberAddPoly_name, SaSaSaSaI_signature, F_S) \ + do_intrinsic(_kyber12To16, com_sun_crypto_provider_ML_KEM, kyber12To16_name, BaISaII_signature, F_S) \ + do_name(kyber12To16_name, "implKyber12To16") \ + do_intrinsic(_kyberBarrettReduce, com_sun_crypto_provider_ML_KEM, kyberBarrettReduce_name, SaI_signature, F_S) \ + do_name(kyberBarrettReduce_name, "implKyberBarrettReduce") \ \ /* support for java.util.zip */ \ do_class(java_util_zip_CRC32, "java/util/zip/CRC32") \ diff --git a/src/hotspot/share/jvmci/vmStructs_jvmci.cpp b/src/hotspot/share/jvmci/vmStructs_jvmci.cpp index 3cbb1512cd06a..69c181d552c67 100644 --- a/src/hotspot/share/jvmci/vmStructs_jvmci.cpp +++ b/src/hotspot/share/jvmci/vmStructs_jvmci.cpp @@ -400,6 +400,13 @@ static_field(StubRoutines, _dilithiumNttMult, address) \ static_field(StubRoutines, _dilithiumMontMulByConstant, address) \ static_field(StubRoutines, _dilithiumDecomposePoly, address) \ + static_field(StubRoutines, _kyberNtt, address) \ + static_field(StubRoutines, _kyberInverseNtt, address) \ + static_field(StubRoutines, _kyberNttMult, address) \ + static_field(StubRoutines, _kyberAddPoly_2, address) \ + static_field(StubRoutines, _kyberAddPoly_3, address) \ + static_field(StubRoutines, _kyber12To16, address) \ + static_field(StubRoutines, _kyberBarrettReduce, address) \ static_field(StubRoutines, _updateBytesCRC32, address) \ static_field(StubRoutines, _crc_table_adr, address) \ static_field(StubRoutines, _crc32c_table_addr, address) \ diff --git a/src/hotspot/share/opto/c2compiler.cpp b/src/hotspot/share/opto/c2compiler.cpp index 3effa8eee0498..2563d14b6c543 100644 --- a/src/hotspot/share/opto/c2compiler.cpp +++ b/src/hotspot/share/opto/c2compiler.cpp @@ -797,6 +797,13 @@ bool C2Compiler::is_intrinsic_supported(vmIntrinsics::ID id) { case vmIntrinsics::_dilithiumNttMult: case vmIntrinsics::_dilithiumMontMulByConstant: case vmIntrinsics::_dilithiumDecomposePoly: + case vmIntrinsics::_kyberNtt: + case vmIntrinsics::_kyberInverseNtt: + case vmIntrinsics::_kyberNttMult: + case vmIntrinsics::_kyberAddPoly_2: + case vmIntrinsics::_kyberAddPoly_3: + case vmIntrinsics::_kyber12To16: + case vmIntrinsics::_kyberBarrettReduce: case vmIntrinsics::_base64_encodeBlock: case vmIntrinsics::_base64_decodeBlock: case vmIntrinsics::_poly1305_processBlocks: diff --git a/src/hotspot/share/opto/escape.cpp b/src/hotspot/share/opto/escape.cpp index 23cf8a67be7b6..63d7e65d2d460 100644 --- a/src/hotspot/share/opto/escape.cpp +++ b/src/hotspot/share/opto/escape.cpp @@ -2197,6 +2197,13 @@ void ConnectionGraph::process_call_arguments(CallNode *call) { strcmp(call->as_CallLeaf()->_name, "dilithiumNttMult") == 0 || strcmp(call->as_CallLeaf()->_name, "dilithiumMontMulByConstant") == 0 || strcmp(call->as_CallLeaf()->_name, "dilithiumDecomposePoly") == 0 || + strcmp(call->as_CallLeaf()->_name, "kyberNtt") == 0 || + strcmp(call->as_CallLeaf()->_name, "kyberInverseNtt") == 0 || + strcmp(call->as_CallLeaf()->_name, "kyberNttMult") == 0 || + strcmp(call->as_CallLeaf()->_name, "kyberAddPoly_2") == 0 || + strcmp(call->as_CallLeaf()->_name, "kyberAddPoly_3") == 0 || + strcmp(call->as_CallLeaf()->_name, "kyber12To16") == 0 || + strcmp(call->as_CallLeaf()->_name, "kyberBarrettReduce") == 0 || strcmp(call->as_CallLeaf()->_name, "encodeBlock") == 0 || strcmp(call->as_CallLeaf()->_name, "decodeBlock") == 0 || strcmp(call->as_CallLeaf()->_name, "md5_implCompress") == 0 || diff --git a/src/hotspot/share/opto/library_call.cpp b/src/hotspot/share/opto/library_call.cpp index 3bb432ac6077b..20c084ada4bcb 100644 --- a/src/hotspot/share/opto/library_call.cpp +++ b/src/hotspot/share/opto/library_call.cpp @@ -636,6 +636,20 @@ bool LibraryCallKit::try_to_inline(int predicate) { return inline_dilithiumMontMulByConstant(); case vmIntrinsics::_dilithiumDecomposePoly: return inline_dilithiumDecomposePoly(); + case vmIntrinsics::_kyberNtt: + return inline_kyberNtt(); + case vmIntrinsics::_kyberInverseNtt: + return inline_kyberInverseNtt(); + case vmIntrinsics::_kyberNttMult: + return inline_kyberNttMult(); + case vmIntrinsics::_kyberAddPoly_2: + return inline_kyberAddPoly_2(); + case vmIntrinsics::_kyberAddPoly_3: + return inline_kyberAddPoly_3(); + case vmIntrinsics::_kyber12To16: + return inline_kyber12To16(); + case vmIntrinsics::_kyberBarrettReduce: + return inline_kyberBarrettReduce(); case vmIntrinsics::_base64_encodeBlock: return inline_base64_encodeBlock(); case vmIntrinsics::_base64_decodeBlock: @@ -7807,6 +7821,245 @@ bool LibraryCallKit::inline_dilithiumDecomposePoly() { return true; } +//------------------------------inline_kyberNtt +bool LibraryCallKit::inline_kyberNtt() { + address stubAddr; + const char *stubName; + assert(UseKyberIntrinsics, "need Kyber intrinsics support"); + assert(callee()->signature()->size() == 2, "kyberNtt has 2 parameters"); + + stubAddr = StubRoutines::kyberNtt(); + stubName = "kyberNtt"; + if (!stubAddr) return false; + + Node* coeffs = argument(0); + Node* ntt_zetas = argument(1); + + coeffs = must_be_not_null(coeffs, true); + ntt_zetas = must_be_not_null(ntt_zetas, true); + + Node* coeffs_start = array_element_address(coeffs, intcon(0), T_SHORT); + assert(coeffs_start, "coeffs is null"); + Node* ntt_zetas_start = array_element_address(ntt_zetas, intcon(0), T_SHORT); + assert(ntt_zetas_start, "ntt_zetas is null"); + Node* kyberNtt = make_runtime_call(RC_LEAF|RC_NO_FP, + OptoRuntime::kyberNtt_Type(), + stubAddr, stubName, TypePtr::BOTTOM, + coeffs_start, ntt_zetas_start); + // return an int + Node* retvalue = _gvn.transform(new ProjNode(kyberNtt, TypeFunc::Parms)); + set_result(retvalue); + return true; +} + +//------------------------------inline_kyberInverseNtt +bool LibraryCallKit::inline_kyberInverseNtt() { + address stubAddr; + const char *stubName; + assert(UseKyberIntrinsics, "need Kyber intrinsics support"); + assert(callee()->signature()->size() == 2, "kyberInverseNtt has 2 parameters"); + + stubAddr = StubRoutines::kyberInverseNtt(); + stubName = "kyberInverseNtt"; + if (!stubAddr) return false; + + Node* coeffs = argument(0); + Node* zetas = argument(1); + + coeffs = must_be_not_null(coeffs, true); + zetas = must_be_not_null(zetas, true); + + Node* coeffs_start = array_element_address(coeffs, intcon(0), T_SHORT); + assert(coeffs_start, "coeffs is null"); + Node* zetas_start = array_element_address(zetas, intcon(0), T_SHORT); + assert(zetas_start, "inverseNtt_zetas is null"); + Node* kyberInverseNtt = make_runtime_call(RC_LEAF|RC_NO_FP, + OptoRuntime::kyberInverseNtt_Type(), + stubAddr, stubName, TypePtr::BOTTOM, + coeffs_start, zetas_start); + + // return an int + Node* retvalue = _gvn.transform(new ProjNode(kyberInverseNtt, TypeFunc::Parms)); + set_result(retvalue); + return true; +} + +//------------------------------inline_kyberNttMult +bool LibraryCallKit::inline_kyberNttMult() { + address stubAddr; + const char *stubName; + assert(UseKyberIntrinsics, "need Kyber intrinsics support"); + assert(callee()->signature()->size() == 4, "kyberNttMult has 4 parameters"); + + stubAddr = StubRoutines::kyberNttMult(); + stubName = "kyberNttMult"; + if (!stubAddr) return false; + + Node* result = argument(0); + Node* ntta = argument(1); + Node* nttb = argument(2); + Node* zetas = argument(3); + + result = must_be_not_null(result, true); + ntta = must_be_not_null(ntta, true); + nttb = must_be_not_null(nttb, true); + zetas = must_be_not_null(zetas, true); + Node* result_start = array_element_address(result, intcon(0), T_SHORT); + assert(result_start, "result is null"); + Node* ntta_start = array_element_address(ntta, intcon(0), T_SHORT); + assert(ntta_start, "ntta is null"); + Node* nttb_start = array_element_address(nttb, intcon(0), T_SHORT); + assert(nttb_start, "nttb is null"); + Node* zetas_start = array_element_address(zetas, intcon(0), T_SHORT); + assert(zetas_start, "nttMult_zetas is null"); + Node* kyberNttMult = make_runtime_call(RC_LEAF|RC_NO_FP, + OptoRuntime::kyberNttMult_Type(), + stubAddr, stubName, TypePtr::BOTTOM, + result_start, ntta_start, nttb_start, + zetas_start); + + // return an int + Node* retvalue = _gvn.transform(new ProjNode(kyberNttMult, TypeFunc::Parms)); + set_result(retvalue); + + return true; +} + +//------------------------------inline_kyberAddPoly_2 +bool LibraryCallKit::inline_kyberAddPoly_2() { + address stubAddr; + const char *stubName; + assert(UseKyberIntrinsics, "need Kyber intrinsics support"); + assert(callee()->signature()->size() == 3, "kyberAddPoly_2 has 3 parameters"); + + stubAddr = StubRoutines::kyberAddPoly_2(); + stubName = "kyberAddPoly_2"; + if (!stubAddr) return false; + + Node* result = argument(0); + Node* a = argument(1); + Node* b = argument(2); + + result = must_be_not_null(result, true); + a = must_be_not_null(a, true); + b = must_be_not_null(b, true); + + Node* result_start = array_element_address(result, intcon(0), T_SHORT); + assert(result_start, "result is null"); + Node* a_start = array_element_address(a, intcon(0), T_SHORT); + assert(a_start, "a is null"); + Node* b_start = array_element_address(b, intcon(0), T_SHORT); + assert(b_start, "b is null"); + Node* kyberAddPoly_2 = make_runtime_call(RC_LEAF|RC_NO_FP, + OptoRuntime::kyberAddPoly_2_Type(), + stubAddr, stubName, TypePtr::BOTTOM, + result_start, a_start, b_start); + // return an int + Node* retvalue = _gvn.transform(new ProjNode(kyberAddPoly_2, TypeFunc::Parms)); + set_result(retvalue); + return true; +} + +//------------------------------inline_kyberAddPoly_3 +bool LibraryCallKit::inline_kyberAddPoly_3() { + address stubAddr; + const char *stubName; + assert(UseKyberIntrinsics, "need Kyber intrinsics support"); + assert(callee()->signature()->size() == 4, "kyberAddPoly_3 has 4 parameters"); + + stubAddr = StubRoutines::kyberAddPoly_3(); + stubName = "kyberAddPoly_3"; + if (!stubAddr) return false; + + Node* result = argument(0); + Node* a = argument(1); + Node* b = argument(2); + Node* c = argument(3); + + result = must_be_not_null(result, true); + a = must_be_not_null(a, true); + b = must_be_not_null(b, true); + c = must_be_not_null(c, true); + + Node* result_start = array_element_address(result, intcon(0), T_SHORT); + assert(result_start, "result is null"); + Node* a_start = array_element_address(a, intcon(0), T_SHORT); + assert(a_start, "a is null"); + Node* b_start = array_element_address(b, intcon(0), T_SHORT); + assert(b_start, "b is null"); + Node* c_start = array_element_address(c, intcon(0), T_SHORT); + assert(c_start, "c is null"); + Node* kyberAddPoly_3 = make_runtime_call(RC_LEAF|RC_NO_FP, + OptoRuntime::kyberAddPoly_3_Type(), + stubAddr, stubName, TypePtr::BOTTOM, + result_start, a_start, b_start, c_start); + // return an int + Node* retvalue = _gvn.transform(new ProjNode(kyberAddPoly_3, TypeFunc::Parms)); + set_result(retvalue); + return true; +} + +//------------------------------inline_kyber12To16 +bool LibraryCallKit::inline_kyber12To16() { + address stubAddr; + const char *stubName; + assert(UseKyberIntrinsics, "need Kyber intrinsics support"); + assert(callee()->signature()->size() == 4, "kyber12To16 has 4 parameters"); + + stubAddr = StubRoutines::kyber12To16(); + stubName = "kyber12To16"; + if (!stubAddr) return false; + + Node* condensed = argument(0); + Node* condensedOffs = argument(1); + Node* parsed = argument(2); + Node* parsedLength = argument(3); + + condensed = must_be_not_null(condensed, true); + parsed = must_be_not_null(parsed, true); + + Node* condensed_start = array_element_address(condensed, intcon(0), T_BYTE); + assert(condensed_start, "condensed is null"); + Node* parsed_start = array_element_address(parsed, intcon(0), T_SHORT); + assert(parsed_start, "parsed is null"); + Node* kyber12To16 = make_runtime_call(RC_LEAF|RC_NO_FP, + OptoRuntime::kyber12To16_Type(), + stubAddr, stubName, TypePtr::BOTTOM, + condensed_start, condensedOffs, parsed_start, parsedLength); + // return an int + Node* retvalue = _gvn.transform(new ProjNode(kyber12To16, TypeFunc::Parms)); + set_result(retvalue); + return true; + +} + +//------------------------------inline_kyberBarrettReduce +bool LibraryCallKit::inline_kyberBarrettReduce() { + address stubAddr; + const char *stubName; + assert(UseKyberIntrinsics, "need Kyber intrinsics support"); + assert(callee()->signature()->size() == 1, "kyberBarrettReduce has 1 parameters"); + + stubAddr = StubRoutines::kyberBarrettReduce(); + stubName = "kyberBarrettReduce"; + if (!stubAddr) return false; + + Node* coeffs = argument(0); + + coeffs = must_be_not_null(coeffs, true); + + Node* coeffs_start = array_element_address(coeffs, intcon(0), T_SHORT); + assert(coeffs_start, "coeffs is null"); + Node* kyberBarrettReduce = make_runtime_call(RC_LEAF|RC_NO_FP, + OptoRuntime::kyberBarrettReduce_Type(), + stubAddr, stubName, TypePtr::BOTTOM, + coeffs_start); + // return an int + Node* retvalue = _gvn.transform(new ProjNode(kyberBarrettReduce, TypeFunc::Parms)); + set_result(retvalue); + return true; +} + bool LibraryCallKit::inline_base64_encodeBlock() { address stubAddr; const char *stubName; diff --git a/src/hotspot/share/opto/library_call.hpp b/src/hotspot/share/opto/library_call.hpp index 1f83e28932b70..2b546f90f3cef 100644 --- a/src/hotspot/share/opto/library_call.hpp +++ b/src/hotspot/share/opto/library_call.hpp @@ -319,6 +319,13 @@ class LibraryCallKit : public GraphKit { bool inline_dilithiumNttMult(); bool inline_dilithiumMontMulByConstant(); bool inline_dilithiumDecomposePoly(); + bool inline_kyberNtt(); + bool inline_kyberInverseNtt(); + bool inline_kyberNttMult(); + bool inline_kyberAddPoly_2(); + bool inline_kyberAddPoly_3(); + bool inline_kyber12To16(); + bool inline_kyberBarrettReduce(); bool inline_base64_encodeBlock(); bool inline_base64_decodeBlock(); bool inline_poly1305_processBlocks(); diff --git a/src/hotspot/share/opto/runtime.cpp b/src/hotspot/share/opto/runtime.cpp index e32cedf6cee15..7cdb6795a1c83 100644 --- a/src/hotspot/share/opto/runtime.cpp +++ b/src/hotspot/share/opto/runtime.cpp @@ -242,13 +242,18 @@ const TypeFunc* OptoRuntime::_bigIntegerShift_Type = nullptr; const TypeFunc* OptoRuntime::_vectorizedMismatch_Type = nullptr; const TypeFunc* OptoRuntime::_ghash_processBlocks_Type = nullptr; const TypeFunc* OptoRuntime::_chacha20Block_Type = nullptr; - const TypeFunc* OptoRuntime::_dilithiumAlmostNtt_Type = nullptr; const TypeFunc* OptoRuntime::_dilithiumAlmostInverseNtt_Type = nullptr; const TypeFunc* OptoRuntime::_dilithiumNttMult_Type = nullptr; const TypeFunc* OptoRuntime::_dilithiumMontMulByConstant_Type = nullptr; const TypeFunc* OptoRuntime::_dilithiumDecomposePoly_Type = nullptr; - +const TypeFunc* OptoRuntime::_kyberNtt_Type = nullptr; +const TypeFunc* OptoRuntime::_kyberInverseNtt_Type = nullptr; +const TypeFunc* OptoRuntime::_kyberNttMult_Type = nullptr; +const TypeFunc* OptoRuntime::_kyberAddPoly_2_Type = nullptr; +const TypeFunc* OptoRuntime::_kyberAddPoly_3_Type = nullptr; +const TypeFunc* OptoRuntime::_kyber12To16_Type = nullptr; +const TypeFunc* OptoRuntime::_kyberBarrettReduce_Type = nullptr; const TypeFunc* OptoRuntime::_base64_encodeBlock_Type = nullptr; const TypeFunc* OptoRuntime::_base64_decodeBlock_Type = nullptr; const TypeFunc* OptoRuntime::_string_IndexOf_Type = nullptr; @@ -1450,7 +1455,6 @@ static const TypeFunc* make_dilithiumAlmostInverseNtt_Type() { // Dilithium NTT multiply function static const TypeFunc* make_dilithiumNttMult_Type() { int argcnt = 3; - const Type** fields = TypeTuple::fields(argcnt); int argp = TypeFunc::Parms; fields[argp++] = TypePtr::NOTNULL; // result @@ -1470,7 +1474,6 @@ static const TypeFunc* make_dilithiumNttMult_Type() { // Dilithium Montgomery multiply a polynome coefficient array by a constant static const TypeFunc* make_dilithiumMontMulByConstant_Type() { int argcnt = 2; - const Type** fields = TypeTuple::fields(argcnt); int argp = TypeFunc::Parms; fields[argp++] = TypePtr::NOTNULL; // coeffs @@ -1508,6 +1511,147 @@ static const TypeFunc* make_dilithiumDecomposePoly_Type() { return TypeFunc::make(domain, range); } +// Kyber NTT function +static const TypeFunc* make_kyberNtt_Type() { + int argcnt = 2; + + const Type** fields = TypeTuple::fields(argcnt); + int argp = TypeFunc::Parms; + fields[argp++] = TypePtr::NOTNULL; // coeffs + fields[argp++] = TypePtr::NOTNULL; // NTT zetas + + assert(argp == TypeFunc::Parms + argcnt, "correct decoding"); + const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields); + + // result type needed + fields = TypeTuple::fields(1); + fields[TypeFunc::Parms + 0] = TypeInt::INT; + const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields); + return TypeFunc::make(domain, range); +} + +// Kyber inverse NTT function +static const TypeFunc* make_kyberInverseNtt_Type() { + int argcnt = 2; + + const Type** fields = TypeTuple::fields(argcnt); + int argp = TypeFunc::Parms; + fields[argp++] = TypePtr::NOTNULL; // coeffs + fields[argp++] = TypePtr::NOTNULL; // inverse NTT zetas + + assert(argp == TypeFunc::Parms + argcnt, "correct decoding"); + const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields); + + // result type needed + fields = TypeTuple::fields(1); + fields[TypeFunc::Parms + 0] = TypeInt::INT; + const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields); + return TypeFunc::make(domain, range); +} + +// Kyber NTT multiply function +static const TypeFunc* make_kyberNttMult_Type() { + int argcnt = 4; + + const Type** fields = TypeTuple::fields(argcnt); + int argp = TypeFunc::Parms; + fields[argp++] = TypePtr::NOTNULL; // result + fields[argp++] = TypePtr::NOTNULL; // ntta + fields[argp++] = TypePtr::NOTNULL; // nttb + fields[argp++] = TypePtr::NOTNULL; // NTT multiply zetas + + assert(argp == TypeFunc::Parms + argcnt, "correct decoding"); + const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields); + + // result type needed + fields = TypeTuple::fields(1); + fields[TypeFunc::Parms + 0] = TypeInt::INT; + const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields); + return TypeFunc::make(domain, range); +} + +// Kyber add 2 polynomials function +static const TypeFunc* make_kyberAddPoly_2_Type() { + int argcnt = 3; + + const Type** fields = TypeTuple::fields(argcnt); + int argp = TypeFunc::Parms; + fields[argp++] = TypePtr::NOTNULL; // result + fields[argp++] = TypePtr::NOTNULL; // a + fields[argp++] = TypePtr::NOTNULL; // b + + assert(argp == TypeFunc::Parms + argcnt, "correct decoding"); + const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields); + + // result type needed + fields = TypeTuple::fields(1); + fields[TypeFunc::Parms + 0] = TypeInt::INT; + const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields); + return TypeFunc::make(domain, range); +} + + +// Kyber add 3 polynomials function +static const TypeFunc* make_kyberAddPoly_3_Type() { + int argcnt = 4; + + const Type** fields = TypeTuple::fields(argcnt); + int argp = TypeFunc::Parms; + fields[argp++] = TypePtr::NOTNULL; // result + fields[argp++] = TypePtr::NOTNULL; // a + fields[argp++] = TypePtr::NOTNULL; // b + fields[argp++] = TypePtr::NOTNULL; // c + + assert(argp == TypeFunc::Parms + argcnt, "correct decoding"); + const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields); + + // result type needed + fields = TypeTuple::fields(1); + fields[TypeFunc::Parms + 0] = TypeInt::INT; + const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields); + return TypeFunc::make(domain, range); +} + + +// Kyber XOF output parsing into polynomial coefficients candidates +// or decompress(12,...) function +static const TypeFunc* make_kyber12To16_Type() { + int argcnt = 4; + + const Type** fields = TypeTuple::fields(argcnt); + int argp = TypeFunc::Parms; + fields[argp++] = TypePtr::NOTNULL; // condensed + fields[argp++] = TypeInt::INT; // condensedOffs + fields[argp++] = TypePtr::NOTNULL; // parsed + fields[argp++] = TypeInt::INT; // parsedLength + + assert(argp == TypeFunc::Parms + argcnt, "correct decoding"); + const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields); + + // result type needed + fields = TypeTuple::fields(1); + fields[TypeFunc::Parms + 0] = TypeInt::INT; + const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields); + return TypeFunc::make(domain, range); +} + +// Kyber Barrett reduce function +static const TypeFunc* make_kyberBarrettReduce_Type() { + int argcnt = 1; + + const Type** fields = TypeTuple::fields(argcnt); + int argp = TypeFunc::Parms; + fields[argp++] = TypePtr::NOTNULL; // coeffs + assert(argp == TypeFunc::Parms + argcnt, "correct decoding"); + const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields); + + // result type needed + fields = TypeTuple::fields(1); + fields[TypeFunc::Parms + 0] = TypeInt::INT; + const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields); + return TypeFunc::make(domain, range); +} + static const TypeFunc* make_base64_encodeBlock_Type() { int argcnt = 6; @@ -2120,13 +2264,18 @@ void OptoRuntime::initialize_types() { _vectorizedMismatch_Type = make_vectorizedMismatch_Type(); _ghash_processBlocks_Type = make_ghash_processBlocks_Type(); _chacha20Block_Type = make_chacha20Block_Type(); - _dilithiumAlmostNtt_Type = make_dilithiumAlmostNtt_Type(); _dilithiumAlmostInverseNtt_Type = make_dilithiumAlmostInverseNtt_Type(); _dilithiumNttMult_Type = make_dilithiumNttMult_Type(); _dilithiumMontMulByConstant_Type = make_dilithiumMontMulByConstant_Type(); _dilithiumDecomposePoly_Type = make_dilithiumDecomposePoly_Type(); - + _kyberNtt_Type = make_kyberNtt_Type(); + _kyberInverseNtt_Type = make_kyberInverseNtt_Type(); + _kyberNttMult_Type = make_kyberNttMult_Type(); + _kyberAddPoly_2_Type = make_kyberAddPoly_2_Type(); + _kyberAddPoly_3_Type = make_kyberAddPoly_3_Type(); + _kyber12To16_Type = make_kyber12To16_Type(); + _kyberBarrettReduce_Type = make_kyberBarrettReduce_Type(); _base64_encodeBlock_Type = make_base64_encodeBlock_Type(); _base64_decodeBlock_Type = make_base64_decodeBlock_Type(); _string_IndexOf_Type = make_string_IndexOf_Type(); diff --git a/src/hotspot/share/opto/runtime.hpp b/src/hotspot/share/opto/runtime.hpp index 96b7e9297d637..606889e14f666 100644 --- a/src/hotspot/share/opto/runtime.hpp +++ b/src/hotspot/share/opto/runtime.hpp @@ -185,6 +185,13 @@ class OptoRuntime : public AllStatic { static const TypeFunc* _dilithiumNttMult_Type; static const TypeFunc* _dilithiumMontMulByConstant_Type; static const TypeFunc* _dilithiumDecomposePoly_Type; + static const TypeFunc* _kyberNtt_Type; + static const TypeFunc* _kyberInverseNtt_Type; + static const TypeFunc* _kyberNttMult_Type; + static const TypeFunc* _kyberAddPoly_2_Type; + static const TypeFunc* _kyberAddPoly_3_Type; + static const TypeFunc* _kyber12To16_Type; + static const TypeFunc* _kyberBarrettReduce_Type; static const TypeFunc* _base64_encodeBlock_Type; static const TypeFunc* _base64_decodeBlock_Type; static const TypeFunc* _string_IndexOf_Type; @@ -468,6 +475,10 @@ class OptoRuntime : public AllStatic { return _unsafe_setmemory_Type; } +// static const TypeFunc* digestBase_implCompress_Type(bool is_sha3); +// static const TypeFunc* digestBase_implCompressMB_Type(bool is_sha3); +// static const TypeFunc* double_keccak_Type(); + static inline const TypeFunc* array_fill_Type() { assert(_array_fill_Type != nullptr, "should be initialized"); return _array_fill_Type; @@ -609,6 +620,41 @@ class OptoRuntime : public AllStatic { return _dilithiumDecomposePoly_Type; } + static const TypeFunc* kyberNtt_Type() { + assert(_kyberNtt_Type != nullptr, "should be initialized"); + return _kyberNtt_Type; + } + + static const TypeFunc* kyberInverseNtt_Type() { + assert(_kyberInverseNtt_Type != nullptr, "should be initialized"); + return _kyberInverseNtt_Type; + } + + static const TypeFunc* kyberNttMult_Type() { + assert(_kyberNttMult_Type != nullptr, "should be initialized"); + return _kyberNttMult_Type; + } + + static const TypeFunc* kyberAddPoly_2_Type() { + assert(_kyberAddPoly_2_Type != nullptr, "should be initialized"); + return _kyberAddPoly_2_Type; + } + + static const TypeFunc* kyberAddPoly_3_Type() { + assert(_kyberAddPoly_3_Type != nullptr, "should be initialized"); + return _kyberAddPoly_3_Type; + } + + static const TypeFunc* kyber12To16_Type() { + assert(_kyber12To16_Type != nullptr, "should be initialized"); + return _kyber12To16_Type; + } + + static const TypeFunc* kyberBarrettReduce_Type() { + assert(_kyberBarrettReduce_Type != nullptr, "should be initialized"); + return _kyberBarrettReduce_Type; + } + // Base64 encode function static inline const TypeFunc* base64_encodeBlock_Type() { assert(_base64_encodeBlock_Type != nullptr, "should be initialized"); diff --git a/src/hotspot/share/runtime/globals.hpp b/src/hotspot/share/runtime/globals.hpp index ab975b1b375c0..7cf9f3b551bc7 100644 --- a/src/hotspot/share/runtime/globals.hpp +++ b/src/hotspot/share/runtime/globals.hpp @@ -328,6 +328,9 @@ const int ObjectAlignmentInBytes = 8; product(bool, UseDilithiumIntrinsics, false, DIAGNOSTIC, \ "Use intrinsics for the vectorized version of Dilithium") \ \ + product(bool, UseKyberIntrinsics, false, DIAGNOSTIC, \ + "Use intrinsics for the vectorized version of Kyber") \ + \ product(bool, UseMD5Intrinsics, false, DIAGNOSTIC, \ "Use intrinsics for MD5 crypto hash function") \ \ diff --git a/src/hotspot/share/runtime/stubDeclarations.hpp b/src/hotspot/share/runtime/stubDeclarations.hpp index fd86f2ced3fad..04ed62eed0bba 100644 --- a/src/hotspot/share/runtime/stubDeclarations.hpp +++ b/src/hotspot/share/runtime/stubDeclarations.hpp @@ -693,6 +693,21 @@ do_stub(compiler, dilithiumDecomposePoly) \ do_entry(compiler, dilithiumDecomposePoly, \ dilithiumDecomposePoly, dilithiumDecomposePoly) \ + do_stub(compiler, kyberNtt) \ + do_entry(compiler, kyberNtt, kyberNtt, kyberNtt) \ + do_stub(compiler, kyberInverseNtt) \ + do_entry(compiler, kyberInverseNtt, kyberInverseNtt, kyberInverseNtt) \ + do_stub(compiler, kyberNttMult) \ + do_entry(compiler, kyberNttMult, kyberNttMult, kyberNttMult) \ + do_stub(compiler, kyberAddPoly_2) \ + do_entry(compiler, kyberAddPoly_2, kyberAddPoly_2, kyberAddPoly_2) \ + do_stub(compiler, kyberAddPoly_3) \ + do_entry(compiler, kyberAddPoly_3, kyberAddPoly_3, kyberAddPoly_3) \ + do_stub(compiler, kyber12To16) \ + do_entry(compiler, kyber12To16, kyber12To16, kyber12To16) \ + do_stub(compiler, kyberBarrettReduce) \ + do_entry(compiler, kyberBarrettReduce, kyberBarrettReduce, \ + kyberBarrettReduce) \ do_stub(compiler, data_cache_writeback) \ do_entry(compiler, data_cache_writeback, data_cache_writeback, \ data_cache_writeback) \ @@ -740,11 +755,11 @@ do_stub(compiler, sha3_implCompress) \ do_entry(compiler, sha3_implCompress, sha3_implCompress, \ sha3_implCompress) \ + do_stub(compiler, double_keccak) \ + do_entry(compiler, double_keccak, double_keccak, double_keccak) \ do_stub(compiler, sha3_implCompressMB) \ do_entry(compiler, sha3_implCompressMB, sha3_implCompressMB, \ sha3_implCompressMB) \ - do_stub(compiler, double_keccak) \ - do_entry(compiler, double_keccak, double_keccak, double_keccak) \ do_stub(compiler, updateBytesAdler32) \ do_entry(compiler, updateBytesAdler32, updateBytesAdler32, \ updateBytesAdler32) \ diff --git a/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM.java b/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM.java index 9808a0133032e..cef2f30a9e639 100644 --- a/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM.java +++ b/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM.java @@ -28,6 +28,7 @@ import java.security.*; import java.util.Arrays; import javax.crypto.DecapsulateException; +import jdk.internal.vm.annotation.IntrinsicCandidate; import sun.security.provider.SHA3.SHAKE256; import sun.security.provider.SHA3Parallel.Shake128Parallel; @@ -71,6 +72,268 @@ public final class ML_KEM { -1599, -709, -789, -1317, -57, 1049, -584 }; + private static final short[] montZetasForVectorNttArr = new short[]{ + // level 0 + -758, -758, -758, -758, -758, -758, -758, -758, + -758, -758, -758, -758, -758, -758, -758, -758, + -758, -758, -758, -758, -758, -758, -758, -758, + -758, -758, -758, -758, -758, -758, -758, -758, + -758, -758, -758, -758, -758, -758, -758, -758, + -758, -758, -758, -758, -758, -758, -758, -758, + -758, -758, -758, -758, -758, -758, -758, -758, + -758, -758, -758, -758, -758, -758, -758, -758, + -758, -758, -758, -758, -758, -758, -758, -758, + -758, -758, -758, -758, -758, -758, -758, -758, + -758, -758, -758, -758, -758, -758, -758, -758, + -758, -758, -758, -758, -758, -758, -758, -758, + -758, -758, -758, -758, -758, -758, -758, -758, + -758, -758, -758, -758, -758, -758, -758, -758, + -758, -758, -758, -758, -758, -758, -758, -758, + -758, -758, -758, -758, -758, -758, -758, -758, + // level 1 + -359, -359, -359, -359, -359, -359, -359, -359, + -359, -359, -359, -359, -359, -359, -359, -359, + -359, -359, -359, -359, -359, -359, -359, -359, + -359, -359, -359, -359, -359, -359, -359, -359, + -359, -359, -359, -359, -359, -359, -359, -359, + -359, -359, -359, -359, -359, -359, -359, -359, + -359, -359, -359, -359, -359, -359, -359, -359, + -359, -359, -359, -359, -359, -359, -359, -359, + -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, + -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, + -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, + -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, + -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, + -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, + -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, + -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, + // level 2 + 1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493, + 1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493, + 1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493, + 1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493, + 1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422, + 1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422, + 1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422, + 1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422, + 287, 287, 287, 287, 287, 287, 287, 287, + 287, 287, 287, 287, 287, 287, 287, 287, + 287, 287, 287, 287, 287, 287, 287, 287, + 287, 287, 287, 287, 287, 287, 287, 287, + 202, 202, 202, 202, 202, 202, 202, 202, + 202, 202, 202, 202, 202, 202, 202, 202, + 202, 202, 202, 202, 202, 202, 202, 202, + 202, 202, 202, 202, 202, 202, 202, 202, + // level 3 + -171, -171, -171, -171, -171, -171, -171, -171, + -171, -171, -171, -171, -171, -171, -171, -171, + 622, 622, 622, 622, 622, 622, 622, 622, + 622, 622, 622, 622, 622, 622, 622, 622, + 1577, 1577, 1577, 1577, 1577, 1577, 1577, 1577, + 1577, 1577, 1577, 1577, 1577, 1577, 1577, 1577, + 182, 182, 182, 182, 182, 182, 182, 182, + 182, 182, 182, 182, 182, 182, 182, 182, + 962, 962, 962, 962, 962, 962, 962, 962, + 962, 962, 962, 962, 962, 962, 962, 962, + -1202, -1202, -1202, -1202, -1202, -1202, -1202, -1202, + -1202, -1202, -1202, -1202, -1202, -1202, -1202, -1202, + -1474, -1474, -1474, -1474, -1474, -1474, -1474, -1474, + -1474, -1474, -1474, -1474, -1474, -1474, -1474, -1474, + 1468, 1468, 1468, 1468, 1468, 1468, 1468, 1468, + 1468, 1468, 1468, 1468, 1468, 1468, 1468, 1468, + // level 4 + 573, 573, 573, 573, 573, 573, 573, 573, + -1325, -1325, -1325, -1325, -1325, -1325, -1325, -1325, + 264, 264, 264, 264, 264, 264, 264, 264, + 383, 383, 383, 383, 383, 383, 383, 383, + -829, -829, -829, -829, -829, -829, -829, -829, + 1458, 1458, 1458, 1458, 1458, 1458, 1458, 1458, + -1602, -1602, -1602, -1602, -1602, -1602, -1602, -1602, + -130, -130, -130, -130, -130, -130, -130, -130, + -681, -681, -681, -681, -681, -681, -681, -681, + 1017, 1017, 1017, 1017, 1017, 1017, 1017, 1017, + 732, 732, 732, 732, 732, 732, 732, 732, + 608, 608, 608, 608, 608, 608, 608, 608, + -1542, -1542, -1542, -1542, -1542, -1542, -1542, -1542, + 411, 411, 411, 411, 411, 411, 411, 411, + -205, -205, -205, -205, -205, -205, -205, -205, + -1571, -1571, -1571, -1571, -1571, -1571, -1571, -1571, + // level 5 + 1223, 1223, 1223, 1223, 652, 652, 652, 652, + -552, -552, -552, -552, 1015, 1015, 1015, 1015, + -1293, -1293, -1293, -1293, 1491, 1491, 1491, 1491, + -282, -282, -282, -282, -1544, -1544, -1544, -1544, + 516, 516, 516, 516, -8, -8, -8, -8, + -320, -320, -320, -320, -666, -666, -666, -666, + 1711, 1711, 1711, 1711, -1162, -1162, -1162, -1162, + 126, 126, 126, 126, 1469, 1469, 1469, 1469, + -853, -853, -853, -853, -90, -90, -90, -90, + -271, -271, -271, -271, 830, 830, 830, 830, + 107, 107, 107, 107, -1421, -1421, -1421, -1421, + -247, -247, -247, -247, -951, -951, -951, -951, + -398, -398, -398, -398, 961, 961, 961, 961, + -1508, -1508, -1508, -1508, -725, -725, -725, -725, + 448, 448, 448, 448, -1065, -1065, -1065, -1065, + 677, 677, 677, 677, -1275, -1275, -1275, -1275, + // level 6 + -1103, -1103, 430, 430, 555, 555, 843, 843, + -1251, -1251, 871, 871, 1550, 1550, 105, 105, + 422, 422, 587, 587, 177, 177, -235, -235, + -291, -291, -460, -460, 1574, 1574, 1653, 1653, + -246, -246, 778, 778, 1159, 1159, -147, -147, + -777, -777, 1483, 1483, -602, -602, 1119, 1119, + -1590, -1590, 644, 644, -872, -872, 349, 349, + 418, 418, 329, 329, -156, -156, -75, -75, + 817, 817, 1097, 1097, 603, 603, 610, 610, + 1322, 1322, -1285, -1285, -1465, -1465, 384, 384, + -1215, -1215, -136, -136, 1218, 1218, -1335, -1335, + -874, -874, 220, 220, -1187, -1187, 1670, 1670, + -1185, -1185, -1530, -1530, -1278, -1278, 794, 794, + -1510, -1510, -854, -854, -870, -870, 478, 478, + -108, -108, -308, -308, 996, 996, 991, 991, + 958, 958, -1460, -1460, 1522, 1522, 1628, 1628 + }; + private static final int[] MONT_ZETAS_FOR_INVERSE_NTT = new int[]{ + 584, -1049, 57, 1317, 789, 709, 1599, -1601, + -990, 604, 348, 857, 612, 474, 1177, -1014, + -88, -982, -191, 668, 1386, 486, -1153, -534, + 514, 137, 586, -1178, 227, 339, -907, 244, + 1200, -833, 1394, -30, 1074, 636, -317, -1192, + -1259, -355, -425, -884, -977, 1430, 868, 607, + 184, 1448, 702, 1327, 431, 497, 595, -94, + 1649, -1497, -620, 42, -172, 1107, -222, 1003, + 426, -845, 395, -510, 1613, 825, 1269, -290, + -1429, 623, -567, 1617, 36, 1007, 1440, 332, + -201, 1313, -1382, -744, 669, -1538, 128, -1598, + 1401, 1183, -553, 714, 405, -1155, -445, 406, + -1496, -49, 82, 1369, 259, 1604, 373, 909, + -1249, -1000, -25, -52, 530, -895, 1226, 819, + -185, 281, -742, 1253, 417, 1400, 35, -593, + 97, -1263, 551, -585, 969, -914, -1188 + }; + + private static final short[] montZetasForVectorInverseNttArr = new short[]{ + // level 0 + -1628, -1628, -1522, -1522, 1460, 1460, -958, -958, + -991, -991, -996, -996, 308, 308, 108, 108, + -478, -478, 870, 870, 854, 854, 1510, 1510, + -794, -794, 1278, 1278, 1530, 1530, 1185, 1185, + 1659, 1659, 1187, 1187, -220, -220, 874, 874, + 1335, 1335, -1218, -1218, 136, 136, 1215, 1215, + -384, -384, 1465, 1465, 1285, 1285, -1322, -1322, + -610, -610, -603, -603, -1097, -1097, -817, -817, + 75, 75, 156, 156, -329, -329, -418, -418, + -349, -349, 872, 872, -644, -644, 1590, 1590, + -1119, -1119, 602, 602, -1483, -1483, 777, 777, + 147, 147, -1159, -1159, -778, -778, 246, 246, + -1653, -1653, -1574, -1574, 460, 460, 291, 291, + 235, 235, -177, -177, -587, -587, -422, -422, + -105, -105, -1550, -1550, -871, -871, 1251, 1251, + -843, -843, -555, -555, -430, -430, 1103, 1103, + // level 1 + 1275, 1275, 1275, 1275, -677, -677, -677, -677, + 1065, 1065, 1065, 1065, -448, -448, -448, -448, + 725, 725, 725, 725, 1508, 1508, 1508, 1508, + -961, -961, -961, -961, 398, 398, 398, 398, + 951, 951, 951, 951, 247, 247, 247, 247, + 1421, 1421, 1421, 1421, -107, -107, -107, -107, + -830, -830, -830, -830, 271, 271, 271, 271, + 90, 90, 90, 90, 853, 853, 853, 853, + -1469, -1469, -1469, -1469, -126, -126, -126, -126, + 1162, 1162, 1162, 1162, 1618, 1618, 1618, 1618, + 666, 666, 666, 666, 320, 320, 320, 320, + 8, 8, 8, 8, -516, -516, -516, -516, + 1544, 1544, 1544, 1544, 282, 282, 282, 282, + -1491, -1491, -1491, -1491, 1293, 1293, 1293, 1293, + -1015, -1015, -1015, -1015, 552, 552, 552, 552, + -652, -652, -652, -652, -1223, -1223, -1223, -1223, + // level 2 + 1571, 1571, 1571, 1571, 1571, 1571, 1571, 1571, + 205, 205, 205, 205, 205, 205, 205, 205, + -411, -411, -411, -411, -411, -411, -411, -411, + 1542, 1542, 1542, 1542, 1542, 1542, 1542, 1542, + -608, -608, -608, -608, -608, -608, -608, -608, + -732, -732, -732, -732, -732, -732, -732, -732, + -1017, -1017, -1017, -1017, -1017, -1017, -1017, -1017, + 681, 681, 681, 681, 681, 681, 681, 681, + 130, 130, 130, 130, 130, 130, 130, 130, + 1602, 1602, 1602, 1602, 1602, 1602, 1602, 1602, + -1458, -1458, -1458, -1458, -1458, -1458, -1458, -1458, + 829, 829, 829, 829, 829, 829, 829, 829, + -383, -383, -383, -383, -383, -383, -383, -383, + -264, -264, -264, -264, -264, -264, -264, -264, + 1325, 1325, 1325, 1325, 1325, 1325, 1325, 1325, + -573, -573, -573, -573, -573, -573, -573, -573, + // level 3 + -1468, -1468, -1468, -1468, -1468, -1468, -1468, -1468, + -1468, -1468, -1468, -1468, -1468, -1468, -1468, -1468, + 1474, 1474, 1474, 1474, 1474, 1474, 1474, 1474, + 1474, 1474, 1474, 1474, 1474, 1474, 1474, 1474, + 1202, 1202, 1202, 1202, 1202, 1202, 1202, 1202, + 1202, 1202, 1202, 1202, 1202, 1202, 1202, 1202, + -962, -962, -962, -962, -962, -962, -962, -962, + -962, -962, -962, -962, -962, -962, -962, -962, + -182, -182, -182, -182, -182, -182, -182, -182, + -182, -182, -182, -182, -182, -182, -182, -182, + -1577, -1577, -1577, -1577, -1577, -1577, -1577, -1577, + -1577, -1577, -1577, -1577, -1577, -1577, -1577, -1577, + -622, -622, -622, -622, -622, -622, -622, -622, + -622, -622, -622, -622, -622, -622, -622, -622, + 171, 171, 171, 171, 171, 171, 171, 171, + 171, 171, 171, 171, 171, 171, 171, 171, + // level 4 + -202, -202, -202, -202, -202, -202, -202, -202, + -202, -202, -202, -202, -202, -202, -202, -202, + -202, -202, -202, -202, -202, -202, -202, -202, + -202, -202, -202, -202, -202, -202, -202, -202, + -287, -287, -287, -287, -287, -287, -287, -287, + -287, -287, -287, -287, -287, -287, -287, -287, + -287, -287, -287, -287, -287, -287, -287, -287, + -287, -287, -287, -287, -287, -287, -287, -287, + -1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422, + -1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422, + -1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422, + -1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422, + -1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493, + -1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493, + -1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493, + -1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493, + // level 5 + 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, + 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, + 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, + 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, + 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, + 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, + 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, + 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, + 359, 359, 359, 359, 359, 359, 359, 359, + 359, 359, 359, 359, 359, 359, 359, 359, + 359, 359, 359, 359, 359, 359, 359, 359, + 359, 359, 359, 359, 359, 359, 359, 359, + 359, 359, 359, 359, 359, 359, 359, 359, + 359, 359, 359, 359, 359, 359, 359, 359, + 359, 359, 359, 359, 359, 359, 359, 359, + 359, 359, 359, 359, 359, 359, 359, 359, + // level 6 + 758, 758, 758, 758, 758, 758, 758, 758, + 758, 758, 758, 758, 758, 758, 758, 758, + 758, 758, 758, 758, 758, 758, 758, 758, + 758, 758, 758, 758, 758, 758, 758, 758, + 758, 758, 758, 758, 758, 758, 758, 758, + 758, 758, 758, 758, 758, 758, 758, 758, + 758, 758, 758, 758, 758, 758, 758, 758, + 758, 758, 758, 758, 758, 758, 758, 758, + 758, 758, 758, 758, 758, 758, 758, 758, + 758, 758, 758, 758, 758, 758, 758, 758, + 758, 758, 758, 758, 758, 758, 758, 758, + 758, 758, 758, 758, 758, 758, 758, 758, + 758, 758, 758, 758, 758, 758, 758, 758, + 758, 758, 758, 758, 758, 758, 758, 758, + 758, 758, 758, 758, 758, 758, 758, 758, + 758, 758, 758, 758, 758, 758, 758, 758 + }; + private static final int[] MONT_ZETAS_FOR_NTT_MULT = new int[]{ -1003, 1003, 222, -222, -1107, 1107, 172, -172, -42, 42, 620, -620, 1497, -1497, -1649, 1649, @@ -89,6 +352,24 @@ public final class ML_KEM { 1601, -1601, -1599, 1599, -709, 709, -789, 789, -1317, 1317, -57, 57, 1049, -1049, -584, 584 }; + private static final short[] montZetasForVectorNttMultArr = new short[]{ + -1103, 1103, 430, -430, 555, -555, 843, -843, + -1251, 1251, 871, -871, 1550, -1550, 105, -105, + 422, -422, 587, -587, 177, -177, -235, 235, + -291, 291, -460, 460, 1574, -1574, 1653, -1653, + -246, 246, 778, -778, 1159, -1159, -147, 147, + -777, 777, 1483, -1483, -602, 602, 1119, -1119, + -1590, 1590, 644, -644, -872, 872, 349, -349, + 418, -418, 329, -329, -156, 156, -75, 75, + 817, -817, 1097, -1097, 603, -603, 610, -610, + 1322, -1322, -1285, 1285, -1465, 1465, 384, -384, + -1215, 1215, -136, 136, 1218, -1218, -1335, 1335, + -874, 874, 220, -220, -1187, 1187, 1670, 1659, + -1185, 1185, -1530, 1530, -1278, 1278, 794, -794, + -1510, 1510, -854, 854, -870, 870, 478, -478, + -108, 108, -308, 308, 996, -996, 991, -991, + 958, -958, -1460, 1460, 1522, -1522, 1628, -1628 + }; private final int mlKem_k; private final int mlKem_eta1; @@ -261,7 +542,7 @@ protected ML_KEM_EncapsulateResult encapsulate( try { mlKemH = MessageDigest.getInstance(HASH_H_NAME); mlKemG = MessageDigest.getInstance(HASH_G_NAME); - } catch (NoSuchAlgorithmException e){ + } catch (NoSuchAlgorithmException e) { // This should never happen. throw new RuntimeException(e); } @@ -527,7 +808,7 @@ private short[][][] generateA(byte[] rho, Boolean transposed) { for (int i = 0; i < mlKem_k; i++) { for (int j = 0; j < mlKem_k; j++) { - xofBufArr[parInd] = seedBuf.clone(); + System.arraycopy(seedBuf, 0, xofBufArr[parInd], 0, seedBuf.length); if (transposed) { xofBufArr[parInd][rhoLen] = (byte) i; xofBufArr[parInd][rhoLen + 1] = (byte) j; @@ -707,9 +988,13 @@ private short[][] mlKemVectorInverseNTT(short[][] vector) { return vector; } - // The elements of poly should be in the range [-ML_KEM_Q, ML_KEM_Q] - // The elements of poly at return will be in the range of [0, ML_KEM_Q] - private void mlKemNTT(short[] poly) { + @IntrinsicCandidate + static int implKyberNtt(short[] poly, short[] ntt_zetas) { + implKyberNttJava(poly); + return 1; + } + + static void implKyberNttJava(short[] poly) { int[] coeffs = new int[ML_KEM_N]; for (int m = 0; m < ML_KEM_N; m++) { coeffs[m] = poly[m]; @@ -718,12 +1003,22 @@ private void mlKemNTT(short[] poly) { for (int m = 0; m < ML_KEM_N; m++) { poly[m] = (short) coeffs[m]; } + } + + // The elements of poly should be in the range [-mlKem_q, mlKem_q] + // The elements of poly at return will be in the range of [0, mlKem_q] + private void mlKemNTT(short[] poly) { + implKyberNtt(poly, montZetasForVectorNttArr); mlKemBarrettReduce(poly); } - // Works in place, but also returns its (modified) input so that it can - // be used in expressions - private short[] mlKemInverseNTT(short[] poly) { + @IntrinsicCandidate + static int implKyberInverseNtt(short[] poly, short[] zetas) { + implKyberInverseNttJava(poly); + return 1; + } + + static void implKyberInverseNttJava(short[] poly) { int[] coeffs = new int[ML_KEM_N]; for (int m = 0; m < ML_KEM_N; m++) { coeffs[m] = poly[m]; @@ -732,6 +1027,12 @@ private short[] mlKemInverseNTT(short[] poly) { for (int m = 0; m < ML_KEM_N; m++) { poly[m] = (short) coeffs[m]; } + } + + // Works in place, but also returns its (modified) input so that it can + // be used in expressions + private short[] mlKemInverseNTT(short[] poly) { + implKyberInverseNtt(poly, montZetasForVectorInverseNttArr); return poly; } @@ -822,11 +1123,16 @@ private short[] mlKemVectorScalarMult(short[][] a, short[][] b) { return result; } - // Multiplies two polynomials represented in the NTT domain. - // The result is a representation of the product still in the NTT domain. - // The coefficients in the result are in the range (-ML_KEM_Q, ML_KEM_Q). - private void nttMult(short[] result, short[] ntta, short[] nttb) { + @IntrinsicCandidate + static int implKyberNttMult(short[] result, short[] ntta, short[] nttb, + short[] zetas) { + implKyberNttMultJava(result, ntta, nttb); + return 1; + } + + static void implKyberNttMultJava(short[] result, short[] ntta, short[] nttb) { for (int m = 0; m < ML_KEM_N / 2; m++) { + int a0 = ntta[2 * m]; int a1 = ntta[2 * m + 1]; int b0 = nttb[2 * m]; @@ -839,6 +1145,13 @@ private void nttMult(short[] result, short[] ntta, short[] nttb) { } } + // Multiplies two polynomials represented in the NTT domain. + // The result is a representation of the product still in the NTT domain. + // The coefficients in the result are in the range (-mlKem_q, mlKem_q). + private void nttMult(short[] result, short[] ntta, short[] nttb) { + implKyberNttMult(result, ntta, nttb, montZetasForVectorNttMultArr); + } + // Adds the vector of polynomials b to a in place, i.e. a will hold // the result. It also returns (the modified) a so that it can be used // in an expression. @@ -853,15 +1166,40 @@ private short[][] mlKemAddVec(short[][] a, short[][] b) { return a; } + @IntrinsicCandidate + static int implKyberAddPoly(short[] result, short[] a, short[] b) { + implKyberAddPolyJava(result, a, b); + return 1; + } + + static void implKyberAddPolyJava(short[] result, short[] a, short[] b) { + for (int m = 0; m < ML_KEM_N; m++) { + int r = a[m] + b[m] + ML_KEM_Q; // This makes r > - ML_KEM_Q + a[m] = (short) r; + } + mlKemBarrettReduce(a); + } + // Adds the polynomial b to a in place, i.e. (the modified) a will hold // the result. // The coefficients are supposed be greater than -ML_KEM_Q in a and // greater than -ML_KEM_Q and less than ML_KEM_Q in b. // The coefficients in the result are greater than -ML_KEM_Q. - private void mlKemAddPoly(short[] a, short[] b) { + private short[] mlKemAddPoly(short[] a, short[] b) { + implKyberAddPoly(a, a, b); + return a; + } + + @IntrinsicCandidate + static int implKyberAddPoly(short[] result, short[] a, short[] b, short[] c) { + implKyberAddPolyJava(result, a, b, c); + return 1; + } + + static void implKyberAddPolyJava(short[] result, short[] a, short[] b, short[] c) { for (int m = 0; m < ML_KEM_N; m++) { - int r = a[m] + b[m] + ML_KEM_Q; // This makes r > -ML_KEM_Q - a[m] = (short) r; + int r = a[m] + b[m] + c[m] + 2 * ML_KEM_Q; // This makes r > - ML_KEM_Q + result[m] = (short) r; } } @@ -871,10 +1209,7 @@ private void mlKemAddPoly(short[] a, short[] b) { // greater than -ML_KEM_Q and less than ML_KEM_Q. // The coefficients in the result are nonnegative and less than ML_KEM_Q. private short[] mlKemAddPoly(short[] a, short[] b, short[] c) { - for (int m = 0; m < ML_KEM_N; m++) { - int r = a[m] + b[m] + c[m] + 2 * ML_KEM_Q; // This makes r > - ML_KEM_Q - a[m] = (short) r; - } + implKyberAddPoly(a, a, b, c); mlKemBarrettReduce(a); return a; } @@ -997,15 +1332,13 @@ private short[][] decodeVector(int l, byte[] encodedVector) { return result; } - // The intrinsic implementations assume that the input and output buffers - // are such that condensed can be read in 192-byte chunks and - // parsed can be written in 128 shorts chunks. In other words, - // if (i - 1) * 128 < parsedLengths <= i * 128 then - // parsed.size should be at least i * 128 and - // condensed.size should be at least index + i * 192 - private void twelve2Sixteen(byte[] condensed, int index, - short[] parsed, int parsedLength) { + @IntrinsicCandidate + private static int implKyber12To16(byte[] condensed, int index, short[] parsed, int parsedLength) { + implKyber12To16Java(condensed, index, parsed, parsedLength); + return 1; + } + private static void implKyber12To16Java(byte[] condensed, int index, short[] parsed, int parsedLength) { for (int i = 0; i < parsedLength * 3 / 2; i += 3) { parsed[(i / 3) * 2] = (short) ((condensed[i + index] & 0xff) + 256 * (condensed[i + index + 1] & 0xf)); @@ -1014,6 +1347,28 @@ private void twelve2Sixteen(byte[] condensed, int index, } } + // The intrinsic implementations assume that the input and output buffers + // are such that condensed can be read in 96-byte chunks and + // parsed can be written in 64 shorts chunks except for the last chunk + // that can be either 48 or 64 shorts. In other words, + // if (i - 1) * 64 < parsedLengths <= i * 64 then + // parsed.length should be either i * 64 or (i-1) * 64 + 48 and + // condensed.length should be at least index + i * 96. + private void twelve2Sixteen(byte[] condensed, int index, + short[] parsed, int parsedLength) { + int i = parsedLength / 64; + int remainder = parsedLength - i * 64; + if (remainder != 0) { + i++; + } + if (((remainder != 0) && (remainder != 48)) || + index + i * 96 > condensed.length) { + // this should never happen + throw new ProviderException("Bad parameters"); + } + implKyber12To16(condensed, index, parsed, parsedLength); + } + private static void decodePoly5(byte[] condensed, int index, short[] parsed) { int j = index; for (int i = 0; i < ML_KEM_N; i += 8) { @@ -1152,6 +1507,19 @@ private static short[] decompressDecode(byte[] input) { return result; } + @IntrinsicCandidate + static int implKyberBarrettReduce(short[] coeffs) { + implKyberBarrettReduceJava(coeffs); + return 1; + } + + static void implKyberBarrettReduceJava(short[] poly) { + for (int m = 0; m < ML_KEM_N; m++) { + int tmp = ((int) poly[m] * BARRETT_MULTIPLIER) >> BARRETT_SHIFT; + poly[m] = (short) (poly[m] - tmp * ML_KEM_Q); + } + } + // The input elements can have any short value. // Modifies poly such that upon return poly[i] will be // in the range [0, ML_KEM_Q] and will be congruent with the original @@ -1161,11 +1529,8 @@ private static short[] decompressDecode(byte[] input) { // That means that if the original poly[i] > -ML_KEM_Q then at return it // will be in the range [0, ML_KEM_Q), i.e. it will be the canonical // representative of its residue class. - private void mlKemBarrettReduce(short[] poly) { - for (int m = 0; m < ML_KEM_N; m++) { - int tmp = ((int) poly[m] * BARRETT_MULTIPLIER) >> BARRETT_SHIFT; - poly[m] = (short) (poly[m] - tmp * ML_KEM_Q); - } + private static void mlKemBarrettReduce(short[] poly) { + implKyberBarrettReduce(poly); } // Precondition: -(2^MONT_R_BITS -1) * MONT_Q <= b * c < (2^MONT_R_BITS - 1) * MONT_Q