@@ -33,6 +33,14 @@ using SCFTileSizeComputationFunction =
3333
3434// / Options to use to control tiling.
3535struct SCFTilingOptions {
36+ // / Specify which loop construct to use for tile and fuse.
37+ enum class LoopType { ForOp, ForallOp, CustomOp };
38+ LoopType loopType = LoopType::ForOp;
39+ SCFTilingOptions &setLoopType (LoopType type) {
40+ loopType = type;
41+ return *this ;
42+ }
43+
3644 // / Computation function that returns the tile sizes to use for each loop.
3745 // / Returning a tile size of zero implies no tiling for that loop. If the
3846 // / size of the returned vector is smaller than the number of loops, the inner
@@ -50,6 +58,17 @@ struct SCFTilingOptions {
5058 // / proper interaction with folding.
5159 SCFTilingOptions &setTileSizes (ArrayRef<OpFoldResult> tileSizes);
5260
61+ // / The interchange vector to reorder the tiled loops.
62+ SmallVector<int64_t > interchangeVector = {};
63+ SCFTilingOptions &setInterchange (ArrayRef<int64_t > interchange) {
64+ interchangeVector = llvm::to_vector (interchange);
65+ return *this ;
66+ }
67+
68+ // -------------------------------------------------------------------------//
69+ // Options related to tiling using `scf.forall`.
70+ // -------------------------------------------------------------------------//
71+
5372 // / Computation function that returns the number of threads to use for
5473 // / each loop. Returning a num threads of zero implies no tiling for that
5574 // / loop. If the size of the returned vector is smaller than the number of
@@ -70,21 +89,6 @@ struct SCFTilingOptions {
7089 // / function that computes num threads at the point they are needed.
7190 SCFTilingOptions &setNumThreads (ArrayRef<OpFoldResult> numThreads);
7291
73- // / The interchange vector to reorder the tiled loops.
74- SmallVector<int64_t > interchangeVector = {};
75- SCFTilingOptions &setInterchange (ArrayRef<int64_t > interchange) {
76- interchangeVector = llvm::to_vector (interchange);
77- return *this ;
78- }
79-
80- // / Specify which loop construct to use for tile and fuse.
81- enum class LoopType { ForOp, ForallOp };
82- LoopType loopType = LoopType::ForOp;
83- SCFTilingOptions &setLoopType (LoopType type) {
84- loopType = type;
85- return *this ;
86- }
87-
8892 // / Specify mapping of loops to devices. This is only respected when the loop
8993 // / constructs support such a mapping (like `scf.forall`). Will be ignored
9094 // / when using loop constructs that dont support such a mapping (like
@@ -117,6 +121,98 @@ struct SCFTilingOptions {
117121 reductionDims.insert (dims.begin (), dims.end ());
118122 return *this ;
119123 }
124+
125+ // -------------------------------------------------------------------------//
126+ // Options related to tiling using custom loop.
127+ // -------------------------------------------------------------------------//
128+
129+ // For generating the inter-tile loops using a custom loop, two callback
130+ // functions are needed
131+ // 1. That generates the "loop header", i.e. the loop that iterates over the
132+ // different tiles.
133+ // 2. That generates the loop terminator
134+ //
135+ // For `scf.forall` case the call back to generate loop header would generate
136+ //
137+ // ```mlir
138+ // scf.forall (...) = ... {
139+ // ..
140+ // }
141+ // ```
142+ //
143+ // and the call back to generate the loop terminator would generate the
144+ // `scf.in_parallel` region
145+ //
146+ // ```mlir
147+ // scf.forall (...) = ... {
148+ // scf.in_parallel {
149+ // tensor.parallel_insert_slice ...
150+ // }
151+ // }
152+ // ```
153+ //
154+
155+ // Information that is to be returned by the callback to generate the loop
156+ // header needed for the rest of the tiled codegeneration.
157+ // - `loops`: The generated loops
158+ // - `tileOffset`: The values that represent the offset of the iteration space
159+ // tile
160+ // - `tileSizes` : The values that represent the size of the iteration space
161+ // tile.
162+ // - `destinationTensors` : The tensors to use as destinations during tiling.
163+ struct CustomLoopHeaderInfo {
164+ SmallVector<LoopLikeOpInterface> loops;
165+ SmallVector<OpFoldResult> tileOffset;
166+ SmallVector<OpFoldResult> tileSizes;
167+ SmallVector<Value> destinationTensors;
168+ };
169+
170+ // Type of the callback function that generates the loop headers.
171+ // - `loopRanges` : Values that represent the full size of the iteration space
172+ // being tiled.
173+ // - `giveTileSizes` : The tile sizes that are to be used to tile the
174+ // iteration
175+ // space.
176+ // - `destinationTensors` : The tensors to use as destinations for the results
177+ // of the tiled loop for loops that implement
178+ // `DestinationStyleOpInterface`.
179+ // Returns the `CustomLoopHeaderInfo` object (described above). it is expected
180+ // that this function sets the insertion point of `rewriter` to the program
181+ // point where the intra-tile loop computation is to be generated.
182+ using GenerateLoopHeaderFn = std::function<FailureOr<CustomLoopHeaderInfo>(
183+ RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
184+ ArrayRef<OpFoldResult> givenTileSizes, ValueRange destinationTensors)>;
185+
186+ // Type of the callback function that generates the loop terminator.
187+ // - `tiledResults` : Tiles of the result computed for the iteration space
188+ // tile
189+ // - `resultOffsets` : For each of the `tiledResults`, the offset at which
190+ // the result tile is to be "inserted" back into the
191+ // destination tensor.
192+ // - `resultSizes` : For each of the `tiledResults`, the size of the result
193+ // tile
194+ // that is to be "inserted" back into the destination
195+ // tensor.
196+ // Returns the `CustomLoopHeaderInfo` object (described above)
197+ using GenerateLoopTerminatorFn = std::function<LogicalResult(
198+ RewriterBase &rewriter, Location loc, ValueRange tiledResults,
199+ ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
200+ ArrayRef<SmallVector<OpFoldResult>> resultSizes,
201+ ValueRange destinationTensors)>;
202+
203+ // Callback function to generate the inter-tile loop header.
204+ GenerateLoopHeaderFn generateLoopHeaderFn = nullptr ;
205+ // Callback function to generate the inter-tile loop terminator.
206+ GenerateLoopTerminatorFn generateLoopTerminatorFn = nullptr ;
207+ // Helper function to set the callbacks for inter-tile loop header and
208+ // terminator functions when using a custom operation for the loop.
209+ SCFTilingOptions &
210+ setCustomLoopGenerationFns (GenerateLoopHeaderFn headerFn,
211+ GenerateLoopTerminatorFn terminatorFn) {
212+ generateLoopHeaderFn = std::move (headerFn);
213+ generateLoopTerminatorFn = std::move (terminatorFn);
214+ return *this ;
215+ }
120216};
121217
122218// / Transformation information returned after tiling.
0 commit comments