diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index 92d8846a4640d..ee893c14c9c6b 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -1056,12 +1056,6 @@ std::optional ComputationSliceState::isSliceValid() const { LDBG() << "Unable to compute source's domain"; return std::nullopt; } - // As the set difference utility currently cannot handle symbols in its - // operands, validity of the slice cannot be determined. - if (srcConstraints.getNumSymbolVars() > 0) { - LDBG() << "Cannot handle symbols in source domain"; - return std::nullopt; - } // TODO: Handle local vars in the source domains while using the 'projectOut' // utility below. Currently, aligning is not done assuming that there will be // no local vars in the source domain. @@ -1082,6 +1076,8 @@ std::optional ComputationSliceState::isSliceValid() const { // domain completely in terms of source's IVs. sliceConstraints.projectOut(ivs.size(), sliceConstraints.getNumVars() - ivs.size()); + srcConstraints.projectOut(ivs.size(), + srcConstraints.getNumVars() - ivs.size()); LDBG() << "Domain of the source of the slice:\n" << "Source constraints:" << srcConstraints diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir index 04c8c3ee809a1..d6884fb921ad2 100644 --- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir +++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir @@ -771,3 +771,62 @@ func.func @memref_cast_reused(%arg: memref<*xf32>) { // SIBLING-MAXIMAL-NEXT: affine.store return } + +// ----- + +// Test with symbolic loop bounds. + +// PRODUCER-CONSUMER-MAXIMAL-LABEL: func @fusion_non_constant_bounds_0 +func.func @fusion_non_constant_bounds_0(%arg0: memref) { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 2.000000e+00 : f32 + %c0 = arith.constant 0 : index + %dim = memref.dim %arg0, %c0 : memref + affine.for %arg1 = 0 to %dim { + %0 = affine.load %arg0[%arg1] : memref + %1 = arith.addf %0, %cst : f32 + affine.store %1, %arg0[%arg1] : memref + } + affine.for %arg1 = 0 to %dim { + %0 = affine.load %arg0[%arg1] : memref + %1 = arith.addf %0, %cst_0 : f32 + affine.store %1, %arg0[%arg1] : memref + } + // PRODUCER-CONSUMER-MAXIMAL: affine.for %[[idx:.*]] = 0 to %{{.*}} { + // PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.load %[[arr:.*]][%[[idx]]] : memref + // PRODUCER-CONSUMER-MAXIMAL-NEXT: arith.addf + // PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.store %{{.*}}, %[[arr]][%[[idx]]] : memref + // PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.load %[[arr:.*]][%[[idx]]] : memref + // PRODUCER-CONSUMER-MAXIMAL-NEXT: arith.addf + // PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.store %{{.*}}, %[[arr]][%[[idx]]] : memref + // PRODUCER-CONSUMER-MAXIMAL-NEXT: } + return +} + +// PRODUCER-CONSUMER-MAXIMAL-LABEL: func @fusion_non_constant_bounds_1 +func.func @fusion_non_constant_bounds_1(%N: index, %M: memref, %cst: f32) { + affine.for %i = 0 to %N { + affine.store %cst, %M[%i] : memref + } + affine.for %i = 0 to %N { + affine.load %M[%i] : memref + } + // Should be fused. + // PRODUCER-CONSUMER-MAXIMAL: affine.for + // PRODUCER-CONSUMER-MAXIMAL-NOT: affine.for + + // Bounds not matching. Still fused; source remains. + // PRODUCER-CONSUMER-MAXIMAL: affine.for + // PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.store + affine.for %i = 0 to %N { + affine.store %cst, %M[%i] : memref + } + // PRODUCER-CONSUMER-MAXIMAL: affine.for + // PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.store + // PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.load + affine.for %i = 1 to %N { + affine.load %M[%i] : memref + } + + return +}