Skip to content
Closed
2 changes: 1 addition & 1 deletion src/hotspot/cpu/aarch64/register_aarch64.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ template<int N> bool vs_write_before_read(const VSeq<N>& vout, const VSeq<N>& vi
return false;
}

// convenience methods for splitting 8-way of 4-way vector register
// convenience methods for splitting 8-way or 4-way vector register
// sequences in half -- needed because vector operations can normally
// benefit from 4-way instruction parallelism or, occasionally, 2-way
// parallelism
Expand Down
147 changes: 117 additions & 30 deletions src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5009,7 +5009,7 @@ class StubGenerator: public StubCodeGenerator {
assert(vs_disjoint(va, vtmp), "va and vtmp overlap");
assert(!va.is_constant(), "output vector must identify 2 different registers");

// schedule 2 streams of i<nstructions across the vector sequences
// schedule 2 streams of instructions across the vector sequences
for (int i = 0; i < 2; i++) {
__ sqdmulh(vtmp[i], T, vb[i], vc[i]); // aHigh = hi32(2 * b * c)
__ mulv(va[i], T, vb[i], vc[i]); // aLow = lo32(b * c)
Expand Down Expand Up @@ -5163,11 +5163,17 @@ class StubGenerator: public StubCodeGenerator {
// that we can use the 16-bit arithmetic in the vector unit.
//
// On each level, we fill up the vector registers in such a way that the
// array elements that need to be multiplied by the zetas be in one
// array elements that need to be multiplied by the zetas go into one
// set of vector registers while the corresponding ones that don't need to
// be multiplied, in another set. We can do 32 Montgomery multiplications
// in parallel, using 12 vector registers interleaving the steps of 4
// identical computations, each done on 8 16-bit values per register.
// be multiplied, go into another set.
// We can do 32 Montgomery multiplications in parallel, using 12 vector
// registers interleaving the steps of 4 identical computations,
// each done on 8 16-bit values per register.

// At levels 0-3 the coefficients multiplied by or added/subtracted
// to the zetas occur in discrete blocks whose size is some multiple
// of 32.

// level 0
__ add(tmpAddr, coeffs, 256);
load64shorts(vs1, tmpAddr);
Expand Down Expand Up @@ -5274,6 +5280,9 @@ class StubGenerator: public StubCodeGenerator {
vs_stpq_indexed(vs3, coeffs, 256 + 32, offsets2);

// level 4
// At level 4 coefficients occur in 8 discrete blocks of size 16
// so they are loaded using employing an ldr at 8 distinct offsets.

vs_ldpq(vq, kyberConsts);
int offsets3[8] = { 0, 32, 64, 96, 128, 160, 192, 224 };
vs_ldr_indexed(vs1, __ Q, coeffs, 16, offsets3);
Expand All @@ -5296,6 +5305,9 @@ class StubGenerator: public StubCodeGenerator {
vs_str_indexed(vs3, __ Q, coeffs, 256 + 16, offsets3);

// level 5
// At level 5 related coefficients occur in discrete blocks of size 8 so
// need to be loaded interleaved using an ld2 operation with arrangement 2D.

vs_ldpq(vq, kyberConsts);
int offsets4[4] = { 0, 32, 64, 96 };
vs_ld2_indexed(vs1, __ T2D, coeffs, tmpAddr, 0, offsets4);
Expand All @@ -5317,6 +5329,9 @@ class StubGenerator: public StubCodeGenerator {
vs_st2_indexed(vs1, __ T2D, coeffs, tmpAddr, 384, offsets4);

// level 6
// At level 6 related coefficients occur in discrete blocks of size 4 so
// need to be loaded interleaved using an ld2 operation with arrangement 4S.

vs_ld2_indexed(vs1, __ T4S, coeffs, tmpAddr, 0, offsets4);
load32shorts(vs_front(vs2), zetas);
kyber_montmul32_sub_add(vs_even(vs1), vs_odd(vs1), vs_front(vs2), vtmp, vq);
Expand Down Expand Up @@ -5373,6 +5388,9 @@ class StubGenerator: public StubCodeGenerator {
ExternalAddress((address) StubRoutines::aarch64::_kyberConsts));

// level 0
// At level 0 related coefficients occur in discrete blocks of size 4 so
// need to be loaded interleaved using an ld2 operation with arrangement 4S.

vs_ldpq(vq, kyberConsts);
int offsets4[4] = { 0, 32, 64, 96 };
vs_ld2_indexed(vs1, __ T4S, coeffs, tmpAddr, 0, offsets4);
Expand All @@ -5397,6 +5415,9 @@ class StubGenerator: public StubCodeGenerator {
vs_st2_indexed(vs1, __ T4S, coeffs, tmpAddr, 384, offsets4);

// level 1
// At level 1 related coefficients occur in discrete blocks of size 8 so
// need to be loaded interleaved using an ld2 operation with arrangement 2D.

vs_ld2_indexed(vs1, __ T2D, coeffs, tmpAddr, 0, offsets4);
load32shorts(vs_front(vs2), zetas);
kyber_sub_add_montmul32(vs_even(vs1), vs_odd(vs1),
Expand All @@ -5420,6 +5441,9 @@ class StubGenerator: public StubCodeGenerator {
vs_st2_indexed(vs1, __ T2D, coeffs, tmpAddr, 384, offsets4);

// level 2
// At level 2 coefficients occur in 8 discrete blocks of size 16
// so they are loaded using employing an ldr at 8 distinct offsets.

int offsets3[8] = { 0, 32, 64, 96, 128, 160, 192, 224 };
vs_ldr_indexed(vs1, __ Q, coeffs, 0, offsets3);
vs_ldr_indexed(vs2, __ Q, coeffs, 16, offsets3);
Expand Down Expand Up @@ -5462,6 +5486,9 @@ class StubGenerator: public StubCodeGenerator {
vs_str_indexed(vs1, __ Q, coeffs, 256, offsets3);

// level 3
// From level 3 upwards coefficients occur in discrete blocks whose size is
// some multiple of 32 so can be loaded using ldpq and suitable indexes.

int offsets2[4] = { 0, 64, 128, 192 };
vs_ldpq_indexed(vs1, coeffs, 0, offsets2);
vs_ldpq_indexed(vs2, coeffs, 32, offsets2);
Expand All @@ -5484,6 +5511,7 @@ class StubGenerator: public StubCodeGenerator {
vs_stpq_indexed(vs2, coeffs, 256 + 32, offsets2);

// level 4

int offsets1[4] = { 0, 32, 128, 160 };
vs_ldpq_indexed(vs1, coeffs, 0, offsets1);
vs_ldpq_indexed(vs2, coeffs, 64, offsets1);
Expand All @@ -5506,6 +5534,7 @@ class StubGenerator: public StubCodeGenerator {
vs_stpq_indexed(vs2, coeffs, 256 + 64, offsets1);

// level 5

__ add(tmpAddr, coeffs, 0);
load64shorts(vs1, tmpAddr);
__ add(tmpAddr, coeffs, 128);
Expand Down Expand Up @@ -5547,6 +5576,7 @@ class StubGenerator: public StubCodeGenerator {
vs_stpq_indexed(vs_front(vs1), coeffs, 0, offsets0);

// level 6

__ add(tmpAddr, coeffs, 0);
load64shorts(vs1, tmpAddr);
__ add(tmpAddr, coeffs, 256);
Expand Down Expand Up @@ -5588,16 +5618,19 @@ class StubGenerator: public StubCodeGenerator {
__ add(tmpAddr, coeffs, 0);
store64shorts(vs2, tmpAddr);

// now tmpAddr contains coeffs + 128 because store64shorts adjusted it so
load64shorts(vs1, tmpAddr);
kyber_montmul64(vs2, vs1, vq3, vtmp, vq);
__ add(tmpAddr, coeffs, 128);
store64shorts(vs2, tmpAddr);

// now tmpAddr contains coeffs + 256
load64shorts(vs1, tmpAddr);
kyber_montmul64(vs2, vs1, vq3, vtmp, vq);
__ add(tmpAddr, coeffs, 256);
store64shorts(vs2, tmpAddr);

// now tmpAddr contains coeffs + 384
load64shorts(vs1, tmpAddr);
kyber_montmul64(vs2, vs1, vq3, vtmp, vq);
__ add(tmpAddr, coeffs, 384);
Expand Down Expand Up @@ -5637,9 +5670,9 @@ class StubGenerator: public StubCodeGenerator {

VSeq<4> vs1(0), vs2(4); // 4 sets of 8x8H inputs/outputs/tmps
VSeq<4> vs3(16), vs4(20);
VSeq<2> vq(30); // pair of constants for montmul
VSeq<2> vq(30); // pair of constants for montmul: q, qinv
VSeq<2> vz(28); // pair of zetas
VSeq<4> vc(27, 0); // constant sequence for montmul
VSeq<4> vc(27, 0); // constant sequence for montmul: montRSquareModQ

__ lea(kyberConsts,
ExternalAddress((address) StubRoutines::aarch64::_kyberConsts));
Expand Down Expand Up @@ -5930,7 +5963,34 @@ class StubGenerator: public StubCodeGenerator {
// load 96 (6 x 16B) byte values
vs_ld3_post(vin, __ T16B, condensed);

// expand groups of input bytes in vin to shorts in va and vb
// The front half of sequence vin (vin[0], vin[1] and vin[2])
// holds 48 (16x3) contiguous bytes from memory striped
// horizontally across each of the 16 byte lanes. Equivalently,
// that is 16 pairs of 12-bit integers. Likewise the back half
// holds the next 48 bytes in the same arrangement.

// Each vector in the front half can also be viewed as a vertical
// strip across the 16 pairs of 12 bit integers. Each byte in
// vin[0] stores the low 8 bits of the first int in a pair. Each
// byte in vin[1] stores the high 4 bits of the first int and the
// low 4 bits of the second int. Each byte in vin[2] stores the
// high 8 bits of the second int. Likewise the vectors in second
// half.

// Converting the data to 16-bit shorts requires first of all
// expanding each of the 6 x 16B vectors into 6 corresponding
// pairs of 8H vectors. Mask, shift and add operations on the
// resulting vector pairs can be used to combine 4 and 8 bit
// parts of related 8H vector elements.
//
// The middle vectors (vin[2] and vin[5]) are actually expanded
// twice, one copy manipulated to provide the lower 4 bits
// belonging to the first short in a pair and another copy
// manipulated to provide the higher 4 bits belonging to the
// second short in a pair. This is why the the vector sequences va
// and vb used to hold the expanded 8H elements are of length 8.

// Expand vin[0] into va[0:1], and vin[1] into va[2:3] and va[4:5]
// n.b. target elements 2 and 3 duplicate elements 4 and 5
__ ushll(va[0], __ T8H, vin[0], __ T8B, 0);
__ ushll2(va[1], __ T8H, vin[0], __ T16B, 0);
Expand All @@ -5939,28 +5999,32 @@ class StubGenerator: public StubCodeGenerator {
__ ushll(va[4], __ T8H, vin[1], __ T8B, 0);
__ ushll2(va[5], __ T8H, vin[1], __ T16B, 0);

// likewise expand vin[3] into vb[0:1], and vin[4] into vb[2:3]
// and vb[4:5]
__ ushll(vb[0], __ T8H, vin[3], __ T8B, 0);
__ ushll2(vb[1], __ T8H, vin[3], __ T16B, 0);
__ ushll(vb[2], __ T8H, vin[4], __ T8B, 0);
__ ushll2(vb[3], __ T8H, vin[4], __ T16B, 0);
__ ushll(vb[4], __ T8H, vin[4], __ T8B, 0);
__ ushll2(vb[5], __ T8H, vin[4], __ T16B, 0);

// offset duplicated elements in va and vb by 8
// shift lo byte of copy 1 of the middle stripe into the high byte
__ shl(va[2], __ T8H, va[2], 8);
__ shl(va[3], __ T8H, va[3], 8);
__ shl(vb[2], __ T8H, vb[2], 8);
__ shl(vb[3], __ T8H, vb[3], 8);

// expand remaining input bytes in vin to shorts in va and vb
// but this time pre-shifted by 4
// expand vin[2] into va[6:7] and vin[5] into vb[6:7] but this
// time pre-shifted by 4 to ensure top bits of input 12-bit int
// are in bit positions [4..11].
__ ushll(va[6], __ T8H, vin[2], __ T8B, 4);
__ ushll2(va[7], __ T8H, vin[2], __ T16B, 4);
__ ushll(vb[6], __ T8H, vin[5], __ T8B, 4);
__ ushll2(vb[7], __ T8H, vin[5], __ T16B, 4);

// split the duplicated 8 bit values into two distinct 4 bit
// upper and lower halves using a mask or a shift
// mask hi 4 bits of the 1st 12-bit int in a pair from copy1 and
// shift lo 4 bits of the 2nd 12-bit int in a pair to the bottom of
// copy2
__ andr(va[2], __ T16B, va[2], v31);
__ andr(va[3], __ T16B, va[3], v31);
__ ushr(va[4], __ T8H, va[4], 4);
Expand All @@ -5970,8 +6034,12 @@ class StubGenerator: public StubCodeGenerator {
__ ushr(vb[4], __ T8H, vb[4], 4);
__ ushr(vb[5], __ T8H, vb[5], 4);

// sum resulting short values into the front halves of va and
// vb pairing registers offset by stride 2
// sum hi 4 bits and lo 8 bits of the 1st 12-bit int in each pair and
// hi 8 bits plus lo 4 bits of the 2nd 12-bit int in each pair
// n.b. the ordering ensures: i) inputs are consumed before they
// are overwritten ii) the order of 16-bit results across successive
// pairs of vectors in va and then vb reflects the order of the
// corresponding 12-bit inputs
__ addv(va[0], __ T8H, va[0], va[2]);
__ addv(va[2], __ T8H, va[1], va[3]);
__ addv(va[1], __ T8H, va[4], va[6]);
Expand All @@ -5981,7 +6049,7 @@ class StubGenerator: public StubCodeGenerator {
__ addv(vb[1], __ T8H, vb[4], vb[6]);
__ addv(vb[3], __ T8H, vb[5], vb[7]);

// store results interleaved as shorts
// store 64 results interleaved as shorts
vs_st2_post(vs_front(va), __ T8H, parsed);
vs_st2_post(vs_front(vb), __ T8H, parsed);

Expand All @@ -5990,13 +6058,14 @@ class StubGenerator: public StubCodeGenerator {
__ br(Assembler::GE, L_loop);
__ cbz(parsedLength, L_end);

// if anything is left it should be a final 72 bytes. so we
// load 48 bytes into both lanes of front(vin) and 24 bytes
// into the lower lane of back(vin)
// if anything is left it should be a final 72 bytes of input
// i.e. a final 48 12-bit values. so we handle this by loading
// 48 bytes into all 16B lanes of front(vin) and only 24
// bytes into the lower 8B lane of back(vin)
vs_ld3_post(vs_front(vin), __ T16B, condensed);
vs_ld3(vs_back(vin), __ T8B, condensed);

// expand groups of input bytes in vin to shorts in va and vb
// Expand vin[0] into va[0:1], and vin[1] into va[2:3] and va[4:5]
// n.b. target elements 2 and 3 of va duplicate elements 4 and
// 5 and target element 2 of vb duplicates element 4.
__ ushll(va[0], __ T8H, vin[0], __ T8B, 0);
Expand All @@ -6006,40 +6075,50 @@ class StubGenerator: public StubCodeGenerator {
__ ushll(va[4], __ T8H, vin[1], __ T8B, 0);
__ ushll2(va[5], __ T8H, vin[1], __ T16B, 0);

// This time expand just the lower 8 lanes
__ ushll(vb[0], __ T8H, vin[3], __ T8B, 0);
__ ushll(vb[2], __ T8H, vin[4], __ T8B, 0);
__ ushll(vb[4], __ T8H, vin[4], __ T8B, 0);

// offset duplicated elements in va and vb by 8
// shift lo byte of copy 1 of the middle stripe into the high byte
__ shl(va[2], __ T8H, va[2], 8);
__ shl(va[3], __ T8H, va[3], 8);
__ shl(vb[2], __ T8H, vb[2], 8);

// expand remaining input bytes in vin to shorts in va and vb
// but this time pre-shifted by 4
// expand vin[2] into va[6:7] and lower 8 lanes of vin[5] into
// vb[6] pre-shifted by 4 to ensure top bits of the input 12-bit
// int are in bit positions [4..11].
__ ushll(va[6], __ T8H, vin[2], __ T8B, 4);
__ ushll2(va[7], __ T8H, vin[2], __ T16B, 4);
__ ushll(vb[6], __ T8H, vin[5], __ T8B, 4);

// split the duplicated 8 bit values into two distinct 4 bit
// upper and lower halves using a mask or a shift
// mask hi 4 bits of each 1st 12-bit int in pair from copy1 and
// shift lo 4 bits of each 2nd 12-bit int in pair to bottom of
// copy2
__ andr(va[2], __ T16B, va[2], v31);
__ andr(va[3], __ T16B, va[3], v31);
__ ushr(va[4], __ T8H, va[4], 4);
__ ushr(va[5], __ T8H, va[5], 4);
__ andr(vb[2], __ T16B, vb[2], v31);
__ ushr(vb[4], __ T8H, vb[4], 4);

// sum resulting short values into the front halves of va and
// vb pairing registers offset by stride 2


// sum hi 4 bits and lo 8 bits of each 1st 12-bit int in pair and
// hi 8 bits plus lo 4 bits of each 2nd 12-bit int in pair

// n.b. ordering ensures: i) inputs are consumed before they are
// overwritten ii) order of 16-bit results across succsessive
// pairs of vectors in va and then lower half of vb reflects order
// of corresponding 12-bit inputs
__ addv(va[0], __ T8H, va[0], va[2]);
__ addv(va[2], __ T8H, va[1], va[3]);
__ addv(va[1], __ T8H, va[4], va[6]);
__ addv(va[3], __ T8H, va[5], va[7]);
__ addv(vb[0], __ T8H, vb[0], vb[2]);
__ addv(vb[1], __ T8H, vb[4], vb[6]);

// store results interleaved as shorts
// store 48 results interleaved as shorts
vs_st2_post(vs_front(va), __ T8H, parsed);
vs_st2_post(vs_front(vs_front(vb)), __ T8H, parsed);

Expand Down Expand Up @@ -6085,13 +6164,14 @@ class StubGenerator: public StubCodeGenerator {
FloatRegister vs2_3 = v29;

// we also need a pair of corresponding constant sequences

VSeq<8> vc1_1(30, 0);
VSeq<2> vc1_2(30, 0);
FloatRegister vc1_3 = v30; // for kyber_q

FloatRegister vc1_3 = v30;
VSeq<8> vc2_1(31, 0);
VSeq<2> vc2_2(31, 0);
FloatRegister vc2_3 = v31;
FloatRegister vc2_3 = v31; // for kyberBarrettMultiplier

__ add(result, coeffs, 0);
__ lea(kyberConsts,
Expand All @@ -6108,21 +6188,28 @@ class StubGenerator: public StubCodeGenerator {
if (i < 2) {
__ ldr(vs1_3, __ Q, __ post(coeffs, 16));
}

// vs2 <- (2 * vs1 * kyberBarrettMultiplier) >> 16
vs_sqdmulh(vs2_1, __ T8H, vs1_1, vc2_1);
vs_sqdmulh(vs2_2, __ T8H, vs1_2, vc2_2);
if (i < 2) {
__ sqdmulh(vs2_3, __ T8H, vs1_3, vc2_3);
}

// vs2 <- (vs1 * kyberBarrettMultiplier) >> 26
vs_sshr(vs2_1, __ T8H, vs2_1, 11);
vs_sshr(vs2_2, __ T8H, vs2_2, 11);
if (i < 2) {
__ sshr(vs2_3, __ T8H, vs2_3, 11);
}

// vs1 <- vs1 - vs2 * kyber_q
vs_mlsv(vs1_1, __ T8H, vs2_1, vc1_1);
vs_mlsv(vs1_2, __ T8H, vs2_2, vc1_2);
if (i < 2) {
__ mlsv(vs1_3, __ T8H, vs2_3, vc1_3);
}

vs_stpq_post(vs1_1, result);
vs_stpq_post(vs1_2, result);
if (i < 2) {
Expand Down
Loading