diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 235dedd604017..ab85cebee1782 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -136,18 +136,18 @@ void applyPermutationToVector(SmallVector &inVec, } /// Utility class used to generate nested loops with ranges described by -/// `loopRanges` and loop type described by the `iteratorTypes`. `allIvs` is -/// populated with induction variables for all generated loops on return, with -/// `fun` used to generate the body of the innermost loop. +/// `loopRanges` and loop type described by the `iteratorTypes`. `bodyBuilderFn` +/// is used to generate the body of the innermost loop. It is passed a range +/// of loop induction variables. template struct GenerateLoopNest { using IndexedValueTy = typename std::conditional::value, AffineIndexedValue, StdIndexedValue>::type; - static void doit(MutableArrayRef allIvs, - ArrayRef loopRanges, + + static void doit(ArrayRef loopRanges, ArrayRef iteratorTypes, - std::function fun); + function_ref bodyBuilderFn); }; } // namespace linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index 98baef1057635..5ccf2a469dcee 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -459,10 +459,6 @@ Optional linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) { auto linalgOp = cast(op); assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); - auto nPar = linalgOp.getNumParallelLoops(); - auto nRed = linalgOp.getNumReductionLoops(); - auto nWin = linalgOp.getNumWindowLoops(); - auto nLoops = nPar + nRed + nWin; auto mapsRange = linalgOp.indexing_maps().template getAsRange(); auto maps = llvm::to_vector<8>( @@ -475,15 +471,14 @@ Optional linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) { return LinalgLoops(); } - SmallVector allIvs(nLoops); + SmallVector allIvs; auto loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap, getViewSizes(builder, linalgOp)); - assert(loopRanges.size() == allIvs.size()); GenerateLoopNest::doit( - allIvs, loopRanges, linalgOp.iterator_types().getValue(), [&] { - SmallVector allIvValues(allIvs.begin(), allIvs.end()); - emitScalarImplementation(allIvValues, linalgOp); + loopRanges, linalgOp.iterator_types().getValue(), [&](ValueRange ivs) { + allIvs.append(ivs.begin(), ivs.end()); + emitScalarImplementation(allIvs, linalgOp); }); // Number of loop ops might be different from the number of ivs since some // loops like affine.parallel and scf.parallel have multiple ivs. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 0dac957396795..ac6903b4bd884 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -375,29 +375,31 @@ Optional static tileLinalgOpImpl( // 3. Create the tiled loops. LinalgOp res = op; - SmallVector ivs(loopRanges.size()); + SmallVector ivs; SmallVector iteratorTypes = llvm::to_vector<4>(op.iterator_types().cast().getValue()); if (!options.interchangeVector.empty()) applyPermutationToVector(iteratorTypes, options.interchangeVector); - GenerateLoopNest::doit(ivs, loopRanges, iteratorTypes, [&] { - auto &b = ScopedContext::getBuilderRef(); - auto loc = ScopedContext::getLocation(); - SmallVector ivValues(ivs.begin(), ivs.end()); - - // If we have to apply a permutation to the tiled loop nest, we have to - // reorder the induction variables This permutation is the right one - // assuming that loopRanges have previously been permuted by - // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation - // of that one: (d0,d1,d2)->(d2,d0,d1) - if (!options.interchangeVector.empty()) - ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues); - - auto views = makeTiledViews(b, loc, op, ivValues, tileSizes, viewSizes); - auto operands = getAssumedNonViewOperands(op); - views.append(operands.begin(), operands.end()); - res = op.clone(b, loc, views); - }); + GenerateLoopNest::doit( + loopRanges, iteratorTypes, [&](ValueRange localIvs) { + auto &b = ScopedContext::getBuilderRef(); + auto loc = ScopedContext::getLocation(); + ivs.assign(localIvs.begin(), localIvs.end()); + SmallVector ivValues(ivs.begin(), ivs.end()); + + // If we have to apply a permutation to the tiled loop nest, we have to + // reorder the induction variables This permutation is the right one + // assuming that loopRanges have previously been permuted by + // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation + // of that one: (d0,d1,d2)->(d2,d0,d1) + if (!options.interchangeVector.empty()) + ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues); + + auto views = makeTiledViews(b, loc, op, ivValues, tileSizes, viewSizes); + auto operands = getAssumedNonViewOperands(op); + views.append(operands.begin(), operands.end()); + res = op.clone(b, loc, views); + }); // 4. Transforms index arguments of `linalg.generic` w.r.t. to the tiling. transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index c48b87aaa4e44..5bba11420d086 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -129,70 +129,125 @@ template struct mlir::linalg::GenerateLoopNest; template struct mlir::linalg::GenerateLoopNest; template struct mlir::linalg::GenerateLoopNest; +/// Given a list of subview ranges, extract individual values for lower, upper +/// bounds and steps and put them into the corresponding vectors. +static void unpackRanges(ArrayRef ranges, + SmallVectorImpl &lbs, + SmallVectorImpl &ubs, + SmallVectorImpl &steps) { + for (SubViewOp::Range range : ranges) { + lbs.emplace_back(range.offset); + ubs.emplace_back(range.size); + steps.emplace_back(range.stride); + } +} + namespace mlir { namespace linalg { -/// Specialization of loop nest generator for scf.parallel loops to handle -/// iterator types that are not parallel. These are generated as sequential -/// loops. + +/// Specialization to build an scf "for" nest. template <> -void GenerateLoopNest::doit(MutableArrayRef allIvs, - ArrayRef loopRanges, - ArrayRef iteratorTypes, - std::function fun) { - edsc::GenericLoopNestRangeBuilder(allIvs, loopRanges)(fun); +void GenerateLoopNest::doit( + ArrayRef loopRanges, ArrayRef iteratorTypes, + function_ref bodyBuilderFn) { + SmallVector lbs, ubs, steps; + unpackRanges(loopRanges, lbs, ubs, steps); + edsc::loopNestBuilder(lbs, ubs, steps, bodyBuilderFn); } +/// Specialization to build affine "for" nest. template <> -void GenerateLoopNest::doit(MutableArrayRef allIvs, - ArrayRef loopRanges, - ArrayRef iteratorTypes, - std::function fun) { - edsc::GenericLoopNestRangeBuilder(allIvs, loopRanges)(fun); +void GenerateLoopNest::doit( + ArrayRef loopRanges, ArrayRef iteratorTypes, + function_ref bodyBuilderFn) { + SmallVector lbs, ubs, steps; + unpackRanges(loopRanges, lbs, ubs, steps); + + // Affine loops require constant steps. + SmallVector constantSteps; + constantSteps.reserve(steps.size()); + for (Value v : steps) { + auto op = v.getDefiningOp(); + assert(op && "Affine loops require constant steps"); + constantSteps.push_back(op.getValue()); + } + + edsc::affineLoopNestBuilder(lbs, ubs, constantSteps, bodyBuilderFn); } -template <> -void GenerateLoopNest::doit( - MutableArrayRef allIvs, ArrayRef loopRanges, - ArrayRef iteratorTypes, std::function fun) { - // Check if there is nothing to do here. This is also the recursion - // termination. - if (loopRanges.empty()) - return; - size_t nOuterPar = iteratorTypes.take_front(loopRanges.size()) - .take_while(isParallelIteratorType) - .size(); - if (nOuterPar == 0 && loopRanges.size() == 1) - // Generate the sequential for loop for the remaining non-parallel loop. - return GenerateLoopNest::doit(allIvs, loopRanges, iteratorTypes, - fun); +/// Generates a loop nest consisting of scf.parallel and scf.for, depending on +/// the `iteratorTypes.` Consecutive parallel loops create a single scf.parallel +/// operation; each sequential loop creates a new scf.for operation. The body +/// of the innermost loop is populated by `bodyBuilderFn` that accepts a range +/// of induction variables for all loops. `ivStorage` is used to store the +/// partial list of induction variables. +// TODO(zinenko,ntv): this function can be made iterative instead. However, it +// will have at most as many recursive calls as nested loops, which rarely +// exceeds 10. +static void +generateParallelLoopNest(ValueRange lbs, ValueRange ubs, ValueRange steps, + ArrayRef iteratorTypes, + function_ref bodyBuilderFn, + SmallVectorImpl &ivStorage) { + assert(lbs.size() == ubs.size()); + assert(lbs.size() == steps.size()); + assert(lbs.size() == iteratorTypes.size()); + + // If there are no (more) loops to be generated, generate the body and be + // done with it. + if (iteratorTypes.empty()) + return bodyBuilderFn(ivStorage); + + // Find the outermost parallel loops and drop their types from the list. + unsigned nLoops = iteratorTypes.size(); + iteratorTypes = iteratorTypes.drop_while(isParallelIteratorType); + unsigned nOuterPar = nLoops - iteratorTypes.size(); + + // If there are no outer parallel loops, generate one sequential loop and + // recurse. Note that we wouldn't have dropped anything from `iteratorTypes` + // in this case. if (nOuterPar == 0) { - // The immediate outer loop is not parallel. Generate a scf.for op for this - // loop, but there might be subsequent loops that are parallel. Use - // recursion to find those. - auto nestedFn = [&]() { - GenerateLoopNest::doit(allIvs.drop_front(), - loopRanges.drop_front(), - iteratorTypes.drop_front(), fun); - }; - return GenerateLoopNest::doit(allIvs[0], loopRanges[0], - iteratorTypes[0], nestedFn); - } - if (nOuterPar == loopRanges.size()) { - // All loops are parallel, so generate the scf.parallel op. - return edsc::GenericLoopNestRangeBuilder(allIvs, - loopRanges)(fun); + edsc::loopNestBuilder(lbs[0], ubs[0], steps[0], [&](Value iv) { + ivStorage.push_back(iv); + generateParallelLoopNest(lbs.drop_front(), ubs.drop_front(), + steps.drop_front(), iteratorTypes.drop_front(), + bodyBuilderFn, ivStorage); + }); + return; } - // Generate scf.parallel for the outer parallel loops. The next inner loop is - // sequential, but there might be more parallel loops after that. So recurse - // into the same method. - auto nestedFn = [&]() { - GenerateLoopNest::doit( - allIvs.drop_front(nOuterPar), loopRanges.drop_front(nOuterPar), - iteratorTypes.drop_front(nOuterPar), fun); - }; - return GenerateLoopNest::doit( - allIvs.take_front(nOuterPar), loopRanges.take_front(nOuterPar), - iteratorTypes.take_front(nOuterPar), nestedFn); + + // Generate a single parallel loop-nest operation for all outermost parallel + // loops and recurse. + edsc::OperationBuilder( + lbs.take_front(nOuterPar), ubs.take_front(nOuterPar), + steps.take_front(nOuterPar), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { + edsc::ScopedContext context(nestedBuilder, nestedLoc); + ivStorage.append(localIvs.begin(), localIvs.end()); + generateParallelLoopNest(lbs.drop_front(nOuterPar), + ubs.drop_front(nOuterPar), + steps.drop_front(nOuterPar), iteratorTypes, + bodyBuilderFn, ivStorage); + }); +} + +/// Specialization for generating a mix of parallel and sequential scf loops. +template <> +void GenerateLoopNest::doit( + ArrayRef loopRanges, ArrayRef iteratorTypes, + function_ref bodyBuilderFn) { + SmallVector lbsStorage, ubsStorage, stepsStorage, ivs; + unpackRanges(loopRanges, lbsStorage, ubsStorage, stepsStorage); + ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage); + + // This function may be passed more iterator types than ranges. + assert(iteratorTypes.size() >= loopRanges.size() && + "expected iterator type for all ranges"); + iteratorTypes = iteratorTypes.take_front(loopRanges.size()); + ivs.reserve(iteratorTypes.size()); + generateParallelLoopNest(lbs, ubs, steps, iteratorTypes, bodyBuilderFn, ivs); + assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops"); } + } // namespace linalg } // namespace mlir