Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 111 additions & 15 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ using SCFTileSizeComputationFunction =

/// Options to use to control tiling.
struct SCFTilingOptions {
/// Specify which loop construct to use for tile and fuse.
enum class LoopType { ForOp, ForallOp, CustomOp };
LoopType loopType = LoopType::ForOp;
SCFTilingOptions &setLoopType(LoopType type) {
loopType = type;
return *this;
}

/// Computation function that returns the tile sizes to use for each loop.
/// Returning a tile size of zero implies no tiling for that loop. If the
/// size of the returned vector is smaller than the number of loops, the inner
Expand All @@ -50,6 +58,17 @@ struct SCFTilingOptions {
/// proper interaction with folding.
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);

/// The interchange vector to reorder the tiled loops.
SmallVector<int64_t> interchangeVector = {};
SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
interchangeVector = llvm::to_vector(interchange);
return *this;
}

//-------------------------------------------------------------------------//
// Options related to tiling using `scf.forall`.
//-------------------------------------------------------------------------//

/// Computation function that returns the number of threads to use for
/// each loop. Returning a num threads of zero implies no tiling for that
/// loop. If the size of the returned vector is smaller than the number of
Expand All @@ -70,21 +89,6 @@ struct SCFTilingOptions {
/// function that computes num threads at the point they are needed.
SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);

/// The interchange vector to reorder the tiled loops.
SmallVector<int64_t> interchangeVector = {};
SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
interchangeVector = llvm::to_vector(interchange);
return *this;
}

/// Specify which loop construct to use for tile and fuse.
enum class LoopType { ForOp, ForallOp };
LoopType loopType = LoopType::ForOp;
SCFTilingOptions &setLoopType(LoopType type) {
loopType = type;
return *this;
}

/// Specify mapping of loops to devices. This is only respected when the loop
/// constructs support such a mapping (like `scf.forall`). Will be ignored
/// when using loop constructs that dont support such a mapping (like
Expand Down Expand Up @@ -117,6 +121,98 @@ struct SCFTilingOptions {
reductionDims.insert(dims.begin(), dims.end());
return *this;
}

//-------------------------------------------------------------------------//
// Options related to tiling using custom loop.
//-------------------------------------------------------------------------//

// For generating the inter-tile loops using a custom loop, two callback
// functions are needed
// 1. That generates the "loop header", i.e. the loop that iterates over the
// different tiles.
// 2. That generates the loop terminator
//
// For `scf.forall` case the call back to generate loop header would generate
//
// ```mlir
// scf.forall (...) = ... {
// ..
// }
// ```
//
// and the call back to generate the loop terminator would generate the
// `scf.in_parallel` region
//
// ```mlir
// scf.forall (...) = ... {
// scf.in_parallel {
// tensor.parallel_insert_slice ...
// }
// }
// ```
//

// Information that is to be returned by the callback to generate the loop
// header needed for the rest of the tiled codegeneration.
// - `loops`: The generated loops
// - `tileOffset`: The values that represent the offset of the iteration space
// tile
// - `tileSizes` : The values that represent the size of the iteration space
// tile.
// - `destinationTensors` : The tensors to use as destinations during tiling.
struct CustomLoopHeaderInfo {
SmallVector<LoopLikeOpInterface> loops;
SmallVector<OpFoldResult> tileOffset;
SmallVector<OpFoldResult> tileSizes;
SmallVector<Value> destinationTensors;
};

// Type of the callback function that generates the loop headers.
// - `loopRanges` : Values that represent the full size of the iteration space
// being tiled.
// - `giveTileSizes` : The tile sizes that are to be used to tile the
// iteration
// space.
// - `destinationTensors` : The tensors to use as destinations for the results
// of the tiled loop for loops that implement
// `DestinationStyleOpInterface`.
// Returns the `CustomLoopHeaderInfo` object (described above). it is expected
// that this function sets the insertion point of `rewriter` to the program
// point where the intra-tile loop computation is to be generated.
using GenerateLoopHeaderFn = std::function<FailureOr<CustomLoopHeaderInfo>(
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> givenTileSizes, ValueRange destinationTensors)>;

// Type of the callback function that generates the loop terminator.
// - `tiledResults` : Tiles of the result computed for the iteration space
// tile
// - `resultOffsets` : For each of the `tiledResults`, the offset at which
// the result tile is to be "inserted" back into the
// destination tensor.
// - `resultSizes` : For each of the `tiledResults`, the size of the result
// tile
// that is to be "inserted" back into the destination
// tensor.
// Returns the `CustomLoopHeaderInfo` object (described above)
using GenerateLoopTerminatorFn = std::function<LogicalResult(
RewriterBase &rewriter, Location loc, ValueRange tiledResults,
ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
ArrayRef<SmallVector<OpFoldResult>> resultSizes,
ValueRange destinationTensors)>;

// Callback function to generate the inter-tile loop header.
GenerateLoopHeaderFn generateLoopHeaderFn = nullptr;
// Callback function to generate the inter-tile loop terminator.
GenerateLoopTerminatorFn generateLoopTerminatorFn = nullptr;
// Helper function to set the callbacks for inter-tile loop header and
// terminator functions when using a custom operation for the loop.
SCFTilingOptions &
setCustomLoopGenerationFns(GenerateLoopHeaderFn headerFn,
GenerateLoopTerminatorFn terminatorFn) {
generateLoopHeaderFn = std::move(headerFn);
generateLoopTerminatorFn = std::move(terminatorFn);
return *this;
}
};

/// Transformation information returned after tiling.
Expand Down
Loading