Skip to content

Commit

Permalink
[CPU] Remove unnecessary factors from getMaxVectorTileSize. (#15843)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhanW committed Dec 12, 2023
1 parent 5a4e764 commit 2529fb3
Showing 1 changed file with 19 additions and 40 deletions.
59 changes: 19 additions & 40 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,53 +552,37 @@ static int64_t getMaxDistributionTileSize(int64_t lb, int64_t ub,
}

/// Computes the maximum tile size that can be used to vectorize (or unroll) a
/// dimension based on its number of iterations and the native vector size of
/// the target. The resulting tile size will be a multiple of the provided
/// vector size, except when `allowIncompleteTile` is set to true. If
/// `enforcePowerOfTwo` is set to true, the resulting tile size will be a power
/// of two.
static int64_t getMaxVectorTileSize(int64_t lb, int64_t ub, int64_t maxSize,
/// dimension based on its number of elements and the native vector size of
/// the target. If `enforcePowerOfTwo` is set to true, the resulting tile size
/// will be a power of two.
static int64_t getMaxVectorTileSize(int64_t numElem, int64_t tileSize,
int64_t vectorSize,
bool allowIncompleteTile = false,
bool enforcePowerOfTwo = false) {
if (ShapedType::isDynamic(ub) || ShapedType::isDynamic(lb)) {
return roundUpToPow2(maxSize, enforcePowerOfTwo);
if (ShapedType::isDynamic(numElem)) {
return roundUpToPow2(tileSize, enforcePowerOfTwo);
}
int64_t numIters = ub - lb;
if (numIters <= maxSize && numIters < vectorSize) {
return roundUpToPow2(numIters, enforcePowerOfTwo);
if (numElem <= tileSize && numElem < vectorSize) {
return roundUpToPow2(numElem, enforcePowerOfTwo);
}

// Return the largest suitable power of two if power of two is enforced.
if (enforcePowerOfTwo) {
return roundUpToPow2(std::min(maxSize, numIters), enforcePowerOfTwo);
return roundUpToPow2(std::min(tileSize, numElem), enforcePowerOfTwo);
}

// Try to find a tile size that is multiple of the vector size.
int64_t scaledUB = std::min(maxSize, numIters) / vectorSize * vectorSize;
int64_t scaledUB = std::min(tileSize, numElem) / vectorSize * vectorSize;
for (int64_t i = scaledUB; i > 0; i -= vectorSize) {
if (numIters % i == 0) {
if (numElem % i == 0) {
return i;
}
}
if (allowIncompleteTile) {
// Try to find a tile size that is not multiple of the vector size but
// multiple of the number of iterations. Otherwise, return `maxSize`.
int64_t start = std::min(maxSize, numIters);
int64_t end = start / 2;
for (int64_t i = start; i >= end; --i) {
if (numIters % i == 0) {
return i;
}
}
return maxSize;
}

// If it can't be a multiple of `vectorSize`, let's choose a factor of
// `numIters` sizes heuristically.
int64_t start = std::min(maxSize, numIters);
// `numElem` sizes heuristically.
int64_t start = std::min(tileSize, numElem);
for (int64_t i = start; i > 0; --i) {
if (numIters % i == 0) {
if (numElem % i == 0) {
return i;
}
}
Expand Down Expand Up @@ -814,7 +798,7 @@ static LogicalResult setMatmulPeelingRootConfig(
auto lhsShapedType = llvm::cast<ShapedType>(op.lhs().getType());
int64_t K = lhsShapedType.getShape().back();
reductionTileSizes.push_back(
getMaxVectorTileSize(0, K, vecTileSizes.back(), vectorSize));
getMaxVectorTileSize(K, vecTileSizes.back(), vectorSize));

SmallVector<int64_t> cacheParallelTileSizes(cacheTileSizes.begin(),
cacheTileSizes.end());
Expand Down Expand Up @@ -854,9 +838,6 @@ static LogicalResult setMatmulNoPadRootConfig(
const SmallVectorImpl<bool> &vecScalableDims = inputScalableTileFlags.back();
SmallVector<int64_t> parallelTileSizes;
SmallVector<bool> parallelScalableFlags;
bool allowIncompleteTile =
vecPreProcStrategy == VectorPreProcStrategy::Peeling ||
vecPreProcStrategy == VectorPreProcStrategy::Masking;

for (auto [index, tileSize] : llvm::enumerate(vecTileSizes)) {
int64_t sz = tileSize;
Expand All @@ -865,9 +846,8 @@ static LogicalResult setMatmulNoPadRootConfig(

if (sz != 0) {
sz = getMaxVectorTileSize(
/*lb=*/0, /*ub=*/shape[index],
/*maxTileSize=*/sz, vectorSize, allowIncompleteTile,
enforcePowerOfTwo);
/*numElem=*/shape[index],
/*tileSize=*/sz, vectorSize, enforcePowerOfTwo);
}
parallelTileSizes.push_back(sz);
// 1x scalable vectors e.g. vector<[1]xty> are also poorly supported, so
Expand Down Expand Up @@ -1399,8 +1379,7 @@ static void setX86VectorTileSizes(linalg::GenericOp genericOp,
for (auto loopNum : llvm::seq<unsigned>(0, numLoops)) {
if (distTileSizes[loopNum]) {
vecTileSizes[loopNum] = getMaxVectorTileSize(
0, distTileSizes[loopNum], minTileSizes[loopNum],
minTileSizes[loopNum], /*allowIncompleteTile=*/false,
distTileSizes[loopNum], minTileSizes[loopNum], minTileSizes[loopNum],
/*enforcePowerOfTwo=*/vecPreProcStrategy ==
VectorPreProcStrategy::Masking);
} else {
Expand Down Expand Up @@ -1793,7 +1772,7 @@ static LogicalResult setConvRootConfig(func::FuncOp entryPointFn,
// The ops will be decomposed to lower-rank named ops.
if (parallelTileSizes[i] != 1) {
parallelTileSizes[i] =
getMaxVectorTileSize(0, tileSize, parallelTileSizes[i], vectorSize);
getMaxVectorTileSize(tileSize, parallelTileSizes[i], vectorSize);
}
}
SmallVector<int64_t> reductionTileSizes;
Expand Down

0 comments on commit 2529fb3

Please sign in to comment.