diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index 8f1249e3afaf0..20b1bc4a7a0fe 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" #include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" @@ -1107,6 +1108,46 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp, if (!areInnerBoundsInvariant(forOp)) return failure(); + // Refuse if the jammed iteration order would violate a memory dep. UJ at + // factor F reorders dep `d` iff `0 < d_outer < F` and some `d_inner_k < 0`. + SmallVector band; + getPerfectlyNestedLoops(band, forOp); + if (band.size() > 1) { + SmallVector loadAndStoreOps; + band[0]->walk([&](Operation *op) { + if (isa(op)) + loadAndStoreOps.push_back(op); + }); + unsigned numLoops = band.size(); + for (unsigned d = 1; d <= numLoops + 1; ++d) { + for (Operation *srcOp : loadAndStoreOps) { + MemRefAccess srcAccess(srcOp); + for (Operation *dstOp : loadAndStoreOps) { + MemRefAccess dstAccess(dstOp); + SmallVector depComps; + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, d, + /*dependenceConstraints=*/nullptr, &depComps); + if (!hasDependence(result)) + continue; + // d_outer >= F: stripe never spans this dep. + if (!depComps.empty() && depComps[0].lb.has_value() && + *depComps[0].lb >= static_cast(unrollJamFactor)) + continue; + for (unsigned k = 1; k < depComps.size(); ++k) { + const DependenceComponent &c = depComps[k]; + if (c.lb.has_value() && c.ub.has_value() && *c.lb <= *c.ub && + *c.ub < 0) { + LDBG() << "[failed] backward inner dep at depth " << d + << " (k=" << k << "); factor " << unrollJamFactor; + return failure(); + } + } + } + } + } + } + // Gather all sub-blocks to jam upon the loop being unrolled. JamBlockGatherer jbg; jbg.walk(forOp); diff --git a/mlir/test/Dialect/Affine/unroll-jam.mlir b/mlir/test/Dialect/Affine/unroll-jam.mlir index 8ed7fccf7d251..aa29c42772001 100644 --- a/mlir/test/Dialect/Affine/unroll-jam.mlir +++ b/mlir/test/Dialect/Affine/unroll-jam.mlir @@ -550,3 +550,43 @@ func.func @unroll_jam_iter_args_addi(%arg0: memref<21xi32, 1>, %init : i32) { // CHECK-NEXT: [[LOAD3:%[0-9]+]] = affine.load {{.*}}[%[[CONST0]]] // CHECK-NEXT: [[ADD4:%[0-9]+]] = arith.addi [[ADD3]], [[LOAD3]] : i32 // CHECK-NEXT: return + +// Verify that affine-loop-unroll-jam refuses to transform a perfectly nested +// band when its dependence components would be violated by the new iteration +// order. Mirrors the gating already used by `affine-loop-tile`. + +// A flow dependence with distance vector (1, -1): unroll-and-jam at factor 2 +// would interchange the intra-stripe iv with the inner iv and turn the +// (1, -1) dep into a backward distance. The pass must leave the nest +// unchanged. + +// CHECK-LABEL: func @unroll_jam_illegal_flow_dep_1_neg1 +// CHECK: affine.for %{{.*}} = 1 to 5 { +// CHECK-NEXT: affine.for %{{.*}} = 0 to 4 { +// CHECK-NOT: affine.for %{{.*}} = 1 to 5 step 2 +func.func @unroll_jam_illegal_flow_dep_1_neg1(%arr: memref<5x5xi32>) { + affine.for %i = 1 to 5 { + affine.for %j = 0 to 4 { + %v = affine.load %arr[%i - 1, %j + 1] : memref<5x5xi32> + affine.store %v, %arr[%i, %j] : memref<5x5xi32> + } + } + return +} + +// Distance vector (2, -1) at factor 2: d_outer >= F, so the unroll-and-jam +// stripe never spans the dep. The pass must still fire even though an +// `isTilingValid`-style check (which ignores the factor) would refuse it. + +// CHECK-LABEL: func @unroll_jam_legal_outer_distance_geq_factor +// CHECK: affine.for %{{.*}} = 2 to 6 step 2 +// CHECK-NEXT: affine.for %{{.*}} = 0 to 5 { +func.func @unroll_jam_legal_outer_distance_geq_factor(%arr: memref<6x6xi32>) { + affine.for %i = 2 to 6 { + affine.for %j = 0 to 5 { + %v = affine.load %arr[%i - 2, %j + 1] : memref<6x6xi32> + affine.store %v, %arr[%i, %j] : memref<6x6xi32> + } + } + return +}