diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp index 78428fe1953c..8f57de53c499 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp @@ -552,53 +552,37 @@ static int64_t getMaxDistributionTileSize(int64_t lb, int64_t ub, } /// Computes the maximum tile size that can be used to vectorize (or unroll) a -/// dimension based on its number of iterations and the native vector size of -/// the target. The resulting tile size will be a multiple of the provided -/// vector size, except when `allowIncompleteTile` is set to true. If -/// `enforcePowerOfTwo` is set to true, the resulting tile size will be a power -/// of two. -static int64_t getMaxVectorTileSize(int64_t lb, int64_t ub, int64_t maxSize, +/// dimension based on its number of elements and the native vector size of +/// the target. If `enforcePowerOfTwo` is set to true, the resulting tile size +/// will be a power of two. +static int64_t getMaxVectorTileSize(int64_t numElem, int64_t tileSize, int64_t vectorSize, - bool allowIncompleteTile = false, bool enforcePowerOfTwo = false) { - if (ShapedType::isDynamic(ub) || ShapedType::isDynamic(lb)) { - return roundUpToPow2(maxSize, enforcePowerOfTwo); + if (ShapedType::isDynamic(numElem)) { + return roundUpToPow2(tileSize, enforcePowerOfTwo); } - int64_t numIters = ub - lb; - if (numIters <= maxSize && numIters < vectorSize) { - return roundUpToPow2(numIters, enforcePowerOfTwo); + if (numElem <= tileSize && numElem < vectorSize) { + return roundUpToPow2(numElem, enforcePowerOfTwo); } // Return the largest suitable power of two if power of two is enforced. if (enforcePowerOfTwo) { - return roundUpToPow2(std::min(maxSize, numIters), enforcePowerOfTwo); + return roundUpToPow2(std::min(tileSize, numElem), enforcePowerOfTwo); } // Try to find a tile size that is multiple of the vector size. - int64_t scaledUB = std::min(maxSize, numIters) / vectorSize * vectorSize; + int64_t scaledUB = std::min(tileSize, numElem) / vectorSize * vectorSize; for (int64_t i = scaledUB; i > 0; i -= vectorSize) { - if (numIters % i == 0) { + if (numElem % i == 0) { return i; } } - if (allowIncompleteTile) { - // Try to find a tile size that is not multiple of the vector size but - // multiple of the number of iterations. Otherwise, return `maxSize`. - int64_t start = std::min(maxSize, numIters); - int64_t end = start / 2; - for (int64_t i = start; i >= end; --i) { - if (numIters % i == 0) { - return i; - } - } - return maxSize; - } // If it can't be a multiple of `vectorSize`, let's choose a factor of - // `numIters` sizes heuristically. - int64_t start = std::min(maxSize, numIters); + // `numElem` sizes heuristically. + int64_t start = std::min(tileSize, numElem); for (int64_t i = start; i > 0; --i) { - if (numIters % i == 0) { + if (numElem % i == 0) { return i; } } @@ -814,7 +798,7 @@ static LogicalResult setMatmulPeelingRootConfig( auto lhsShapedType = llvm::cast(op.lhs().getType()); int64_t K = lhsShapedType.getShape().back(); reductionTileSizes.push_back( - getMaxVectorTileSize(0, K, vecTileSizes.back(), vectorSize)); + getMaxVectorTileSize(K, vecTileSizes.back(), vectorSize)); SmallVector cacheParallelTileSizes(cacheTileSizes.begin(), cacheTileSizes.end()); @@ -854,9 +838,6 @@ static LogicalResult setMatmulNoPadRootConfig( const SmallVectorImpl &vecScalableDims = inputScalableTileFlags.back(); SmallVector parallelTileSizes; SmallVector parallelScalableFlags; - bool allowIncompleteTile = - vecPreProcStrategy == VectorPreProcStrategy::Peeling || - vecPreProcStrategy == VectorPreProcStrategy::Masking; for (auto [index, tileSize] : llvm::enumerate(vecTileSizes)) { int64_t sz = tileSize; @@ -865,9 +846,8 @@ static LogicalResult setMatmulNoPadRootConfig( if (sz != 0) { sz = getMaxVectorTileSize( - /*lb=*/0, /*ub=*/shape[index], - /*maxTileSize=*/sz, vectorSize, allowIncompleteTile, - enforcePowerOfTwo); + /*numElem=*/shape[index], + /*tileSize=*/sz, vectorSize, enforcePowerOfTwo); } parallelTileSizes.push_back(sz); // 1x scalable vectors e.g. vector<[1]xty> are also poorly supported, so @@ -1399,8 +1379,7 @@ static void setX86VectorTileSizes(linalg::GenericOp genericOp, for (auto loopNum : llvm::seq(0, numLoops)) { if (distTileSizes[loopNum]) { vecTileSizes[loopNum] = getMaxVectorTileSize( - 0, distTileSizes[loopNum], minTileSizes[loopNum], - minTileSizes[loopNum], /*allowIncompleteTile=*/false, + distTileSizes[loopNum], minTileSizes[loopNum], minTileSizes[loopNum], /*enforcePowerOfTwo=*/vecPreProcStrategy == VectorPreProcStrategy::Masking); } else { @@ -1793,7 +1772,7 @@ static LogicalResult setConvRootConfig(func::FuncOp entryPointFn, // The ops will be decomposed to lower-rank named ops. if (parallelTileSizes[i] != 1) { parallelTileSizes[i] = - getMaxVectorTileSize(0, tileSize, parallelTileSizes[i], vectorSize); + getMaxVectorTileSize(tileSize, parallelTileSizes[i], vectorSize); } } SmallVector reductionTileSizes;