Skip to content

Commit

Permalink
[mlir] Introduce device mapper attribute for thread_dim_map and `ma…
Browse files Browse the repository at this point in the history
…pped to dims`

`scf.foreach_thread` defines mapping its loops to processors via an integer array, see an example below. A lowering can use this mapping. However, expressing mapping as an integer array is very confusing, especially when there are multiple levels of parallelism. In addition, the op does not verify the integer array. This change introduces device mapping attribute to make mapping descriptive and verifiable. Then it makes GPU transform dialect use it.

```
scf.foreach_thread (%i, %j) in (%c1, %c2) {
	scf.foreach_thread (%i2, %j2) in (%c1, %c2)
	{...} { thread_dim_mapping = [0, 1]}
} { thread_dim_mapping = [0, 1]}
```

It first introduces a `DeviceMappingInterface` which is an attribute interface. `scf.foreach_thread` defines its mapping via this interface. A lowering must define its attributes and implement this interface as well. This way gives us a clear validation.

The change also introduces two new attributes (`#gpu.thread<x/y/z>` and `#gpu.block<x,y,z>` ). After this change, the above code prints as below, as seen here, this way clarifies the loop mappings. The change also implements consuming of these two new attribute by the transform dialect. Transform dialect binds the outermost loops to the thread blocks and innermost loops to threads.

```
scf.foreach_thread (%i, %j) in (%c1, %c2) {
	scf.foreach_thread (%i2, %j2) in (%c1, %c2)
	{...} { thread_dim_mapping = [#gpu.thread<x>, #gpu.thread<y>]}
} { thread_dim_mapping = [#gpu.block<x>, #gpu.block<y>]}
```

Reviewed By: ftynse, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D137413
  • Loading branch information
grypp committed Nov 11, 2022
1 parent 99d3ead commit 6663f34
Show file tree
Hide file tree
Showing 29 changed files with 398 additions and 140 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
Expand Up @@ -172,6 +172,8 @@ void addAsyncDependency(Operation *op, Value token);

#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.h.inc"

#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.h.inc"

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Expand Up @@ -16,6 +16,7 @@
include "mlir/Dialect/DLTI/DLTIBase.td"
include "mlir/Dialect/GPU/IR/GPUBase.td"
include "mlir/Dialect/GPU/IR/ParallelLoopMapperAttr.td"
include "mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt
Expand Up @@ -4,3 +4,8 @@ mlir_tablegen(GPUTransformOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRGPUTransformOpsIncGen)

add_mlir_doc(GPUTransformOps GPUTransformOps Dialects/ -gen-op-doc)

set(LLVM_TARGET_DEFINITIONS GPUDeviceMappingAttr.td)
mlir_tablegen(GPUDeviceMapperEnums.h.inc -gen-enum-decls)
mlir_tablegen(GPUDeviceMapperEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRGPUDeviceMapperEnumsGen)
65 changes: 65 additions & 0 deletions mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
@@ -0,0 +1,65 @@
//===-- GPUDeviceMappingAttr.td - Attribute definition -----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the attribute used to map loops to gpu.
//
//===----------------------------------------------------------------------===//

#ifndef GPU_DEVICE_MAPPING_ATTR
#define GPU_DEVICE_MAPPING_ATTR

include "mlir/Dialect/GPU/IR/GPUBase.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"

def DimX : I64EnumAttrCase<"DimX", 0, "x">;
def DimY : I64EnumAttrCase<"DimY", 1, "y">;
def DimZ : I64EnumAttrCase<"DimZ", 2, "z">;

def ThreadsEnum : I64EnumAttr<"Threads", "threads for loop mapping", [
DimX, DimY, DimZ]> {
let cppNamespace = "::mlir::gpu";
}

def GPUThreadMappingAttr
: GPU_Attr<"GPUThreadMapping", "thread", [ DeviceMappingAttrInterface ]> {
let parameters = (ins
EnumParameter<ThreadsEnum>:$thread
);
let assemblyFormat = "`<` params `>`";
let description = [{
An attribute that allows defining thread parallelism for GPU devices.

Thread (aka work item) are grouped into a thread blocks where block may be
described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates
that thread parallelism is desired. It can be consumed by lowering to
generate GPU.
}];
}

def BlocksEnum : I64EnumAttr<"Blocks", "threads for loop mapping", [
DimX, DimY, DimZ]> {
let cppNamespace = "::mlir::gpu";
}

def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [ DeviceMappingAttrInterface ] > {
let parameters = (ins
EnumParameter<BlocksEnum>:$block
);
let assemblyFormat = "`<` params `>`";
let description = [{
An attribute that allows defining thread block parallelism for GPU devices.

Thread blocks (aka work-group) are grouped into a grid where grid may be
described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates
that thread block parallelism is desired. It can be consumed by lowering to
generate GPU code.
}];
}

#endif // GPU_DEVICE_MAPPING_ATTR
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
Expand Up @@ -29,7 +29,7 @@ def MapNestedForeachToThreads :
The operation searches for `scf.foreach_thread` ops nested under `target`
and maps each such op to GPU threads. Mapping is one-to-one and the
induction variables of `scf.foreach_thread` are rewritten to
`gpu.thread_id` according to the `thread_dim_mapping` attribute.
`gpu.thread_id` according to the `mapping` attribute.

Sibling `scf.foreach_thread` are supported in which case, the union of
the number of threads is computed and may result in predication.
Expand Down Expand Up @@ -73,10 +73,10 @@ def MapNestedForeachToThreads :
threads(%tx, %ty, %tz) in (%tx = %3, %ty = %4, %tz = %5) {
scf.foreach_thread (%i, %j) in (7, 9) {
... // body 1
} {thread_dim_mapping = [1, 0, 2]}
} {mapping = [#gpu.thread<x>, #gpu.thread<y>, #gpu.thread<z>]}
scf.foreach_thread (%i) in (12) {
... // body 2
}
} {mapping = [#gpu.thread<x>]}
gpu.terminator
}
```
Expand Down
Expand Up @@ -47,7 +47,7 @@ DiagnosedSilenceableFailure tileToForeachThreadOpImpl(
RewriterBase &rewriter, transform::TransformState &state,
TransformOpInterface transformOp, ArrayRef<Operation *> targets,
ArrayRef<OpFoldResult> mixedNumThreads,
ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> threadDimMapping,
ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> mapping,
SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps);
} // namespace transform

Expand Down
Expand Up @@ -13,6 +13,7 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
Expand Down Expand Up @@ -792,7 +793,7 @@ def TileToForeachThreadOp :
a valid tiling specification (i.e. that only tiles parallel dimensions,
e.g. in the Linalg case).

If non-empty, the `thread_dim_mapping` is added as an attribute to the
If non-empty, the `mapping` is added as an attribute to the
resulting `scf.foreach_thread`.

#### Return modes
Expand Down Expand Up @@ -832,7 +833,7 @@ def TileToForeachThreadOp :
Variadic<PDL_Operation>:$tile_sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$static_num_threads,
DefaultValuedAttr<I64ArrayAttr, "{}">:$static_tile_sizes,
OptionalAttr<I64ArrayAttr>:$thread_dim_mapping);
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
let results = (outs PDL_Operation:$foreach_thread_op,
PDL_Operation:$tiled_op);

Expand All @@ -841,22 +842,22 @@ def TileToForeachThreadOp :
"ArrayRef<int64_t>":$staticTileSizes,
CArg<"::mlir::transform::TileSizesSpec",
"::mlir::transform::TileSizesSpec()">,
CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
OpBuilder<(ins "Value":$target,
"ArrayRef<OpFoldResult>":$mixedTileSizes,
CArg<"::mlir::transform::TileSizesSpec",
"::mlir::transform::TileSizesSpec()">,
CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
OpBuilder<(ins "Value":$target,
"ArrayRef<int64_t>":$staticNumThreads,
CArg<"::mlir::transform::NumThreadsSpec",
"::mlir::transform::NumThreadsSpec()">,
CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
OpBuilder<(ins "Value":$target,
"ArrayRef<OpFoldResult>":$mixedNumThreads,
CArg<"::mlir::transform::NumThreadsSpec",
"::mlir::transform::NumThreadsSpec()">,
CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
];

let assemblyFormat = [{
Expand All @@ -867,7 +868,7 @@ def TileToForeachThreadOp :
`tile_sizes` custom<DynamicIndexList>($tile_sizes,
$static_tile_sizes,
"ShapedType::kDynamicSize"))
(`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict
(`(` `mapping` `=` $mapping^ `)`)? attr-dict
}];
let hasVerifier = 1;

Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Expand Up @@ -423,7 +423,7 @@ computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,

/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`, applying
/// tiling by `numThreads`.
/// If non-empty, the `threadDimMapping` is added as an attribute to the
/// If non-empty, the `mapping` is added as an attribute to the
/// resulting `scf.foreach_thread`.
/// Zero tile sizes indicate that the dimension is not tiled, and can be
/// thought of as tiling by the full size of data. It is the user's
Expand All @@ -436,14 +436,14 @@ struct ForeachThreadTilingResult {
FailureOr<ForeachThreadTilingResult>
tileToForeachThreadOp(RewriterBase &builder, TilingInterface op,
ArrayRef<OpFoldResult> numThreads,
ArrayRef<int64_t> threadDimMapping = {});
Optional<ArrayAttr> mapping);

/// Same as `tileToForeachThreadOp`, but calculate the number of threads
/// required using the given tileSizes.
FailureOr<ForeachThreadTilingResult>
tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
ArrayRef<OpFoldResult> tileSizes,
ArrayRef<int64_t> threadDimMapping = {});
Optional<ArrayAttr> mapping);

/// All indices returned by IndexOp should be invariant with respect to
/// tiling. Therefore, if an operation is tiled, we have to transform the
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/SCF/IR/CMakeLists.txt
@@ -1,3 +1,10 @@
add_mlir_dialect(SCFOps scf Ops)
add_mlir_doc(SCFOps SCFDialect Dialects/ -gen-dialect-doc)

set(LLVM_TARGET_DEFINITIONS DeviceMappingInterface.td)
mlir_tablegen(DeviceMappingAttrInterface.h.inc -gen-attr-interface-decls)
mlir_tablegen(DeviceMappingAttrInterface.cpp.inc -gen-attr-interface-defs)
mlir_tablegen(DeviceMappingAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(DeviceMappingAttributes.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRDeviceMappingInterfacesIncGen)
add_dependencies(mlir-generic-headers MLIRDeviceMappingInterfacesIncGen)
22 changes: 22 additions & 0 deletions mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.h
@@ -0,0 +1,22 @@
//===- DeviceMappingInterface.h - -------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the definitions of the device mapping interface defined in
// `DeviceMappingInterface.td`.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DEVICEMAPPINGINTERFACE_H
#define MLIR_DEVICEMAPPINGINTERFACE_H

#include "mlir/IR/OpDefinition.h"

/// Include the generated interface declarations.
#include "mlir/Dialect/SCF/IR/DeviceMappingAttrInterface.h.inc"

#endif // MLIR_DEVICEMAPPINGINTERFACE_H
43 changes: 43 additions & 0 deletions mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
@@ -0,0 +1,43 @@
//===- DeviceMappingInterface.td - Device mapping interfaces*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the interfaces for the device mapping specification for the loops.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DEVICEMAPPINGINTERFACE
#define MLIR_DEVICEMAPPINGINTERFACE

include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
// Attribute interfaces
//===----------------------------------------------------------------------===//

def DeviceMappingAttrInterface : AttrInterface<"DeviceMappingAttrInterface"> {
let cppNamespace = "::mlir";
let description = [{
Attribute interface describing how to map a region to a processing unit.

It is intended to be a generic mechanism for binding regions to execution
units of an actual or virtual device. Each device first expresses its own
mappings, and those mappings must implement this interface. These mappings
can be used by the device-specific code generators and the desired regions
can be connected to the given processing unit.

Currently, `scf.foreach_thread` uses this interface to express the mapping
of the loops it contains to the GPU's parallelism units such as threads and
thread blocks.
}];
}

def DeviceMappingArrayAttr :
TypedArrayAttrBase<DeviceMappingAttrInterface,
"Device Mapping array attribute"> { }

#endif // MLIR_DEVICEMAPPINGINTERFACE
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/SCF/IR/SCF.h
Expand Up @@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_SCF_SCF_H
#define MLIR_DIALECT_SCF_SCF_H

#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/RegionKindInterface.h"
Expand Down
40 changes: 21 additions & 19 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Expand Up @@ -16,6 +16,7 @@
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
Expand Down Expand Up @@ -378,14 +379,14 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
application per thread. Further lowerings are responsible for specifying
how this is materialized on concrete hardware resources.

An optional thread_dim_mapping index array attribute specifies for each
virtual thread dimension, how it remaps 1-1 to a set of concrete processing
An optional `mapping` is an attribute array that specifies processing units
with their dimension, how it remaps 1-1 to a set of concrete processing
element resources (e.g. a CUDA grid dimension or a level of concrete nested
async parallelism). At this time, the specification is backend-dependent and
is not verified by the op, beyond being an index array attribute.
It is the reponsibility of the lowering to interpret the index array in the
context of the concrete target the op is lowered to, or to ignore it when
the specification is ill-formed or unsupported for a particular target.
async parallelism). It is expressed via any attribute that implements the
device mapping interface. It is the reponsibility of the lowering mechanism
to interpret the `mapping` attributes in the context of the concrete target
the op is lowered to, or to ignore it when the specification is ill-formed
or unsupported for a particular target.

The only allowed terminator is `scf.foreach_thread.perform_concurrently`.
`scf.foreach_thread` returns one value per `shared_out` operand. The
Expand Down Expand Up @@ -440,11 +441,12 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
//
```

Example with thread_dim_mapping attribute:
Example with mapping attribute:

```mlir
//
// Sequential context.
// Sequential context. Here `mapping` is expressed as GPU thread mapping
// attributes
//
%matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in
(%num_threads_1, %numthread_id_2) shared_outs(...)
Expand All @@ -456,7 +458,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
scf.foreach_thread.perform_concurrently {
...
}
} { thread_dim_mapping = [1, 0] }
} { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
// Implicit synchronization point.
// Sequential context.
//
Expand All @@ -480,7 +482,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
}];
let arguments = (ins Variadic<Index>:$num_threads,
Variadic<AnyRankedTensor>:$outputs,
DefaultValuedAttr<I64ArrayAttr, "{}">:$thread_dim_mapping);
OptionalAttr<DeviceMappingArrayAttr>:$mapping);

let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
Expand All @@ -493,10 +495,10 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
let builders = [
// Bodyless builder, outputs must be specified.
OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads,
CArg<"ArrayRef<int64_t>", "{}">:$thread_dim_mapping)>,
"Optional<ArrayAttr>":$mapping)>,
// Builder that takes a bodyBuilder lambda.
OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads,
"ArrayRef<int64_t>":$thread_dim_mapping,
"ArrayRef<Attribute>":$mapping,
"function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
];
let extraClassDeclaration = [{
Expand Down Expand Up @@ -535,14 +537,14 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
}

/// Return the thread indices in the order specified by the
/// thread_dim_mapping attribute. Return failure is
/// thread_dim_mapping is not a valid permutation.
FailureOr<SmallVector<Value>> getPermutedThreadIndices();
/// given mapping argument. Return failure is
/// mapping is not a valid permutation.
FailureOr<SmallVector<Value>> getPermutedThreadIndices(ArrayRef<int64_t> mapping);

/// Return the number of threads in the order specified by the
/// thread_dim_mapping attribute.
/// Return failure is thread_dim_mapping is not a valid permutation.
FailureOr<SmallVector<OpFoldResult>> getPermutedNumThreads(OpBuilder &b);
/// given mapping argument.
/// Return failure is mapping is not a valid permutation.
FailureOr<SmallVector<OpFoldResult>> getPermutedNumThreads(OpBuilder &b, ArrayRef<int64_t> mapping);

// The ensureTerminator method generated by SingleBlockImplicitTerminator is
// unaware of the fact that our terminator also needs a region to be
Expand Down

0 comments on commit 6663f34

Please sign in to comment.