diff --git a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp index 69407df201cfa..4d8027c604cdf 100644 --- a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp +++ b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp @@ -27,13 +27,11 @@ struct SincosFusionPattern : OpRewritePattern { mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath(); math::CosOp cosOp = nullptr; - sinOp->getBlock()->walk([&](math::CosOp op) { + for (auto op : sinOp->getBlock()->getOps()) if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) { cosOp = op; - return WalkResult::interrupt(); + break; } - return WalkResult::advance(); - }); if (!cosOp) return failure(); diff --git a/mlir/test/Dialect/Math/sincos-fusion.mlir b/mlir/test/Dialect/Math/sincos-fusion.mlir index 29fb9f12475b8..cf16f9f02f63a 100644 --- a/mlir/test/Dialect/Math/sincos-fusion.mlir +++ b/mlir/test/Dialect/Math/sincos-fusion.mlir @@ -74,6 +74,29 @@ func.func @sincos_no_fusion_different_block(%arg0 : f32, %flag : i1) -> f32 { func.return %0 : f32 } +// CHECK-LABEL: func.func @sincos_no_fusion_nested_region( +// CHECK-SAME: %[[ARG0:.*]]: f32, +// CHECK-SAME: %[[ARG1:.*]]: i1) -> (f32, f32) { +// CHECK: %[[SIN:.*]] = math.sin %[[ARG0]] : f32 +// CHECK: %[[IF:.*]] = scf.if %[[ARG1]] -> (f32) { +// CHECK: %[[COS:.*]] = math.cos %[[ARG0]] : f32 +// CHECK: scf.yield %[[COS]] : f32 +// CHECK: } else { +// CHECK: scf.yield %[[SIN]] : f32 +// CHECK: } +// CHECK: return %[[SIN]], %[[IF]] : f32, f32 +// CHECK: } +func.func @sincos_no_fusion_nested_region(%arg0 : f32, %flag : i1) -> (f32, f32) { + %s = math.sin %arg0 : f32 + %r = scf.if %flag -> f32 { + %c = math.cos %arg0 : f32 + scf.yield %c : f32 + } else { + scf.yield %s : f32 + } + func.return %s, %r : f32, f32 +} + // CHECK-LABEL: func.func @sincos_fusion_preserve_fastmath( // CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) { // CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] fastmath : f32