Skip to content

Commit

Permalink
[mlir][GPUTransforms] NFC - Refactor GPUTransforms.cpp in preparation…
Browse files Browse the repository at this point in the history
… for improvements.

Depends on: D145977

Differential Revision: https://reviews.llvm.org/D145980
  • Loading branch information
nicolasvasilache committed Mar 14, 2023
1 parent 710983a commit aafb52d
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 224 deletions.
26 changes: 13 additions & 13 deletions mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ class DialectRegistry;
namespace transform {
namespace gpu {

/// Map the top level `scf.forall` op to GPU Thread Blocks.
/// Mapping is one-to-one and the induction variables of `scf.forall` are
/// rewritten to gpu.block_id according to the thread_dim_apping attribute.
/// Dynamic, `scf.forall` trip counts are currently not supported.
/// Dynamic block dim sizes are currently not supported.
DiagnosedSilenceableFailure mapForallToBlocksImpl(
RewriterBase &rewriter, scf::ForallOp forallOp,
function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
blockIdGenerator,
SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp,
const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes);

/// Search `scf.forall` ops nested under `target` and map each such op to GPU
/// threads. Mapping is one-to-one and the induction variables of `scf.forall`
/// are rewritten to gpu.thread_id according to the thread_dim_mapping
Expand All @@ -43,24 +55,12 @@ namespace gpu {
/// Dynamic block dim sizes are currently not supported.
DiagnosedSilenceableFailure mapNestedForallToThreadsImpl(
RewriterBase &rewriter, Operation *target,
const SmallVectorImpl<int64_t> &blockDim,
const SmallVectorImpl<int64_t> &kernelBlockDims,
function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
threadIdGenerator,
bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes);

/// Map the top level `scf.forall` op to GPU Thread Blocks.
/// Mapping is one-to-one and the induction variables of `scf.forall` are
/// rewritten to gpu.block_id according to the thread_dim_apping attribute.
/// Dynamic, `scf.forall` trip counts are currently not supported.
/// Dynamic block dim sizes are currently not supported.
DiagnosedSilenceableFailure mapForallToBlocksImpl(
RewriterBase &rewriter, scf::ForallOp forallOp,
function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
blockIdGenerator,
SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp,
const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes);

/// Find the unique top level scf::ForallOp within a given target op.
DiagnosedSilenceableFailure
findTopLevelForallOp(Operation *target, scf::ForallOp &topLevelForallOp,
Expand Down
9 changes: 0 additions & 9 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -652,15 +652,6 @@ def ForallOp : SCF_Op<"forall", [
/// Checks if the lbs are zeros and steps are ones.
bool isNormalized();

/// Helper to sort `values` according to matching `keys`.
/// Take a custom `compare` binary comparator which returns true if the first
/// element is smaller than the second (i.e. compatible with std::sort).
/// This is a helper typically used to sort numThreads values before they are
/// mapped to concrete physical dimensions of hardware.
static SmallVector<Value> getValuesSortedByKey(
ArrayRef<Attribute> keys, ValueRange values,
llvm::function_ref<bool(Attribute, Attribute)> compare);

// The ensureTerminator method generated by SingleBlockImplicitTerminator is
// unaware of the fact that our terminator also needs a region to be
// well-formed. We override it here to ensure that we do the right thing.
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ std::pair<ArrayAttr, SmallVector<Value>>
decomposeMixedValues(Builder &b,
const SmallVectorImpl<OpFoldResult> &mixedValues);

/// Helper to sort `values` according to matching `keys`.
SmallVector<Value>
getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
llvm::function_ref<bool(Attribute, Attribute)> compare);
SmallVector<OpFoldResult>
getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
llvm::function_ref<bool(Attribute, Attribute)> compare);

} // namespace mlir

#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
Loading

0 comments on commit aafb52d

Please sign in to comment.