diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp index f1203b2bdfee5..e3717aa9d940e 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -94,7 +94,9 @@ static void specializeForLoopForUnrolling(ForOp op) { OpBuilder b(op); IRMapping map; - Value constant = arith::ConstantIndexOp::create(b, op.getLoc(), minConstant); + Value constant = arith::ConstantOp::create( + b, op.getLoc(), + IntegerAttr::get(op.getUpperBound().getType(), minConstant)); Value cond = arith::CmpIOp::create(b, op.getLoc(), arith::CmpIPredicate::eq, bound, constant); map.map(bound, constant); @@ -150,6 +152,9 @@ static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp, ValueRange{forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()}); + if (splitBound.getType() != forOp.getLowerBound().getType()) + splitBound = b.createOrFold( + loc, forOp.getLowerBound().getType(), splitBound); // Create ForOp for partial iteration. b.setInsertionPointAfter(forOp); @@ -230,6 +235,9 @@ LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp, auto loc = forOp.getLoc(); Value splitBound = b.createOrFold( loc, ubMap, ValueRange{forOp.getLowerBound(), forOp.getStep()}); + if (splitBound.getType() != forOp.getUpperBound().getType()) + splitBound = b.createOrFold( + loc, forOp.getUpperBound().getType(), splitBound); // Peel the first iteration. IRMapping map; diff --git a/mlir/test/Dialect/SCF/for-loop-peeling.mlir b/mlir/test/Dialect/SCF/for-loop-peeling.mlir index f59b79603b489..03c446c11981b 100644 --- a/mlir/test/Dialect/SCF/for-loop-peeling.mlir +++ b/mlir/test/Dialect/SCF/for-loop-peeling.mlir @@ -67,6 +67,41 @@ func.func @fully_static_bounds() -> i32 { // ----- +// CHECK: func @fully_static_bounds_integers( +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32 +// CHECK: %[[LOOP:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C16]] +// CHECK-SAME: step %[[C4]] iter_args(%[[ACC:.*]] = %[[C0]]) -> (i32) +// CHECK: %[[MAP:.*]] = affine.min +// CHECK: %[[MAP_CAST:.*]] = arith.index_cast %[[MAP]] +// CHECK: %[[ADD:.*]] = arith.addi %[[ACC]], %[[MAP_CAST]] : i32 +// CHECK: scf.yield %[[ADD]] +// CHECK: } +// CHECK: %[[RESULT:.*]] = arith.addi %[[LOOP]], %[[C1]] : i32 +// CHECK: return %[[RESULT]] +#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)> +func.func @fully_static_bounds_integers() -> i32 { + %c0_i32 = arith.constant 0 : i32 + %lb = arith.constant 0 : i32 + %step = arith.constant 4 : i32 + %ub = arith.constant 17 : i32 + %r = scf.for %iv = %lb to %ub step %step + iter_args(%arg = %c0_i32) -> i32 : i32 { + %ub_index = arith.index_cast %ub : i32 to index + %iv_index = arith.index_cast %iv : i32 to index + %step_index = arith.index_cast %step : i32 to index + %s = affine.min #map(%ub_index, %iv_index)[%step_index] + %casted = arith.index_cast %s : index to i32 + %0 = arith.addi %arg, %casted : i32 + scf.yield %0 : i32 + } + return %r : i32 +} + +// ----- + // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> ((s0 floordiv 4) * 4)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0)> // CHECK: func @dynamic_upper_bound(