Skip to content

Commit

Permalink
[CPU][NFC] Decouple distribution and vector tile size computation (#1…
Browse files Browse the repository at this point in the history
…2785)

Small step towards decoupling tile size computation for distribution and vectorization.
  • Loading branch information
dcaballe committed Mar 28, 2023
1 parent 6f24a2f commit c00e8b9
Showing 1 changed file with 71 additions and 20 deletions.
91 changes: 71 additions & 20 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,12 +450,59 @@ static int64_t roundUpToPow2(int64_t size, bool predicate) {
return llvm::PowerOf2Ceil(size);
}

/// Adjusts the workload per workgroup to be a multiple of vector size to ensure
/// that the op vectorizes.
static int64_t getMaxTileSize(int64_t lb, int64_t ub, int64_t maxSize,
int64_t vectorSize,
bool allowIncompleteTile = false,
bool enforcePowerOfTwo = false) {
/// Computes the maximum tile size that can be used to distribute a dimension
/// based on its number of iterations and the native vector size used of the
/// target. The resulting tile size will be a multiple of the provided vector
/// size, except when `allowIncompleteTile` is set to true.
static int64_t getMaxDistributionTileSize(int64_t lb, int64_t ub,
int64_t maxSize, int64_t vectorSize,
bool allowIncompleteTile = false) {
if (ub == ShapedType::kDynamic || lb == ShapedType::kDynamic) {
return maxSize;
}
int64_t numIters = ub - lb;
if (numIters <= maxSize && numIters < vectorSize) {
return numIters;
}

int64_t scaledUB = std::min(maxSize, numIters) / vectorSize * vectorSize;
for (int64_t i = scaledUB; i > 0; i -= vectorSize) {
if (numIters % i == 0) {
return i;
}
}
if (allowIncompleteTile) {
// Set bound to half to avoid too many workgroup.
int64_t start = std::min(maxSize, numIters);
int64_t end = start / 2;
for (int64_t i = start; i >= end; --i) {
if (numIters % i == 0) {
return i;
}
}
return maxSize;
}
// If it can't be a multiple of `vectorSize`, let's choose a factor of
// `numIters` sizes heuristically.
int64_t start = std::min(maxSize, numIters);
for (int64_t i = start; i > 0; --i) {
if (numIters % i == 0) {
return i;
}
}
return 1;
}

/// Computes the maximum tile size that can be used to vectorize (or unroll) a
/// dimension based on its number of iterations and the native vector size of
/// the target. The resulting tile size will be a multiple of the provided
/// vector size, except when `allowIncompleteTile` is set to true. If
/// `enforcePowerOfTwo` is set to true, the resulting tile size will be a power
/// of two.
static int64_t getMaxVectorTileSize(int64_t lb, int64_t ub, int64_t maxSize,
int64_t vectorSize,
bool allowIncompleteTile = false,
bool enforcePowerOfTwo = false) {
if (ub == ShapedType::kDynamic || lb == ShapedType::kDynamic) {
return roundUpToPow2(maxSize, enforcePowerOfTwo);
}
Expand All @@ -469,14 +516,16 @@ static int64_t getMaxTileSize(int64_t lb, int64_t ub, int64_t maxSize,
return roundUpToPow2(std::min(maxSize, numIters), enforcePowerOfTwo);
}

// Try to find a tile size that is multiple of the vector size.
int64_t scaledUB = std::min(maxSize, numIters) / vectorSize * vectorSize;
for (int64_t i = scaledUB; i > 0; i -= vectorSize) {
if (numIters % i == 0) {
return i;
}
}
if (allowIncompleteTile) {
// Set bound to half to avoid too many workgroup.
// Try to find a tile size that is not multiple of the vector size but
// multiple of the number of iterations. Otherwise, return `maxSize`.
int64_t start = std::min(maxSize, numIters);
int64_t end = start / 2;
for (int64_t i = start; i >= end; --i) {
Expand All @@ -486,6 +535,7 @@ static int64_t getMaxTileSize(int64_t lb, int64_t ub, int64_t maxSize,
}
return maxSize;
}

// If it can't be a multiple of `vectorSize`, let's choose a factor of
// `numIters` sizes heuristically.
int64_t start = std::min(maxSize, numIters);
Expand All @@ -494,6 +544,7 @@ static int64_t getMaxTileSize(int64_t lb, int64_t ub, int64_t maxSize,
return i;
}
}

return 1;
}

Expand Down Expand Up @@ -541,8 +592,8 @@ static SmallVector<int64_t> getDefaultDistributedLevelTileSizes(
for (auto i : llvm::seq<unsigned>(0, distributedTileSizes.size())) {
if (!distributedTileSizes[i]) continue;
distributedTileSizes[i] =
getMaxTileSize(lbs[i], ubs[i], distributedTileSizes[i], minTileSizes[i],
allowIncompleteTile);
getMaxDistributionTileSize(lbs[i], ubs[i], distributedTileSizes[i],
minTileSizes[i], allowIncompleteTile);
}
return distributedTileSizes;
}
Expand Down Expand Up @@ -709,7 +760,7 @@ static LogicalResult setMatmulPadRootConfig(
auto lhsShapedType = op.lhs().getType().cast<ShapedType>();
int64_t K = lhsShapedType.getShape().back();
reductionTileSizes.push_back(
getMaxTileSize(0, K, workgroupTileSizes.back(), vectorSize));
getMaxVectorTileSize(0, K, workgroupTileSizes.back(), vectorSize));

TileSizesListType tileSizes;
tileSizes.emplace_back(flowTileSizes.begin(), flowTileSizes.end());
Expand Down Expand Up @@ -759,7 +810,7 @@ static LogicalResult setMatmulNoPadRootConfig(
vecPreProcStrategy == VectorPreProcStrategy::Masking;

if (sz != 0) {
sz = getMaxTileSize(
sz = getMaxVectorTileSize(
/*lb=*/0, /*ub=*/shape[index],
/*maxTileSize=*/sz, vectorSize, allowIncompleteTile);
}
Expand Down Expand Up @@ -797,14 +848,14 @@ static LogicalResult setAArch64RootConfig(func::FuncOp entryPointFn,
auto shape = cast<linalg::LinalgOp>(op.getOperation()).getStaticLoopRanges();
for (auto [index, tileSize] : llvm::enumerate(flowTileSizes.drop_back())) {
parallelTileSizes.push_back(
getMaxTileSize(0, tileSize ? tileSize : shape[index],
workgroupTileSizes[index], vectorSize));
getMaxVectorTileSize(0, tileSize ? tileSize : shape[index],
workgroupTileSizes[index], vectorSize));
}

auto lhsShapedType = op.lhs().getType().cast<ShapedType>();
int64_t K = lhsShapedType.getShape().back();
parallelTileSizes.push_back(
getMaxTileSize(0, K, workgroupTileSizes.back(), vectorSize));
getMaxVectorTileSize(0, K, workgroupTileSizes.back(), vectorSize));

SmallVector<int64_t> reductionTileSizes;
splitParallelAndReductionTiles(op.getOperation(), parallelTileSizes,
Expand Down Expand Up @@ -1125,11 +1176,11 @@ static void setX86WorkgroupTileSizes(
SmallVector<int64_t, 4> staticLoopRanges = genericOp.getStaticLoopRanges();
for (auto loopNum : llvm::seq<unsigned>(0, numLoops)) {
if (flowTileSizes[loopNum]) {
workgroupTileSizes[loopNum] =
getMaxTileSize(0, flowTileSizes[loopNum], minTileSizes[loopNum],
minTileSizes[loopNum], /*allowIncompleteTile=*/false,
/*enforcePowerOfTwo=*/vecPreProcStrategy ==
VectorPreProcStrategy::Masking);
workgroupTileSizes[loopNum] = getMaxVectorTileSize(
0, flowTileSizes[loopNum], minTileSizes[loopNum],
minTileSizes[loopNum], /*allowIncompleteTile=*/false,
/*enforcePowerOfTwo=*/vecPreProcStrategy ==
VectorPreProcStrategy::Masking);
} else {
// If the flow level tile size is zero, and static loop range is 0 as
// well, set the tile sizes here to zero as well.
Expand Down Expand Up @@ -1498,7 +1549,7 @@ static LogicalResult setConvRootConfig(func::FuncOp entryPointFn,
// The ops will be decomposed to lower-rank named ops.
if (parallelTileSizes[i] != 1) {
parallelTileSizes[i] =
getMaxTileSize(0, tileSize, parallelTileSizes[i], vectorSize);
getMaxVectorTileSize(0, tileSize, parallelTileSizes[i], vectorSize);
}
}
SmallVector<int64_t> reductionTileSizes;
Expand Down

0 comments on commit c00e8b9

Please sign in to comment.