Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
8278171: [vectorapi] Mask incorrectly computed for zero extending cast
Reviewed-by: psandoz
  • Loading branch information
merykitty authored and Paul Sandoz committed Dec 3, 2021
1 parent fbf096e commit 2e30fa9
Show file tree
Hide file tree
Showing 2 changed files with 322 additions and 1 deletion.
Expand Up @@ -707,7 +707,7 @@ AbstractVector<F> convert0(char kind, AbstractSpecies<F> rsp) {
return v.convert0('C', rspi);
}
// extend in place, but remove unwanted sign extension
long mask = -1L >>> sizeChange;
long mask = -1L >>> -dsp.elementSize();
return (AbstractVector<F>)
v.convert0('C', rspi)
.lanewise(AND, rspi.broadcast(mask));
Expand Down
321 changes: 321 additions & 0 deletions test/jdk/jdk/incubator/vector/VectorReshapeTests.java
Expand Up @@ -3334,6 +3334,327 @@ static void testCastFromDoubleFail() {
}
}

@ForceInline
static
void testVectorUCastByteToShort(VectorSpecies<Byte> a, VectorSpecies<Short> b, byte[] input, short[] output) {
assert(input.length == a.length());
assert(output.length == b.length());

ByteVector av = ByteVector.fromArray(a, input, 0);
ShortVector bv = (ShortVector) av.convertShape(VectorOperators.ZERO_EXTEND_B2S, b, 0);
bv.intoArray(output, 0);

for (int i = 0; i < Math.min(input.length, output.length); i++) {
Assert.assertEquals(output[i], Byte.toUnsignedLong(input[i]));
}
for(int i = input.length; i < output.length; i++) {
Assert.assertEquals(output[i], (short)0);
}
}

@ForceInline
static
void testVectorUCastByteToInt(VectorSpecies<Byte> a, VectorSpecies<Integer> b, byte[] input, int[] output) {
assert(input.length == a.length());
assert(output.length == b.length());

ByteVector av = ByteVector.fromArray(a, input, 0);
IntVector bv = (IntVector) av.convertShape(VectorOperators.ZERO_EXTEND_B2I, b, 0);
bv.intoArray(output, 0);

for (int i = 0; i < Math.min(input.length, output.length); i++) {
Assert.assertEquals(output[i], Byte.toUnsignedLong(input[i]));
}
for(int i = input.length; i < output.length; i++) {
Assert.assertEquals(output[i], (int)0);
}
}

@ForceInline
static
void testVectorUCastByteToLong(VectorSpecies<Byte> a, VectorSpecies<Long> b, byte[] input, long[] output) {
assert(input.length == a.length());
assert(output.length == b.length());

ByteVector av = ByteVector.fromArray(a, input, 0);
LongVector bv = (LongVector) av.convertShape(VectorOperators.ZERO_EXTEND_B2L, b, 0);
bv.intoArray(output, 0);

for (int i = 0; i < Math.min(input.length, output.length); i++) {
Assert.assertEquals(output[i], Byte.toUnsignedLong(input[i]));
}
for(int i = input.length; i < output.length; i++) {
Assert.assertEquals(output[i], (long)0);
}
}

@ForceInline
static
void testVectorUCastShortToInt(VectorSpecies<Short> a, VectorSpecies<Integer> b, short[] input, int[] output) {
assert(input.length == a.length());
assert(output.length == b.length());

ShortVector av = ShortVector.fromArray(a, input, 0);
IntVector bv = (IntVector) av.convertShape(VectorOperators.ZERO_EXTEND_S2I, b, 0);
bv.intoArray(output, 0);

for (int i = 0; i < Math.min(input.length, output.length); i++) {
Assert.assertEquals(output[i], Short.toUnsignedLong(input[i]));
}
for(int i = input.length; i < output.length; i++) {
Assert.assertEquals(output[i], (int)0);
}
}

@ForceInline
static
void testVectorUCastShortToLong(VectorSpecies<Short> a, VectorSpecies<Long> b, short[] input, long[] output) {
assert(input.length == a.length());
assert(output.length == b.length());

ShortVector av = ShortVector.fromArray(a, input, 0);
LongVector bv = (LongVector) av.convertShape(VectorOperators.ZERO_EXTEND_S2L, b, 0);
bv.intoArray(output, 0);

for (int i = 0; i < Math.min(input.length, output.length); i++) {
Assert.assertEquals(output[i], Short.toUnsignedLong(input[i]));
}
for(int i = input.length; i < output.length; i++) {
Assert.assertEquals(output[i], (long)0);
}
}

@ForceInline
static
void testVectorUCastIntToLong(VectorSpecies<Integer> a, VectorSpecies<Long> b, int[] input, long[] output) {
assert(input.length == a.length());
assert(output.length == b.length());

IntVector av = IntVector.fromArray(a, input, 0);
LongVector bv = (LongVector) av.convertShape(VectorOperators.ZERO_EXTEND_I2L, b, 0);
bv.intoArray(output, 0);

for (int i = 0; i < Math.min(input.length, output.length); i++) {
Assert.assertEquals(output[i], Integer.toUnsignedLong(input[i]));
}
for(int i = input.length; i < output.length; i++) {
Assert.assertEquals(output[i], (long)0);
}
}

@Test(dataProvider = "byteUnaryOpProvider")
static void testUCastFromByte(IntFunction<byte[]> fa) {
byte[] bin64 = fa.apply(bspec64.length());
byte[] bin128 = fa.apply(bspec128.length());
byte[] bin256 = fa.apply(bspec256.length());
byte[] bin512 = fa.apply(bspec512.length());

short[] sout64 = new short[sspec64.length()];
short[] sout128 = new short[sspec128.length()];
short[] sout256 = new short[sspec256.length()];
short[] sout512 = new short[sspec512.length()];

int[] iout64 = new int[ispec64.length()];
int[] iout128 = new int[ispec128.length()];
int[] iout256 = new int[ispec256.length()];
int[] iout512 = new int[ispec512.length()];

long[] lout64 = new long[lspec64.length()];
long[] lout128 = new long[lspec128.length()];
long[] lout256 = new long[lspec256.length()];
long[] lout512 = new long[lspec512.length()];

for (int i = 0; i < NUM_ITER; i++) {
// B2S exact fit
testVectorUCastByteToShort(bspec64, sspec128, bin64, sout128);
testVectorUCastByteToShort(bspec128, sspec256, bin128, sout256);
testVectorUCastByteToShort(bspec256, sspec512, bin256, sout512);

// B2S expansion
testVectorUCastByteToShort(bspec64, sspec64, bin64, sout64);
testVectorUCastByteToShort(bspec128, sspec128, bin128, sout128);
testVectorUCastByteToShort(bspec256, sspec256, bin256, sout256);
testVectorUCastByteToShort(bspec512, sspec512, bin512, sout512);

testVectorUCastByteToShort(bspec128, sspec64, bin128, sout64);
testVectorUCastByteToShort(bspec256, sspec128, bin256, sout128);
testVectorUCastByteToShort(bspec512, sspec256, bin512, sout256);

testVectorUCastByteToShort(bspec256, sspec64, bin256, sout64);
testVectorUCastByteToShort(bspec512, sspec128, bin512, sout128);

testVectorUCastByteToShort(bspec512, sspec64, bin512, sout64);

// B2S contraction
testVectorUCastByteToShort(bspec64, sspec256, bin64, sout256);
testVectorUCastByteToShort(bspec128, sspec512, bin128, sout512);

testVectorUCastByteToShort(bspec64, sspec512, bin64, sout512);

// B2I exact fit
testVectorUCastByteToInt(bspec64, ispec256, bin64, iout256);
testVectorUCastByteToInt(bspec128, ispec512, bin128, iout512);

// B2I expansion
testVectorUCastByteToInt(bspec64, ispec128, bin64, iout128);
testVectorUCastByteToInt(bspec128, ispec256, bin128, iout256);
testVectorUCastByteToInt(bspec256, ispec512, bin256, iout512);

testVectorUCastByteToInt(bspec64, ispec64, bin64, iout64);
testVectorUCastByteToInt(bspec128, ispec128, bin128, iout128);
testVectorUCastByteToInt(bspec256, ispec256, bin256, iout256);
testVectorUCastByteToInt(bspec512, ispec512, bin512, iout512);

testVectorUCastByteToInt(bspec128, ispec64, bin128, iout64);
testVectorUCastByteToInt(bspec256, ispec128, bin256, iout128);
testVectorUCastByteToInt(bspec512, ispec256, bin512, iout256);

testVectorUCastByteToInt(bspec256, ispec64, bin256, iout64);
testVectorUCastByteToInt(bspec512, ispec128, bin512, iout128);

testVectorUCastByteToInt(bspec512, ispec64, bin512, iout64);

// B2I contraction
testVectorUCastByteToInt(bspec64, ispec512, bin64, iout512);

// B2L exact fit
testVectorUCastByteToLong(bspec64, lspec512, bin64, lout512);

// B2L expansion
testVectorUCastByteToLong(bspec64, lspec256, bin64, lout256);
testVectorUCastByteToLong(bspec128, lspec512, bin128, lout512);

testVectorUCastByteToLong(bspec64, lspec128, bin64, lout128);
testVectorUCastByteToLong(bspec128, lspec256, bin128, lout256);
testVectorUCastByteToLong(bspec256, lspec512, bin256, lout512);

testVectorUCastByteToLong(bspec64, lspec64, bin64, lout64);
testVectorUCastByteToLong(bspec128, lspec128, bin128, lout128);
testVectorUCastByteToLong(bspec256, lspec256, bin256, lout256);
testVectorUCastByteToLong(bspec512, lspec512, bin512, lout512);

testVectorUCastByteToLong(bspec128, lspec64, bin128, lout64);
testVectorUCastByteToLong(bspec256, lspec128, bin256, lout128);
testVectorUCastByteToLong(bspec512, lspec256, bin512, lout256);

testVectorUCastByteToLong(bspec256, lspec64, bin256, lout64);
testVectorUCastByteToLong(bspec512, lspec128, bin512, lout128);

testVectorUCastByteToLong(bspec512, lspec64, bin512, lout64);
}
}

@Test(dataProvider = "shortUnaryOpProvider")
static void testUCastFromShort(IntFunction<short[]> fa) {
short[] sin64 = fa.apply(sspec64.length());
short[] sin128 = fa.apply(sspec128.length());
short[] sin256 = fa.apply(sspec256.length());
short[] sin512 = fa.apply(sspec512.length());

int[] iout64 = new int[ispec64.length()];
int[] iout128 = new int[ispec128.length()];
int[] iout256 = new int[ispec256.length()];
int[] iout512 = new int[ispec512.length()];

long[] lout64 = new long[lspec64.length()];
long[] lout128 = new long[lspec128.length()];
long[] lout256 = new long[lspec256.length()];
long[] lout512 = new long[lspec512.length()];

for (int i = 0; i < NUM_ITER; i++) {
// S2I exact fit
testVectorUCastShortToInt(sspec64, ispec128, sin64, iout128);
testVectorUCastShortToInt(sspec128, ispec256, sin128, iout256);
testVectorUCastShortToInt(sspec256, ispec512, sin256, iout512);

// S2I expansion
testVectorUCastShortToInt(sspec64, ispec64, sin64, iout64);
testVectorUCastShortToInt(sspec128, ispec128, sin128, iout128);
testVectorUCastShortToInt(sspec256, ispec256, sin256, iout256);
testVectorUCastShortToInt(sspec512, ispec512, sin512, iout512);

testVectorUCastShortToInt(sspec128, ispec64, sin128, iout64);
testVectorUCastShortToInt(sspec256, ispec128, sin256, iout128);
testVectorUCastShortToInt(sspec512, ispec256, sin512, iout256);

testVectorUCastShortToInt(sspec256, ispec64, sin256, iout64);
testVectorUCastShortToInt(sspec512, ispec128, sin512, iout128);

testVectorUCastShortToInt(sspec512, ispec64, sin512, iout64);

// S2I contraction
testVectorUCastShortToInt(sspec64, ispec256, sin64, iout256);
testVectorUCastShortToInt(sspec128, ispec512, sin128, iout512);

testVectorUCastShortToInt(sspec64, ispec512, sin64, iout512);

// S2L exact fit
testVectorUCastShortToLong(sspec64, lspec256, sin64, lout256);
testVectorUCastShortToLong(sspec128, lspec512, sin128, lout512);

// S2L expansion
testVectorUCastShortToLong(sspec64, lspec128, sin64, lout128);
testVectorUCastShortToLong(sspec128, lspec256, sin128, lout256);
testVectorUCastShortToLong(sspec256, lspec512, sin256, lout512);

testVectorUCastShortToLong(sspec64, lspec64, sin64, lout64);
testVectorUCastShortToLong(sspec128, lspec128, sin128, lout128);
testVectorUCastShortToLong(sspec256, lspec256, sin256, lout256);
testVectorUCastShortToLong(sspec512, lspec512, sin512, lout512);

testVectorUCastShortToLong(sspec128, lspec64, sin128, lout64);
testVectorUCastShortToLong(sspec256, lspec128, sin256, lout128);
testVectorUCastShortToLong(sspec512, lspec256, sin512, lout256);

testVectorUCastShortToLong(sspec256, lspec64, sin256, lout64);
testVectorUCastShortToLong(sspec512, lspec128, sin512, lout128);

testVectorUCastShortToLong(sspec512, lspec64, sin512, lout64);

// S2L contraction
testVectorUCastShortToLong(sspec64, lspec512, sin64, lout512);
}
}

@Test(dataProvider = "intUnaryOpProvider")
static void testUCastFromInt(IntFunction<int[]> fa) {
int[] iin64 = fa.apply(ispec64.length());
int[] iin128 = fa.apply(ispec128.length());
int[] iin256 = fa.apply(ispec256.length());
int[] iin512 = fa.apply(ispec512.length());

long[] lout64 = new long[lspec64.length()];
long[] lout128 = new long[lspec128.length()];
long[] lout256 = new long[lspec256.length()];
long[] lout512 = new long[lspec512.length()];

// I2L exact fit
testVectorUCastIntToLong(ispec64, lspec128, iin64, lout128);
testVectorUCastIntToLong(ispec128, lspec256, iin128, lout256);
testVectorUCastIntToLong(ispec256, lspec512, iin256, lout512);

// I2L expansion
testVectorUCastIntToLong(ispec64, lspec64, iin64, lout64);
testVectorUCastIntToLong(ispec128, lspec128, iin128, lout128);
testVectorUCastIntToLong(ispec256, lspec256, iin256, lout256);
testVectorUCastIntToLong(ispec512, lspec512, iin512, lout512);

testVectorUCastIntToLong(ispec128, lspec64, iin128, lout64);
testVectorUCastIntToLong(ispec256, lspec128, iin256, lout128);
testVectorUCastIntToLong(ispec512, lspec256, iin512, lout256);

testVectorUCastIntToLong(ispec256, lspec64, iin256, lout64);
testVectorUCastIntToLong(ispec512, lspec128, iin512, lout128);

testVectorUCastIntToLong(ispec512, lspec64, iin512, lout64);

// I2L contraction
testVectorUCastIntToLong(ispec64, lspec256, iin64, lout256);
testVectorUCastIntToLong(ispec128, lspec512, iin128, lout512);

testVectorUCastIntToLong(ispec64, lspec512, iin64, lout512);
}

static
void testVectorCastByteMaxToByte(VectorSpecies<Byte> a, VectorSpecies<Byte> b,
byte[] input, byte[] output) {
Expand Down

1 comment on commit 2e30fa9

@openjdk-notifier
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.