diff --git a/src/Nncase.Importer/Onnx/Reduce.cs b/src/Nncase.Importer/Onnx/Reduce.cs index b0605c4f9..3020143ca 100644 --- a/src/Nncase.Importer/Onnx/Reduce.cs +++ b/src/Nncase.Importer/Onnx/Reduce.cs @@ -16,11 +16,12 @@ private Expr VisitReduce(in NodeProto op, ReduceOp reduceOp, Expr initValue) return ReduceCore(op, reduceOp, initValue, expr => expr); } - private Expr ReduceCore(in NodeProto op, ReduceOp reduceOp, Expr initValue, Func f) + private Expr ReduceCore(in NodeProto op, ReduceOp reduceOp, Expr initValue, Func f, long opVersion = 999) { var input = GetInputExpr(op, 0); Expr axis; - if ((reduceOp == ReduceOp.Sum && GetOpSet(op) < 13) || GetOpSet(op) < 18) + + if ((reduceOp == ReduceOp.Sum && opVersion < 13) || (reduceOp != ReduceOp.Sum && GetOpSet(op) < 18)) { axis = GetAxesAttribute(op, input); } @@ -63,7 +64,9 @@ private Expr ReduceCore(in NodeProto op, ReduceOp reduceOp, Expr initValue, Func private Expr ReduceSumZero(in NodeProto op, Func f) { - return ReduceCore(op, ReduceOp.Sum, 0f, f); + // Reduce_sum opVersion 13 == other reduce opVersion 18. Axis is not Attributes. + // If GetOpSet(op) > 13, use reduce_sum opVersion 11. Axis is Attributes. + return ReduceCore(op, ReduceOp.Sum, 0f, f, GetOpSet(op) >= 18 ? 13 : 11); } private Expr VisitReduceL1(in NodeProto op)