diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h index df4145db90a61..daf7976118cea 100644 --- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h @@ -547,10 +547,12 @@ struct MemRefRegion { /// use int64_t instead of uint64_t since index types can be at most /// int64_t. `lbs` are set to the lower bound maps for each of the rank /// dimensions where each of these maps is purely symbolic in the constraints - /// set's symbols. - std::optional getConstantBoundingSizeAndShape( - SmallVectorImpl *shape = nullptr, - SmallVectorImpl *lbs = nullptr) const; + /// set's symbols. If `minShape` is provided, each computed bound is at least + /// `minShape[d]` for dimension `d`. + std::optional + getConstantBoundingSizeAndShape(SmallVectorImpl *shape = nullptr, + SmallVectorImpl *lbs = nullptr, + ArrayRef minShape = {}) const; /// Gets the lower and upper bound map for the dimensional variable at /// `pos`. diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index f38493bc9a96e..4e934a3b6e580 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -25,6 +25,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" +#include #include #define DEBUG_TYPE "analysis-utils" @@ -1158,10 +1159,12 @@ unsigned MemRefRegion::getRank() const { } std::optional MemRefRegion::getConstantBoundingSizeAndShape( - SmallVectorImpl *shape, SmallVectorImpl *lbs) const { + SmallVectorImpl *shape, SmallVectorImpl *lbs, + ArrayRef minShape) const { auto memRefType = cast(memref.getType()); MLIRContext *context = memref.getContext(); unsigned rank = memRefType.getRank(); + assert(minShape.empty() || minShape.size() == rank); if (shape) shape->reserve(rank); @@ -1203,12 +1206,14 @@ std::optional MemRefRegion::getConstantBoundingSizeAndShape( lb = AffineMap::get(/*dimCount=*/0, cstWithShapeBounds.getNumSymbolVars(), /*result=*/getAffineConstantExpr(0, context)); } - numElements *= diffConstant; + int64_t finalDiff = + minShape.empty() ? diffConstant : std::max(diffConstant, minShape[d]); + numElements *= finalDiff; // Populate outputs if available. if (lbs) lbs->push_back(lb); if (shape) - shape->push_back(diffConstant); + shape->push_back(finalDiff); } return numElements; } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index ff0157eb9e4f3..0fa140027b4c3 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -28,6 +28,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" +#include #include #include #include @@ -376,10 +377,28 @@ static Value createPrivateMemRef(AffineForOp forOp, SmallVector newShape; SmallVector lbs; lbs.reserve(rank); - // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed - // by 'srcStoreOpInst' at depth 'dstLoopDepth'. + SmallVector minShape; + ArrayRef minShapeRef; + if (auto vectorStore = dyn_cast(srcStoreOp)) { + ArrayRef vectorShape = vectorStore.getVectorType().getShape(); + unsigned vectorRank = vectorShape.size(); + if (vectorRank > rank) { + LDBG() << "Private memref creation unsupported for vector store with " + << "rank greater than memref rank"; + return nullptr; + } + minShape.assign(rank, 0); + for (unsigned i = 0; i < vectorRank; ++i) { + unsigned memDim = rank - vectorRank + i; + int64_t vecDim = vectorShape[i]; + assert(!ShapedType::isDynamic(vecDim) && + "vector store should have static shape"); + minShape[memDim] = std::max(minShape[memDim], vecDim); + } + minShapeRef = minShape; + } std::optional numElements = - region.getConstantBoundingSizeAndShape(&newShape, &lbs); + region.getConstantBoundingSizeAndShape(&newShape, &lbs, minShapeRef); assert(numElements && "non-constant number of elts in local buffer"); const FlatAffineValueConstraints *cst = region.getConstraints(); diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp index c6abb0d734d88..48d1db15d84cb 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp @@ -31,6 +31,15 @@ using namespace mlir; using namespace mlir::affine; +/// Returns the vector type associated with an affine vector load/store op. +static std::optional getAffineVectorType(Operation *op) { + if (auto vectorLoad = dyn_cast(op)) + return vectorLoad.getVectorType(); + if (auto vectorStore = dyn_cast(op)) + return vectorStore.getVectorType(); + return std::nullopt; +} + // Gathers all load and store memref accesses in 'opA' into 'values', where // 'values[memref] == true' for each store operation. static void getLoadAndStoreMemRefAccesses(Operation *opA, @@ -334,6 +343,40 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp, break; } + // Guard vector fusion by matching producer/consumer vector shapes on actual + // dependence pairs (here we duplicate the early dependence check used in + // `computeSliceUnion` to avoid rejecting disjoint accesses). + for (Operation *srcOp : strategyOpsA) { + MemRefAccess srcAccess(srcOp); + auto srcVectorType = getAffineVectorType(srcOp); + bool srcIsRead = isa(srcOp); + for (Operation *dstOp : opsB) { + MemRefAccess dstAccess(dstOp); + if (srcAccess.memref != dstAccess.memref) + continue; + bool dstIsRead = isa(dstOp); + bool readReadAccesses = srcIsRead && dstIsRead; + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1, + /*dependenceConstraints=*/nullptr, + /*dependenceComponents=*/nullptr, readReadAccesses); + if (result.value == DependenceResult::Failure) { + LDBG() << "Dependency check failed"; + return FusionResult::FailPrecondition; + } + if (result.value == DependenceResult::NoDependence) + continue; + if (readReadAccesses) + continue; + auto dstVectorType = getAffineVectorType(dstOp); + if (srcVectorType && dstVectorType && + srcVectorType->getShape() != dstVectorType->getShape()) { + LDBG() << "Mismatching vector shapes between producer and consumer"; + return FusionResult::FailPrecondition; + } + } + } + // Compute union of computation slices computed between all pairs of ops // from 'forOpA' and 'forOpB'. SliceComputationResult sliceComputationResult = affine::computeSliceUnion( diff --git a/mlir/test/Dialect/Affine/loop-fusion-vector.mlir b/mlir/test/Dialect/Affine/loop-fusion-vector.mlir new file mode 100644 index 0000000000000..f5dd13c36f8d3 --- /dev/null +++ b/mlir/test/Dialect/Affine/loop-fusion-vector.mlir @@ -0,0 +1,97 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING + +// CHECK-LABEL: func.func @skip_fusing_mismatched_vectors +// CHECK: affine.for %{{.*}} = 0 to 8 { +// CHECK: affine.vector_store {{.*}} : memref<64x512xf32>, vector<64x64xf32> +// CHECK: } +// CHECK: affine.for %{{.*}} = 0 to 8 { +// CHECK: affine.vector_load {{.*}} : memref<64x512xf32>, vector<64x512xf32> +// CHECK: } +func.func @skip_fusing_mismatched_vectors(%a: memref<64x512xf32>, %b: memref<64x512xf32>, %c: memref<64x512xf32>, %d: memref<64x4096xf32>, %e: memref<64x4096xf32>) { + affine.for %j = 0 to 8 { + %lhs = affine.vector_load %a[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32> + %rhs = affine.vector_load %b[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32> + %res = arith.addf %lhs, %rhs : vector<64x64xf32> + affine.vector_store %res, %c[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32> + } + + affine.for %j = 0 to 8 { + %lhs = affine.vector_load %c[0, 0] : memref<64x512xf32>, vector<64x512xf32> + %rhs = affine.vector_load %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32> + %res = arith.subf %lhs, %rhs : vector<64x512xf32> + affine.vector_store %res, %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32> + } + return +} + +// ----- + +// CHECK-LABEL: func.func @vector_private_memref +// CHECK: memref.alloc() : memref<1x64xf32> +// CHECK-NOT: memref<1x1xf32> +// CHECK: affine.vector_store {{.*}} : memref<1x64xf32>, vector<64xf32> +func.func @vector_private_memref(%src: memref<10x64xf32>, %dst: memref<10x64xf32>) { + %tmp = memref.alloc() : memref<10x64xf32> + affine.for %i = 0 to 10 { + %vec = affine.vector_load %src[%i, 0] : memref<10x64xf32>, vector<64xf32> + affine.vector_store %vec, %tmp[%i, 0] : memref<10x64xf32>, vector<64xf32> + } + + affine.for %i = 0 to 10 { + %vec = affine.vector_load %tmp[%i, 0] : memref<10x64xf32>, vector<64xf32> + affine.vector_store %vec, %dst[%i, 0] : memref<10x64xf32>, vector<64xf32> + } + return +} + +// ----- + +// CHECK-LABEL: func.func @fuse_scalar_vector +// CHECK: %[[TMP:.*]] = memref.alloc() : memref<64xf32> +// CHECK: affine.for %[[I:.*]] = 0 to 16 { +// CHECK: %[[S0:.*]] = affine.load %[[SRC:.*]][%[[I]] * 4] : memref<64xf32> +// CHECK: affine.store %[[S0]], %[[TMP]][%[[I]] * 4] : memref<64xf32> +// CHECK: %[[V:.*]] = affine.vector_load %[[TMP]][%[[I]] * 4] : memref<64xf32>, vector<4xf32> +// CHECK: affine.vector_store %[[V]], %[[DST:.*]][%[[I]] * 4] : memref<64xf32>, vector<4xf32> +// CHECK: } +func.func @fuse_scalar_vector(%src: memref<64xf32>, %dst: memref<64xf32>) { + %tmp = memref.alloc() : memref<64xf32> + affine.for %i = 0 to 16 { + %s0 = affine.load %src[%i * 4] : memref<64xf32> + affine.store %s0, %tmp[%i * 4] : memref<64xf32> + %s1 = affine.load %src[%i * 4 + 1] : memref<64xf32> + affine.store %s1, %tmp[%i * 4 + 1] : memref<64xf32> + %s2 = affine.load %src[%i * 4 + 2] : memref<64xf32> + affine.store %s2, %tmp[%i * 4 + 2] : memref<64xf32> + %s3 = affine.load %src[%i * 4 + 3] : memref<64xf32> + affine.store %s3, %tmp[%i * 4 + 3] : memref<64xf32> + } + + affine.for %i = 0 to 16 { + %vec = affine.vector_load %tmp[%i * 4] : memref<64xf32>, vector<4xf32> + affine.vector_store %vec, %dst[%i * 4] : memref<64xf32>, vector<4xf32> + } + memref.dealloc %tmp : memref<64xf32> + return +} + +// ----- + +// SIBLING-LABEL: func.func @sibling_vector_mismatch +// SIBLING: affine.for %{{.*}} = 0 to 10 { +// SIBLING: affine.vector_load %{{.*}} : memref<10x16xf32>, vector<4xf32> +// SIBLING: affine.vector_load %{{.*}} : memref<10x16xf32>, vector<8xf32> +// SIBLING: affine.vector_load %{{.*}} : memref<10x16xf32>, vector<4xf32> +// SIBLING: } +func.func @sibling_vector_mismatch(%src: memref<10x16xf32>) { + affine.for %i = 0 to 10 { + %vec = affine.vector_load %src[%i, 0] : memref<10x16xf32>, vector<4xf32> + } + + affine.for %i = 0 to 10 { + %wide = affine.vector_load %src[%i, 8] : memref<10x16xf32>, vector<8xf32> + %vec = affine.vector_load %src[%i, 0] : memref<10x16xf32>, vector<4xf32> + } + return +}