Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix several issues regarding recent mapping update #4551

Merged
merged 7 commits into from
Sep 30, 2022

Conversation

jcwchen
Copy link
Member

@jcwchen jcwchen commented Sep 28, 2022

Description

Motivation and Context

Recent #4270 refactored existing code, but I forgot to add the case for SequenceProto.MAP in to_list.

cc @gramalingam Please review this PR. Sorry that I should catch this in advance before my PR got merged. Thank you!

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
@jcwchen jcwchen requested a review from a team as a code owner September 28, 2022 16:33
@jcwchen jcwchen changed the title Add missing case for SequenceProto.MAP in numpy_helper.to_list [WIP] Add missing case for SequenceProto.MAP in numpy_helper.to_list Sep 28, 2022
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
@jcwchen jcwchen changed the title [WIP] Add missing case for SequenceProto.MAP in numpy_helper.to_list Fix several issues regarding recent mapping update Sep 28, 2022
Copy link
Member

@linkerzhang linkerzhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@xadupre
Copy link
Contributor

xadupre commented Sep 29, 2022

It is possible to add a unit test to check the warning is raised when using the dictionary and not raised when using the function?

@jcwchen jcwchen changed the title Fix several issues regarding recent mapping update [WIP] Fix several issues regarding recent mapping update Sep 29, 2022
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
@jcwchen jcwchen changed the title [WIP] Fix several issues regarding recent mapping update Fix several issues regarding recent mapping update Sep 29, 2022
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
@jcwchen
Copy link
Member Author

jcwchen commented Sep 29, 2022

It is possible to add a unit test to check the warning is raised when using the dictionary and not raised when using the function?

Good idea. I just added tests by TestHelperMappingFunctions and actually np_dtype_to_tensor_dtype needs to be fixed as well. With the added tests, now I am confident that these new introduced helpers should not unexpectedly throw a deprecation warnings. Please review it. Thank you!

@@ -724,6 +724,26 @@ def test_make_tensor_raw(tensor_dtype: int) -> None:
np.testing.assert_equal(np_array, numpy_helper.to_array(tensor))


# TODO (#4554): remove this test after the deprecation period
# Test these new functions should not raise any depreaction warnings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Misspelling deprecation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected. Thanks!

def test_tensor_dtype_to_np_dtype_not_throw_warning(self) -> None:
_ = helper.tensor_dtype_to_np_dtype(TensorProto.FLOAT)

@pytest.mark.filterwarnings("error::DeprecationWarning")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it is worth checking what happens in case of bfloat16.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests for bfloat16 looks useful since the mapping for bfloat16 is quite confusing. Adding these tests can prevent future regression. Just added them. PTAL. Thanks!

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
@@ -542,7 +542,7 @@ message TensorProto {
// float16 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16
// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16 or BFLOAT16
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note that I think BFLOAT16 is missing here.

helper.tensor_dtype_to_field(TensorProto.BFLOAT16), "int32_data"
)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can be done in another PR but I would check all the types:

    def test_numeric_types(self):  # type: ignore
        dtypes = [
            np.float16,
            np.float32,
            np.float64,
            np.int8,
            np.int16,
            np.int32,
            np.int64,
            np.uint8,
            np.uint16,
            np.uint32,
            np.uint64,
            np.complex64,
            np.complex128,
        ]
        for dt in dtypes:
            with self.subTest(dtype=dt):
                t = np.array([0, 1, 2], dtype=dt)
                ot = from_array(t)
                u = to_array(ot)
                self.assertEqual(t.dtype, u.dtype)
                assert_almost_equal(t, u)

    def test_make_tensor(self):  # type: ignore
        for pt, dt in TENSOR_TYPE_TO_NP_TYPE.items():
            if pt == TensorProto.BFLOAT16:
                continue
            with self.subTest(dt=dt, pt=pt, raw=False):
                if pt == TensorProto.STRING:
                    t = np.array([["i0", "i1", "i2"], ["i6", "i7", "i8"]], dtype=dt)
                else:
                    t = np.array([[0, 1, 2], [6, 7, 8]], dtype=dt)
                ot = make_tensor("test", pt, t.shape, t, raw=False)
                self.assertFalse(ot is None)
                u = to_array(ot)
                self.assertEqual(t.dtype, u.dtype)
                self.assertEqual(t.tolist(), u.tolist())
            with self.subTest(dt=dt, pt=pt, raw=True):
                t = np.array([[0, 1, 2], [6, 7, 8]], dtype=dt)
                if pt == TensorProto.STRING:
                    with self.assertRaises(TypeError):
                        make_tensor("test", pt, t.shape, t.tobytes(), raw=True)
                else:
                    ot = make_tensor("test", pt, t.shape, t.tobytes(), raw=True)
                    self.assertFalse(ot is None)
                    u = to_array(ot)
                    self.assertEqual(t.dtype, u.dtype)
                    assert_almost_equal(t, u)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @xadupre for the suggestion. It looks good to me, but I am inclined to have it in another PR (this PR more focuses on bug fixes) Feel free to let me know if you still have other concern.

@jcwchen jcwchen merged commit 46b9627 into onnx:main Sep 30, 2022
@jcwchen jcwchen deleted the jcw/fix-map-helper branch September 30, 2022 18:54
broune pushed a commit to broune/onnx that referenced this pull request May 6, 2023
* add missing case for Squence.MAP

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* remove warning and improve printable_graph

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* fix lint and mypy

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* _NP_TYPE_TO_TENSOR_TYPE and mention available in ONNX 1.13

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* add test to prevent throwing deprecation in new functions

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* add bfloat test  for tensor_dtype_to_field

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* add more tests for bfloat16

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

printable_graph error and INVALID_GRAPH (ONNXRuntime) when using bf16 datatype
4 participants