Skip to content

Commit

Permalink
8313760: [REDO] Enhance AES performance
Browse files Browse the repository at this point in the history
Co-authored-by: Andrew Haley <aph@openjdk.org>
Reviewed-by: adinn, aph, sviswanathan, rhalade, kvn, dlong
  • Loading branch information
chhagedorn and Andrew Haley committed Aug 16, 2023
1 parent d46f0fb commit 49ddb19
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 37 deletions.
48 changes: 34 additions & 14 deletions src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2944,6 +2944,23 @@ class StubGenerator: public StubCodeGenerator {
return start;
}

// Big-endian 128-bit + 64-bit -> 128-bit addition.
// Inputs: 128-bits. in is preserved.
// The least-significant 64-bit word is in the upper dword of each vector.
// inc (the 64-bit increment) is preserved. Its lower dword must be zero.
// Output: result
void be_add_128_64(FloatRegister result, FloatRegister in,
FloatRegister inc, FloatRegister tmp) {
assert_different_registers(result, tmp, inc);

__ addv(result, __ T2D, in, inc); // Add inc to the least-significant dword of
// input
__ cm(__ HI, tmp, __ T2D, inc, result);// Check for result overflowing
__ ext(tmp, __ T16B, tmp, tmp, 0x08); // Swap LSD of comparison result to MSD and
// MSD == 0 (must be!) to LSD
__ subv(result, __ T2D, result, tmp); // Subtract -1 from MSD if there was an overflow
}

// CTR AES crypt.
// Arguments:
//
Expand Down Expand Up @@ -3053,13 +3070,16 @@ class StubGenerator: public StubCodeGenerator {
// Setup the counter
__ movi(v4, __ T4S, 0);
__ movi(v5, __ T4S, 1);
__ ins(v4, __ S, v5, 3, 3); // v4 contains { 0, 0, 0, 1 }
__ ins(v4, __ S, v5, 2, 2); // v4 contains { 0, 1 }

__ ld1(v0, __ T16B, counter); // Load the counter into v0
__ rev32(v16, __ T16B, v0);
__ addv(v16, __ T4S, v16, v4);
__ rev32(v16, __ T16B, v16);
__ st1(v16, __ T16B, counter); // Save the incremented counter back
// 128-bit big-endian increment
__ ld1(v0, __ T16B, counter);
__ rev64(v16, __ T16B, v0);
be_add_128_64(v16, v16, v4, /*tmp*/v5);
__ rev64(v16, __ T16B, v16);
__ st1(v16, __ T16B, counter);
// Previous counter value is in v0
// v4 contains { 0, 1 }

{
// We have fewer than bulk_width blocks of data left. Encrypt
Expand Down Expand Up @@ -3091,9 +3111,9 @@ class StubGenerator: public StubCodeGenerator {

// Increment the counter, store it back
__ orr(v0, __ T16B, v16, v16);
__ rev32(v16, __ T16B, v16);
__ addv(v16, __ T4S, v16, v4);
__ rev32(v16, __ T16B, v16);
__ rev64(v16, __ T16B, v16);
be_add_128_64(v16, v16, v4, /*tmp*/v5);
__ rev64(v16, __ T16B, v16);
__ st1(v16, __ T16B, counter); // Save the incremented counter back

__ b(inner_loop);
Expand Down Expand Up @@ -3141,7 +3161,7 @@ class StubGenerator: public StubCodeGenerator {
// Keys should already be loaded into the correct registers

__ ld1(v0, __ T16B, counter); // v0 contains the first counter
__ rev32(v16, __ T16B, v0); // v16 contains byte-reversed counter
__ rev64(v16, __ T16B, v0); // v16 contains byte-reversed counter

// AES/CTR loop
{
Expand All @@ -3151,12 +3171,12 @@ class StubGenerator: public StubCodeGenerator {
// Setup the counters
__ movi(v8, __ T4S, 0);
__ movi(v9, __ T4S, 1);
__ ins(v8, __ S, v9, 3, 3); // v8 contains { 0, 0, 0, 1 }
__ ins(v8, __ S, v9, 2, 2); // v8 contains { 0, 1 }

for (int i = 0; i < bulk_width; i++) {
FloatRegister v0_ofs = as_FloatRegister(v0->encoding() + i);
__ rev32(v0_ofs, __ T16B, v16);
__ addv(v16, __ T4S, v16, v8);
__ rev64(v0_ofs, __ T16B, v16);
be_add_128_64(v16, v16, v8, /*tmp*/v9);
}

__ ld1(v8, v9, v10, v11, __ T16B, __ post(in, 4 * 16));
Expand Down Expand Up @@ -3186,7 +3206,7 @@ class StubGenerator: public StubCodeGenerator {
}

// Save the counter back where it goes
__ rev32(v16, __ T16B, v16);
__ rev64(v16, __ T16B, v16);
__ st1(v16, __ T16B, counter);

__ pop(saved_regs, sp);
Expand Down
8 changes: 8 additions & 0 deletions src/hotspot/cpu/x86/assembler_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4431,6 +4431,14 @@ void Assembler::evpcmpuw(KRegister kdst, XMMRegister nds, XMMRegister src, Compa
emit_int24(0x3E, (0xC0 | encode), vcc);
}

void Assembler::evpcmpuq(KRegister kdst, XMMRegister nds, XMMRegister src, ComparisonPredicate vcc, int vector_len) {
assert(VM_Version::supports_avx512vl(), "");
InstructionAttr attributes(vector_len, /* rex_w */ true, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true);
attributes.set_is_evex_instruction();
int encode = vex_prefix_and_encode(kdst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_66, VEX_OPCODE_0F_3A, &attributes);
emit_int24(0x1E, (0xC0 | encode), vcc);
}

void Assembler::evpcmpuw(KRegister kdst, XMMRegister nds, Address src, ComparisonPredicate vcc, int vector_len) {
assert(VM_Version::supports_avx512vlbw(), "");
InstructionMark im(this);
Expand Down
2 changes: 2 additions & 0 deletions src/hotspot/cpu/x86/assembler_x86.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1806,6 +1806,8 @@ class Assembler : public AbstractAssembler {
void evpcmpuw(KRegister kdst, XMMRegister nds, XMMRegister src, ComparisonPredicate vcc, int vector_len);
void evpcmpuw(KRegister kdst, XMMRegister nds, Address src, ComparisonPredicate vcc, int vector_len);

void evpcmpuq(KRegister kdst, XMMRegister nds, XMMRegister src, ComparisonPredicate vcc, int vector_len);

void pcmpeqw(XMMRegister dst, XMMRegister src);
void vpcmpeqw(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len);
void evpcmpeqw(KRegister kdst, XMMRegister nds, XMMRegister src, int vector_len);
Expand Down
11 changes: 11 additions & 0 deletions src/hotspot/cpu/x86/macroAssembler_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9257,6 +9257,17 @@ void MacroAssembler::evpandq(XMMRegister dst, XMMRegister nds, AddressLiteral sr
}
}

void MacroAssembler::evpaddq(XMMRegister dst, KRegister mask, XMMRegister nds, AddressLiteral src, bool merge, int vector_len, Register rscratch) {
assert(rscratch != noreg || always_reachable(src), "missing");

if (reachable(src)) {
Assembler::evpaddq(dst, mask, nds, as_Address(src), merge, vector_len);
} else {
lea(rscratch, src);
Assembler::evpaddq(dst, mask, nds, Address(rscratch, 0), merge, vector_len);
}
}

void MacroAssembler::evporq(XMMRegister dst, XMMRegister nds, AddressLiteral src, int vector_len, Register rscratch) {
assert(rscratch != noreg || always_reachable(src), "missing");

Expand Down
3 changes: 3 additions & 0 deletions src/hotspot/cpu/x86/macroAssembler_x86.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1788,6 +1788,9 @@ class MacroAssembler: public Assembler {
using Assembler::evpandq;
void evpandq(XMMRegister dst, XMMRegister nds, AddressLiteral src, int vector_len, Register rscratch = noreg);

using Assembler::evpaddq;
void evpaddq(XMMRegister dst, KRegister mask, XMMRegister nds, AddressLiteral src, bool merge, int vector_len, Register rscratch = noreg);

using Assembler::evporq;
void evporq(XMMRegister dst, XMMRegister nds, AddressLiteral src, int vector_len, Register rscratch = noreg);

Expand Down
3 changes: 2 additions & 1 deletion src/hotspot/cpu/x86/stubGenerator_x86_64.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,8 @@ class StubGenerator: public StubCodeGenerator {

// Utility routine for increase 128bit counter (iv in CTR mode)
void inc_counter(Register reg, XMMRegister xmmdst, int inc_delta, Label& next_block);

void ev_add128(XMMRegister xmmdst, XMMRegister xmmsrc1, XMMRegister xmmsrc2,
int vector_len, KRegister ktmp, Register rscratch = noreg);
void generate_aes_stubs();


Expand Down
68 changes: 46 additions & 22 deletions src/hotspot/cpu/x86/stubGenerator_x86_64_aes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ static address counter_mask_linc32_addr() {
return (address)COUNTER_MASK_LINC32;
}

ATTRIBUTE_ALIGNED(64) uint64_t COUNTER_MASK_ONES[] = {
0x0000000000000000UL, 0x0000000000000001UL,
0x0000000000000000UL, 0x0000000000000001UL,
0x0000000000000000UL, 0x0000000000000001UL,
0x0000000000000000UL, 0x0000000000000001UL,
};
static address counter_mask_ones_addr() {
return (address)COUNTER_MASK_ONES;
}

ATTRIBUTE_ALIGNED(64) static const uint64_t GHASH_POLYNOMIAL_REDUCTION[] = {
0x00000001C2000000UL, 0xC200000000000000UL,
0x00000001C2000000UL, 0xC200000000000000UL,
Expand Down Expand Up @@ -1623,6 +1633,17 @@ void StubGenerator::ev_load_key(XMMRegister xmmdst, Register key, int offset, Re
__ evshufi64x2(xmmdst, xmmdst, xmmdst, 0x0, Assembler::AVX_512bit);
}

// Add 128-bit integers in xmmsrc1 to xmmsrc2, then place the result in xmmdst.
// Clobber ktmp and rscratch.
// Used by aesctr_encrypt.
void StubGenerator::ev_add128(XMMRegister xmmdst, XMMRegister xmmsrc1, XMMRegister xmmsrc2,
int vector_len, KRegister ktmp, Register rscratch) {
__ vpaddq(xmmdst, xmmsrc1, xmmsrc2, vector_len);
__ evpcmpuq(ktmp, xmmdst, xmmsrc2, __ lt, vector_len);
__ kshiftlbl(ktmp, ktmp, 1);
__ evpaddq(xmmdst, ktmp, xmmdst, ExternalAddress(counter_mask_ones_addr()), /*merge*/true,
vector_len, rscratch);
}

// AES-ECB Encrypt Operation
void StubGenerator::aesecb_encrypt(Register src_addr, Register dest_addr, Register key, Register len) {
Expand Down Expand Up @@ -2046,7 +2067,6 @@ void StubGenerator::aesecb_decrypt(Register src_addr, Register dest_addr, Regist
}



// AES Counter Mode using VAES instructions
void StubGenerator::aesctr_encrypt(Register src_addr, Register dest_addr, Register key, Register counter,
Register len_reg, Register used, Register used_addr, Register saved_encCounter_start) {
Expand Down Expand Up @@ -2104,14 +2124,17 @@ void StubGenerator::aesctr_encrypt(Register src_addr, Register dest_addr, Regist
// The counter is incremented after each block i.e. 16 bytes is processed;
// each zmm register has 4 counter values as its MSB
// the counters are incremented in parallel
__ vpaddd(xmm8, xmm8, ExternalAddress(counter_mask_linc0_addr()), Assembler::AVX_512bit, r15 /*rscratch*/);
__ vpaddd(xmm9, xmm8, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/);
__ vpaddd(xmm10, xmm9, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/);
__ vpaddd(xmm11, xmm10, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/);
__ vpaddd(xmm12, xmm11, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/);
__ vpaddd(xmm13, xmm12, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/);
__ vpaddd(xmm14, xmm13, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/);
__ vpaddd(xmm15, xmm14, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/);

__ evmovdquq(xmm19, ExternalAddress(counter_mask_linc0_addr()), Assembler::AVX_512bit, r15 /*rscratch*/);
ev_add128(xmm8, xmm8, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
__ evmovdquq(xmm19, ExternalAddress(counter_mask_linc4_addr()), Assembler::AVX_512bit, r15 /*rscratch*/);
ev_add128(xmm9, xmm8, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
ev_add128(xmm10, xmm9, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
ev_add128(xmm11, xmm10, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
ev_add128(xmm12, xmm11, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
ev_add128(xmm13, xmm12, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
ev_add128(xmm14, xmm13, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
ev_add128(xmm15, xmm14, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);

// load linc32 mask in zmm register.linc32 increments counter by 32
__ evmovdquq(xmm19, ExternalAddress(counter_mask_linc32_addr()), Assembler::AVX_512bit, r15 /*rscratch*/);
Expand Down Expand Up @@ -2159,21 +2182,21 @@ void StubGenerator::aesctr_encrypt(Register src_addr, Register dest_addr, Regist
// This is followed by incrementing counter values in zmm8-zmm15.
// Since we will be processing 32 blocks at a time, the counter is incremented by 32.
roundEnc(xmm21, 7);
__ vpaddq(xmm8, xmm8, xmm19, Assembler::AVX_512bit);
ev_add128(xmm8, xmm8, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm22, 7);
__ vpaddq(xmm9, xmm9, xmm19, Assembler::AVX_512bit);
ev_add128(xmm9, xmm9, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm23, 7);
__ vpaddq(xmm10, xmm10, xmm19, Assembler::AVX_512bit);
ev_add128(xmm10, xmm10, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm24, 7);
__ vpaddq(xmm11, xmm11, xmm19, Assembler::AVX_512bit);
ev_add128(xmm11, xmm11, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm25, 7);
__ vpaddq(xmm12, xmm12, xmm19, Assembler::AVX_512bit);
ev_add128(xmm12, xmm12, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm26, 7);
__ vpaddq(xmm13, xmm13, xmm19, Assembler::AVX_512bit);
ev_add128(xmm13, xmm13, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm27, 7);
__ vpaddq(xmm14, xmm14, xmm19, Assembler::AVX_512bit);
ev_add128(xmm14, xmm14, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm28, 7);
__ vpaddq(xmm15, xmm15, xmm19, Assembler::AVX_512bit);
ev_add128(xmm15, xmm15, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
roundEnc(xmm29, 7);

__ cmpl(rounds, 52);
Expand Down Expand Up @@ -2251,8 +2274,8 @@ void StubGenerator::aesctr_encrypt(Register src_addr, Register dest_addr, Regist
__ vpshufb(xmm3, xmm11, xmm16, Assembler::AVX_512bit);
__ evpxorq(xmm3, xmm3, xmm20, Assembler::AVX_512bit);
// Increment counter values by 16
__ vpaddq(xmm8, xmm8, xmm19, Assembler::AVX_512bit);
__ vpaddq(xmm9, xmm9, xmm19, Assembler::AVX_512bit);
ev_add128(xmm8, xmm8, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
ev_add128(xmm9, xmm9, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
// AES encode rounds
roundEnc(xmm21, 3);
roundEnc(xmm22, 3);
Expand Down Expand Up @@ -2319,7 +2342,7 @@ void StubGenerator::aesctr_encrypt(Register src_addr, Register dest_addr, Regist
__ vpshufb(xmm1, xmm9, xmm16, Assembler::AVX_512bit);
__ evpxorq(xmm1, xmm1, xmm20, Assembler::AVX_512bit);
// increment counter by 8
__ vpaddq(xmm8, xmm8, xmm19, Assembler::AVX_512bit);
ev_add128(xmm8, xmm8, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
// AES encode
roundEnc(xmm21, 1);
roundEnc(xmm22, 1);
Expand Down Expand Up @@ -2376,8 +2399,9 @@ void StubGenerator::aesctr_encrypt(Register src_addr, Register dest_addr, Regist
// XOR counter with first roundkey
__ vpshufb(xmm0, xmm8, xmm16, Assembler::AVX_512bit);
__ evpxorq(xmm0, xmm0, xmm20, Assembler::AVX_512bit);

// Increment counter
__ vpaddq(xmm8, xmm8, xmm19, Assembler::AVX_512bit);
ev_add128(xmm8, xmm8, xmm19, Assembler::AVX_512bit, /*ktmp*/k1, r15 /*rscratch*/);
__ vaesenc(xmm0, xmm0, xmm21, Assembler::AVX_512bit);
__ vaesenc(xmm0, xmm0, xmm22, Assembler::AVX_512bit);
__ vaesenc(xmm0, xmm0, xmm23, Assembler::AVX_512bit);
Expand Down Expand Up @@ -2427,7 +2451,7 @@ void StubGenerator::aesctr_encrypt(Register src_addr, Register dest_addr, Regist
__ evpxorq(xmm0, xmm0, xmm20, Assembler::AVX_128bit);
__ vaesenc(xmm0, xmm0, xmm21, Assembler::AVX_128bit);
// Increment counter by 1
__ vpaddq(xmm8, xmm8, xmm19, Assembler::AVX_128bit);
ev_add128(xmm8, xmm8, xmm19, Assembler::AVX_128bit, /*ktmp*/k1, r15 /*rscratch*/);
__ vaesenc(xmm0, xmm0, xmm22, Assembler::AVX_128bit);
__ vaesenc(xmm0, xmm0, xmm23, Assembler::AVX_128bit);
__ vaesenc(xmm0, xmm0, xmm24, Assembler::AVX_128bit);
Expand Down

1 comment on commit 49ddb19

@openjdk-notifier
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.