Skip to content

Commit

Permalink
Ref tests
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Chu <justinchu@microsoft.com>
  • Loading branch information
justinchuby committed May 4, 2024
1 parent 9fd4909 commit 4c18a15
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 33 deletions.
15 changes: 4 additions & 11 deletions onnx/backend/test/case/node/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,6 @@ def export() -> None:

vect_float32_to_float8e4m3 = np.vectorize(float32_to_float8e4m3)
vect_float32_to_float8e5m2 = np.vectorize(float32_to_float8e5m2)
vect_float32_to_uint4 = np.vectorize(
lambda x: subbyte.float32_to_4bit_unpacked(x, signed=False)
)
vect_float32_to_int4 = np.vectorize(
lambda x: subbyte.float32_to_4bit_unpacked(x, signed=True)
)

f8_types = ("FLOAT8E4M3FN", "FLOAT8E4M3FNUZ", "FLOAT8E5M2", "FLOAT8E5M2FNUZ")

for from_type, to_type in test_cases:
Expand Down Expand Up @@ -239,12 +232,12 @@ def export() -> None:
"x", TensorProto.FLOAT16, input_shape, input_values.tolist()
)
elif from_type == "UINT4":
input_values = vect_float32_to_uint4(np_fp32)
input_values = subbyte.cast_uint4(np_fp32)
input = make_tensor(
"x", TensorProto.UINT4, input_shape, input_values.tolist()
)
elif from_type == "INT4":
input_values = vect_float32_to_int4(np_fp32)
input_values = subbyte.cast_int4(np_fp32)
input = make_tensor(
"x", TensorProto.INT4, input_shape, input_values.tolist()
)
Expand All @@ -253,9 +246,9 @@ def export() -> None:
"Conversion from {from_type} to {to_type} is not tested."
)
if to_type == "UINT4":
expected = vect_float32_to_uint4(input_values).astype(custom.uint4)
expected = subbyte.cast_uint4(input_values).astype(custom.uint4)
elif to_type == "INT4":
expected = vect_float32_to_int4(input_values).astype(custom.int4)
expected = subbyte.cast_int4(input_values).astype(custom.int4)
elif to_type == "FLOAT16":
expected = input_values.astype(np.float16)
elif to_type == "FLOAT":
Expand Down
2 changes: 1 addition & 1 deletion onnx/numpy_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def bfloat16_to_float32(
Returns:
A numpy array of float32 with the same dimension.
"""
return _left_shift_16_bits(data).view(np.float32)
return _left_shift_16_bits(data.astype(np.uint32)).view(np.float32)


def float8e4m3_to_float32(
Expand Down
8 changes: 3 additions & 5 deletions onnx/reference/ops/op_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,9 @@ def cast_to(x, to, saturate): # noqa: PLR0911

if to == tensor_type:
xf = x.astype(np.float32).ravel()
y = np.empty(xf.shape, dtype=np_type).ravel()
for i in range(y.shape[0]):
el = subbyte.float32_to_4bit_unpacked(xf[i], signed=signed)
y[i] = el
return y.reshape(x.shape)
if signed:
return subbyte.cast_int4(xf).view(np_type)
return subbyte.cast_uint4(xf).view(np_type)

f8back = {
TensorProto.FLOAT8E4M3FN: (
Expand Down
9 changes: 4 additions & 5 deletions onnx/reference/ops/op_quantize_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,10 @@ def _run(
if tensor_type in (TensorProto.UINT4, TensorProto.INT4):
xi = np.rint(x).astype(np.int32)
xi += zero_point
single_func = lambda x: subbyte.float32_to_4bit_unpacked( # noqa: E731
x, signed=(tensor_type == TensorProto.INT4)
)
func = np.vectorize(single_func)
i4 = func(xi)
if tensor_type == TensorProto.INT4:
i4 = subbyte.cast_int4(xi)
else:
i4 = subbyte.cast_uint4(xi)
return (i4,) # type: ignore[attr-defined]

raise ValueError(
Expand Down
21 changes: 10 additions & 11 deletions onnx/test/reference_evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5647,12 +5647,12 @@ def test_cast_int4_output(self, cast_from, cast_to):
)
ref = ReferenceEvaluator(model)
data = np.array([0, 1, 2.4, 2.6, 4, 10], dtype=np.float32)
signed = cast_to == TensorProto.INT4
expected1 = np.array(
[subbyte.float32_to_4bit_unpacked(x, signed=signed) for x in data]
)
if cast_to == TensorProto.INT4:
expected = subbyte.cast_int4(data)
else:
expected = subbyte.cast_uint4(data)
got = ref.run(None, {"X": data})
self.assertEqual(expected1.tolist(), got[0].tolist())
self.assertEqual(expected.tolist(), got[0].tolist())

@parameterized.parameterized.expand(
itertools.product(
Expand All @@ -5675,13 +5675,12 @@ def test_cast_int4_input(self, cast_from, cast_to):
)
ref = ReferenceEvaluator(model)
data = np.array(range(7), dtype=np.float32)
cast_from_np = custom.uint4 if cast_from == TensorProto.UINT4 else custom.int4
data = data.astype(cast_from_np)
expected1 = np.array(
[subbyte.float32_to_4bit_unpacked(x, cast_from_np) for x in data]
cast_from_dtype = (
custom.uint4 if cast_from == TensorProto.UINT4 else custom.int4
)
got = ref.run(None, {"X": data})
self.assertEqual(expected1.tolist(), got[0].tolist())
expected = data
got = ref.run(None, {"X": data.astype(cast_from_dtype)})
self.assertEqual(expected.tolist(), got[0].tolist())

def test_a_function_calling_a_function_once(self):
X = make_tensor_value_info("X", TensorProto.FLOAT, ["N"])
Expand Down

0 comments on commit 4c18a15

Please sign in to comment.