Skip to content

Commit

Permalink
[api] Adds uint16, uint32, uint64, int16, bf16 data type (#2570)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Apr 27, 2023
1 parent 26ec36c commit 4a2d5fc
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 10 deletions.
16 changes: 15 additions & 1 deletion api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ NDManager getAlternativeManager() {
public static void validateBuffer(Buffer buffer, DataType dataType, int expected) {
boolean isByteBuffer = buffer instanceof ByteBuffer;
DataType type = DataType.fromBuffer(buffer);
if (type != dataType && !isByteBuffer) {
if (!isCompatible(type, dataType) && !isByteBuffer) {
// It's ok if type != datatype and buffer is ByteBuffer,
// since buffer will be copied into ByteBuffer
throw new IllegalArgumentException(
Expand All @@ -464,6 +464,20 @@ public static void validateBuffer(Buffer buffer, DataType dataType, int expected
}
}

private static boolean isCompatible(DataType type1, DataType type2) {
if (type1.getNumOfBytes() != type1.getNumOfBytes()) {
return false;
}
if (type1.getNumOfBytes() == 2) {
// fp16, bf16, int16, uint16 all uses ShortBuffer
return true;
}
if (type1.getFormat() == type2.getFormat()) {
return true;
}
return type1.isInteger() && type2.isInteger();
}

/**
* Copies data from the source {@code Buffer} to the target {@code ByteBuffer}.
*
Expand Down
46 changes: 40 additions & 6 deletions api/src/main/java/ai/djl/ndarray/types/DataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

/** An enum representing the underlying {@link NDArray}'s data type. */
public enum DataType {
// do not change order, pytorch engine rely on DataType.ordinal()
FLOAT32(Format.FLOATING, 4),
FLOAT64(Format.FLOATING, 8),
FLOAT16(Format.FLOATING, 2),
Expand All @@ -35,7 +36,12 @@ public enum DataType {
BOOLEAN(Format.BOOLEAN, 1),
COMPLEX64(Format.FLOATING, 4),
UNKNOWN(Format.UNKNOWN, 0),
STRING(Format.STRING, -1);
STRING(Format.STRING, -1),
BFLOAT16(Format.FLOATING, 2),
UINT64(Format.UINT, 8),
UINT32(Format.UINT, 4),
UINT16(Format.UINT, 2),
INT16(Format.INT, 2);

/** The general data type format categories. */
public enum Format {
Expand Down Expand Up @@ -147,12 +153,28 @@ public static DataType fromNumpy(String dtype) {
return FLOAT16;
case "|u1":
return UINT8;
case "<u2":
case ">u2":
case "=u2":
return UINT16;
case "<u4":
case ">u4":
case "=u4":
return UINT32;
case "<u8":
case ">u8":
case "=u8":
return UINT64;
case "|i1":
return INT8;
case "<i2":
case ">i2":
case "=i2":
return INT16;
case "<i4":
case ">i4":
case "=i4":
return INT32;
case "|i1":
return INT8;
case "<i8":
case ">i8":
case "=i8":
Expand All @@ -175,14 +197,17 @@ public static DataType fromNumpy(String dtype) {
public Buffer asDataType(ByteBuffer data) {
switch (this) {
case FLOAT16:
case BFLOAT16:
return data.asShortBuffer();
case FLOAT32:
return data.asFloatBuffer();
case FLOAT64:
return data.asDoubleBuffer();
case INT32:
case UINT32:
return data.asIntBuffer();
case INT64:
case UINT64:
return data.asLongBuffer();
case UINT8:
case INT8:
Expand All @@ -209,16 +234,25 @@ public String asNumpy() {
return order + "f2";
case UINT8:
return "|u1";
case INT32:
return order + "i4";
case UINT16:
return order + "u2";
case UINT32:
return order + "u4";
case UINT64:
return order + "u8";
case INT8:
return "|i1";
case INT16:
return order + "i2";
case INT32:
return order + "i4";
case INT64:
return "<i8";
return order + "i8";
case BOOLEAN:
return "|b1";
case STRING:
return "|S1";
case BFLOAT16:
case COMPLEX64:
case UNKNOWN:
default:
Expand Down
17 changes: 15 additions & 2 deletions api/src/test/java/ai/djl/ndarray/types/DataTypeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,26 @@
import org.testng.Assert;
import org.testng.annotations.Test;

import java.nio.ByteOrder;

public class DataTypeTest {

@Test
public void numpyTest() {
Assert.assertEquals("|S1", DataType.STRING.asNumpy());
Assert.assertEquals(DataType.STRING, DataType.fromNumpy("|S1"));
char order = ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN ? '>' : '<';

Assert.assertEquals(DataType.INT16.asNumpy(), order + "i2");
Assert.assertEquals(DataType.UINT16.asNumpy(), order + "u2");
Assert.assertEquals(DataType.UINT32.asNumpy(), order + "u4");
Assert.assertEquals(DataType.UINT64.asNumpy(), order + "u8");
Assert.assertEquals(DataType.STRING.asNumpy(), "|S1");
Assert.assertEquals(DataType.fromNumpy("<i2"), DataType.INT16);
Assert.assertEquals(DataType.fromNumpy(">u2"), DataType.UINT16);
Assert.assertEquals(DataType.fromNumpy("=u4"), DataType.UINT32);
Assert.assertEquals(DataType.fromNumpy(">u8"), DataType.UINT64);
Assert.assertEquals(DataType.fromNumpy("|S1"), DataType.STRING);

Assert.expectThrows(IllegalArgumentException.class, DataType.BFLOAT16::asNumpy);
Assert.expectThrows(IllegalArgumentException.class, DataType.UNKNOWN::asNumpy);
Assert.expectThrows(IllegalArgumentException.class, () -> DataType.fromNumpy("|i8"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@
import ai.djl.engine.Engine;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;

public class PassthroughNDManagerTest {

@Test
Expand All @@ -41,6 +47,41 @@ public void testPassthrough() {
array.toByteBuffer();
manager.from(array);

LongBuffer lb = LongBuffer.allocate(1);
lb.put(0, 10);
NDArray ui64 = manager.create(lb, new Shape(1, 1), DataType.UINT64);
ByteBuffer bb = ui64.toByteBuffer();
Assert.assertEquals(bb.remaining(), 8);
Assert.assertEquals(bb.getLong(), 10);

IntBuffer ib = IntBuffer.allocate(1);
ib.put(0, 11);
NDArray ui32 = manager.create(ib, new Shape(1, 1), DataType.UINT32);
bb = ui32.toByteBuffer();
Assert.assertEquals(bb.remaining(), 4);
Assert.assertEquals(bb.getInt(), 11);

ShortBuffer sb = ShortBuffer.allocate(1);
sb.put(0, (short) 12);
NDArray ui16 = manager.create(sb, new Shape(1, 1), DataType.UINT16);
bb = ui16.toByteBuffer();
Assert.assertEquals(bb.remaining(), 2);
Assert.assertEquals(bb.getShort(), 12);

sb.rewind();
sb.put(0, (short) 13);
NDArray i16 = manager.create(sb, new Shape(1, 1), DataType.INT16);
bb = i16.toByteBuffer();
Assert.assertEquals(bb.remaining(), 2);
Assert.assertEquals(bb.getShort(), 13);

sb.rewind();
sb.put(0, (short) 14);
NDArray bf16 = manager.create(sb, new Shape(1, 1), DataType.BFLOAT16);
bb = bf16.toByteBuffer();
Assert.assertEquals(bb.remaining(), 2);
Assert.assertEquals(bb.getShort(), 14);

PassthroughNDArray pa = manager.create((Object) "test");
Assert.assertThrows(pa::toByteBuffer);
Assert.assertEquals(pa.getObject(), "test");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ inline jint GetDTypeFromScalarType(const torch::ScalarType& type) {
return 6;
} else if (torch::kBool == type) {
return 7;
} else {
} else if (torch::kComplexFloat == type) {
return 8;
} else {
return 9;
}
}

Expand Down

0 comments on commit 4a2d5fc

Please sign in to comment.