Skip to content

Commit

Permalink
[mlir][gpu] NFC change to pass threadID ops to rewriteOneForeachThrea…
Browse files Browse the repository at this point in the history
…dToGpuThreads

This allows user to give both the thread ids and dimension of the threads we want to distribute on.
This means we can use it to distribute on warps as well.

Reviewed By: harsh

Differential Revision: https://reviews.llvm.org/D143950
  • Loading branch information
ThomasRaoux committed Feb 14, 2023
1 parent 69373a5 commit 288ae0b
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
Expand Up @@ -366,7 +366,8 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
/// not supported. Dynamic block dim sizes are currently not supported.
static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
const SmallVectorImpl<int64_t> &globalBlockDims, bool syncAfterDistribute,
const SmallVectorImpl<int64_t> &globalBlockDims,
const SmallVectorImpl<Value> &threadOps, bool syncAfterDistribute,
std::optional<TransformOpInterface> transformOp,
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
// Step 0. Target-specific verifications. There is no good place to anchor
Expand Down Expand Up @@ -427,28 +428,26 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
// Step 3. Create the gpu.thread ops and map the induction variables to the
// newly created ops.
IndexType indexType = rewriter.getIndexType();
SmallVector<Value> threadOps{
rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x),
rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y),
rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z)};
// Replace ids of dimension size 1 by zero to simplify the IR.
SmallVector<Value> threadOpsUpdated(threadOps.begin(), threadOps.end());
assert(threadOps.size() == globalBlockDims.size());
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
for (size_t i : llvm::seq(size_t(0), globalBlockDims.size())) {
if (globalBlockDims[i] == 1)
threadOps[i] = zero;
threadOpsUpdated[i] = zero;
}
IRMapping bvm;
for (auto [blockIdx, blockDim] :
llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) {
bvm.map(
blockIdx,
threadOps[blockDim.cast<DeviceMappingAttrInterface>().getMappingId()]);
bvm.map(blockIdx,
threadOpsUpdated[blockDim.cast<DeviceMappingAttrInterface>()
.getMappingId()]);
}

// Step 4. Maybe create conditionals to predicate the region.
Value predicate;
for (auto [threadId, blockDim, globalBlockDim] :
llvm::zip(threadOps, blockDims, globalBlockDims)) {
llvm::zip(threadOpsUpdated, blockDims, globalBlockDims)) {
if (blockDim > globalBlockDim) {
return failureHelper(
"The requested GPU threads are fewer than the number of loop trip "
Expand Down Expand Up @@ -519,9 +518,17 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
foreachThreadOp.getMapping(), transformOp);
if (diag.succeeded()) {
rewriter.setInsertionPoint(foreachThreadOp);
IndexType indexType = rewriter.getIndexType();
SmallVector<Value> threadOps{
rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
Dimension::x),
rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
Dimension::y),
rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
Dimension::z)};
diag = rewriteOneForeachThreadToGpuThreads(
rewriter, foreachThreadOp, blockDim, syncAfterDistribute, transformOp,
threadMappingAttributes);
rewriter, foreachThreadOp, blockDim, threadOps, syncAfterDistribute,
transformOp, threadMappingAttributes);
}
return diag.succeeded() ? WalkResult::advance() : WalkResult::interrupt();
});
Expand Down

0 comments on commit 288ae0b

Please sign in to comment.