Skip to content

Commit

Permalink
[mlir:Async] Change async-parallel-for block size/count calculation
Browse files Browse the repository at this point in the history
Depends On D105037

Avoid creating too many tasks when the number of workers is large.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D105126
  • Loading branch information
ezhulenev committed Jun 29, 2021
1 parent f57b242 commit c1194c2
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
Expand Up @@ -653,9 +653,19 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
for (size_t i = 1; i < tripCounts.size(); ++i)
tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);

// With large number of threads the value of creating many compute blocks
// is reduced because the problem typically becomes memory bound. For small
// number of threads it helps with stragglers.
float overshardingFactor = numWorkerThreads <= 4 ? 8.0
: numWorkerThreads <= 8 ? 4.0
: numWorkerThreads <= 16 ? 2.0
: numWorkerThreads <= 32 ? 1.0
: numWorkerThreads <= 64 ? 0.8
: 0.6;

// Do not overload worker threads with too many compute blocks.
Value maxComputeBlocks =
b.create<ConstantIndexOp>(numWorkerThreads * kMaxOversharding);
Value maxComputeBlocks = b.create<ConstantIndexOp>(
std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));

// Target block size from the pass parameters.
Value targetComputeBlockSize = b.create<ConstantIndexOp>(targetBlockSize);
Expand All @@ -668,7 +678,11 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlockSize);
Value bs2 = b.create<SelectOp>(bs1, bs0, targetComputeBlockSize);
Value bs3 = b.create<CmpIOp>(CmpIPredicate::sle, tripCount, bs2);
Value blockSize = b.create<SelectOp>(bs3, tripCount, bs2);
Value blockSize0 = b.create<SelectOp>(bs3, tripCount, bs2);
Value blockCount0 = b.create<SignedCeilDivIOp>(tripCount, blockSize0);

// Compute balanced block size for the estimated block count.
Value blockSize = b.create<SignedCeilDivIOp>(tripCount, blockCount0);
Value blockCount = b.create<SignedCeilDivIOp>(tripCount, blockSize);

// Create a parallel compute function that takes a block id and computes the
Expand Down

0 comments on commit c1194c2

Please sign in to comment.