Skip to content

Commit b25d894

Browse files
RealFYangArd Biesheuveldgbo
committed
8252204: AArch64: Implement SHA3 accelerator/intrinsic
Co-authored-by: Ard Biesheuvel <ard.biesheuvel@linaro.org> Co-authored-by: Dong Bo <dongbo4@huawei.com> Reviewed-by: aph, kvn
1 parent 7d3d4da commit b25d894

36 files changed

+1243
-255
lines changed

src/hotspot/cpu/aarch64/aarch64-asmtest.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,44 @@ def astr(self):
11101110
+ ('\t%s, %s, %s.2D' % (self.reg[0].astr("q"),
11111111
self.reg[1].astr("q"), self.reg[2].astr("v"))))
11121112

1113+
class SHA3SIMDOp(Instruction):
1114+
1115+
def generate(self):
1116+
if ((self._name == 'eor3') or (self._name == 'bcax')):
1117+
self.reg = [FloatRegister().generate(), FloatRegister().generate(),
1118+
FloatRegister().generate(), FloatRegister().generate()]
1119+
else:
1120+
self.reg = [FloatRegister().generate(), FloatRegister().generate(),
1121+
FloatRegister().generate()]
1122+
if (self._name == 'xar'):
1123+
self.imm6 = random.randint(0, 63)
1124+
return self
1125+
1126+
def cstr(self):
1127+
if ((self._name == 'eor3') or (self._name == 'bcax')):
1128+
return (super(SHA3SIMDOp, self).cstr()
1129+
+ ('%s, __ T16B, %s, %s, %s);' % (self.reg[0], self.reg[1], self.reg[2], self.reg[3])))
1130+
elif (self._name == 'rax1'):
1131+
return (super(SHA3SIMDOp, self).cstr()
1132+
+ ('%s, __ T2D, %s, %s);' % (self.reg[0], self.reg[1], self.reg[2])))
1133+
else:
1134+
return (super(SHA3SIMDOp, self).cstr()
1135+
+ ('%s, __ T2D, %s, %s, %s);' % (self.reg[0], self.reg[1], self.reg[2], self.imm6)))
1136+
1137+
def astr(self):
1138+
if ((self._name == 'eor3') or (self._name == 'bcax')):
1139+
return (super(SHA3SIMDOp, self).astr()
1140+
+ ('\t%s.16B, %s.16B, %s.16B, %s.16B' % (self.reg[0].astr("v"), self.reg[1].astr("v"),
1141+
self.reg[2].astr("v"), self.reg[3].astr("v"))))
1142+
elif (self._name == 'rax1'):
1143+
return (super(SHA3SIMDOp, self).astr()
1144+
+ ('\t%s.2D, %s.2D, %s.2D') % (self.reg[0].astr("v"), self.reg[1].astr("v"),
1145+
self.reg[2].astr("v")))
1146+
else:
1147+
return (super(SHA3SIMDOp, self).astr()
1148+
+ ('\t%s.2D, %s.2D, %s.2D, #%s') % (self.reg[0].astr("v"), self.reg[1].astr("v"),
1149+
self.reg[2].astr("v"), self.imm6))
1150+
11131151
class LSEOp(Instruction):
11141152
def __init__(self, args):
11151153
self._name, self.asmname, self.size, self.suffix = args
@@ -1441,8 +1479,6 @@ def generate(kind, names):
14411479
["fcmge", "fcmge", "2D"],
14421480
])
14431481

1444-
generate(SHA512SIMDOp, ["sha512h", "sha512h2", "sha512su0", "sha512su1"])
1445-
14461482
generate(SpecialCases, [["ccmn", "__ ccmn(zr, zr, 3u, Assembler::LE);", "ccmn\txzr, xzr, #3, LE"],
14471483
["ccmnw", "__ ccmnw(zr, zr, 5u, Assembler::EQ);", "ccmn\twzr, wzr, #5, EQ"],
14481484
["ccmp", "__ ccmp(zr, 1, 4u, Assembler::NE);", "ccmp\txzr, 1, #4, NE"],
@@ -1517,6 +1553,11 @@ def generate(kind, names):
15171553
["ldumin", "ldumin", size, suffix],
15181554
["ldumax", "ldumax", size, suffix]]);
15191555

1556+
# ARMv8.2A
1557+
generate(SHA3SIMDOp, ["bcax", "eor3", "rax1", "xar"])
1558+
1559+
generate(SHA512SIMDOp, ["sha512h", "sha512h2", "sha512su0", "sha512su1"])
1560+
15201561
generate(SVEVectorOp, [["add", "ZZZ"],
15211562
["sub", "ZZZ"],
15221563
["fadd", "ZZZ"],
@@ -1565,8 +1606,8 @@ def generate(kind, names):
15651606

15661607
outfile.close()
15671608

1568-
# compile for sve with 8.1 and sha2 because of lse atomics and sha512 crypto extension.
1569-
subprocess.check_call([AARCH64_AS, "-march=armv8.1-a+sha2+sve", "aarch64ops.s", "-o", "aarch64ops.o"])
1609+
# compile for sve with 8.2 and sha3 because of SHA3 crypto extension.
1610+
subprocess.check_call([AARCH64_AS, "-march=armv8.2-a+sha3+sve", "aarch64ops.s", "-o", "aarch64ops.o"])
15701611

15711612
print
15721613
print "/*"

src/hotspot/cpu/aarch64/assembler_aarch64.cpp

Lines changed: 184 additions & 177 deletions
Large diffs are not rendered by default.

src/hotspot/cpu/aarch64/assembler_aarch64.hpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2550,6 +2550,40 @@ void mvnw(Register Rd, Register Rm,
25502550

25512551
#undef INSN
25522552

2553+
#define INSN(NAME, opc) \
2554+
void NAME(FloatRegister Vd, SIMD_Arrangement T, FloatRegister Vn, FloatRegister Vm, FloatRegister Va) { \
2555+
starti; \
2556+
assert(T == T16B, "arrangement must be T16B"); \
2557+
f(0b11001110, 31, 24), f(opc, 23, 21), rf(Vm, 16), f(0b0, 15, 15), rf(Va, 10), rf(Vn, 5), rf(Vd, 0); \
2558+
}
2559+
2560+
INSN(eor3, 0b000);
2561+
INSN(bcax, 0b001);
2562+
2563+
#undef INSN
2564+
2565+
#define INSN(NAME, opc) \
2566+
void NAME(FloatRegister Vd, SIMD_Arrangement T, FloatRegister Vn, FloatRegister Vm, unsigned imm) { \
2567+
starti; \
2568+
assert(T == T2D, "arrangement must be T2D"); \
2569+
f(0b11001110, 31, 24), f(opc, 23, 21), rf(Vm, 16), f(imm, 15, 10), rf(Vn, 5), rf(Vd, 0); \
2570+
}
2571+
2572+
INSN(xar, 0b100);
2573+
2574+
#undef INSN
2575+
2576+
#define INSN(NAME, opc) \
2577+
void NAME(FloatRegister Vd, SIMD_Arrangement T, FloatRegister Vn, FloatRegister Vm) { \
2578+
starti; \
2579+
assert(T == T2D, "arrangement must be T2D"); \
2580+
f(0b11001110, 31, 24), f(opc, 23, 21), rf(Vm, 16), f(0b100011, 15, 10), rf(Vn, 5), rf(Vd, 0); \
2581+
}
2582+
2583+
INSN(rax1, 0b011);
2584+
2585+
#undef INSN
2586+
25532587
#define INSN(NAME, opc) \
25542588
void NAME(FloatRegister Vd, FloatRegister Vn) { \
25552589
starti; \

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3309,6 +3309,225 @@ class StubGenerator: public StubCodeGenerator {
33093309
return start;
33103310
}
33113311

3312+
// Arguments:
3313+
//
3314+
// Inputs:
3315+
// c_rarg0 - byte[] source+offset
3316+
// c_rarg1 - byte[] SHA.state
3317+
// c_rarg2 - int digest_length
3318+
// c_rarg3 - int offset
3319+
// c_rarg4 - int limit
3320+
//
3321+
address generate_sha3_implCompress(bool multi_block, const char *name) {
3322+
static const uint64_t round_consts[24] = {
3323+
0x0000000000000001L, 0x0000000000008082L, 0x800000000000808AL,
3324+
0x8000000080008000L, 0x000000000000808BL, 0x0000000080000001L,
3325+
0x8000000080008081L, 0x8000000000008009L, 0x000000000000008AL,
3326+
0x0000000000000088L, 0x0000000080008009L, 0x000000008000000AL,
3327+
0x000000008000808BL, 0x800000000000008BL, 0x8000000000008089L,
3328+
0x8000000000008003L, 0x8000000000008002L, 0x8000000000000080L,
3329+
0x000000000000800AL, 0x800000008000000AL, 0x8000000080008081L,
3330+
0x8000000000008080L, 0x0000000080000001L, 0x8000000080008008L
3331+
};
3332+
3333+
__ align(CodeEntryAlignment);
3334+
StubCodeMark mark(this, "StubRoutines", name);
3335+
address start = __ pc();
3336+
3337+
Register buf = c_rarg0;
3338+
Register state = c_rarg1;
3339+
Register digest_length = c_rarg2;
3340+
Register ofs = c_rarg3;
3341+
Register limit = c_rarg4;
3342+
3343+
Label sha3_loop, rounds24_loop;
3344+
Label sha3_512, sha3_384_or_224, sha3_256;
3345+
3346+
__ stpd(v8, v9, __ pre(sp, -64));
3347+
__ stpd(v10, v11, Address(sp, 16));
3348+
__ stpd(v12, v13, Address(sp, 32));
3349+
__ stpd(v14, v15, Address(sp, 48));
3350+
3351+
// load state
3352+
__ add(rscratch1, state, 32);
3353+
__ ld1(v0, v1, v2, v3, __ T1D, state);
3354+
__ ld1(v4, v5, v6, v7, __ T1D, __ post(rscratch1, 32));
3355+
__ ld1(v8, v9, v10, v11, __ T1D, __ post(rscratch1, 32));
3356+
__ ld1(v12, v13, v14, v15, __ T1D, __ post(rscratch1, 32));
3357+
__ ld1(v16, v17, v18, v19, __ T1D, __ post(rscratch1, 32));
3358+
__ ld1(v20, v21, v22, v23, __ T1D, __ post(rscratch1, 32));
3359+
__ ld1(v24, __ T1D, rscratch1);
3360+
3361+
__ BIND(sha3_loop);
3362+
3363+
// 24 keccak rounds
3364+
__ movw(rscratch2, 24);
3365+
3366+
// load round_constants base
3367+
__ lea(rscratch1, ExternalAddress((address) round_consts));
3368+
3369+
// load input
3370+
__ ld1(v25, v26, v27, v28, __ T8B, __ post(buf, 32));
3371+
__ ld1(v29, v30, v31, __ T8B, __ post(buf, 24));
3372+
__ eor(v0, __ T8B, v0, v25);
3373+
__ eor(v1, __ T8B, v1, v26);
3374+
__ eor(v2, __ T8B, v2, v27);
3375+
__ eor(v3, __ T8B, v3, v28);
3376+
__ eor(v4, __ T8B, v4, v29);
3377+
__ eor(v5, __ T8B, v5, v30);
3378+
__ eor(v6, __ T8B, v6, v31);
3379+
3380+
// digest_length == 64, SHA3-512
3381+
__ tbnz(digest_length, 6, sha3_512);
3382+
3383+
__ ld1(v25, v26, v27, v28, __ T8B, __ post(buf, 32));
3384+
__ ld1(v29, v30, __ T8B, __ post(buf, 16));
3385+
__ eor(v7, __ T8B, v7, v25);
3386+
__ eor(v8, __ T8B, v8, v26);
3387+
__ eor(v9, __ T8B, v9, v27);
3388+
__ eor(v10, __ T8B, v10, v28);
3389+
__ eor(v11, __ T8B, v11, v29);
3390+
__ eor(v12, __ T8B, v12, v30);
3391+
3392+
// digest_length == 28, SHA3-224; digest_length == 48, SHA3-384
3393+
__ tbnz(digest_length, 4, sha3_384_or_224);
3394+
3395+
// SHA3-256
3396+
__ ld1(v25, v26, v27, v28, __ T8B, __ post(buf, 32));
3397+
__ eor(v13, __ T8B, v13, v25);
3398+
__ eor(v14, __ T8B, v14, v26);
3399+
__ eor(v15, __ T8B, v15, v27);
3400+
__ eor(v16, __ T8B, v16, v28);
3401+
__ b(rounds24_loop);
3402+
3403+
__ BIND(sha3_384_or_224);
3404+
__ tbz(digest_length, 2, rounds24_loop); // bit 2 cleared? SHA-384
3405+
3406+
// SHA3-224
3407+
__ ld1(v25, v26, v27, v28, __ T8B, __ post(buf, 32));
3408+
__ ld1(v29, __ T8B, __ post(buf, 8));
3409+
__ eor(v13, __ T8B, v13, v25);
3410+
__ eor(v14, __ T8B, v14, v26);
3411+
__ eor(v15, __ T8B, v15, v27);
3412+
__ eor(v16, __ T8B, v16, v28);
3413+
__ eor(v17, __ T8B, v17, v29);
3414+
__ b(rounds24_loop);
3415+
3416+
__ BIND(sha3_512);
3417+
__ ld1(v25, v26, __ T8B, __ post(buf, 16));
3418+
__ eor(v7, __ T8B, v7, v25);
3419+
__ eor(v8, __ T8B, v8, v26);
3420+
3421+
__ BIND(rounds24_loop);
3422+
__ subw(rscratch2, rscratch2, 1);
3423+
3424+
__ eor3(v29, __ T16B, v4, v9, v14);
3425+
__ eor3(v26, __ T16B, v1, v6, v11);
3426+
__ eor3(v28, __ T16B, v3, v8, v13);
3427+
__ eor3(v25, __ T16B, v0, v5, v10);
3428+
__ eor3(v27, __ T16B, v2, v7, v12);
3429+
__ eor3(v29, __ T16B, v29, v19, v24);
3430+
__ eor3(v26, __ T16B, v26, v16, v21);
3431+
__ eor3(v28, __ T16B, v28, v18, v23);
3432+
__ eor3(v25, __ T16B, v25, v15, v20);
3433+
__ eor3(v27, __ T16B, v27, v17, v22);
3434+
3435+
__ rax1(v30, __ T2D, v29, v26);
3436+
__ rax1(v26, __ T2D, v26, v28);
3437+
__ rax1(v28, __ T2D, v28, v25);
3438+
__ rax1(v25, __ T2D, v25, v27);
3439+
__ rax1(v27, __ T2D, v27, v29);
3440+
3441+
__ eor(v0, __ T16B, v0, v30);
3442+
__ xar(v29, __ T2D, v1, v25, (64 - 1));
3443+
__ xar(v1, __ T2D, v6, v25, (64 - 44));
3444+
__ xar(v6, __ T2D, v9, v28, (64 - 20));
3445+
__ xar(v9, __ T2D, v22, v26, (64 - 61));
3446+
__ xar(v22, __ T2D, v14, v28, (64 - 39));
3447+
__ xar(v14, __ T2D, v20, v30, (64 - 18));
3448+
__ xar(v31, __ T2D, v2, v26, (64 - 62));
3449+
__ xar(v2, __ T2D, v12, v26, (64 - 43));
3450+
__ xar(v12, __ T2D, v13, v27, (64 - 25));
3451+
__ xar(v13, __ T2D, v19, v28, (64 - 8));
3452+
__ xar(v19, __ T2D, v23, v27, (64 - 56));
3453+
__ xar(v23, __ T2D, v15, v30, (64 - 41));
3454+
__ xar(v15, __ T2D, v4, v28, (64 - 27));
3455+
__ xar(v28, __ T2D, v24, v28, (64 - 14));
3456+
__ xar(v24, __ T2D, v21, v25, (64 - 2));
3457+
__ xar(v8, __ T2D, v8, v27, (64 - 55));
3458+
__ xar(v4, __ T2D, v16, v25, (64 - 45));
3459+
__ xar(v16, __ T2D, v5, v30, (64 - 36));
3460+
__ xar(v5, __ T2D, v3, v27, (64 - 28));
3461+
__ xar(v27, __ T2D, v18, v27, (64 - 21));
3462+
__ xar(v3, __ T2D, v17, v26, (64 - 15));
3463+
__ xar(v25, __ T2D, v11, v25, (64 - 10));
3464+
__ xar(v26, __ T2D, v7, v26, (64 - 6));
3465+
__ xar(v30, __ T2D, v10, v30, (64 - 3));
3466+
3467+
__ bcax(v20, __ T16B, v31, v22, v8);
3468+
__ bcax(v21, __ T16B, v8, v23, v22);
3469+
__ bcax(v22, __ T16B, v22, v24, v23);
3470+
__ bcax(v23, __ T16B, v23, v31, v24);
3471+
__ bcax(v24, __ T16B, v24, v8, v31);
3472+
3473+
__ ld1r(v31, __ T2D, __ post(rscratch1, 8));
3474+
3475+
__ bcax(v17, __ T16B, v25, v19, v3);
3476+
__ bcax(v18, __ T16B, v3, v15, v19);
3477+
__ bcax(v19, __ T16B, v19, v16, v15);
3478+
__ bcax(v15, __ T16B, v15, v25, v16);
3479+
__ bcax(v16, __ T16B, v16, v3, v25);
3480+
3481+
__ bcax(v10, __ T16B, v29, v12, v26);
3482+
__ bcax(v11, __ T16B, v26, v13, v12);
3483+
__ bcax(v12, __ T16B, v12, v14, v13);
3484+
__ bcax(v13, __ T16B, v13, v29, v14);
3485+
__ bcax(v14, __ T16B, v14, v26, v29);
3486+
3487+
__ bcax(v7, __ T16B, v30, v9, v4);
3488+
__ bcax(v8, __ T16B, v4, v5, v9);
3489+
__ bcax(v9, __ T16B, v9, v6, v5);
3490+
__ bcax(v5, __ T16B, v5, v30, v6);
3491+
__ bcax(v6, __ T16B, v6, v4, v30);
3492+
3493+
__ bcax(v3, __ T16B, v27, v0, v28);
3494+
__ bcax(v4, __ T16B, v28, v1, v0);
3495+
__ bcax(v0, __ T16B, v0, v2, v1);
3496+
__ bcax(v1, __ T16B, v1, v27, v2);
3497+
__ bcax(v2, __ T16B, v2, v28, v27);
3498+
3499+
__ eor(v0, __ T16B, v0, v31);
3500+
3501+
__ cbnzw(rscratch2, rounds24_loop);
3502+
3503+
if (multi_block) {
3504+
// block_size = 200 - 2 * digest_length, ofs += block_size
3505+
__ add(ofs, ofs, 200);
3506+
__ sub(ofs, ofs, digest_length, Assembler::LSL, 1);
3507+
3508+
__ cmp(ofs, limit);
3509+
__ br(Assembler::LE, sha3_loop);
3510+
__ mov(c_rarg0, ofs); // return ofs
3511+
}
3512+
3513+
__ st1(v0, v1, v2, v3, __ T1D, __ post(state, 32));
3514+
__ st1(v4, v5, v6, v7, __ T1D, __ post(state, 32));
3515+
__ st1(v8, v9, v10, v11, __ T1D, __ post(state, 32));
3516+
__ st1(v12, v13, v14, v15, __ T1D, __ post(state, 32));
3517+
__ st1(v16, v17, v18, v19, __ T1D, __ post(state, 32));
3518+
__ st1(v20, v21, v22, v23, __ T1D, __ post(state, 32));
3519+
__ st1(v24, __ T1D, state);
3520+
3521+
__ ldpd(v14, v15, Address(sp, 48));
3522+
__ ldpd(v12, v13, Address(sp, 32));
3523+
__ ldpd(v10, v11, Address(sp, 16));
3524+
__ ldpd(v8, v9, __ post(sp, 64));
3525+
3526+
__ ret(lr);
3527+
3528+
return start;
3529+
}
3530+
33123531
// Safefetch stubs.
33133532
void generate_safefetch(const char* name, int size, address* entry,
33143533
address* fault_pc, address* continuation_pc) {
@@ -6048,6 +6267,10 @@ class StubGenerator: public StubCodeGenerator {
60486267
StubRoutines::_sha512_implCompress = generate_sha512_implCompress(false, "sha512_implCompress");
60496268
StubRoutines::_sha512_implCompressMB = generate_sha512_implCompress(true, "sha512_implCompressMB");
60506269
}
6270+
if (UseSHA3Intrinsics) {
6271+
StubRoutines::_sha3_implCompress = generate_sha3_implCompress(false, "sha3_implCompress");
6272+
StubRoutines::_sha3_implCompressMB = generate_sha3_implCompress(true, "sha3_implCompressMB");
6273+
}
60516274

60526275
// generate Adler32 intrinsics code
60536276
if (UseAdler32Intrinsics) {

src/hotspot/cpu/aarch64/vm_version_aarch64.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ void VM_Version::initialize() {
194194
if (_features & CPU_AES) strcat(buf, ", aes");
195195
if (_features & CPU_SHA1) strcat(buf, ", sha1");
196196
if (_features & CPU_SHA2) strcat(buf, ", sha256");
197+
if (_features & CPU_SHA3) strcat(buf, ", sha3");
197198
if (_features & CPU_SHA512) strcat(buf, ", sha512");
198199
if (_features & CPU_LSE) strcat(buf, ", lse");
199200
if (_features & CPU_SVE) strcat(buf, ", sve");
@@ -275,7 +276,7 @@ void VM_Version::initialize() {
275276
FLAG_SET_DEFAULT(UseMD5Intrinsics, false);
276277
}
277278

278-
if (_features & (CPU_SHA1 | CPU_SHA2)) {
279+
if (_features & (CPU_SHA1 | CPU_SHA2 | CPU_SHA3 | CPU_SHA512)) {
279280
if (FLAG_IS_DEFAULT(UseSHA)) {
280281
FLAG_SET_DEFAULT(UseSHA, true);
281282
}
@@ -302,6 +303,16 @@ void VM_Version::initialize() {
302303
FLAG_SET_DEFAULT(UseSHA256Intrinsics, false);
303304
}
304305

306+
if (UseSHA && (_features & CPU_SHA3)) {
307+
// Do not auto-enable UseSHA3Intrinsics until it has been fully tested on hardware
308+
// if (FLAG_IS_DEFAULT(UseSHA3Intrinsics)) {
309+
// FLAG_SET_DEFAULT(UseSHA3Intrinsics, true);
310+
// }
311+
} else if (UseSHA3Intrinsics) {
312+
warning("Intrinsics for SHA3-224, SHA3-256, SHA3-384 and SHA3-512 crypto hash functions not available on this CPU.");
313+
FLAG_SET_DEFAULT(UseSHA3Intrinsics, false);
314+
}
315+
305316
if (UseSHA && (_features & CPU_SHA512)) {
306317
// Do not auto-enable UseSHA512Intrinsics until it has been fully tested on hardware
307318
// if (FLAG_IS_DEFAULT(UseSHA512Intrinsics)) {
@@ -312,7 +323,7 @@ void VM_Version::initialize() {
312323
FLAG_SET_DEFAULT(UseSHA512Intrinsics, false);
313324
}
314325

315-
if (!(UseSHA1Intrinsics || UseSHA256Intrinsics || UseSHA512Intrinsics)) {
326+
if (!(UseSHA1Intrinsics || UseSHA256Intrinsics || UseSHA3Intrinsics || UseSHA512Intrinsics)) {
316327
FLAG_SET_DEFAULT(UseSHA, false);
317328
}
318329

0 commit comments

Comments
 (0)