diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index be6fc14bc4e..46f7a4f4454 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -27,6 +27,7 @@ import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; +import java.nio.ShortBuffer; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.Arrays; @@ -260,6 +261,42 @@ default float[] toFloatArray() { return ret; } + /** + * Converts this {@code NDArray} to an short array. + * + * @return an int array + * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches + */ + default short[] toShortArray() { + if (getDataType() != DataType.INT16) { + throw new IllegalStateException( + "DataType mismatch, Required int" + " Actual " + getDataType()); + } + ShortBuffer ib = toByteBuffer().asShortBuffer(); + short[] ret = new short[ib.remaining()]; + ib.get(ret); + return ret; + } + + /** + * Converts this {@code NDArray} to an short array. + * + * @return an int array + * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches + */ + default int[] toUnsignedShortArray() { + if (getDataType() != DataType.UINT16) { + throw new IllegalStateException( + "DataType mismatch, Required int" + " Actual " + getDataType()); + } + ShortBuffer ib = toByteBuffer().asShortBuffer(); + int[] ret = new int[ib.remaining()]; + for (int i = 0; i < ret.length; ++i) { + ret[i] = ib.get() & 0xffff; + } + return ret; + } + /** * Converts this {@code NDArray} to an int array. * @@ -277,6 +314,25 @@ default int[] toIntArray() { return ret; } + /** + * Converts this {@code NDArray} to an unsigned int array. + * + * @return a long array + * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches + */ + default long[] toUnsignedIntArray() { + if (getDataType() != DataType.UINT32) { + throw new IllegalStateException( + "DataType mismatch, Required int" + " Actual " + getDataType()); + } + IntBuffer ib = toByteBuffer().asIntBuffer(); + long[] ret = new long[ib.remaining()]; + for (int i = 0; i < ret.length; ++i) { + ret[i] = ib.get() & 0X00000000FFFFFFFFL; + } + return ret; + } + /** * Converts this {@code NDArray} to a long array. * @@ -370,6 +426,7 @@ default String[] toStringArray() { * * @return a Number array */ + @SuppressWarnings("PMD.AvoidArrayLoops") default Number[] toArray() { switch (getDataType()) { case FLOAT16: @@ -380,9 +437,21 @@ default Number[] toArray() { .toArray(Number[]::new); case FLOAT64: return Arrays.stream(toDoubleArray()).boxed().toArray(Double[]::new); + case INT16: + short[] buf = toShortArray(); + Short[] sbuf = new Short[buf.length]; + for (int i = 0; i < buf.length; ++i) { + sbuf[i] = buf[i]; + } + return sbuf; + case UINT16: + return Arrays.stream(toUnsignedShortArray()).boxed().toArray(Integer[]::new); case INT32: return Arrays.stream(toIntArray()).boxed().toArray(Integer[]::new); + case UINT32: + return Arrays.stream(toUnsignedIntArray()).boxed().toArray(Long[]::new); case INT64: + case UINT64: return Arrays.stream(toLongArray()).boxed().toArray(Long[]::new); case BOOLEAN: case INT8: