diff --git a/src/Nncase.Importer/Onnx/Reduce.cs b/src/Nncase.Importer/Onnx/Reduce.cs index e3baa6822e..63b8432175 100644 --- a/src/Nncase.Importer/Onnx/Reduce.cs +++ b/src/Nncase.Importer/Onnx/Reduce.cs @@ -11,12 +11,12 @@ namespace Nncase.Importer { public partial class OnnxImporter { - private Expr VisitReduce(in NodeProto op, ReduceOp reduceOp, float initValue) + private Expr VisitReduce(in NodeProto op, ReduceOp reduceOp, Expr initValue) { return ReduceCore(op, reduceOp, initValue, expr => expr); } - private Expr ReduceCore(in NodeProto op, ReduceOp reduceOp, float initValue, Func f) + private Expr ReduceCore(in NodeProto op, ReduceOp reduceOp, Expr initValue, Func f) { var input = GetInputExpr(op, 0); Expr axis; @@ -51,6 +51,12 @@ private Expr ReduceCore(in NodeProto op, ReduceOp reduceOp, float initValue, Fun var x when x == ReduceOp.Max && input.CheckedDataType == DataTypes.Int32 => F.Tensors.Reduce(reduceOp, f(input), axis, int.MinValue, keepDims), var x when x == ReduceOp.Min && input.CheckedDataType == DataTypes.Int64 => F.Tensors.Reduce(reduceOp, f(input), axis, long.MaxValue, keepDims), var x when x == ReduceOp.Min && input.CheckedDataType == DataTypes.Int32 => F.Tensors.Reduce(reduceOp, f(input), axis, int.MaxValue, keepDims), + var x when x == ReduceOp.Max && input.CheckedDataType == DataTypes.Float32 => F.Tensors.Reduce(reduceOp, f(input), axis, float.MinValue, keepDims), + var x when x == ReduceOp.Max && input.CheckedDataType == DataTypes.Float16 => F.Tensors.Reduce(reduceOp, f(input), axis, Half.MinValue, keepDims), + var x when x == ReduceOp.Max && input.CheckedDataType == DataTypes.BFloat16 => F.Tensors.Reduce(reduceOp, f(input), axis, BFloat16.RoundToBFloat16(float.MinValue), keepDims), + var x when x == ReduceOp.Min && input.CheckedDataType == DataTypes.Float32 => F.Tensors.Reduce(reduceOp, f(input), axis, float.MaxValue, keepDims), + var x when x == ReduceOp.Min && input.CheckedDataType == DataTypes.Float16 => F.Tensors.Reduce(reduceOp, f(input), axis, Half.MaxValue, keepDims), + var x when x == ReduceOp.Min && input.CheckedDataType == DataTypes.BFloat16 => F.Tensors.Reduce(reduceOp, f(input), axis, BFloat16.RoundToBFloat16(float.MaxValue), keepDims), _ => F.Tensors.Reduce(reduceOp, f(input), axis, F.Tensors.Cast(initValue, input.CheckedDataType), keepDims), }; } diff --git a/src/Nncase.Importer/Onnx/ReduceWindow2D.cs b/src/Nncase.Importer/Onnx/ReduceWindow2D.cs index 07273a38ae..25bcb23873 100644 --- a/src/Nncase.Importer/Onnx/ReduceWindow2D.cs +++ b/src/Nncase.Importer/Onnx/ReduceWindow2D.cs @@ -14,7 +14,7 @@ namespace Nncase.Importer public partial class OnnxImporter { // isGlobal used for GlobalXXXPool - private Expr VisitReduceWindow2D(in NodeProto op, ReduceOp reduceOp, float initValue, bool isGlobal = false) + private Expr VisitReduceWindow2D(in NodeProto op, ReduceOp reduceOp, Expr initValue, bool isGlobal = false) { // auto_pad had been DEPRECATED var input = GetInputExpr(op, 0); diff --git a/tests/importer/onnx_/basic/test_reduce.py b/tests/importer/onnx_/basic/test_reduce.py index 2404899474..879d7f9a7b 100644 --- a/tests/importer/onnx_/basic/test_reduce.py +++ b/tests/importer/onnx_/basic/test_reduce.py @@ -22,7 +22,7 @@ import numpy as np -def _make_module(in_shape, reduce_op, axes, keepdims, op_version): +def _make_module(in_shape, in_datatype, reduce_op, axes, keepdims, op_version): inputs = [] outputs = [] initializers = [] @@ -30,14 +30,14 @@ def _make_module(in_shape, reduce_op, axes, keepdims, op_version): nodes = [] # input - input = helper.make_tensor_value_info('input', TensorProto.FLOAT, in_shape) + input = helper.make_tensor_value_info('input', in_datatype, in_shape) inputs.append('input') # output kd = 1 if keepdims is None else keepdims data = np.ones(in_shape) out_shape = np.prod(data, axis=tuple(axes), keepdims=kd).shape - output = helper.make_tensor_value_info('output', TensorProto.FLOAT, out_shape) + output = helper.make_tensor_value_info('output', in_datatype, out_shape) outputs.append('output') # axes @@ -73,6 +73,11 @@ def _make_module(in_shape, reduce_op, axes, keepdims, op_version): [1, 3, 16, 16] ] +in_datatypes = [ + TensorProto.FLOAT, + TensorProto.FLOAT16 +] + reduce_ops = [ 'ReduceMax', 'ReduceMean', @@ -108,13 +113,14 @@ def _make_module(in_shape, reduce_op, axes, keepdims, op_version): @pytest.mark.parametrize('in_shape', in_shapes) +@pytest.mark.parametrize('in_datatype', in_datatypes) @pytest.mark.parametrize('reduce_op', reduce_ops) @pytest.mark.parametrize('axes', axes_list) @pytest.mark.parametrize('keepdims', keepdims_lists) @pytest.mark.parametrize('op_version', op_version_lists) -def test_reduce(in_shape, reduce_op, axes, keepdims, request, op_version): +def test_reduce(in_shape, in_datatype, reduce_op, axes, keepdims, request, op_version): if len(axes) <= len(in_shape): - model_def = _make_module(in_shape, reduce_op, axes, keepdims, op_version) + model_def = _make_module(in_shape, in_datatype, reduce_op, axes, keepdims, op_version) runner = OnnxTestRunner(request.node.name) model_file = runner.from_onnx_helper(model_def)