diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeTransposeGenericOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeTransposeGenericOps.cpp index ba33194c11f9..7a03bc81510e 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeTransposeGenericOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeTransposeGenericOps.cpp @@ -39,8 +39,11 @@ struct TransposeGenericOpPattern : public OpRewritePattern { std::optional mapForInterchange; for (auto operand : genericOp.getDpsInputOperands()) { - auto producer = operand->get().getDefiningOp(); - if (!producer || !llvm::hasSingleElement(producer->getUsers())) + // Check that the producer is a named op or a reduction op (i.e. not + // elementwise op) with a single use. + auto producer = operand->get().getDefiningOp(); + if (!producer || !llvm::hasSingleElement(producer->getUsers()) || + linalg::isElementwise(producer)) continue; // check if the generic op has a non-identity map for the operand. diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/interchange_transpose_generic_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/interchange_transpose_generic_ops.mlir index cc9fe7a619cc..9da0a964aabf 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/interchange_transpose_generic_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/interchange_transpose_generic_ops.mlir @@ -20,9 +20,38 @@ util.func @supported_conv(%arg0 : tensor<2x130x130x16xf16>, %arg1 : tensor<3x3x1 } -> tensor<2x320x128x128xf16> util.return %truncf : tensor<2x320x128x128xf16> } -// CHECK-LABEL: func public @supported_conv +// CHECK-LABEL: func public @supported_conv( // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>] // CHECK-SAME: ins(%[[CONV]] : // CHECK: return %[[GENERIC]] + +// ----- + +util.func @generalize_to_any_linalg_op(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor, %arg3 : tensor, %arg4 : tensor) -> tensor { + %c0_i64 = arith.constant 0 : i64 + %0 = linalg.conv_2d_nhwc_hwcf_q { + dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} + ins(%arg0, %arg1, %c0_i64, %c0_i64 : tensor, tensor, i64, i64) + outs(%arg2 : tensor) -> tensor + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3, d0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%0 : tensor) outs(%arg4 : tensor) { + ^bb0(%in: i64, %out: i8): + %3 = arith.trunci %in : i64 to i32 + %4 = arith.sitofp %3 : i32 to f32 + %5 = arith.fptosi %4 : f32 to i8 + linalg.yield %5 : i8 + } -> tensor + util.return %2 : tensor +} +// CHECK-LABEL: func public @generalize_to_any_linalg_op( +// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>] +// CHECK: return %[[RESULT]]