From 2fcd68ac1ff8b8fdb9af62cdbe9fcbcd6466f485 Mon Sep 17 00:00:00 2001 From: Vitalii Shutov Date: Tue, 4 Nov 2025 16:58:59 +0000 Subject: [PATCH 1/3] [TOSA] Fix empty-dim reductions Teach the TorchToTosa reducer that an explicit empty dim list means "all dims" and cast the result back to the requested dtype. Add MLIR and e2e regression cases and update XFAILs. Change-Id: Ibd1be38d219ad5c1986eb4a641efbb9ff0cb6a55 --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 5 ++ .../TorchToTosa/TosaLegalizeCommon.cpp | 12 ++++- projects/pt1/e2e_testing/xfail_sets.py | 4 +- .../test_suite/reduction.py | 46 +++++++++++++++++++ test/Conversion/TorchToTosa/basic.mlir | 23 ++++++++++ 5 files changed, 87 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 0bc93f711ad6..0f19dad4cb1f 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1089,6 +1089,11 @@ class ConvertAtenMultipleDimsReductionOp for (int64_t i = 0; i < inputRank; i++) reduceDims.push_back(i); } + // PyTorch treats an explicit empty list the same as "reduce all dims". + if (reduceDims.empty()) { + for (int64_t i = 0; i < inputRank; i++) + reduceDims.push_back(i); + } int64_t N = reduceDims.size(); for (unsigned i = 0; i < N; i++) { diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 02d1390ed148..444a2bdd2508 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -782,13 +782,23 @@ std::optional convertReduceOpCommon( // Optionally squeeze out the reduced axes. if (!keep_dims) { + auto squeezedType = + RankedTensorType::get(output_shape, reduce_element_type); auto reshape_op = CreateOpAndInfer( - rewriter, op->getLoc(), output_type, val, + rewriter, op->getLoc(), squeezedType, val, tosa::getTosaConstShape(rewriter, op->getLoc(), output_shape)); val = reshape_op.getResult(); } } + // Ensure the result element type matches the expected output type. + if (val.getType() != output_type) { + auto casted = tosa::tosaCastTensorToType(rewriter, val, output_type); + if (!casted) + return std::nullopt; + val = casted.value(); + } + return val; } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 81071c6ab058..efbfaf259ac2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3434,6 +3434,8 @@ "ElementwiseClampMinModule_bfloat16", "ElementwiseClampModule_bfloat16", "ElementwiseReluModule_bfloat16", + # torch.onnx.errors.SymbolicValueError: Cannot determine scalar type for this '' + "ReduceSumEmptyDimListInt8ToInt32Module_basic", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): @@ -3846,7 +3848,6 @@ "MaxPool3dWithIndicesNonDefaultParamsModule_basic", "MaxPool3dWithIndicesNonDefaultStrideModule_basic", "MaxPool3dWithIndicesStaticModule_basic", - "MeanDimEmptyDimModule_basic", "MlGroupNormManualModule_basic", "MlGroupNormModule_basic", "MlLayerNormManualModule_basic", @@ -3901,7 +3902,6 @@ "ReduceL3NormKeepDimComplexModule_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - "ReduceSumDimIntListEmptyDimModule_basic", "RollModule_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 0eb0545e7f11..2e4ba9c4ccfc 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -58,6 +58,52 @@ def ReduceSumDtypeFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceSumEmptyDimListInt8ToInt32Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int8, True), + ] + ) + def forward(self, a): + return torch.sum(a, dim=[], dtype=torch.int32) + + +@register_test_case(module_factory=lambda: ReduceSumEmptyDimListInt8ToInt32Module()) +def ReduceSumEmptyDimListInt8ToInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, low=-16, high=16).to(torch.int8)) + + +# ============================================================================== + + +class ReduceSumEmptyDimListInt8Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int8, True), + ] + ) + def forward(self, a): + return torch.sum(a, dim=[]) + + +@register_test_case(module_factory=lambda: ReduceSumEmptyDimListInt8Module()) +def ReduceSumEmptyDimListInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, low=-16, high=16).to(torch.int8)) + + +# ============================================================================== + + class ReduceSumElementTypeBoolModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index d100fe9dcfde..543dc09a65b2 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -311,6 +311,29 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[3,4,5,6],f32>) -> ! // ----- +// CHECK-LABEL: func.func @test_reduce_sum_empty_dims$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 0 : i32} : (tensor<2x3x4xf32>) -> tensor<1x3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reduce_sum %[[VAL_4]] {axis = 1 : i32} : (tensor<1x3x4xf32>) -> tensor<1x1x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reduce_sum %[[VAL_5]] {axis = 2 : i32} : (tensor<1x1x4xf32>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x1xf32>, !tosa.shape<0>) -> tensor +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[],f32> +// CHECK: } +func.func @test_reduce_sum_empty_dims$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> { + %none = torch.constant.none + %false = torch.constant.bool false + %empty = torch.prim.ListConstruct : () -> !torch.list + %0 = torch.aten.sum.dim_IntList %arg0, %empty, %false, %none : !torch.vtensor<[2,3,4],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_linalg_vector_norm$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,151,64],f32>) -> !torch.vtensor<[3,151,1],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,151,64],f32> -> tensor<3x151x64xf32> From 3249880a938468f68bbb24bd9313a62d8475e9d1 Mon Sep 17 00:00:00 2001 From: Vitalii Shutov Date: Tue, 25 Nov 2025 13:29:17 +0000 Subject: [PATCH 2/3] fix tests Change-Id: I8015810a71e31adaae19b0d0a839f5fe3bebf8bb --- test/Conversion/TorchToTosa/basic.mlir | 56 ++++++++++++++++++-------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 543dc09a65b2..8276f3f89c8f 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -312,24 +312,48 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[3,4,5,6],f32>) -> ! // ----- // CHECK-LABEL: func.func @test_reduce_sum_empty_dims$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.none -// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 0 : i32} : (tensor<2x3x4xf32>) -> tensor<1x3x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.reduce_sum %[[VAL_4]] {axis = 1 : i32} : (tensor<1x3x4xf32>) -> tensor<1x1x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.reduce_sum %[[VAL_5]] {axis = 2 : i32} : (tensor<1x1x4xf32>) -> tensor<1x1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.const_shape -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x1xf32>, !tosa.shape<0>) -> tensor -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[],f32> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[],f32> +// CHECK-SAME: %[[INPUT_F32:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[INPUT_F32_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_F32]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[EMPTY_DIMS:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[SUM_DIM0:.*]] = tosa.reduce_sum %[[INPUT_F32_TENSOR]] {axis = 0 : i32} : (tensor<2x3x4xf32>) -> tensor<1x3x4xf32> +// CHECK: %[[SUM_DIM1:.*]] = tosa.reduce_sum %[[SUM_DIM0]] {axis = 1 : i32} : (tensor<1x3x4xf32>) -> tensor<1x1x4xf32> +// CHECK: %[[SUM_DIM2:.*]] = tosa.reduce_sum %[[SUM_DIM1]] {axis = 2 : i32} : (tensor<1x1x4xf32>) -> tensor<1x1x1xf32> +// CHECK: %[[SCALAR_SHAPE:.*]] = tosa.const_shape +// CHECK: %[[RESHAPED_SCALAR:.*]] = tosa.reshape %[[SUM_DIM2]], %[[SCALAR_SHAPE]] : (tensor<1x1x1xf32>, !tosa.shape<0>) -> tensor +// CHECK: %[[RESULT_F32:.*]] = torch_c.from_builtin_tensor %[[RESHAPED_SCALAR]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[RESULT_F32]] : !torch.vtensor<[],f32> // CHECK: } func.func @test_reduce_sum_empty_dims$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> { - %none = torch.constant.none - %false = torch.constant.bool false - %empty = torch.prim.ListConstruct : () -> !torch.list - %0 = torch.aten.sum.dim_IntList %arg0, %empty, %false, %none : !torch.vtensor<[2,3,4],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> - return %0 : !torch.vtensor<[],f32> + %dtype_none = torch.constant.none + %keep_dims_false = torch.constant.bool false + %all_dims_list = torch.prim.ListConstruct : () -> !torch.list + %sum_all_dims = torch.aten.sum.dim_IntList %arg0, %all_dims_list, %keep_dims_false, %dtype_none : !torch.vtensor<[2,3,4],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + return %sum_all_dims : !torch.vtensor<[],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_sum_empty_dims_i8_to_i32$basic( +// CHECK-SAME: %[[INPUT_I8:.*]]: !torch.vtensor<[2,3,4],si8>) -> !torch.vtensor<[],si32> { +// CHECK: %[[INPUT_I8_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_I8]] : !torch.vtensor<[2,3,4],si8> -> tensor<2x3x4xi8> +// CHECK: %[[DTYPE_I32:.*]] = torch.constant.int 3 +// CHECK: %[[EMPTY_DIMS:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[CAST_INPUT_TO_I32:.*]] = tosa.cast %[[INPUT_I8_TENSOR]] : (tensor<2x3x4xi8>) -> tensor<2x3x4xi32> +// CHECK: %[[SUM_DIM0:.*]] = tosa.reduce_sum %[[CAST_INPUT_TO_I32]] {axis = 0 : i32} : (tensor<2x3x4xi32>) -> tensor<1x3x4xi32> +// CHECK: %[[SUM_DIM1:.*]] = tosa.reduce_sum %[[SUM_DIM0]] {axis = 1 : i32} : (tensor<1x3x4xi32>) -> tensor<1x1x4xi32> +// CHECK: %[[SUM_DIM2:.*]] = tosa.reduce_sum %[[SUM_DIM1]] {axis = 2 : i32} : (tensor<1x1x4xi32>) -> tensor<1x1x1xi32> +// CHECK: %[[SCALAR_SHAPE:.*]] = tosa.const_shape +// CHECK: %[[RESHAPED_SCALAR:.*]] = tosa.reshape %[[SUM_DIM2]], %[[SCALAR_SHAPE]] : (tensor<1x1x1xi32>, !tosa.shape<0>) -> tensor +// CHECK: %[[RESULT_I32:.*]] = torch_c.from_builtin_tensor %[[RESHAPED_SCALAR]] : tensor -> !torch.vtensor<[],si32> +// CHECK: return %[[RESULT_I32]] : !torch.vtensor<[],si32> +// CHECK: } +func.func @test_reduce_sum_empty_dims_i8_to_i32$basic(%arg0: !torch.vtensor<[2,3,4],si8>) -> !torch.vtensor<[],si32> { + %dtype_i32 = torch.constant.int 3 + %keep_dims_false = torch.constant.bool false + %all_dims_list = torch.prim.ListConstruct : () -> !torch.list + %sum_all_dims_to_i32 = torch.aten.sum.dim_IntList %arg0, %all_dims_list, %keep_dims_false, %dtype_i32 : !torch.vtensor<[2,3,4],si8>, !torch.list, !torch.bool, !torch.int -> !torch.vtensor<[],si32> + return %sum_all_dims_to_i32 : !torch.vtensor<[],si32> } // ----- From bcfcef8fffd2d965124fc782acfc4d254962456b Mon Sep 17 00:00:00 2001 From: Vitalii Shutov Date: Tue, 25 Nov 2025 15:13:36 +0000 Subject: [PATCH 3/3] fix xfails Change-Id: I1248f00fb1aa07f8f1422b4ca913041e9f8a056e --- projects/pt1/e2e_testing/xfail_sets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 5f40c4e10b3d..d9eb6c05a957 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3824,7 +3824,6 @@ "MaxPool3dWithIndicesNonDefaultStrideModule_basic", "MaxPool3dWithIndicesStaticModule_basic", "MaxPool3dSingleIntTupleDilationModule_basic", - "MeanDimEmptyDimModule_basic", "MlGroupNormManualModule_basic", "MlGroupNormModule_basic", "MlLayerNormManualModule_basic",