Skip to content

Commit b65fdf5

Browse files
author
Ian Graves
committed
8358768: [vectorapi] Make VectorOperators.SUADD an Associative
Reviewed-by: psandoz
1 parent d2082c5 commit b65fdf5

37 files changed

+3158
-1
lines changed

src/jdk.incubator.vector/share/classes/jdk/incubator/vector/VectorOperators.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ static boolean opKind(Operator op, int bit) {
577577
/** Produce saturating unsigned {@code a+b}. Integral only.
578578
* @see VectorMath#addSaturatingUnsigned(int, int)
579579
*/
580-
public static final Binary SUADD = binary("SUADD", "+", VectorSupport.VECTOR_OP_SUADD, VO_NOFP);
580+
public static final Associative SUADD = assoc("SUADD", "+", VectorSupport.VECTOR_OP_SUADD, VO_NOFP+VO_ASSOC);
581581
/** Produce saturating {@code a-b}. Integral only.
582582
* @see VectorMath#subSaturating(int, int)
583583
*/

test/jdk/jdk/incubator/vector/Byte128VectorTests.java

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,52 @@ static FBinMaskOp lift(FBinOp f) {
405405
}
406406
}
407407

408+
static void assertArraysEqualsAssociative(byte[] rl, byte[] rr, byte[] a, byte[] b, byte[] c, FBinOp f) {
409+
int i = 0;
410+
try {
411+
for (; i < a.length; i++) {
412+
//Left associative
413+
Assert.assertEquals(rl[i], f.apply(f.apply(a[i], b[i]), c[i]));
414+
415+
//Right associative
416+
Assert.assertEquals(rr[i], f.apply(a[i], f.apply(b[i], c[i])));
417+
418+
//Results equal sanity check
419+
Assert.assertEquals(rl[i], rr[i]);
420+
}
421+
} catch (AssertionError e) {
422+
Assert.assertEquals(rl[i], f.apply(f.apply(a[i], b[i]), c[i]), "left associative test at index #" + i + ", input1 = " + a[i] + ", input2 = " + b[i] + ", input3 = " + c[i]);
423+
Assert.assertEquals(rr[i], f.apply(a[i], f.apply(b[i], c[i])), "right associative test at index #" + i + ", input1 = " + a[i] + ", input2 = " + b[i] + ", input3 = " + c[i]);
424+
Assert.assertEquals(rl[i], rr[i], "Result checks not equal at index #" + i + "leftRes = " + rl[i] + ", rightRes = " + rr[i]);
425+
}
426+
}
427+
428+
static void assertArraysEqualsAssociative(byte[] rl, byte[] rr, byte[] a, byte[] b, byte[] c, boolean[] mask, FBinOp f) {
429+
assertArraysEqualsAssociative(rl, rr, a, b, c, mask, FBinMaskOp.lift(f));
430+
}
431+
432+
static void assertArraysEqualsAssociative(byte[] rl, byte[] rr, byte[] a, byte[] b, byte[] c, boolean[] mask, FBinMaskOp f) {
433+
int i = 0;
434+
boolean mask_bit = false;
435+
try {
436+
for (; i < a.length; i++) {
437+
mask_bit = mask[i % SPECIES.length()];
438+
//Left associative
439+
Assert.assertEquals(rl[i], f.apply(f.apply(a[i], b[i], mask_bit), c[i], mask_bit));
440+
441+
//Right associative
442+
Assert.assertEquals(rr[i], f.apply(a[i], f.apply(b[i], c[i], mask_bit), mask_bit));
443+
444+
//Results equal sanity check
445+
Assert.assertEquals(rl[i], rr[i]);
446+
}
447+
} catch (AssertionError e) {
448+
Assert.assertEquals(rl[i], f.apply(f.apply(a[i], b[i], mask_bit), c[i], mask_bit), "left associative masked test at index #" + i + ", input1 = " + a[i] + ", input2 = " + b[i] + ", input3 = " + c[i] + ", mask = " + mask_bit);
449+
Assert.assertEquals(rr[i], f.apply(a[i], f.apply(b[i], c[i], mask_bit), mask_bit), "right associative masked test at index #" + i + ", input1 = " + a[i] + ", input2 = " + b[i] + ", input3 = " + c[i] + ", mask = " + mask_bit);
450+
Assert.assertEquals(rl[i], rr[i], "Result checks not equal at index #" + i + "leftRes = " + rl[i] + ", rightRes = " + rr[i]);
451+
}
452+
}
453+
408454
static void assertArraysEquals(byte[] r, byte[] a, byte[] b, FBinOp f) {
409455
int i = 0;
410456
try {
@@ -1016,6 +1062,21 @@ static byte bits(byte e) {
10161062
})
10171063
);
10181064

1065+
static final List<IntFunction<byte[]>> BYTE_SATURATING_GENERATORS_ASSOC = List.of(
1066+
withToString("byte[Byte.MAX_VALUE]", (int s) -> {
1067+
return fill(s * BUFFER_REPS,
1068+
i -> (byte)(Byte.MAX_VALUE));
1069+
}),
1070+
withToString("byte[Byte.MAX_VALUE - 100]", (int s) -> {
1071+
return fill(s * BUFFER_REPS,
1072+
i -> (byte)(Byte.MAX_VALUE - 100));
1073+
}),
1074+
withToString("byte[-1]", (int s) -> {
1075+
return fill(s * BUFFER_REPS,
1076+
i -> (byte)(-1));
1077+
})
1078+
);
1079+
10191080
// Create combinations of pairs
10201081
// @@@ Might be sensitive to order e.g. div by 0
10211082
static final List<List<IntFunction<byte[]>>> BYTE_GENERATOR_PAIRS =
@@ -1028,6 +1089,12 @@ static byte bits(byte e) {
10281089
flatMap(fa -> BYTE_SATURATING_GENERATORS.stream().skip(1).map(fb -> List.of(fa, fb))).
10291090
collect(Collectors.toList());
10301091

1092+
static final List<List<IntFunction<byte[]>>> BYTE_SATURATING_GENERATOR_TRIPLETS =
1093+
Stream.of(BYTE_GENERATORS.get(1))
1094+
.flatMap(fa -> BYTE_SATURATING_GENERATORS_ASSOC.stream().map(fb -> List.of(fa, fb)))
1095+
.flatMap(pair -> BYTE_SATURATING_GENERATORS_ASSOC.stream().map(f -> List.of(pair.get(0), pair.get(1), f)))
1096+
.collect(Collectors.toList());
1097+
10311098
@DataProvider
10321099
public Object[][] boolUnaryOpProvider() {
10331100
return BOOL_ARRAY_GENERATORS.stream().
@@ -1064,6 +1131,22 @@ public Object[][] byteSaturatingBinaryOpProvider() {
10641131
toArray(Object[][]::new);
10651132
}
10661133

1134+
@DataProvider
1135+
public Object[][] byteSaturatingBinaryOpAssocProvider() {
1136+
return BYTE_SATURATING_GENERATOR_TRIPLETS.stream().map(List::toArray).
1137+
toArray(Object[][]::new);
1138+
}
1139+
1140+
@DataProvider
1141+
public Object[][] byteSaturatingBinaryOpAssocMaskProvider() {
1142+
return BOOLEAN_MASK_GENERATORS.stream().
1143+
flatMap(fm -> BYTE_SATURATING_GENERATOR_TRIPLETS.stream().map(lfa -> {
1144+
return Stream.concat(lfa.stream(), Stream.of(fm)).toArray();
1145+
})).
1146+
toArray(Object[][]::new);
1147+
}
1148+
1149+
10671150
@DataProvider
10681151
public Object[][] byteIndexedOpProvider() {
10691152
return BYTE_GENERATOR_PAIRS.stream().map(List::toArray).
@@ -3420,6 +3503,51 @@ static void maxByte128VectorTestsBroadcastSmokeTest(IntFunction<byte[]> fa, IntF
34203503

34213504
assertBroadcastArraysEquals(r, a, b, Byte128VectorTests::max);
34223505
}
3506+
@Test(dataProvider = "byteSaturatingBinaryOpAssocProvider")
3507+
static void SUADDAssocByte128VectorTests(IntFunction<byte[]> fa, IntFunction<byte[]> fb, IntFunction<byte[]> fc) {
3508+
byte[] a = fa.apply(SPECIES.length());
3509+
byte[] b = fb.apply(SPECIES.length());
3510+
byte[] c = fc.apply(SPECIES.length());
3511+
byte[] rl = fr.apply(SPECIES.length());
3512+
byte[] rr = fr.apply(SPECIES.length());
3513+
3514+
for (int ic = 0; ic < INVOC_COUNT; ic++) {
3515+
for (int i = 0; i < a.length; i += SPECIES.length()) {
3516+
ByteVector av = ByteVector.fromArray(SPECIES, a, i);
3517+
ByteVector bv = ByteVector.fromArray(SPECIES, b, i);
3518+
ByteVector cv = ByteVector.fromArray(SPECIES, c, i);
3519+
av.lanewise(VectorOperators.SUADD, bv).lanewise(VectorOperators.SUADD, cv).intoArray(rl, i);
3520+
av.lanewise(VectorOperators.SUADD, bv.lanewise(VectorOperators.SUADD, cv)).intoArray(rr, i);
3521+
}
3522+
}
3523+
3524+
assertArraysEqualsAssociative(rl, rr, a, b, c, Byte128VectorTests::SUADD);
3525+
}
3526+
3527+
@Test(dataProvider = "byteSaturatingBinaryOpAssocMaskProvider")
3528+
static void SUADDAssocByte128VectorTestsMasked(IntFunction<byte[]> fa, IntFunction<byte[]> fb,
3529+
IntFunction<byte[]> fc, IntFunction<boolean[]> fm) {
3530+
byte[] a = fa.apply(SPECIES.length());
3531+
byte[] b = fb.apply(SPECIES.length());
3532+
byte[] c = fc.apply(SPECIES.length());
3533+
boolean[] mask = fm.apply(SPECIES.length());
3534+
byte[] rl = fr.apply(SPECIES.length());
3535+
byte[] rr = fr.apply(SPECIES.length());
3536+
3537+
VectorMask<Byte> vmask = VectorMask.fromArray(SPECIES, mask, 0);
3538+
3539+
for (int ic = 0; ic < INVOC_COUNT; ic++) {
3540+
for (int i = 0; i < a.length; i += SPECIES.length()) {
3541+
ByteVector av = ByteVector.fromArray(SPECIES, a, i);
3542+
ByteVector bv = ByteVector.fromArray(SPECIES, b, i);
3543+
ByteVector cv = ByteVector.fromArray(SPECIES, c, i);
3544+
av.lanewise(VectorOperators.SUADD, bv, vmask).lanewise(VectorOperators.SUADD, cv, vmask).intoArray(rl, i);
3545+
av.lanewise(VectorOperators.SUADD, bv.lanewise(VectorOperators.SUADD, cv, vmask), vmask).intoArray(rr, i);
3546+
}
3547+
}
3548+
3549+
assertArraysEqualsAssociative(rl, rr, a, b, c, mask, Byte128VectorTests::SUADD);
3550+
}
34233551

34243552
static byte ANDReduce(byte[] a, int idx) {
34253553
byte res = -1;

test/jdk/jdk/incubator/vector/Byte256VectorTests.java

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,52 @@ static FBinMaskOp lift(FBinOp f) {
405405
}
406406
}
407407

408+
static void assertArraysEqualsAssociative(byte[] rl, byte[] rr, byte[] a, byte[] b, byte[] c, FBinOp f) {
409+
int i = 0;
410+
try {
411+
for (; i < a.length; i++) {
412+
//Left associative
413+
Assert.assertEquals(rl[i], f.apply(f.apply(a[i], b[i]), c[i]));
414+
415+
//Right associative
416+
Assert.assertEquals(rr[i], f.apply(a[i], f.apply(b[i], c[i])));
417+
418+
//Results equal sanity check
419+
Assert.assertEquals(rl[i], rr[i]);
420+
}
421+
} catch (AssertionError e) {
422+
Assert.assertEquals(rl[i], f.apply(f.apply(a[i], b[i]), c[i]), "left associative test at index #" + i + ", input1 = " + a[i] + ", input2 = " + b[i] + ", input3 = " + c[i]);
423+
Assert.assertEquals(rr[i], f.apply(a[i], f.apply(b[i], c[i])), "right associative test at index #" + i + ", input1 = " + a[i] + ", input2 = " + b[i] + ", input3 = " + c[i]);
424+
Assert.assertEquals(rl[i], rr[i], "Result checks not equal at index #" + i + "leftRes = " + rl[i] + ", rightRes = " + rr[i]);
425+
}
426+
}
427+
428+
static void assertArraysEqualsAssociative(byte[] rl, byte[] rr, byte[] a, byte[] b, byte[] c, boolean[] mask, FBinOp f) {
429+
assertArraysEqualsAssociative(rl, rr, a, b, c, mask, FBinMaskOp.lift(f));
430+
}
431+
432+
static void assertArraysEqualsAssociative(byte[] rl, byte[] rr, byte[] a, byte[] b, byte[] c, boolean[] mask, FBinMaskOp f) {
433+
int i = 0;
434+
boolean mask_bit = false;
435+
try {
436+
for (; i < a.length; i++) {
437+
mask_bit = mask[i % SPECIES.length()];
438+
//Left associative
439+
Assert.assertEquals(rl[i], f.apply(f.apply(a[i], b[i], mask_bit), c[i], mask_bit));
440+
441+
//Right associative
442+
Assert.assertEquals(rr[i], f.apply(a[i], f.apply(b[i], c[i], mask_bit), mask_bit));
443+
444+
//Results equal sanity check
445+
Assert.assertEquals(rl[i], rr[i]);
446+
}
447+
} catch (AssertionError e) {
448+
Assert.assertEquals(rl[i], f.apply(f.apply(a[i], b[i], mask_bit), c[i], mask_bit), "left associative masked test at index #" + i + ", input1 = " + a[i] + ", input2 = " + b[i] + ", input3 = " + c[i] + ", mask = " + mask_bit);
449+
Assert.assertEquals(rr[i], f.apply(a[i], f.apply(b[i], c[i], mask_bit), mask_bit), "right associative masked test at index #" + i + ", input1 = " + a[i] + ", input2 = " + b[i] + ", input3 = " + c[i] + ", mask = " + mask_bit);
450+
Assert.assertEquals(rl[i], rr[i], "Result checks not equal at index #" + i + "leftRes = " + rl[i] + ", rightRes = " + rr[i]);
451+
}
452+
}
453+
408454
static void assertArraysEquals(byte[] r, byte[] a, byte[] b, FBinOp f) {
409455
int i = 0;
410456
try {
@@ -1016,6 +1062,21 @@ static byte bits(byte e) {
10161062
})
10171063
);
10181064

1065+
static final List<IntFunction<byte[]>> BYTE_SATURATING_GENERATORS_ASSOC = List.of(
1066+
withToString("byte[Byte.MAX_VALUE]", (int s) -> {
1067+
return fill(s * BUFFER_REPS,
1068+
i -> (byte)(Byte.MAX_VALUE));
1069+
}),
1070+
withToString("byte[Byte.MAX_VALUE - 100]", (int s) -> {
1071+
return fill(s * BUFFER_REPS,
1072+
i -> (byte)(Byte.MAX_VALUE - 100));
1073+
}),
1074+
withToString("byte[-1]", (int s) -> {
1075+
return fill(s * BUFFER_REPS,
1076+
i -> (byte)(-1));
1077+
})
1078+
);
1079+
10191080
// Create combinations of pairs
10201081
// @@@ Might be sensitive to order e.g. div by 0
10211082
static final List<List<IntFunction<byte[]>>> BYTE_GENERATOR_PAIRS =
@@ -1028,6 +1089,12 @@ static byte bits(byte e) {
10281089
flatMap(fa -> BYTE_SATURATING_GENERATORS.stream().skip(1).map(fb -> List.of(fa, fb))).
10291090
collect(Collectors.toList());
10301091

1092+
static final List<List<IntFunction<byte[]>>> BYTE_SATURATING_GENERATOR_TRIPLETS =
1093+
Stream.of(BYTE_GENERATORS.get(1))
1094+
.flatMap(fa -> BYTE_SATURATING_GENERATORS_ASSOC.stream().map(fb -> List.of(fa, fb)))
1095+
.flatMap(pair -> BYTE_SATURATING_GENERATORS_ASSOC.stream().map(f -> List.of(pair.get(0), pair.get(1), f)))
1096+
.collect(Collectors.toList());
1097+
10311098
@DataProvider
10321099
public Object[][] boolUnaryOpProvider() {
10331100
return BOOL_ARRAY_GENERATORS.stream().
@@ -1064,6 +1131,22 @@ public Object[][] byteSaturatingBinaryOpProvider() {
10641131
toArray(Object[][]::new);
10651132
}
10661133

1134+
@DataProvider
1135+
public Object[][] byteSaturatingBinaryOpAssocProvider() {
1136+
return BYTE_SATURATING_GENERATOR_TRIPLETS.stream().map(List::toArray).
1137+
toArray(Object[][]::new);
1138+
}
1139+
1140+
@DataProvider
1141+
public Object[][] byteSaturatingBinaryOpAssocMaskProvider() {
1142+
return BOOLEAN_MASK_GENERATORS.stream().
1143+
flatMap(fm -> BYTE_SATURATING_GENERATOR_TRIPLETS.stream().map(lfa -> {
1144+
return Stream.concat(lfa.stream(), Stream.of(fm)).toArray();
1145+
})).
1146+
toArray(Object[][]::new);
1147+
}
1148+
1149+
10671150
@DataProvider
10681151
public Object[][] byteIndexedOpProvider() {
10691152
return BYTE_GENERATOR_PAIRS.stream().map(List::toArray).
@@ -3420,6 +3503,51 @@ static void maxByte256VectorTestsBroadcastSmokeTest(IntFunction<byte[]> fa, IntF
34203503

34213504
assertBroadcastArraysEquals(r, a, b, Byte256VectorTests::max);
34223505
}
3506+
@Test(dataProvider = "byteSaturatingBinaryOpAssocProvider")
3507+
static void SUADDAssocByte256VectorTests(IntFunction<byte[]> fa, IntFunction<byte[]> fb, IntFunction<byte[]> fc) {
3508+
byte[] a = fa.apply(SPECIES.length());
3509+
byte[] b = fb.apply(SPECIES.length());
3510+
byte[] c = fc.apply(SPECIES.length());
3511+
byte[] rl = fr.apply(SPECIES.length());
3512+
byte[] rr = fr.apply(SPECIES.length());
3513+
3514+
for (int ic = 0; ic < INVOC_COUNT; ic++) {
3515+
for (int i = 0; i < a.length; i += SPECIES.length()) {
3516+
ByteVector av = ByteVector.fromArray(SPECIES, a, i);
3517+
ByteVector bv = ByteVector.fromArray(SPECIES, b, i);
3518+
ByteVector cv = ByteVector.fromArray(SPECIES, c, i);
3519+
av.lanewise(VectorOperators.SUADD, bv).lanewise(VectorOperators.SUADD, cv).intoArray(rl, i);
3520+
av.lanewise(VectorOperators.SUADD, bv.lanewise(VectorOperators.SUADD, cv)).intoArray(rr, i);
3521+
}
3522+
}
3523+
3524+
assertArraysEqualsAssociative(rl, rr, a, b, c, Byte256VectorTests::SUADD);
3525+
}
3526+
3527+
@Test(dataProvider = "byteSaturatingBinaryOpAssocMaskProvider")
3528+
static void SUADDAssocByte256VectorTestsMasked(IntFunction<byte[]> fa, IntFunction<byte[]> fb,
3529+
IntFunction<byte[]> fc, IntFunction<boolean[]> fm) {
3530+
byte[] a = fa.apply(SPECIES.length());
3531+
byte[] b = fb.apply(SPECIES.length());
3532+
byte[] c = fc.apply(SPECIES.length());
3533+
boolean[] mask = fm.apply(SPECIES.length());
3534+
byte[] rl = fr.apply(SPECIES.length());
3535+
byte[] rr = fr.apply(SPECIES.length());
3536+
3537+
VectorMask<Byte> vmask = VectorMask.fromArray(SPECIES, mask, 0);
3538+
3539+
for (int ic = 0; ic < INVOC_COUNT; ic++) {
3540+
for (int i = 0; i < a.length; i += SPECIES.length()) {
3541+
ByteVector av = ByteVector.fromArray(SPECIES, a, i);
3542+
ByteVector bv = ByteVector.fromArray(SPECIES, b, i);
3543+
ByteVector cv = ByteVector.fromArray(SPECIES, c, i);
3544+
av.lanewise(VectorOperators.SUADD, bv, vmask).lanewise(VectorOperators.SUADD, cv, vmask).intoArray(rl, i);
3545+
av.lanewise(VectorOperators.SUADD, bv.lanewise(VectorOperators.SUADD, cv, vmask), vmask).intoArray(rr, i);
3546+
}
3547+
}
3548+
3549+
assertArraysEqualsAssociative(rl, rr, a, b, c, mask, Byte256VectorTests::SUADD);
3550+
}
34233551

34243552
static byte ANDReduce(byte[] a, int idx) {
34253553
byte res = -1;

0 commit comments

Comments
 (0)