Skip to content

Commit

Permalink
Support lowering of imperfectly nested loops into GPU dialect.
Browse files Browse the repository at this point in the history
The current lowering of loops to GPU only supports lowering of loop
nests where the loops mapped to workgroups and workitems are perfectly
nested. Here a new lowering is added to handle lowering of imperfectly
nested loop body with the following properties
1) The loops partitioned to workgroups are perfectly nested.
2) The loop body of the inner most loop partitioned to workgroups can
contain one or more loop nests that are to be partitioned across
workitems. Each individual loops nests partitioned to workitems should
also be perfectly nested.
3) The number of workgroups and workitems are not deduced from the
loop bounds but are passed in by the caller of the lowering as values.
4) For statements within the perfectly nested loop nest partitioned
across workgroups that are not loops, it is valid to have all threads
execute that statement. This is NOT verified.

PiperOrigin-RevId: 277958868
  • Loading branch information
Mahesh Ravishankar authored and tensorflower-gardener committed Nov 1, 2019
1 parent bd94a10 commit 9cbbd8f
Show file tree
Hide file tree
Showing 12 changed files with 691 additions and 37 deletions.
29 changes: 29 additions & 0 deletions mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h
Expand Up @@ -17,9 +17,12 @@
#ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_
#define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_

#include "mlir/Support/LLVM.h"

namespace mlir {
class AffineForOp;
struct LogicalResult;
class Value;

namespace loop {
class ForOp;
Expand Down Expand Up @@ -52,6 +55,32 @@ LogicalResult convertAffineLoopNestToGPULaunch(AffineForOp forOp,
LogicalResult convertLoopNestToGPULaunch(loop::ForOp forOp,
unsigned numBlockDims,
unsigned numThreadDims);

/// Convert a loop operation into a GPU launch with the values provided in
/// `numWorkGroups` as the grid size and the values provided in `workGroupSizes`
/// as the block size. Size of `numWorkGroups` and workGroupSizes` must be less
/// than or equal to 3. The loop operation can be an imperfectly nested
/// computation with the following restrictions:
/// 1) The loop nest must contain as many perfectly nested loops as the number
/// of values passed in through `numWorkGroups`. This corresponds to the number
/// of grid dimensions of the launch. All loops within the loop nest must be
/// parallel.
/// 2) The body of the innermost loop of the above perfectly nested loops, must
/// contain statements that satisfy one of the two conditions below:
/// a) A perfect loop nest of depth greater than or equal to the number of
/// values passed in through `workGroupSizes`, i.e. the number of thread
/// dimensions of the launch. Loops at depth less than or equal to size of
/// `workGroupSizes` must be parallel. Loops nested deeper can be sequential
/// and are retained as such in the generated GPU launch code.
/// b) Statements that are safe to be executed by all threads within the
/// workgroup. No checks are performed that this is indeed the case.
/// TODO(ravishankarm) : Add checks that verify 2(b) above.
/// The above conditions are assumed to be satisfied by the computation rooted
/// at `forOp`.
LogicalResult convertLoopToGPULaunch(loop::ForOp forOp,
ArrayRef<Value *> numWorkGroups,
ArrayRef<Value *> workGroupSizes);

} // namespace mlir

#endif // MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_
12 changes: 12 additions & 0 deletions mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h
Expand Up @@ -17,6 +17,8 @@
#ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_
#define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_

#include "mlir/Support/LLVM.h"

#include <memory>

namespace mlir {
Expand All @@ -33,6 +35,16 @@ template <typename T> class OpPassBase;
/// calling the conversion.
std::unique_ptr<OpPassBase<FuncOp>>
createSimpleLoopsToGPUPass(unsigned numBlockDims, unsigned numThreadDims);

/// Create a pass that converts every loop operation within the body of the
/// FuncOp into a GPU launch. The number of workgroups and workgroup size for
/// the implementation is controlled by SSA values passed into conversion
/// method. For testing, the values are set as constants obtained from a command
/// line flag. See convertLoopToGPULaunch for a description of the required
/// semantics of the converted loop operation.
std::unique_ptr<OpPassBase<FuncOp>>
createLoopToGPUPass(ArrayRef<int64_t> numWorkGroups,
ArrayRef<int64_t> workGroupSize);
} // namespace mlir

#endif // MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_
4 changes: 2 additions & 2 deletions mlir/include/mlir/Transforms/LoopUtils.h
Expand Up @@ -224,8 +224,8 @@ void coalesceLoops(MutableArrayRef<loop::ForOp> loops);
/// is rewritten into a version resembling the following pseudo-IR:
///
/// ```
/// loop.for %i = %lb + threadIdx.x + blockIdx.x * blockDim.x to %ub
/// step %gridDim.x * blockDim.x {
/// loop.for %i = %lb + %step * (threadIdx.x + blockIdx.x * blockDim.x)
/// to %ub step %gridDim.x * blockDim.x * %step {
/// ...
/// }
/// ```
Expand Down

0 comments on commit 9cbbd8f

Please sign in to comment.