diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 4382579a5368..42767c2cc332 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -352,9 +352,15 @@ class JSONAttrSetter : public AttrVisitor { template void ParseValue(const char* key, T* value) const { std::istringstream is(GetValue(key)); - is >> *value; - if (is.fail()) { - LOG(FATAL) << "Wrong value format for field " << key; + if (is.str() == "inf") { + *value = std::numeric_limits::infinity(); + } else if (is.str() == "-inf") { + *value = -std::numeric_limits::infinity(); + } else { + is >> *value; + if (is.fail()) { + LOG(FATAL) << "Wrong value format for field " << key; + } } } void Visit(const char* key, double* value) final { ParseValue(key, value); } diff --git a/tests/python/unittest/test_node_reflection.py b/tests/python/unittest/test_node_reflection.py index 3a7318c75761..d375fa0f75c6 100644 --- a/tests/python/unittest/test_node_reflection.py +++ b/tests/python/unittest/test_node_reflection.py @@ -28,6 +28,16 @@ def test_const_saveload_json(): zz = tvm.ir.load_json(json_str) tvm.ir.assert_structural_equal(zz, z, map_free_vars=True) +def _test_infinity_value(value, dtype): + x = tvm.tir.const(value, dtype) + json_str = tvm.ir.save_json(x) + tvm.ir.assert_structural_equal(x, tvm.ir.load_json(json_str)) + +def test_infinity_value(): + _test_infinity_value(float("inf"), 'float64') + _test_infinity_value(float("-inf"), 'float64') + _test_infinity_value(float("inf"), 'float32') + _test_infinity_value(float("-inf"), 'float32') def test_make_smap(): # save load json @@ -145,3 +155,4 @@ def test_dict(): test_make_sum() test_pass_config() test_dict() + test_infinity_value()