diff --git a/src/hotspot/cpu/x86/assembler_x86.cpp b/src/hotspot/cpu/x86/assembler_x86.cpp index 828d8cfda91eb..d4e2b90ecfeb6 100644 --- a/src/hotspot/cpu/x86/assembler_x86.cpp +++ b/src/hotspot/cpu/x86/assembler_x86.cpp @@ -3498,6 +3498,30 @@ void Assembler::vmovdqu(Address dst, XMMRegister src) { emit_operand(src, dst, 0); } +// Move Aligned 256bit Vector +void Assembler::vmovdqa(XMMRegister dst, Address src) { + assert(UseAVX > 0, ""); + InstructionMark im(this); + InstructionAttr attributes(AVX_256bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_address_attributes(/* tuple_type */ EVEX_FVM, /* input_size_in_bits */ EVEX_NObit); + vex_prefix(src, 0, dst->encoding(), VEX_SIMD_66, VEX_OPCODE_0F, &attributes); + emit_int8(0x6F); + emit_operand(dst, src, 0); +} + +void Assembler::vmovdqa(Address dst, XMMRegister src) { + assert(UseAVX > 0, ""); + InstructionMark im(this); + InstructionAttr attributes(AVX_256bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_address_attributes(/* tuple_type */ EVEX_FVM, /* input_size_in_bits */ EVEX_NObit); + attributes.reset_is_clear_context(); + // swap src<->dst for encoding + assert(src != xnoreg, "sanity"); + vex_prefix(dst, 0, src->encoding(), VEX_SIMD_66, VEX_OPCODE_0F, &attributes); + emit_int8(0x7F); + emit_operand(src, dst, 0); +} + void Assembler::vpmaskmovd(XMMRegister dst, XMMRegister mask, Address src, int vector_len) { assert((VM_Version::supports_avx2() && vector_len == AVX_256bit), ""); InstructionMark im(this); @@ -3760,6 +3784,27 @@ void Assembler::evmovdquq(XMMRegister dst, KRegister mask, Address src, bool mer emit_operand(dst, src, 0); } +// Move Aligned 512bit Vector +void Assembler::evmovdqaq(XMMRegister dst, Address src, int vector_len) { + // Unmasked instruction + evmovdqaq(dst, k0, src, /*merge*/ false, vector_len); +} + +void Assembler::evmovdqaq(XMMRegister dst, KRegister mask, Address src, bool merge, int vector_len) { + assert(VM_Version::supports_evex(), ""); + InstructionMark im(this); + InstructionAttr attributes(vector_len, /* vex_w */ true, /* legacy_mode */ false, /* no_mask_reg */ false, /* uses_vl */ true); + attributes.set_address_attributes(/* tuple_type */ EVEX_FVM, /* input_size_in_bits */ EVEX_NObit); + attributes.set_embedded_opmask_register_specifier(mask); + attributes.set_is_evex_instruction(); + if (merge) { + attributes.reset_is_clear_context(); + } + vex_prefix(src, 0, dst->encoding(), VEX_SIMD_66, VEX_OPCODE_0F, &attributes); + emit_int8(0x6F); + emit_operand(dst, src, 0); +} + void Assembler::evmovntdquq(Address dst, XMMRegister src, int vector_len) { // Unmasked instruction evmovntdquq(dst, k0, src, /*merge*/ true, vector_len); diff --git a/src/hotspot/cpu/x86/assembler_x86.hpp b/src/hotspot/cpu/x86/assembler_x86.hpp index 25be0d6a48d32..45d11c873a73f 100644 --- a/src/hotspot/cpu/x86/assembler_x86.hpp +++ b/src/hotspot/cpu/x86/assembler_x86.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1997, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1997, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -1755,6 +1755,10 @@ class Assembler : public AbstractAssembler { void vmovdqu(XMMRegister dst, Address src); void vmovdqu(XMMRegister dst, XMMRegister src); + // Move Aligned 256bit Vector + void vmovdqa(XMMRegister dst, Address src); + void vmovdqa(Address dst, XMMRegister src); + // Move Unaligned 512bit Vector void evmovdqub(XMMRegister dst, XMMRegister src, int vector_len); void evmovdqub(XMMRegister dst, Address src, int vector_len); @@ -1788,6 +1792,10 @@ class Assembler : public AbstractAssembler { void evmovdquq(XMMRegister dst, KRegister mask, Address src, bool merge, int vector_len); void evmovdquq(Address dst, KRegister mask, XMMRegister src, bool merge, int vector_len); + // Move Aligned 512bit Vector + void evmovdqaq(XMMRegister dst, Address src, int vector_len); + void evmovdqaq(XMMRegister dst, KRegister mask, Address src, bool merge, int vector_len); + // Move lower 64bit to high 64bit in 128bit register void movlhps(XMMRegister dst, XMMRegister src); diff --git a/src/hotspot/cpu/x86/macroAssembler_x86.cpp b/src/hotspot/cpu/x86/macroAssembler_x86.cpp index 0830b6b098387..841dcbc1cbbc0 100644 --- a/src/hotspot/cpu/x86/macroAssembler_x86.cpp +++ b/src/hotspot/cpu/x86/macroAssembler_x86.cpp @@ -2696,6 +2696,60 @@ void MacroAssembler::vmovdqu(XMMRegister dst, AddressLiteral src, int vector_len } } +void MacroAssembler::vmovdqu(XMMRegister dst, XMMRegister src, int vector_len) { + if (vector_len == AVX_512bit) { + evmovdquq(dst, src, AVX_512bit); + } else if (vector_len == AVX_256bit) { + vmovdqu(dst, src); + } else { + movdqu(dst, src); + } +} + +void MacroAssembler::vmovdqu(Address dst, XMMRegister src, int vector_len) { + if (vector_len == AVX_512bit) { + evmovdquq(dst, src, AVX_512bit); + } else if (vector_len == AVX_256bit) { + vmovdqu(dst, src); + } else { + movdqu(dst, src); + } +} + +void MacroAssembler::vmovdqu(XMMRegister dst, Address src, int vector_len) { + if (vector_len == AVX_512bit) { + evmovdquq(dst, src, AVX_512bit); + } else if (vector_len == AVX_256bit) { + vmovdqu(dst, src); + } else { + movdqu(dst, src); + } +} + +void MacroAssembler::vmovdqa(XMMRegister dst, AddressLiteral src, Register rscratch) { + assert(rscratch != noreg || always_reachable(src), "missing"); + + if (reachable(src)) { + vmovdqa(dst, as_Address(src)); + } + else { + lea(rscratch, src); + vmovdqa(dst, Address(rscratch, 0)); + } +} + +void MacroAssembler::vmovdqa(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch) { + assert(rscratch != noreg || always_reachable(src), "missing"); + + if (vector_len == AVX_512bit) { + evmovdqaq(dst, src, AVX_512bit, rscratch); + } else if (vector_len == AVX_256bit) { + vmovdqa(dst, src, rscratch); + } else { + movdqa(dst, src, rscratch); + } +} + void MacroAssembler::kmov(KRegister dst, Address src) { if (VM_Version::supports_avx512bw()) { kmovql(dst, src); @@ -2820,6 +2874,29 @@ void MacroAssembler::evmovdquq(XMMRegister dst, AddressLiteral src, int vector_l } } +void MacroAssembler::evmovdqaq(XMMRegister dst, KRegister mask, AddressLiteral src, bool merge, int vector_len, Register rscratch) { + assert(rscratch != noreg || always_reachable(src), "missing"); + + if (reachable(src)) { + Assembler::evmovdqaq(dst, mask, as_Address(src), merge, vector_len); + } else { + lea(rscratch, src); + Assembler::evmovdqaq(dst, mask, Address(rscratch, 0), merge, vector_len); + } +} + +void MacroAssembler::evmovdqaq(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch) { + assert(rscratch != noreg || always_reachable(src), "missing"); + + if (reachable(src)) { + Assembler::evmovdqaq(dst, as_Address(src), vector_len); + } else { + lea(rscratch, src); + Assembler::evmovdqaq(dst, Address(rscratch, 0), vector_len); + } +} + + void MacroAssembler::movdqa(XMMRegister dst, AddressLiteral src, Register rscratch) { assert(rscratch != noreg || always_reachable(src), "missing"); diff --git a/src/hotspot/cpu/x86/macroAssembler_x86.hpp b/src/hotspot/cpu/x86/macroAssembler_x86.hpp index c6e5b2a115f03..4cc92fb9a495f 100644 --- a/src/hotspot/cpu/x86/macroAssembler_x86.hpp +++ b/src/hotspot/cpu/x86/macroAssembler_x86.hpp @@ -1347,6 +1347,14 @@ class MacroAssembler: public Assembler { void vmovdqu(XMMRegister dst, XMMRegister src); void vmovdqu(XMMRegister dst, AddressLiteral src, Register rscratch = noreg); void vmovdqu(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch = noreg); + void vmovdqu(XMMRegister dst, XMMRegister src, int vector_len); + void vmovdqu(XMMRegister dst, Address src, int vector_len); + void vmovdqu(Address dst, XMMRegister src, int vector_len); + + // AVX Aligned forms + using Assembler::vmovdqa; + void vmovdqa(XMMRegister dst, AddressLiteral src, Register rscratch = noreg); + void vmovdqa(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch = noreg); // AVX512 Unaligned void evmovdqu(BasicType type, KRegister kmask, Address dst, XMMRegister src, bool merge, int vector_len); @@ -1403,6 +1411,7 @@ class MacroAssembler: public Assembler { void evmovdquq(XMMRegister dst, Address src, int vector_len) { Assembler::evmovdquq(dst, src, vector_len); } void evmovdquq(Address dst, XMMRegister src, int vector_len) { Assembler::evmovdquq(dst, src, vector_len); } void evmovdquq(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch = noreg); + void evmovdqaq(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch = noreg); void evmovdquq(XMMRegister dst, KRegister mask, XMMRegister src, bool merge, int vector_len) { if (dst->encoding() != src->encoding() || mask != k0) { @@ -1412,6 +1421,7 @@ class MacroAssembler: public Assembler { void evmovdquq(Address dst, KRegister mask, XMMRegister src, bool merge, int vector_len) { Assembler::evmovdquq(dst, mask, src, merge, vector_len); } void evmovdquq(XMMRegister dst, KRegister mask, Address src, bool merge, int vector_len) { Assembler::evmovdquq(dst, mask, src, merge, vector_len); } void evmovdquq(XMMRegister dst, KRegister mask, AddressLiteral src, bool merge, int vector_len, Register rscratch = noreg); + void evmovdqaq(XMMRegister dst, KRegister mask, AddressLiteral src, bool merge, int vector_len, Register rscratch = noreg); // Move Aligned Double Quadword void movdqa(XMMRegister dst, XMMRegister src) { Assembler::movdqa(dst, src); } diff --git a/src/hotspot/cpu/x86/stubGenerator_x86_64_poly_mont.cpp b/src/hotspot/cpu/x86/stubGenerator_x86_64_poly_mont.cpp index 1732d251c98a4..3ab7c71eefe3b 100644 --- a/src/hotspot/cpu/x86/stubGenerator_x86_64_poly_mont.cpp +++ b/src/hotspot/cpu/x86/stubGenerator_x86_64_poly_mont.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, Intel Corporation. All rights reserved. + * Copyright (c) 2024, 2025, Intel Corporation. All rights reserved. * * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -28,17 +28,17 @@ #define __ _masm-> -ATTRIBUTE_ALIGNED(64) uint64_t MODULUS_P256[] = { +ATTRIBUTE_ALIGNED(64) constexpr uint64_t MODULUS_P256[] = { 0x000fffffffffffffULL, 0x00000fffffffffffULL, 0x0000000000000000ULL, 0x0000001000000000ULL, 0x0000ffffffff0000ULL, 0x0000000000000000ULL, 0x0000000000000000ULL, 0x0000000000000000ULL }; -static address modulus_p256() { - return (address)MODULUS_P256; +static address modulus_p256(int index = 0) { + return (address)&MODULUS_P256[index]; } -ATTRIBUTE_ALIGNED(64) uint64_t P256_MASK52[] = { +ATTRIBUTE_ALIGNED(64) constexpr uint64_t P256_MASK52[] = { 0x000fffffffffffffULL, 0x000fffffffffffffULL, 0x000fffffffffffffULL, 0x000fffffffffffffULL, 0xffffffffffffffffULL, 0xffffffffffffffffULL, @@ -48,7 +48,7 @@ static address p256_mask52() { return (address)P256_MASK52; } -ATTRIBUTE_ALIGNED(64) uint64_t SHIFT1R[] = { +ATTRIBUTE_ALIGNED(64) constexpr uint64_t SHIFT1R[] = { 0x0000000000000001ULL, 0x0000000000000002ULL, 0x0000000000000003ULL, 0x0000000000000004ULL, 0x0000000000000005ULL, 0x0000000000000006ULL, @@ -58,7 +58,7 @@ static address shift_1R() { return (address)SHIFT1R; } -ATTRIBUTE_ALIGNED(64) uint64_t SHIFT1L[] = { +ATTRIBUTE_ALIGNED(64) constexpr uint64_t SHIFT1L[] = { 0x0000000000000007ULL, 0x0000000000000000ULL, 0x0000000000000001ULL, 0x0000000000000002ULL, 0x0000000000000003ULL, 0x0000000000000004ULL, @@ -68,6 +68,14 @@ static address shift_1L() { return (address)SHIFT1L; } +ATTRIBUTE_ALIGNED(64) constexpr uint64_t MASKL5[] = { + 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, + 0xFFFFFFFFFFFFFFFFULL, 0x0000000000000000ULL, +}; +static address mask_limb5() { + return (address)MASKL5; +} + /** * Unrolled Word-by-Word Montgomery Multiplication * r = a * b * 2^-260 (mod P) @@ -94,26 +102,25 @@ static address shift_1L() { * B = replicate(bLimbs[i]) |bi|bi|bi|bi|bi|bi|bi|bi| * +--+--+--+--+--+--+--+--+ * +--+--+--+--+--+--+--+--+ + * | 0| 0| 0|a5|a4|a3|a2|a1| + * Acc1 += A * B *|bi|bi|bi|bi|bi|bi|bi|bi| * Acc1+=| 0| 0| 0|c5|c4|c3|c2|c1| - * *| 0| 0| 0|a5|a4|a3|a2|a1| - * Acc1 += A * B |bi|bi|bi|bi|bi|bi|bi|bi| * +--+--+--+--+--+--+--+--+ - * Acc2+=| 0| 0| 0| 0| 0| 0| 0| 0| - * *h| 0| 0| 0|a5|a4|a3|a2|a1| - * Acc2 += A *h B |bi|bi|bi|bi|bi|bi|bi|bi| + * | 0| 0| 0|a5|a4|a3|a2|a1| + * Acc2 += A *h B *h|bi|bi|bi|bi|bi|bi|bi|bi| + * Acc2+=| 0| 0| 0| d5|d4|d3|d2|d1| * +--+--+--+--+--+--+--+--+ * N = replicate(Acc1[0]) |n0|n0|n0|n0|n0|n0|n0|n0| * +--+--+--+--+--+--+--+--+ * +--+--+--+--+--+--+--+--+ - * Acc1+=| 0| 0| 0|c5|c4|c3|c2|c1| - * *| 0| 0| 0|m5|m4|m3|m2|m1| - * Acc1 += M * N |n0|n0|n0|n0|n0|n0|n0|n0| Note: 52 low bits of Acc1[0] == 0 due to Montgomery! + * | 0| 0| 0|m5|m4|m3|m2|m1| + * Acc1 += M * N *|n0|n0|n0|n0|n0|n0|n0|n0| + * Acc1+=| 0| 0| 0|c5|c4|c3|c2|c1| Note: 52 low bits of c1 == 0 due to Montgomery! * +--+--+--+--+--+--+--+--+ + * | 0| 0| 0|m5|m4|m3|m2|m1| + * Acc2 += M *h N *h|n0|n0|n0|n0|n0|n0|n0|n0| * Acc2+=| 0| 0| 0|d5|d4|d3|d2|d1| - * *h| 0| 0| 0|m5|m4|m3|m2|m1| - * Acc2 += M *h N |n0|n0|n0|n0|n0|n0|n0|n0| * +--+--+--+--+--+--+--+--+ - * if (i == 4) break; * // Combine high/low partial sums Acc1 + Acc2 * +--+--+--+--+--+--+--+--+ * carry = Acc1[0] >> 52 | 0| 0| 0| 0| 0| 0| 0|c1| @@ -124,13 +131,35 @@ static address shift_1L() { * +--+--+--+--+--+--+--+--+ * Acc1 = Acc1 + Acc2 * ---- done - * // Last Carry round: Combine high/low partial sums Acc1 + Acc1 + Acc2 - * carry = Acc1 >> 52 - * Acc1 = Acc1 shift one q element >> - * Acc1 = mask52(Acc1) - * Acc2 += carry - * Acc1 = Acc1 + Acc2 - * output to rLimbs + * + * At this point the result in Acc1 can overflow by 1 Modulus and needs carry + * propagation. Subtract one modulus, carry-propagate both results and select + * (constant-time) the positive number of the two + * + * Carry = Acc1[0] >> 52 + * Acc1L = Acc1[0] & mask52 + * Acc1 = Acc1 shift one q element>> + * Acc1 += Carry + * + * Carry = Acc2[0] >> 52 + * Acc2L = Acc2[0] & mask52 + * Acc2 = Acc2 shift one q element>> + * Acc2 += Carry + * + * for col:=1 to 4 + * Carry = Acc2[col]>>52 + * Carry = Carry shift one q element<< + * Acc2 += Carry + * + * Carry = Acc1[col]>>52 + * Carry = Carry shift one q element<< + * Acc1 += Carry + * done + * + * Acc1 &= mask52 + * Acc2 &= mask52 + * Mask = sign(Acc2) + * Result = select(Mask ? Acc1 or Acc2) */ void montgomeryMultiply(const Register aLimbs, const Register bLimbs, const Register rLimbs, const Register tmp, MacroAssembler* _masm) { Register t0 = tmp; @@ -145,26 +174,30 @@ void montgomeryMultiply(const Register aLimbs, const Register bLimbs, const Regi XMMRegister Acc1 = xmm10; XMMRegister Acc2 = xmm11; XMMRegister N = xmm12; - XMMRegister carry = xmm13; + XMMRegister Carry = xmm13; // // Constants - XMMRegister modulus = xmm20; - XMMRegister shift1L = xmm21; - XMMRegister shift1R = xmm22; - XMMRegister mask52 = xmm23; - KRegister limb0 = k1; - KRegister allLimbs = k2; - - __ mov64(t0, 0x1); - __ kmovql(limb0, t0); + XMMRegister modulus = xmm5; + XMMRegister shift1L = xmm6; + XMMRegister shift1R = xmm7; + XMMRegister Mask52 = xmm8; + KRegister allLimbs = k1; + KRegister limb0 = k2; + KRegister masks[] = {limb0, k3, k4, k5}; + + for (int i=0; i<4; i++) { + __ mov64(t0, 1ULL<> 52 - __ evpsrlq(carry, limb0, Acc1, 52, true, Assembler::AVX_512bit); + __ evpsrlq(Carry, limb0, Acc1, 52, true, Assembler::AVX_512bit); // Acc2[0] += carry - __ evpaddq(Acc2, limb0, carry, Acc2, true, Assembler::AVX_512bit); + __ evpaddq(Acc2, limb0, Carry, Acc2, true, Assembler::AVX_512bit); // Acc1 = Acc1 shift one q element >> __ evpermq(Acc1, allLimbs, shift1R, Acc1, false, Assembler::AVX_512bit); @@ -213,26 +244,317 @@ void montgomeryMultiply(const Register aLimbs, const Register bLimbs, const Regi __ vpaddq(Acc1, Acc1, Acc2, Assembler::AVX_512bit); } - // Last Carry round: Combine high/low partial sums Acc1 + Acc1 + Acc2 - // carry = Acc1 >> 52 - __ evpsrlq(carry, allLimbs, Acc1, 52, true, Assembler::AVX_512bit); - - // Acc1 = Acc1 shift one q element >> + // At this point the result is in Acc1, but needs to be normailized to 52bit + // limbs (i.e. needs carry propagation) It can also overflow by 1 modulus. + // Subtract one modulus from Acc1 into Acc2 then carry propagate both + // simultaneously + + XMMRegister Acc1L = A; + XMMRegister Acc2L = B; + __ vpsubq(Acc2, Acc1, modulus, Assembler::AVX_512bit); + + // digit 0 carry out + // Also split Acc1 and Acc2 into two 256-bit vectors each {Acc1, Acc1L} and + // {Acc2, Acc2L} to use 256bit operations + __ evpsraq(Carry, limb0, Acc2, 52, false, Assembler::AVX_256bit); + __ evpandq(Acc2L, limb0, Acc2, Mask52, false, Assembler::AVX_256bit); + __ evpermq(Acc2, allLimbs, shift1R, Acc2, false, Assembler::AVX_512bit); + __ vpaddq(Acc2, Acc2, Carry, Assembler::AVX_256bit); + + __ evpsraq(Carry, limb0, Acc1, 52, false, Assembler::AVX_256bit); + __ evpandq(Acc1L, limb0, Acc1, Mask52, false, Assembler::AVX_256bit); __ evpermq(Acc1, allLimbs, shift1R, Acc1, false, Assembler::AVX_512bit); + __ vpaddq(Acc1, Acc1, Carry, Assembler::AVX_256bit); + + /* remaining digits carry + * Note1: Carry register contains just the carry for the particular + * column (zero-mask the rest) and gets progressively shifted left + * Note2: 'element shift' with vpermq is more expensive, so using vpalignr when + * possible. vpalignr shifts 'right' not left, so place the carry appropiately + * +--+--+--+--+ +--+--+--+--+ +--+--+ + * vpalignr(X, X, X, 8): |x4|x3|x2|x1| >> |x2|x1|x2|x1| |x1|x2| + * +--+--+--+--+ +--+--+--+--+ >> +--+--+ + * | +--+--+--+--+ +--+--+ + * | |x4|x3|x4|x3| |x3|x4| + * | +--+--+--+--+ +--+--+ + * | vv + * | +--+--+--+--+ + * (x3 and x1 is effectively shifted +------------------------> |x3|x4|x1|x2| + * left; zero-mask everything but one column of interest) +--+--+--+--+ + */ + for (int i = 1; i<4; i++) { + __ evpsraq(Carry, masks[i-1], Acc2, 52, false, Assembler::AVX_256bit); + if (i == 1 || i == 3) { + __ vpalignr(Carry, Carry, Carry, 8, Assembler::AVX_256bit); + } else { + __ vpermq(Carry, Carry, 0b10010011, Assembler::AVX_256bit); + } + __ vpaddq(Acc2, Acc2, Carry, Assembler::AVX_256bit); + + __ evpsraq(Carry, masks[i-1], Acc1, 52, false, Assembler::AVX_256bit); + if (i == 1 || i == 3) { + __ vpalignr(Carry, Carry, Carry, 8, Assembler::AVX_256bit); + } else { + __ vpermq(Carry, Carry, 0b10010011, Assembler::AVX_256bit); //0b-2-1-0-3 + } + __ vpaddq(Acc1, Acc1, Carry, Assembler::AVX_256bit); + } - // Acc1 = mask52(Acc1) - __ evpandq(Acc1, Acc1, mask52, Assembler::AVX_512bit); // Clear top 12 bits - - // Acc2 += carry - __ evpaddq(Acc2, allLimbs, carry, Acc2, true, Assembler::AVX_512bit); + // Iff Acc2 is negative, then Acc1 contains the result. + // if Acc2 is negative, upper 12 bits will be set; arithmetic shift by 64 bits + // generates a mask from Acc2 sign bit + __ evpsraq(Carry, Acc2, 64, Assembler::AVX_256bit); + __ vpermq(Carry, Carry, 0b11111111, Assembler::AVX_256bit); //0b-3-3-3-3 + __ evpandq(Acc1, Acc1, Mask52, Assembler::AVX_256bit); + __ evpandq(Acc2, Acc2, Mask52, Assembler::AVX_256bit); - // Acc1 = Acc1 + Acc2 - __ vpaddq(Acc1, Acc1, Acc2, Assembler::AVX_512bit); + // Acc2 = (Acc1 & Mask) | (Acc2 & !Mask) + __ vpandn(Acc2L, Carry, Acc2L, Assembler::AVX_256bit); + __ vpternlogq(Acc2L, 0xF8, Carry, Acc1L, Assembler::AVX_256bit); // A | B&C orAandBC + __ vpandn(Acc2, Carry, Acc2, Assembler::AVX_256bit); + __ vpternlogq(Acc2, 0xF8, Carry, Acc1, Assembler::AVX_256bit); // output to rLimbs (1 + 4 limbs) - __ movq(Address(rLimbs, 0), Acc1); - __ evpermq(Acc1, k0, shift1R, Acc1, true, Assembler::AVX_512bit); - __ evmovdquq(Address(rLimbs, 8), k0, Acc1, true, Assembler::AVX_256bit); + __ movq(Address(rLimbs, 0), Acc2L); + __ evmovdquq(Address(rLimbs, 8), Acc2, Assembler::AVX_256bit); + + // Cleanup + // Zero out zmm0-zmm15, higher registers not used by intrinsic. + __ vzeroall(); +} + +/** + * Unrolled Word-by-Word Montgomery Multiplication + * r = a * b * 2^-260 (mod P) + * + * Use vpmadd52{l,h}uq multiply for upper four limbs and use + * scalar mulq for the lowest limb. + * + * One has to be careful with mulq vs vpmadd52 'crossovers'; mulq high/low + * is split as 40:64 bits vs 52:52 in the vector version. Shifts are required + * to line up values before addition (see following ascii art) + * + * Pseudocode: + * + * +--+--+--+--+ +--+ + * M = load(*modulus_p256) |m5|m4|m3|m2| |m1| + * +--+--+--+--+ +--+ + * A = load(*aLimbs) |a5|a4|a3|a2| |a1| + * +--+--+--+--+ +--+ + * Acc1 = 0 | 0| 0| 0| 0| | 0| + * +--+--+--+--+ +--+ + * ---- for i = 0 to 4 + * +--+--+--+--+ +--+ + * Acc2 = 0 | 0| 0| 0| 0| | 0| + * +--+--+--+--+ +--+ + * B = replicate(bLimbs[i]) |bi|bi|bi|bi| |bi| + * +--+--+--+--+ +--+ + * +--+--+--+--+ +--+ + * |a5|a4|a3|a2| |a1| + * Acc1 += A * B *|bi|bi|bi|bi| |bi| + * Acc1+=|c5|c4|c3|c2| |c1| + * +--+--+--+--+ +--+ + * |a5|a4|a3|a2| |a1| + * Acc2 += A *h B *h|bi|bi|bi|bi| |bi| + * Acc2+=|d5|d4|d3|d2| |d1| + * +--+--+--+--+ +--+ + * N = replicate(Acc1[0]) |n0|n0|n0|n0| |n0| + * +--+--+--+--+ +--+ + * +--+--+--+--+ +--+ + * |m5|m4|m3|m2| |m1| + * Acc1 += M * N *|n0|n0|n0|n0| |n0| + * Acc1+=|c5|c4|c3|c2| |c1| Note: 52 low bits of c1 == 0 due to Montgomery! + * +--+--+--+--+ +--+ + * |m5|m4|m3|m2| |m1| + * Acc2 += M *h N *h|n0|n0|n0|n0| |n0| + * Acc2+=|d5|d4|d3|d2| |d1| + * +--+--+--+--+ +--+ + * // Combine high/low partial sums Acc1 + Acc2 + * +--+ + * carry = Acc1[0] >> 52 |c1| + * +--+ + * Acc2[0] += carry |d1| + * +--+ + * +--+--+--+--+ +--+ + * Acc1 = Acc1 shift one q element>> | 0|c5|c4|c3| |c2| + * +|d5|d4|d3|d2| |d1| + * Acc1 = Acc1 + Acc2 Acc1+=|c5|c4|c3|c2| |c1| + * +--+--+--+--+ +--+ + * ---- done + * +--+--+--+--+ +--+ + * Acc2 = Acc1 - M |d5|d4|d3|d2| |d1| + * +--+--+--+--+ +--+ + * Carry propagate Acc2 + * Carry propagate Acc1 + * Mask = sign(Acc2) + * Result = select(Mask ? Acc1 or Acc2) + * + * Acc1 can overflow by one modulus (hence Acc2); Either Acc1 or Acc2 contain + * the correct result. However, they both need carry propagation (i.e. normalize + * limbs down to 52 bits each). + * + * Carry propagation would require relatively expensive vector lane operations, + * so instead dump to memory and read as scalar registers + * + * Note: the order of reduce-then-propagate vs propagate-then-reduce is different + * in Java + */ +void montgomeryMultiplyAVX2(const Register aLimbs, const Register bLimbs, const Register rLimbs, + const Register tmp_rax, const Register tmp_rdx, const Register tmp1, const Register tmp2, + const Register tmp3, const Register tmp4, const Register tmp5, const Register tmp6, + const Register tmp7, MacroAssembler* _masm) { + Register rscratch = tmp1; + + // Inputs + Register a = tmp1; + XMMRegister A = xmm0; + XMMRegister B = xmm1; + + // Intermediates + Register acc1 = tmp2; + XMMRegister Acc1 = xmm3; + Register acc2 = tmp3; + XMMRegister Acc2 = xmm4; + XMMRegister N = xmm5; + XMMRegister Carry = xmm6; + + // Constants + Register modulus = tmp4; + XMMRegister Modulus = xmm7; + Register mask52 = tmp5; + XMMRegister Mask52 = xmm8; + XMMRegister MaskLimb5 = xmm9; + XMMRegister Zero = xmm10; + + __ mov64(mask52, P256_MASK52[0]); + __ movq(Mask52, mask52); + __ vpbroadcastq(Mask52, Mask52, Assembler::AVX_256bit); + __ vmovdqa(MaskLimb5, ExternalAddress(mask_limb5()), Assembler::AVX_256bit, rscratch); + __ vpxor(Zero, Zero, Zero, Assembler::AVX_256bit); + + // M = load(*modulus_p256) + __ movq(modulus, mask52); + __ vmovdqu(Modulus, ExternalAddress(modulus_p256(1)), Assembler::AVX_256bit, rscratch); + + // A = load(*aLimbs); + __ movq(a, Address(aLimbs, 0)); + __ vmovdqu(A, Address(aLimbs, 8)); //Assembler::AVX_256bit + + // Acc1 = 0 + __ vpxor(Acc1, Acc1, Acc1, Assembler::AVX_256bit); + for (int i = 0; i< 5; i++) { + // Acc2 = 0 + __ vpxor(Acc2, Acc2, Acc2, Assembler::AVX_256bit); + + // B = replicate(bLimbs[i]) + __ movq(tmp_rax, Address(bLimbs, i*8)); //(b==rax) + __ vpbroadcastq(B, Address(bLimbs, i*8), Assembler::AVX_256bit); + + // Acc1 += A * B + // Acc2 += A *h B + __ mulq(a); // rdx:rax = a*rax + if (i == 0) { + __ movq(acc1, tmp_rax); + __ movq(acc2, tmp_rdx); + } else { + // Careful with limb size/carries; from mulq, tmp_rax uses full 64 bits + __ xorq(acc2, acc2); + __ addq(acc1, tmp_rax); + __ adcq(acc2, tmp_rdx); + } + __ vpmadd52luq(Acc1, A, B, Assembler::AVX_256bit); + __ vpmadd52huq(Acc2, A, B, Assembler::AVX_256bit); + + // N = replicate(Acc1[0]) + if (i != 0) { + __ movq(tmp_rax, acc1); // (n==rax) + } + __ andq(tmp_rax, mask52); + __ movq(N, acc1); // masking implicit in vpmadd52 + __ vpbroadcastq(N, N, Assembler::AVX_256bit); + + // Acc1 += M * N + __ mulq(modulus); // rdx:rax = modulus*rax + __ vpmadd52luq(Acc1, Modulus, N, Assembler::AVX_256bit); + __ addq(acc1, tmp_rax); //carry flag set! + + // Acc2 += M *h N + __ adcq(acc2, tmp_rdx); + __ vpmadd52huq(Acc2, Modulus, N, Assembler::AVX_256bit); + + // Combine high/low partial sums Acc1 + Acc2 + + // carry = Acc1[0] >> 52 + __ shrq(acc1, 52); // low 52 of acc1 ignored, is zero, because Montgomery + + // Acc2[0] += carry + __ shlq(acc2, 12); + __ addq(acc2, acc1); + + // Acc1 = Acc1 shift one q element >> + __ movq(acc1, Acc1); + __ vpermq(Acc1, Acc1, 0b11111001, Assembler::AVX_256bit); + __ vpand(Acc1, Acc1, MaskLimb5, Assembler::AVX_256bit); + + // Acc1 = Acc1 + Acc2 + __ addq(acc1, acc2); + __ vpaddq(Acc1, Acc1, Acc2, Assembler::AVX_256bit); + } + + __ movq(acc2, acc1); + __ subq(acc2, modulus); + __ vpsubq(Acc2, Acc1, Modulus, Assembler::AVX_256bit); + __ vmovdqa(Address(rsp, 0), Acc2); //Assembler::AVX_256bit + + // Carry propagate the subtraction result Acc2 first (since the last carry is + // used to select result). Careful, following registers overlap: + // acc1 = tmp2; acc2 = tmp3; mask52 = tmp5 + // Note that Acc2 limbs are signed (i.e. result of a subtract with modulus) + // i.e. using signed shift is needed for correctness + Register limb[] = {acc2, tmp1, tmp4, tmp_rdx, tmp6}; + Register carry = tmp_rax; + for (int i = 0; i<5; i++) { + if (i > 0) { + __ movq(limb[i], Address(rsp, -8+i*8)); + __ addq(limb[i], carry); + } + __ movq(carry, limb[i]); + if (i==4) break; + __ sarq(carry, 52); + } + __ sarq(carry, 63); + __ notq(carry); //select + Register select = carry; + carry = tmp7; + + // Now carry propagate the multiply result and (constant-time) select correct + // output digit + Register digit = acc1; + __ vmovdqa(Address(rsp, 0), Acc1); //Assembler::AVX_256bit + + for (int i = 0; i<5; i++) { + if (i>0) { + __ movq(digit, Address(rsp, -8+i*8)); + __ addq(digit, carry); + } + __ movq(carry, digit); + __ sarq(carry, 52); + + // long dummyLimbs = maskValue & (a[i] ^ b[i]); + // a[i] = dummyLimbs ^ a[i]; + __ xorq(limb[i], digit); + __ andq(limb[i], select); + __ xorq(digit, limb[i]); + + __ andq(digit, mask52); + __ movq(Address(rLimbs, i*8), digit); + } + + // Cleanup + // Zero out ymm0-ymm15. + __ vzeroall(); + __ vpxor(Acc1, Acc1, Acc1, Assembler::AVX_256bit); + __ vmovdqa(Address(rsp, 0), Acc1); //Assembler::AVX_256bit } address StubGenerator::generate_intpoly_montgomeryMult_P256() { @@ -241,13 +563,58 @@ address StubGenerator::generate_intpoly_montgomeryMult_P256() { address start = __ pc(); __ enter(); - // Register Map - const Register aLimbs = c_rarg0; // rdi | rcx - const Register bLimbs = c_rarg1; // rsi | rdx - const Register rLimbs = c_rarg2; // rdx | r8 - const Register tmp = r9; - - montgomeryMultiply(aLimbs, bLimbs, rLimbs, tmp, _masm); + if (EnableX86ECoreOpts && UseAVX > 1) { + __ push(r12); + __ push(r13); + __ push(r14); + #ifdef _WIN64 + __ push(rsi); + __ push(rdi); + #endif + __ push(rbp); + __ movq(rbp, rsp); + __ andq(rsp, -32); + __ subptr(rsp, 32); + + // Register Map + const Register aLimbs = c_rarg0; // c_rarg0: rdi | rcx + const Register bLimbs = rsi; // c_rarg1: rsi | rdx + const Register rLimbs = r8; // c_rarg2: rdx | r8 + const Register tmp1 = r9; + const Register tmp2 = r10; + const Register tmp3 = r11; + const Register tmp4 = r12; + const Register tmp5 = r13; + const Register tmp6 = r14; + #ifdef _WIN64 + const Register tmp7 = rdi; + __ movq(bLimbs, c_rarg1); // free-up rdx + #else + const Register tmp7 = rcx; + __ movq(rLimbs, c_rarg2); // free-up rdx + #endif + + montgomeryMultiplyAVX2(aLimbs, bLimbs, rLimbs, rax, rdx, + tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, _masm); + + __ movq(rsp, rbp); + __ pop(rbp); + #ifdef _WIN64 + __ pop(rdi); + __ pop(rsi); + #endif + __ pop(r14); + __ pop(r13); + __ pop(r12); + } else { + // Register Map + const Register aLimbs = c_rarg0; // rdi | rcx + const Register bLimbs = c_rarg1; // rsi | rdx + const Register rLimbs = c_rarg2; // rdx | r8 + const Register tmp = r9; + + montgomeryMultiply(aLimbs, bLimbs, rLimbs, tmp, _masm); + } __ leave(); __ ret(0); @@ -258,18 +625,35 @@ address StubGenerator::generate_intpoly_montgomeryMult_P256() { // Must be: // - constant time (i.e. no branches) // - no-side channel (i.e. all memory must always be accessed, and in same order) -void assign_avx(XMMRegister A, Address aAddr, XMMRegister B, Address bAddr, KRegister select, int vector_len, MacroAssembler* _masm) { - __ evmovdquq(A, aAddr, vector_len); - __ evmovdquq(B, bAddr, vector_len); - __ evmovdquq(A, select, B, true, vector_len); - __ evmovdquq(aAddr, A, vector_len); +void assign_avx(Register aBase, Register bBase, int offset, XMMRegister select, XMMRegister tmp, XMMRegister aTmp, int vector_len, MacroAssembler* _masm) { + if (vector_len == Assembler::AVX_512bit && UseAVX < 3) { + assign_avx(aBase, bBase, offset, select, tmp, aTmp, Assembler::AVX_256bit, _masm); + assign_avx(aBase, bBase, offset + 32, select, tmp, aTmp, Assembler::AVX_256bit, _masm); + return; + } + + Address aAddr = Address(aBase, offset); + Address bAddr = Address(bBase, offset); + + // Original java: + // long dummyLimbs = maskValue & (a[i] ^ b[i]); + // a[i] = dummyLimbs ^ a[i]; + __ vmovdqu(tmp, aAddr, vector_len); + __ vmovdqu(aTmp, tmp, vector_len); + __ vpxor(tmp, tmp, bAddr, vector_len); + __ vpand(tmp, tmp, select, vector_len); + __ vpxor(tmp, tmp, aTmp, vector_len); + __ vmovdqu(aAddr, tmp, vector_len); } -void assign_scalar(Address aAddr, Address bAddr, Register select, Register tmp, MacroAssembler* _masm) { +void assign_scalar(Register aBase, Register bBase, int offset, Register select, Register tmp, MacroAssembler* _masm) { // Original java: // long dummyLimbs = maskValue & (a[i] ^ b[i]); // a[i] = dummyLimbs ^ a[i]; + Address aAddr = Address(aBase, offset); + Address bAddr = Address(bBase, offset); + __ movq(tmp, aAddr); __ xorq(tmp, bAddr); __ andq(tmp, select); @@ -306,13 +690,18 @@ address StubGenerator::generate_intpoly_assign() { const Register length = c_rarg3; XMMRegister A = xmm0; XMMRegister B = xmm1; + XMMRegister select = xmm2; Register tmp = r9; - KRegister select = k1; Label L_Length5, L_Length10, L_Length14, L_Length16, L_Length19, L_DefaultLoop, L_Done; __ negq(set); - __ kmovql(select, set); + if (UseAVX > 2) { + __ evpbroadcastq(select, set, Assembler::AVX_512bit); + } else { + __ movq(select, set); + __ vpbroadcastq(select, select, Assembler::AVX_256bit); + } // NOTE! Crypto code cannot branch on user input. However; allowed to branch on number of limbs; // Number of limbs is a constant in each IntegerPolynomial (i.e. this side-channel branch leaks @@ -332,7 +721,7 @@ address StubGenerator::generate_intpoly_assign() { __ cmpl(length, 0); __ jcc(Assembler::lessEqual, L_Done); __ bind(L_DefaultLoop); - assign_scalar(Address(aLimbs, 0), Address(bLimbs, 0), set, tmp, _masm); + assign_scalar(aLimbs, bLimbs, 0, set, tmp, _masm); __ subl(length, 1); __ lea(aLimbs, Address(aLimbs,8)); __ lea(bLimbs, Address(bLimbs,8)); @@ -341,31 +730,31 @@ address StubGenerator::generate_intpoly_assign() { __ jmp(L_Done); __ bind(L_Length5); // 1 + 4 - assign_scalar(Address(aLimbs, 0), Address(bLimbs, 0), set, tmp, _masm); - assign_avx(A, Address(aLimbs, 8), B, Address(bLimbs, 8), select, Assembler::AVX_256bit, _masm); + assign_scalar(aLimbs, bLimbs, 0, set, tmp, _masm); + assign_avx (aLimbs, bLimbs, 8, select, A, B, Assembler::AVX_256bit, _masm); __ jmp(L_Done); __ bind(L_Length10); // 2 + 8 - assign_avx(A, Address(aLimbs, 0), B, Address(bLimbs, 0), select, Assembler::AVX_128bit, _masm); - assign_avx(A, Address(aLimbs, 16), B, Address(bLimbs, 16), select, Assembler::AVX_512bit, _masm); + assign_avx(aLimbs, bLimbs, 0, select, A, B, Assembler::AVX_128bit, _masm); + assign_avx(aLimbs, bLimbs, 16, select, A, B, Assembler::AVX_512bit, _masm); __ jmp(L_Done); __ bind(L_Length14); // 2 + 4 + 8 - assign_avx(A, Address(aLimbs, 0), B, Address(bLimbs, 0), select, Assembler::AVX_128bit, _masm); - assign_avx(A, Address(aLimbs, 16), B, Address(bLimbs, 16), select, Assembler::AVX_256bit, _masm); - assign_avx(A, Address(aLimbs, 48), B, Address(bLimbs, 48), select, Assembler::AVX_512bit, _masm); + assign_avx(aLimbs, bLimbs, 0, select, A, B, Assembler::AVX_128bit, _masm); + assign_avx(aLimbs, bLimbs, 16, select, A, B, Assembler::AVX_256bit, _masm); + assign_avx(aLimbs, bLimbs, 48, select, A, B, Assembler::AVX_512bit, _masm); __ jmp(L_Done); __ bind(L_Length16); // 8 + 8 - assign_avx(A, Address(aLimbs, 0), B, Address(bLimbs, 0), select, Assembler::AVX_512bit, _masm); - assign_avx(A, Address(aLimbs, 64), B, Address(bLimbs, 64), select, Assembler::AVX_512bit, _masm); + assign_avx(aLimbs, bLimbs, 0, select, A, B, Assembler::AVX_512bit, _masm); + assign_avx(aLimbs, bLimbs, 64, select, A, B, Assembler::AVX_512bit, _masm); __ jmp(L_Done); __ bind(L_Length19); // 1 + 2 + 8 + 8 - assign_scalar(Address(aLimbs, 0), Address(bLimbs, 0), set, tmp, _masm); - assign_avx(A, Address(aLimbs, 8), B, Address(bLimbs, 8), select, Assembler::AVX_128bit, _masm); - assign_avx(A, Address(aLimbs, 24), B, Address(bLimbs, 24), select, Assembler::AVX_512bit, _masm); - assign_avx(A, Address(aLimbs, 88), B, Address(bLimbs, 88), select, Assembler::AVX_512bit, _masm); + assign_scalar(aLimbs, bLimbs, 0, set, tmp, _masm); + assign_avx (aLimbs, bLimbs, 8, select, A, B, Assembler::AVX_128bit, _masm); + assign_avx (aLimbs, bLimbs, 24, select, A, B, Assembler::AVX_512bit, _masm); + assign_avx (aLimbs, bLimbs, 88, select, A, B, Assembler::AVX_512bit, _masm); __ bind(L_Done); __ leave(); diff --git a/src/hotspot/cpu/x86/vm_version_x86.cpp b/src/hotspot/cpu/x86/vm_version_x86.cpp index cc438ce951f96..395bcd7992438 100644 --- a/src/hotspot/cpu/x86/vm_version_x86.cpp +++ b/src/hotspot/cpu/x86/vm_version_x86.cpp @@ -1403,7 +1403,7 @@ void VM_Version::get_processor_features() { } #ifdef _LP64 - if (supports_avx512ifma() && supports_avx512vlbw()) { + if ((supports_avx512ifma() && supports_avx512vlbw()) || supports_avxifma()) { if (FLAG_IS_DEFAULT(UseIntPolyIntrinsics)) { FLAG_SET_DEFAULT(UseIntPolyIntrinsics, true); } diff --git a/src/hotspot/share/classfile/vmIntrinsics.hpp b/src/hotspot/share/classfile/vmIntrinsics.hpp index 6ec0222324e06..32dd5c4c7f6fd 100644 --- a/src/hotspot/share/classfile/vmIntrinsics.hpp +++ b/src/hotspot/share/classfile/vmIntrinsics.hpp @@ -532,7 +532,7 @@ class methodHandle; /* support for sun.security.util.math.intpoly.MontgomeryIntegerPolynomialP256 */ \ do_class(sun_security_util_math_intpoly_MontgomeryIntegerPolynomialP256, "sun/security/util/math/intpoly/MontgomeryIntegerPolynomialP256") \ do_intrinsic(_intpoly_montgomeryMult_P256, sun_security_util_math_intpoly_MontgomeryIntegerPolynomialP256, intPolyMult_name, intPolyMult_signature, F_R) \ - do_name(intPolyMult_name, "multImpl") \ + do_name(intPolyMult_name, "mult") \ do_signature(intPolyMult_signature, "([J[J[J)V") \ \ do_class(sun_security_util_math_intpoly_IntegerPolynomial, "sun/security/util/math/intpoly/IntegerPolynomial") \ diff --git a/src/java.base/share/classes/sun/security/util/math/intpoly/MontgomeryIntegerPolynomialP256.java b/src/java.base/share/classes/sun/security/util/math/intpoly/MontgomeryIntegerPolynomialP256.java index e50890bd976e7..954713bea5fe1 100644 --- a/src/java.base/share/classes/sun/security/util/math/intpoly/MontgomeryIntegerPolynomialP256.java +++ b/src/java.base/share/classes/sun/security/util/math/intpoly/MontgomeryIntegerPolynomialP256.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -159,14 +159,8 @@ protected void square(long[] a, long[] r) { * numAdds to reuse existing overflow logic. */ @Override - protected void mult(long[] a, long[] b, long[] r) { - multImpl(a, b, r); - reducePositive(r); - } - - @ForceInline @IntrinsicCandidate - private void multImpl(long[] a, long[] b, long[] r) { + protected void mult(long[] a, long[] b, long[] r) { long aa0 = a[0]; long aa1 = a[1]; long aa2 = a[2]; @@ -398,17 +392,43 @@ private void multImpl(long[] a, long[] b, long[] r) { dd4 += Math.unsignedMultiplyHigh(n, modulus[4]) << shift1 | (n4 >>> shift2); d4 += n4 & LIMB_MASK; + // Final carry propagate c5 += d1 + dd0 + (d0 >>> BITS_PER_LIMB); - c6 += d2 + dd1; - c7 += d3 + dd2; - c8 += d4 + dd3; - c9 = dd4; - - r[0] = c5; - r[1] = c6; - r[2] = c7; - r[3] = c8; - r[4] = c9; + c6 += d2 + dd1 + (c5 >>> BITS_PER_LIMB); + c7 += d3 + dd2 + (c6 >>> BITS_PER_LIMB); + c8 += d4 + dd3 + (c7 >>> BITS_PER_LIMB); + c9 = dd4 + (c8 >>> BITS_PER_LIMB); + + c5 &= LIMB_MASK; + c6 &= LIMB_MASK; + c7 &= LIMB_MASK; + c8 &= LIMB_MASK; + + // At this point, the result {c5, c6, c7, c8, c9} could overflow by + // one modulus. Subtract one modulus (with carry propagation), into + // {c0, c1, c2, c3, c4}. Note that in this calculation, limbs are + // signed + c0 = c5 - modulus[0]; + c1 = c6 - modulus[1] + (c0 >> BITS_PER_LIMB); + c0 &= LIMB_MASK; + c2 = c7 - modulus[2] + (c1 >> BITS_PER_LIMB); + c1 &= LIMB_MASK; + c3 = c8 - modulus[3] + (c2 >> BITS_PER_LIMB); + c2 &= LIMB_MASK; + c4 = c9 - modulus[4] + (c3 >> BITS_PER_LIMB); + c3 &= LIMB_MASK; + + // We now must select a result that is in range of [0,modulus). i.e. + // either {c0-4} or {c5-9}. Iff {c0-4} is negative, then {c5-9} contains + // the result. (After carry propagation) IF c4 is negative, {c0-4} is + // negative. Arithmetic shift by 64 bits generates a mask from c4 that + // can be used to select 'constant time' either {c0-4} or {c5-9}. + long mask = c4 >> 63; + r[0] = ((c5 & mask) | (c0 & ~mask)); + r[1] = ((c6 & mask) | (c1 & ~mask)); + r[2] = ((c7 & mask) | (c2 & ~mask)); + r[3] = ((c8 & mask) | (c3 & ~mask)); + r[4] = ((c9 & mask) | (c4 & ~mask)); } @Override diff --git a/test/jdk/com/sun/security/util/math/intpoly/MontgomeryPolynomialFuzzTest.java b/test/jdk/com/sun/security/util/math/intpoly/MontgomeryPolynomialFuzzTest.java index 8c76a312f5339..f2e91d71c2168 100644 --- a/test/jdk/com/sun/security/util/math/intpoly/MontgomeryPolynomialFuzzTest.java +++ b/test/jdk/com/sun/security/util/math/intpoly/MontgomeryPolynomialFuzzTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, Intel Corporation. All rights reserved. + * Copyright (c) 2024, 2025, Intel Corporation. All rights reserved. * * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -23,9 +23,8 @@ */ import java.util.Random; -import sun.security.util.math.IntegerMontgomeryFieldModuloP; -import sun.security.util.math.ImmutableIntegerModuloP; import java.math.BigInteger; +import sun.security.util.math.*; import sun.security.util.math.intpoly.*; /* @@ -35,7 +34,7 @@ * java.base/sun.security.util.math.intpoly * @run main/othervm -XX:+UnlockDiagnosticVMOptions -XX:-UseIntPolyIntrinsics * MontgomeryPolynomialFuzzTest - * @summary Unit test MontgomeryPolynomialFuzzTest. + * @summary Unit test MontgomeryPolynomialFuzzTest without intrinsic, plain java */ /* @@ -45,10 +44,11 @@ * java.base/sun.security.util.math.intpoly * @run main/othervm -XX:+UnlockDiagnosticVMOptions -XX:+UseIntPolyIntrinsics * MontgomeryPolynomialFuzzTest - * @summary Unit test MontgomeryPolynomialFuzzTest. + * @summary Unit test MontgomeryPolynomialFuzzTest with intrinsic enabled */ -// This test case is NOT entirely deterministic, it uses a random seed for pseudo-random number generator +// This test case is NOT entirely deterministic, it uses a random seed for +// pseudo-random number generator // If a failure occurs, hardcode the seed to make the test case deterministic public class MontgomeryPolynomialFuzzTest { public static void main(String[] args) throws Exception { @@ -60,15 +60,38 @@ public static void main(String[] args) throws Exception { System.out.println("Fuzz Success"); } - private static void check(BigInteger reference, + private static void checkOverflow(String opMsg, ImmutableIntegerModuloP testValue, long seed) { - if (!reference.equals(testValue.asBigInteger())) { - throw new RuntimeException("SEED: " + seed); + long limbs[] = testValue.getLimbs(); + BigInteger mod = MontgomeryIntegerPolynomialP256.ONE.MODULUS; + BigInteger ref = BigInteger.ZERO; + for (int i = 0; i