From 373d10cdc9abe5d2ad4bb78755eaa08b6b46f9ef Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Sun, 19 May 2024 22:10:18 -0700 Subject: [PATCH 01/21] add deep tile pass for matmul and tests --- include/gc/Transforms/Passes.td | 18 +- lib/gc/Transforms/CMakeLists.txt | 2 +- .../Transforms/DeepTileContractionNamedOp.cpp | 749 ++++++++++++++++++ lib/gc/Transforms/TileNamed.cpp | 49 -- .../deepTileContractionNamedOp.mlir | 48 ++ 5 files changed, 810 insertions(+), 56 deletions(-) create mode 100644 lib/gc/Transforms/DeepTileContractionNamedOp.cpp delete mode 100644 lib/gc/Transforms/TileNamed.cpp create mode 100644 test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 9d75ac2e9..2933b65ba 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -11,12 +11,6 @@ include "mlir/Pass/PassBase.td" -def TileLinalgNamed : Pass<"tile-named-linalg", "func::FuncOp"> { - let summary = "Tile linalg named operations."; - let dependentDialects = - ["linalg::LinalgDialect", "scf::SCFDialect", "tensor::TensorDialect"]; -} - #ifdef GC_HAS_ONEDNN_DIALECT def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> { let summary = @@ -71,6 +65,18 @@ def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion", "Decide if enable cost model to control iterative fusion.">, ListOption<"defaultTileSize", "default-tile-size", "std::string", "Set default TileSize for the certain type of op, saying `matmul:{32,32}`">, + ]; +} +def DeepTileContractionNamedOp + : Pass<"deep-tile-contraction-named-op", "func::FuncOp"> { + let summary = "Tile linalg contraction named operation deeply"; + let description = + [{The pass tries to tile the linalg contraction named op deeply.}]; + let dependentDialects = [ + "func::FuncDialect", + "arith::ArithDialect", + "tensor::TensorDialect", + "linalg::LinalgDialect", ]; } diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index d240f28c1..21f522224 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -12,10 +12,10 @@ get_property(mlir_conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) gc_add_mlir_library(GcPasses OneDNNGraphToLinalg.cpp Pipeline.cpp - TileNamed.cpp IterativeTilingAndFusion.cpp TilingUsingInterfaceX.cpp VerifyTargetDescription.cpp + DeepTileContractionNamedOp.cpp DEPENDS GraphCompilerPassIncGen diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp new file mode 100644 index 000000000..a99c6e0ad --- /dev/null +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -0,0 +1,749 @@ +//===----------------------------------------------------------------------===// +//===- DeepTileContractionNamedOp.cpp - the Fusion for any tilable MLIR +// operation --*- C++ +//-*-=// +//-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include + +#include "gc/Transforms/Passes.h" + +#include + +#include + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_DEEPTILECONTRACTIONNAMEDOP +#include "gc/Transforms/Passes.h.inc" + +namespace { + +struct SystemDesc { + // get runtime OMP_NUM_THREADS + uint32_t getNumThreads(); + // get cache size by cacheLevel + size_t getCacheSize(uint8_t cacheLevel); +}; + +struct MatmulConfig { + int MBlock, NBlock, KBlock; + int MThreads, NThreads, KThreads; + int innerMostMBlock, innerMostNBlock, innerMostKBlock; +}; + +template inline T divAndCeil(T a, T b) { return (a - 1) / b + 1; } + +MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) { + // TODO: build a more complex heuristic to determine the best tiling + auto M = linalgOp.getShape(linalgOp.getDpsInputOperand(0))[0]; + auto N = linalgOp.getShape(linalgOp.getDpsInputOperand(1))[1]; + auto K = linalgOp.getShape(linalgOp.getDpsInputOperand(1))[0]; + MatmulConfig cfg; + + // innermost Block + auto defaultBlock = 32; + cfg.innerMostMBlock = M % defaultBlock == 0 ? defaultBlock : M; + cfg.innerMostNBlock = N % defaultBlock == 0 ? defaultBlock : N; + cfg.innerMostKBlock = K % defaultBlock == 0 ? defaultBlock : K; + + // Number of block + auto MNumBlock = M / cfg.innerMostMBlock; + auto NNumBlock = N / cfg.innerMostNBlock; + auto KNumBlock = K / cfg.innerMostKBlock; + + // Threads + cfg.MThreads = 32; + cfg.NThreads = 1; + cfg.KThreads = 1; + + // Block + cfg.MBlock = divAndCeil((int)MNumBlock, cfg.MThreads) * cfg.innerMostMBlock; + cfg.NBlock = divAndCeil((int)NNumBlock, cfg.NThreads) * cfg.innerMostNBlock; + cfg.KBlock = divAndCeil((int)KNumBlock, cfg.KThreads) * cfg.innerMostKBlock; + + cfg.innerMostMBlock = 32; + cfg.innerMostNBlock = 32; + cfg.innerMostKBlock = 32; + cfg.MBlock = 64; + cfg.NBlock = 64; + cfg.KBlock = 64; + cfg.MThreads = 2; + cfg.NThreads = 1; + cfg.KThreads = 1; + return cfg; +} + +static Value tensorViewRankedTensor(RewriterBase &rewriter, + RankedTensorType outTensorType, + Value value) { + // TODO: add support for plain layout transpose + Value result, currentValue = value; + auto loc = currentValue.getLoc(); + auto inTensorType = cast(currentValue.getType()); + auto inShape = inTensorType.getShape(); + auto outShape = outTensorType.getShape(); + auto tensorElementType = inTensorType.getElementType(); + + if (inShape == outShape) { + return currentValue; + } + + if (outTensorType.getNumDynamicDims() != inTensorType.getNumDynamicDims()) { + SmallVector alignOutShape(outShape.begin(), outShape.end()); + if (outShape.size() < inShape.size()) { + SmallVector oneVector(inShape.size() - outShape.size(), 1); + alignOutShape.insert(alignOutShape.begin(), oneVector.begin(), + oneVector.end()); + } else { + alignOutShape.erase(alignOutShape.begin(), + alignOutShape.begin() + + (outShape.size() - inShape.size())); + } + auto type = RankedTensorType::get(alignOutShape, tensorElementType); + currentValue = rewriter.create(loc, type, currentValue); + if (type == outTensorType) { + return currentValue; + } + } + + if (outShape.size() < inShape.size()) { + SmallVector reassocIndices; + ReassociationIndices firstEntry; + for (auto i = 0UL; i < inShape.size() - outShape.size() + 1; i++) { + firstEntry.push_back(i); + } + reassocIndices.push_back(firstEntry); + for (auto i = inShape.size() - outShape.size() + 1UL; i < inShape.size(); + i++) { + reassocIndices.push_back({(int)i}); + } + result = rewriter.create( + loc, outTensorType, currentValue, reassocIndices); + } else if (outShape.size() > inShape.size()) { + SmallVector reassocIndices; + ReassociationIndices firstEntry; + for (auto i = 0UL; i < outShape.size() - inShape.size() + 1; i++) { + firstEntry.push_back((int)i); + } + reassocIndices.push_back(firstEntry); + for (auto i = outShape.size() - inShape.size() + 1UL; i < outShape.size(); + i++) { + reassocIndices.push_back({(int)i}); + } + result = rewriter.create( + loc, outTensorType, currentValue, reassocIndices); + } else { + result = rewriter.create(loc, outTensorType, currentValue); + } + return result; +} + +struct OuterLoopGenerationOption { + enum LoopType { ForOp, ForallOp }; + SmallVector> nestedTileSizes; + SmallVector loopType; + SmallVector> loopDim; +}; + +struct OuterLoopGenerationResult { + /// Tiled operations that are generated during tiling. The order does not + /// matter except the last op. The replacements are expected to be the results + /// of the last op. + SmallVector tiledOps; + /// The `scf.for` operations that iterate over the tiles. + SmallVector loops; + /// Values to use as replacements for the untiled op. Is the same size as the + /// number of results of the untiled op. + SmallVector replacements; +}; + +static FailureOr +generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, + const OuterLoopGenerationOption &option) { + // TODO: handle the return value + OuterLoopGenerationResult result; + auto nestedTileSizes = option.nestedTileSizes; + auto loopType = option.loopType; + auto loopDim = option.loopDim; + + if (loopType.size() != loopDim.size() || + loopDim.size() != nestedTileSizes.size()) { + return b.notifyMatchFailure( + linalgOp, + "loopType, loopDim and nestedTileSizes should have the same size"); + } + + if (linalgOp.hasPureBufferSemantics()) + return b.notifyMatchFailure( + linalgOp, "currentOp should not has pure buffer semantics"); + + linalg::LinalgOp currentOp = linalgOp; + for (auto iteratorType : llvm::enumerate(loopType)) { + auto [i, type] = iteratorType; + auto currentDim = loopDim[i]; + auto currentTileSize = nestedTileSizes[i]; + if (type == OuterLoopGenerationOption::LoopType::ForOp) { + scf::SCFTilingOptions tileOption; + SmallVector TileSizes( + currentOp.getNumLoops(), getAsIndexOpFoldResult(b.getContext(), 0)); + + for (auto [d, tile] : llvm::zip(currentDim, currentTileSize)) { + TileSizes[d] = getAsIndexOpFoldResult(b.getContext(), tile); + } + tileOption.setTileSizes(TileSizes); + tileOption.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(currentOp); + auto tilingResult = scf::tileUsingSCF( + b, cast(currentOp.getOperation()), tileOption); + if (failed(tilingResult)) + return failure(); + b.replaceOp(currentOp, tilingResult->replacements); + currentOp = dyn_cast(tilingResult->tiledOps.back()); + } else if (type == OuterLoopGenerationOption::LoopType::ForallOp) { + SmallVector tileSizes( + currentOp.getNumLoops(), getAsIndexOpFoldResult(b.getContext(), 0)); + SmallVector threads( + currentOp.getNumLoops(), getAsIndexOpFoldResult(b.getContext(), 0)); + SmallVector reductionDims; + currentOp.getReductionDims(reductionDims); + for (auto [d, tile] : llvm::zip(currentDim, currentTileSize)) { + if (llvm::find(reductionDims, d) != reductionDims.end() && + !dyn_cast(currentOp.getOperation())) + tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), 0); + else + tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), tile); + } + + SmallVector numThreads; + SmallVector loopRanges = + cast(currentOp.getOperation()).getIterationDomain(b); + unsigned nLoops = loopRanges.size(); + numThreads.reserve(nLoops); + AffineExpr s0, s1; + bindSymbols(b.getContext(), s0, s1); + AffineExpr divExpr = s0.ceilDiv(s1); + for (const auto &it : llvm::zip(tileSizes, loopRanges)) { + OpFoldResult numTiles = std::get<0>(it); + if (!isConstantIntValue(numTiles, 0)) + numTiles = mlir::affine::makeComposedFoldedAffineApply( + b, currentOp.getLoc(), divExpr, + {std::get<1>(it).size, std::get<0>(it)}); + numThreads.push_back(numTiles); + } + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(currentOp); + // TODO: add split reduction support here + // if (auto partialInterface = + // dyn_cast(currentOp.getOperation())) + // { + // auto tilingResult = linalg::tileReductionUsingForall( + // b, cast(currentOp.getOperation()), + // numThreads, tileSizes, std::nullopt); + // if (failed(tilingResult)) + // return failure(); + // currentOp = + // dyn_cast(tilingResult->parallelTiledOp); + // } else + if (auto tilingInterface = + cast(currentOp.getOperation())) { + auto tilingResult = linalg::tileToForallOpUsingTileSizes( + b, tilingInterface, tileSizes, std::nullopt); + if (failed(tilingResult)) + return failure(); + b.replaceOp(currentOp, tilingResult->tileOp); + currentOp = dyn_cast(tilingResult->tiledOp); + } + } + } + result.tiledOps.emplace_back(currentOp); + return result; +} + +static void getMatmulParallelDims(linalg::LinalgOp linalgOp, + unsigned operandIdx, + SmallVectorImpl &dims) { + AffineMap map = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(operandIdx)); + SmallVector iteratorTypes = + linalgOp.getIteratorTypesArray(); + + ArrayRef results = map.getResults(); + for (auto dim : results) { + auto dimExpr = dyn_cast(dim); + if (dimExpr && iteratorTypes[dimExpr.getPosition()] == + mlir::utils::IteratorType::parallel) { + dims.push_back(dimExpr.getPosition()); + } + } +} + +static unsigned getOprandDim(linalg::LinalgOp &linalgOp, unsigned iteratorPos, + unsigned operandIdx) { + Value Operand; + unsigned dimPos; + [[maybe_unused]] auto result = + linalgOp.mapIterationSpaceDimToOperandDim(iteratorPos, Operand, dimPos); + return linalgOp.getShape(linalgOp.getDpsInputOperand(operandIdx))[dimPos]; +} + +static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter, + Operation *op, + bool isExtract, + SmallVector size, + int shrinDimNum = 0) { + if (auto extractSlice = dyn_cast(op)) { + SmallVector mixedOffsets = extractSlice.getMixedOffsets(); + SmallVector mixedSizes = extractSlice.getMixedSizes(); + SmallVector mixedStrides = extractSlice.getMixedStrides(); + for (auto i = 0UL; i < mixedSizes.size(); i++) { + mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), size[i]); + } + if (shrinDimNum > 0) { + rewriter.replaceOpWithNewOp( + extractSlice, + mlir::RankedTensorType::get( + SmallVector(size.begin() + shrinDimNum, size.end()), + extractSlice.getResult().getType().getElementType()), + extractSlice.getSource(), mixedOffsets, mixedSizes, mixedStrides); + } else { + rewriter.replaceOpWithNewOp( + extractSlice, extractSlice.getSource(), mixedOffsets, mixedSizes, + mixedStrides); + } + } else { + return failure(); + } + return mlir::success(); +} + +static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter, + Operation *op, Value source, + SmallVector size) { + if (auto insertSlice = dyn_cast(op)) { + SmallVector mixedOffsets = insertSlice.getMixedOffsets(); + SmallVector mixedSizes = insertSlice.getMixedSizes(); + SmallVector mixedStrides = insertSlice.getMixedStrides(); + for (auto i = 0UL; i < mixedSizes.size(); i++) { + mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), size[i]); + } + rewriter.replaceOpWithNewOp( + insertSlice, source, insertSlice.getDest(), mixedOffsets, mixedSizes, + mixedStrides); + } else { + return failure(); + } + return success(); +} + +enum DimType { Batch, M, N, K }; + +static FailureOr>> +getOprandDimType(linalg::LinalgOp &linalgOp) { + // TODO: add more support for other linalg named matmul + if (isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K}, + SmallVector{DimType::K, DimType::N}, + SmallVector{DimType::M, DimType::N}}; + } else if (isa(linalgOp)) { + auto iteratorTypes = linalgOp.getIteratorTypesArray(); + if (iteratorTypes.size() == 7UL) { + // 4Dx5D, brgemm vnni + return SmallVector>{ + SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, + DimType::K}, + SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; + } else if (iteratorTypes.size() == 6UL) { + // 4Dx4D + return SmallVector>{ + SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K, DimType::K, DimType::N}, + SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; + } + } else { + return failure(); + } + return failure(); +} + +/* +forall([PM, PN]: [MThreads, NThreads) { + for(PK : KThreads) { + CSlice = [KThreads, PM * MOuterBlock: (PM + 1) * MOuterBlock, + PN * NOuterBlock: (PN + 1) * NOuterBlock] + ASlice = A[PM * MOuterBlock: (PM + 1) * MOuterBlock, PK * KOuterBlock * (PK ++ 1) * KOuterBlock] + BSlice = B[PK * KOuterBlock * (PK + 1) * KOuterBlock, PN * +NOuterBlock: (PN + 1) * NOuterBlock] CSlice2 = CSlice[PK, PM * MOuterBlock: (PM ++ 1) * MOuterBlock, PN * NOuterBlock: (PN + 1) * NOuterBlock] + + MNumBlock = MOuterBlock / MBlock + NNumBlock = NOuterBlock / NBlock + KNumBlock = KOuterBlock / KBlovk + for([om, on, ok]: [MNumBlock, NNumBlock, KNumBlock]) { + ASlice2 = ASlice[om * MBlock: (om + 1) * MBlock, ok * KBlock: (ok + 1) * +KBlock] + BSlice2 = BSlice[0, om * MBlock: (om + 1) * MBlock, ok * KBlock: (ok + +1) * KBlock] + CSlice3 = CSlice2[0, om * MBlock: (om + 1) * MBlock, on * NBlock: +(on + 1) * NBlock] (init with 0 when ok == 0) + MNumInnerBlock = MBlock / iim_block_ + ... + for([im, in]: [MNumInnerBlock, NNumInnerBlock]) { + ASlice3 = ASlice2[im * iim_block_: (im + 1) * iim_block_, :] + BSlice3 = BSlice2[0, im * iim_block_: (im + 1) * iim_block_, :] + CSlice4 = CSlice3[0, im * iim_block_: (im + 1) * iim_block_, in * +iin_block_: (in + 1) * iin_block_] (init with 0 when ok == 0) + brgemm(bs=KNumInnerBlock, M=iim_block_, N=iin_block_, K=iik_block, +A=ASlice3, B=BSlice3, C=CSlice4, onlyUpdate=(ok!=0)); + } + } + } + C = final_reduce(CSlice) +} +*/ +struct deepTileMatmul : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + FailureOr + outerLoopGeneration(RewriterBase &rewriter, linalg::LinalgOp linalgOp, + MatmulConfig cfg) const { + SmallVector KDimPos, MDimPos, NDimPos; + linalgOp.getReductionDims(KDimPos); + getMatmulParallelDims(linalgOp, 0, MDimPos); + getMatmulParallelDims(linalgOp, 1, NDimPos); + bool useBlockedLayout = KDimPos.size() > 1; + + OuterLoopGenerationOption option; + auto iteratorTypes = linalgOp.getIteratorTypesArray(); + auto KFirstDim = (int)getOprandDim(linalgOp, KDimPos[0], 1); + auto MFirstDim = (int)getOprandDim(linalgOp, MDimPos[0], 0); + auto NFirstDim = (int)getOprandDim(linalgOp, NDimPos[0], 1); + auto KParallelBlockSize = + useBlockedLayout + ? divAndCeil(KFirstDim, cfg.KThreads) + : divAndCeil(divAndCeil(KFirstDim, cfg.KBlock), cfg.KThreads) * + cfg.KBlock; + auto MParallelBlockSize = + useBlockedLayout + ? divAndCeil(MFirstDim, cfg.MThreads) + : divAndCeil(divAndCeil(MFirstDim, cfg.MBlock), cfg.MThreads) * + cfg.MBlock; + auto NParallelBlockSize = + useBlockedLayout + ? divAndCeil(NFirstDim, cfg.NThreads) + : divAndCeil(divAndCeil(NFirstDim, cfg.NBlock), cfg.NThreads) * + cfg.NBlock; + auto KOuterBlockSize = useBlockedLayout + ? (cfg.KBlock - 1) / cfg.innerMostKBlock + 1 + : cfg.KBlock; + auto MOuterBlockSize = useBlockedLayout + ? (cfg.MBlock - 1) / cfg.innerMostMBlock + 1 + : cfg.MBlock; + auto NOuterBlockSize = useBlockedLayout + ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1 + : cfg.NBlock; + // Outer + option.nestedTileSizes.emplace_back(SmallVector{ + MParallelBlockSize, NParallelBlockSize, KParallelBlockSize}); + option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForallOp); + option.loopDim.emplace_back( + SmallVector{(int)MDimPos[0], (int)NDimPos[0], (int)KDimPos[0]}); + // Middle + for (auto [tile, dim] : + llvm::zip(SmallVector{MOuterBlockSize, NOuterBlockSize, + KOuterBlockSize}, + SmallVector{(int)MDimPos[0], (int)NDimPos[0], + (int)KDimPos[0]})) { + option.nestedTileSizes.emplace_back(SmallVector{tile}); + option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); + option.loopDim.emplace_back(SmallVector{dim}); + } + // Inner + if (!useBlockedLayout) { + option.nestedTileSizes.emplace_back(SmallVector{cfg.KBlock}); + option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); + option.loopDim.emplace_back(SmallVector{(int)KDimPos.back()}); + } + for (auto dim = 0UL; dim < linalgOp.getNumLoops(); dim++) { + if (dim != MDimPos.back() && dim != NDimPos.back() && + iteratorTypes[dim] != mlir::utils::IteratorType::reduction) { + option.nestedTileSizes.emplace_back(SmallVector{1}); + option.loopType.emplace_back( + OuterLoopGenerationOption::LoopType::ForOp); + option.loopDim.emplace_back(SmallVector{(int)dim}); + } + } + return generateOuterLoop(rewriter, linalgOp, option); + } + + struct innerBodyGenerationOption { + bool hasFillOp = false; + Value fillValue; + }; + + LogicalResult + innerBodyGeneration(RewriterBase &rewriter, linalg::LinalgOp originOp, + linalg::LinalgOp currentOp, + const innerBodyGenerationOption &option) const { + auto operandDimTypes = getOprandDimType(originOp); + MatmulConfig cfg = getDefaultMatmulConfig(originOp); + auto AShape = originOp.getShape(originOp.getDpsInputOperand(0)); + auto BShape = originOp.getShape(originOp.getDpsInputOperand(1)); + auto CShape = originOp.getShape(originOp.getDpsInitOperand(0)); + bool useBlockedLayout = BShape.size() > 2; + // TODO: support plain in/block out format + SmallVector AInnermostDims, BInnermostDims, CInnermostDims; + if (useBlockedLayout) { + bool firstM = true, firstK = true, firstN = true; + for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[0])) { + if (iter == DimType::M && firstM) { + AInnermostDims.push_back(1); + firstM = false; + } else if (iter == DimType::Batch) { + AInnermostDims.push_back(1); + } else if (iter == DimType::K && firstK) { + AInnermostDims.push_back(cfg.KBlock / cfg.innerMostKBlock); + firstK = false; + } else { + AInnermostDims.push_back(AShape[idx]); + } + } + firstN = true; + firstK = true; + for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[1])) { + if (iter == DimType::N && firstN) { + BInnermostDims.push_back(1); + firstN = false; + } else if (iter == DimType::Batch) { + BInnermostDims.push_back(1); + } else if (iter == DimType::K && firstK) { + BInnermostDims.push_back(cfg.KBlock / cfg.innerMostKBlock); + firstK = false; + } else { + BInnermostDims.push_back(BShape[idx]); + } + } + firstM = true; + firstN = true; + for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[2])) { + if (iter == DimType::M && firstM) { + CInnermostDims.push_back(1); + firstM = false; + } else if (iter == DimType::Batch) { + CInnermostDims.push_back(1); + } else if (iter == DimType::N && firstN) { + CInnermostDims.push_back(1); + firstN = false; + } else { + CInnermostDims.push_back(CShape[idx]); + } + } + } else { + AInnermostDims = SmallVector{cfg.innerMostMBlock, + cfg.KBlock / cfg.innerMostKBlock * + cfg.innerMostKBlock}; + BInnermostDims = SmallVector{cfg.KBlock / cfg.innerMostKBlock * + cfg.innerMostKBlock, + cfg.innerMostNBlock}; + CInnermostDims = + SmallVector{cfg.innerMostMBlock, cfg.innerMostNBlock}; + } + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(currentOp); + auto dataType = + dyn_cast(currentOp.getDpsInputs()[0].getType()); + auto weightType = + dyn_cast(currentOp.getDpsInputs()[1].getType()); + auto resultType = + dyn_cast(currentOp.getDpsInits()[0].getType()); + // use shrink layout when it is able to be converted to brgemm + bool useShrinkedLayout = (BInnermostDims.size() == 4); + + // update the extractSlice to static size, replace it with + // useBlockedLayout when + if (failed(setStaticSizeForExtractSliceOp( + rewriter, currentOp.getDpsInits()[0].getDefiningOp(), true, + CInnermostDims, useShrinkedLayout ? 2 : 0)) || + failed(setStaticSizeForExtractSliceOp( + rewriter, currentOp.getDpsInputs()[1].getDefiningOp(), true, + BInnermostDims, useShrinkedLayout)) || + failed(setStaticSizeForExtractSliceOp( + rewriter, currentOp.getDpsInputs()[0].getDefiningOp(), true, + AInnermostDims, useShrinkedLayout))) { + return failure(); + } + + // View the tensor to brgemm required format + Value dataOprand = tensorViewRankedTensor( + rewriter, + mlir::RankedTensorType::get( + useBlockedLayout + ? SmallVector(AInnermostDims.begin() + 1, + AInnermostDims.end()) + : SmallVector{1, AInnermostDims[0], AInnermostDims[1]}, + dataType.getElementType()), + currentOp.getDpsInputs()[0]); + Value weightOprand = tensorViewRankedTensor( + rewriter, + mlir::RankedTensorType::get( + useBlockedLayout + ? SmallVector(BInnermostDims.begin() + 1, + BInnermostDims.end()) + : SmallVector{1, BInnermostDims[0], BInnermostDims[1]}, + weightType.getElementType()), + currentOp.getDpsInputs()[1]); + Value resultOprand = tensorViewRankedTensor( + rewriter, + mlir::RankedTensorType::get( + SmallVector(CInnermostDims.begin() + + (useBlockedLayout ? 2 : 0), + CInnermostDims.end()), + resultType.getElementType()), + currentOp.getDpsInits()[0]); + + // Create the brgemm op + // TODO: use brgemm_vnni to replace generic when it is applicable + linalg::LinalgOp matmul; + if (BInnermostDims.size() == 4 || BInnermostDims.size() == 2) { + matmul = rewriter.create( + resultOprand.getLoc(), resultOprand.getType(), + ValueRange{dataOprand, weightOprand}, resultOprand); + } else { + IRMapping mapping; + matmul = dyn_cast( + *rewriter.clone(*(currentOp.getOperation()))); + } + Value result = matmul.getOperation()->getResult(0); + + // Insert the result back to the original tensor + for (Operation *user : currentOp->getResult(0).getUsers()) { + if (failed(setStaticSizeForInsertSliceOp(rewriter, user, result, + CInnermostDims))) { + return failure(); + } + } + rewriter.replaceOp(currentOp, matmul.getOperation()->getResult(0)); + currentOp = matmul; + + if (option.hasFillOp) { + // TODO: support partial K in sinsngle threads, control flow may need + // easy builder support + rewriter.setInsertionPointAfter(currentOp); + auto fillOp = rewriter.create( + currentOp->getLoc(), option.fillValue, currentOp.getDpsInits()[0]); + IRMapping mapping; + mapping.map(currentOp.getDpsInits()[0], fillOp.getResult(0)); + auto res = rewriter.clone(*(currentOp.getOperation()), mapping); + rewriter.replaceOp(currentOp, res); + currentOp = dyn_cast(res); + } + currentOp.getOperation()->getParentOfType().dump(); + return success(); + } + + LogicalResult matchAndRewrite(linalg::LinalgOp matmulOp, + PatternRewriter &rewriter) const override { + if (matmulOp.hasPureBufferSemantics()) + return failure(); + linalg::LinalgOp linalgOp; + linalgOp = dyn_cast(matmulOp.getOperation()); + if (linalgOp.getOperation()->getParentOfType()) + return failure(); + + // Step 1. Match and remove the init/fill operation + // Fuse the fill op manually before fusion support this case(fuse it into + // if-else block) + bool hasFillOp = false; + Value fillValue; + SmallVector KLoopHandle; + if (auto op = dyn_cast( + linalgOp.getDpsInits()[0].getDefiningOp())) { + hasFillOp = true; + fillValue = op.getDpsInputs()[0]; + rewriter.replaceOp(op, op.getDpsInits()[0]); + } + + // Step 2. The processes of outer Loop Generation + // 2.0 Get the iteration infomation first + MatmulConfig cfg = getDefaultMatmulConfig(linalgOp); + // TODO: move the reduction dim to the front. (M, N, threads) -> + // (threads, M, N) + auto outerLoopResult = outerLoopGeneration(rewriter, linalgOp, cfg); + if (failed(outerLoopResult)) { + return failure(); + } + linalgOp = dyn_cast(outerLoopResult->tiledOps.back()); + + // Step 3 inner loop generation, convert the linalg.generic to brgemm + if (failed(innerBodyGeneration( + rewriter, matmulOp, linalgOp, + innerBodyGenerationOption{hasFillOp, fillValue}))) { + return failure(); + } + return success(); + } +}; + +struct DeepTileContractionNamedOp + : public impl::DeepTileContractionNamedOpBase { +public: + void runOnOperation() final { + auto &ctx = getContext(); + IRRewriter rewriter(&ctx); + RewritePatternSet patterns(&ctx); + + patterns.add(patterns.getContext()); + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); + linalg::ControlDropUnitDims options; + options.rankReductionStrategy = + linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice; + linalg::populateFoldUnitExtentDimsPatterns(patterns, options); + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + + for (auto *dialect : ctx.getLoadedDialects()) + dialect->getCanonicalizationPatterns(patterns); + for (RegisteredOperationName op : ctx.getRegisteredOperations()) + op.getCanonicalizationPatterns(patterns, &ctx); + + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/TileNamed.cpp b/lib/gc/Transforms/TileNamed.cpp deleted file mode 100644 index 43348685d..000000000 --- a/lib/gc/Transforms/TileNamed.cpp +++ /dev/null @@ -1,49 +0,0 @@ -//===-- TileNamed.cpp - Tile Named Linalg Ops -------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "gc/Transforms/Passes.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; - -namespace mlir { -namespace gc { -#define GEN_PASS_DEF_TILELINALGNAMED -#include "gc/Transforms/Passes.h.inc" -} // namespace gc -} // namespace mlir - -namespace { -class TileLinalg : public mlir::gc::impl::TileLinalgNamedBase { - - void runOnOperation() override { - auto *ctx = &getContext(); - IRRewriter rewriter(ctx); - - llvm::SmallVector to_tile; - for (Operation &o : getOperation()->getRegion(0).front().getOperations()) { - if (isa(o)) { - to_tile.push_back(&o); - } - } - - for (Operation *o : to_tile) { - llvm::errs() << "func op body to tile: " << *o << "\n"; - } - } -}; - -} // namespace diff --git a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir new file mode 100644 index 000000000..209145f9e --- /dev/null +++ b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir @@ -0,0 +1,48 @@ +// RUN: gc-opt --split-input-file --deep-tile-contraction-named-op %s + +// ----- + +/// CHECK-LABEL: @blocked_matmul_f32 +func.func @blocked_matmul_f32(%arg0: tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> { + %cst = arith.constant dense<1.000000e+00> : tensor<128x128x32x32xf32> + %cst_0 = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<128x128x32x32xf32> + %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> + %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %cst : tensor<128x128x32x32xf32>, tensor<128x128x32x32xf32>) outs(%1 : tensor<128x128x32x32xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %3 = arith.mulf %in, %in_1 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<128x128x32x32xf32> + return %2 : tensor<128x128x32x32xf32> +} + +// ----- + +/// CHECK-LABEL: @plain_matmul_f32 +func.func @plain_matmul_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> { + %cst = arith.constant dense<1.000000e+00> : tensor<4096x4096xf32> + %cst_0 = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<4096x4096xf32> + %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> + %2 = linalg.matmul ins(%arg0, %cst : tensor<4096x4096xf32>, tensor<4096x4096xf32>) outs(%1 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> + return %2 : tensor<4096x4096xf32> +} + +// ----- + +/// CHECK-LABEL: @blocked_matmul_bf16 +func.func @blocked_matmul_bf16(%arg0: tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> { + %cst = arith.constant dense<1.000000e+00> : tensor<128x128x16x32x2xbf16> + %cst_0 = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x128x32x32xbf16> + %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> + %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)>], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %cst : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>) { + ^bb0(%in: bf16, %in_1: bf16, %out: bf16): + %3 = arith.mulf %in, %in_1 : bf16 + %4 = arith.addf %out, %3 : bf16 + linalg.yield %4 : bf16 + } -> tensor<128x128x32x32xbf16> + return %2 : tensor<128x128x32x32xbf16> +} + From 8ec246b177ed4f62d9c12c308fe71f11d8614778 Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Tue, 21 May 2024 22:56:59 -0700 Subject: [PATCH 02/21] Enhance upstream utility and merge all parallel into one forall --- .../Transforms/DeepTileContractionNamedOp.cpp | 24 +- lib/gc/Transforms/Tiling.cpp | 1035 +++++++++++++++++ lib/gc/Transforms/Tiling.hpp | 56 + 3 files changed, 1102 insertions(+), 13 deletions(-) create mode 100644 lib/gc/Transforms/Tiling.cpp create mode 100644 lib/gc/Transforms/Tiling.hpp diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index a99c6e0ad..334ac1902 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "./Tiling.hpp" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -262,19 +263,16 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(currentOp); // TODO: add split reduction support here - // if (auto partialInterface = - // dyn_cast(currentOp.getOperation())) - // { - // auto tilingResult = linalg::tileReductionUsingForall( - // b, cast(currentOp.getOperation()), - // numThreads, tileSizes, std::nullopt); - // if (failed(tilingResult)) - // return failure(); - // currentOp = - // dyn_cast(tilingResult->parallelTiledOp); - // } else - if (auto tilingInterface = - cast(currentOp.getOperation())) { + if (auto partialInterface = + dyn_cast(currentOp.getOperation())) { + auto tilingResult = linalgX::tileAllUsingForall( + b, cast(currentOp.getOperation()), + numThreads, tileSizes, std::nullopt); + if (failed(tilingResult)) + return failure(); + currentOp = dyn_cast(tilingResult->parallelTiledOp); + } else if (auto tilingInterface = + cast(currentOp.getOperation())) { auto tilingResult = linalg::tileToForallOpUsingTileSizes( b, tilingInterface, tileSizes, std::nullopt); if (failed(tilingResult)) diff --git a/lib/gc/Transforms/Tiling.cpp b/lib/gc/Transforms/Tiling.cpp new file mode 100644 index 000000000..b9de4a777 --- /dev/null +++ b/lib/gc/Transforms/Tiling.cpp @@ -0,0 +1,1035 @@ +//===- Tiling.cpp - Implementation of linalg Tiling -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the linalg dialect Tiling pass. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/CommandLine.h" +#include +#include + +namespace mlir { +#define GEN_PASS_DEF_LINALGTILINGPASS +#include "mlir/Dialect/Linalg/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::affine; +using namespace mlir::linalg; +using namespace mlir::scf; + +#define DEBUG_TYPE "linalg-tiling" + +namespace mlir { +namespace linalgX { + +struct LinalgOpPartialReductionInterface { + static FailureOr> generateInitialTensorForPartialReduction( + Operation *op, OpBuilder &b, Location loc, ArrayRef sizes, + ArrayRef reductionDims, ArrayRef newParallelDims) { + auto linalgOp = cast(op); + OpBuilder::InsertionGuard guard(b); + + if (newParallelDims.empty()) + newParallelDims = reductionDims; + + if (linalgOp.hasPureBufferSemantics()) + return op->emitOpError("expected operation to have tensor semantics"); + // Insert the new parallel dimension based on the index of the reduction + // loops. This could be controlled by user for more flexibility. + SmallVector inits; + for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e; + ++initIdx) { + SmallVector combinerOps; + if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) || + combinerOps.size() != 1) + return op->emitOpError("Failed to anaysis the reduction operation."); + + Operation *reductionOp = combinerOps[0]; + std::optional identity = arith::getNeutralElement(reductionOp); + if (!identity.has_value()) + return op->emitOpError( + "Failed to get an identity value for the reduction operation."); + + ArrayRef oldShape = + linalgOp.getShape(linalgOp.getDpsInitOperand(0)); + + // Extend tile size vector to the rank of the output tensor. + SmallVector tileSizeVector = + getValueOrCreateConstantIndexOp(b, loc, sizes); + if (tileSizeVector.size() < oldShape.size()) { + auto zero = b.create(loc, 0); + tileSizeVector.append(oldShape.size() - tileSizeVector.size(), zero); + } + + // Calculate the new shape, we insert the new dimensions based on the + // index of the reduction dimensions. + SmallVector newOutputShape; + SmallVector dynamicDims; + int64_t currReductionDims = 0; + DenseSet newParallelDimsSet(newParallelDims.begin(), + newParallelDims.end()); + for (int64_t idx : + llvm::seq(0, oldShape.size() + newParallelDims.size())) { + if (newParallelDimsSet.contains(idx)) { + dispatchIndexOpFoldResults(sizes[reductionDims[currReductionDims]], + dynamicDims, newOutputShape); + currReductionDims++; + continue; + } + int64_t oldIdx = idx - currReductionDims; + int64_t dim = oldShape[oldIdx]; + newOutputShape.push_back(dim); + if (ShapedType::isDynamic(dim)) + dynamicDims.push_back(b.create( + loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx)); + } + Value emptyTensor = b.create( + loc, newOutputShape, linalgOp.getRegionOutputArgs()[0].getType(), + dynamicDims); + Value constantOp = b.create(loc, *identity); + auto identityTensor = + b.create(loc, constantOp, emptyTensor); + inits.push_back(identityTensor.getResult(0)); + } + return inits; + } + + static Operation *tileToPartialReduction(Operation *op, OpBuilder &b, + Location loc, ValueRange init, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef reductionDims) { + OpBuilder::InsertionGuard guard(b); + auto linalgOp = cast(op); + + AffineMap oldOutputMap = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0)); + SmallVector outputExpr(oldOutputMap.getNumResults() + + reductionDims.size()); + + for (int idx : reductionDims) + outputExpr[idx] = b.getAffineDimExpr(idx); + int currExpr = 0; + for (int idx : llvm::seq(0, outputExpr.size())) { + if (outputExpr[idx]) + continue; + outputExpr[idx] = oldOutputMap.getResult(currExpr++); + } + + // Step 1: Extract a slice of the input operands. + SmallVector valuesToTile = linalgOp.getDpsInputs(); + SmallVector tiledOperands = makeTiledShapes( + b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); + + // Step 2: Extract the accumulator operands + SmallVector strides(offsets.size(), b.getIndexAttr(1)); + SmallVector outOffsets(offsets.size(), b.getIndexAttr(0)); + // TODO: use SubsetExtractOpInterface once it is available. + Value out = b.create(loc, init[0], outOffsets, + sizes, strides); + + // Step3. Create a generic op where the reduction dimensions are replaced + // by a parallel dimension of the size of reduction. + SmallVector newIteratorTypes = + linalgOp.getIteratorTypesArray(); + for (int dim : reductionDims) + newIteratorTypes[dim] = utils::IteratorType::parallel; + SmallVector newMaps = linalgOp.getIndexingMapsArray(); + newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr, + linalgOp.getContext()); + auto genericOp = + b.create(loc, TypeRange({out.getType()}), tiledOperands, + ValueRange({out}), newMaps, newIteratorTypes); + IRMapping mapping; + op->getRegion(0).cloneInto(&genericOp.getRegion(), + genericOp.getRegion().begin(), mapping); + return genericOp.getOperation(); + } + + static Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc, + ValueRange partialReduce, + ArrayRef reductionDims) { + auto linalgOp = cast(op); + + DenseSet reductionDimsSet(reductionDims.begin(), reductionDims.end()); + + // Then create a new reduction that only reduce the newly added dimensions + // from the previous op. + int64_t intermRank = cast(partialReduce[0].getType()).getRank(); + AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); + SmallVector reductionIteratorTypes; + SmallVector exprs; + + for (int64_t i : llvm::seq(0, intermRank)) { + if (reductionDimsSet.contains(i)) { + reductionIteratorTypes.push_back(utils::IteratorType::reduction); + } else { + exprs.push_back(b.getAffineDimExpr(i)); + reductionIteratorTypes.push_back(utils::IteratorType::parallel); + } + } + + AffineMap outputMap = + AffineMap::get(intermRank, 0, exprs, op->getContext()); + SmallVector reductionMaps = {inputMap, outputMap}; + + SmallVector combinerOps; + matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps); + Operation *reductionOp = combinerOps[0]; + + auto reduction = b.create( + loc, op->getResultTypes(), ValueRange({partialReduce[0]}), + linalgOp.getDpsInits(), reductionMaps, reductionIteratorTypes, + [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { + Operation *clonedReductionOp = b.clone(*reductionOp); + clonedReductionOp->setOperand(0, inputs[0]); + clonedReductionOp->setOperand(1, inputs[1]); + b.create(loc, clonedReductionOp->getResult(0)); + }); + return reduction.getOperation(); + } +}; + +std::tuple, LoopIndexToRangeIndexMap> +makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, + ArrayRef allShapeSizes, + ArrayRef allTileSizes) { + assert(allTileSizes.size() == map.getNumResults()); + // Apply `map` to get shape sizes in loop order. + SmallVector shapeSizes = + makeComposedFoldedMultiResultAffineApply(b, loc, map, allShapeSizes); + SmallVector tileSizes(allTileSizes.begin(), allTileSizes.end()); + + // Traverse the tile sizes, which are in loop order, erase zeros everywhere. + LoopIndexToRangeIndexMap loopIndexToRangeIndex; + for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) { + if (getConstantIntValue(tileSizes[idx - zerosCount]) == + static_cast(0)) { + shapeSizes.erase(shapeSizes.begin() + idx - zerosCount); + tileSizes.erase(tileSizes.begin() + idx - zerosCount); + ++zerosCount; + continue; + } + loopIndexToRangeIndex[idx] = idx - zerosCount; + } + + // Create a new range with the applied tile sizes. + SmallVector res; + for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) + res.push_back(Range{b.getIndexAttr(0), shapeSizes[idx], tileSizes[idx]}); + return std::make_tuple(res, loopIndexToRangeIndex); +} + +void transformIndexOps(RewriterBase &b, LinalgOp op, + SmallVectorImpl &ivs, + const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { + SmallVector allIvs(op.getNumLoops(), nullptr); + for (auto en : enumerate(allIvs)) { + auto rangeIndex = loopIndexToRangeIndex.find(en.index()); + if (rangeIndex == loopIndexToRangeIndex.end()) + continue; + en.value() = ivs[rangeIndex->second]; + } + offsetIndices(b, op, getAsOpFoldResult(allIvs)); +} + +/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less +/// than `iterationSize`. +static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, + OpFoldResult numThreads, + OpFoldResult iterationSize) { + std::optional tileSizeConst = getConstantIntValue(tileSize); + std::optional numThreadsConst = getConstantIntValue(numThreads); + std::optional iterSizeConst = getConstantIntValue(iterationSize); + if (!tileSizeConst || !numThreadsConst || !iterSizeConst) + return false; + return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; +} + +/// Build an `affine_max` of all the `vals`. +static OpFoldResult buildMax(OpBuilder &b, Location loc, + ArrayRef vals) { + return affine::makeComposedFoldedAffineMax( + b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), + vals); +} + +/// Build an `affine_min` of all the `vals`. +static OpFoldResult buildMin(OpBuilder &b, Location loc, + ArrayRef vals) { + return affine::makeComposedFoldedAffineMin( + b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), + vals); +} + +/// Fill out the `tiledOffsets` and `tiledSizes` to be used to tile to a given +/// number of threads. +static void calculateTileOffsetsAndSizes( + RewriterBase &b, Location loc, scf::ForallOp forallOp, + ArrayRef numThreads, SmallVector loopRanges, + bool omitTileOffsetBoundsCheck, + std::optional> nominalTileSizes, + SmallVector &tiledOffsets, + SmallVector &tiledSizes) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(forallOp.getBody(0)); + + ValueRange threadIds = forallOp.getInductionVars(); + SmallVector nonZeroNumThreads = + llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { + return !isConstantIntValue(ofr, 0); + })); + int64_t nLoops = loopRanges.size(); + tiledOffsets.reserve(nLoops); + tiledSizes.reserve(nLoops); + for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) { + bool overflow = loopIdx >= numThreads.size(); + bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0); + // Degenerate case: take the whole domain. + if (overflow || isZero) { + tiledOffsets.push_back(loopRanges[loopIdx].offset); + tiledSizes.push_back(loopRanges[loopIdx].size); + continue; + } + + // Tiled case: compute the offset and size. + AffineExpr i, j, m, n, o; + bindDims(b.getContext(), i, j); + bindSymbols(b.getContext(), m, n, o); + OpFoldResult size = loopRanges[loopIdx].size; + OpFoldResult offset = loopRanges[loopIdx].offset; + OpFoldResult threadId = threadIds[threadIdIdx]; + // Symbolic fixed max size per thread. + // TODO: floor + 0/1 depending on case for better load-balancing. + OpFoldResult tileSizePerThread = + nominalTileSizes.has_value() + ? (*nominalTileSizes)[loopIdx] + : makeComposedFoldedAffineApply( + b, loc, m.ceilDiv(n), + ArrayRef{size, nonZeroNumThreads[threadIdIdx]}); + // Dynamic offset shifted by threadId * maxSizePerThread. + OpFoldResult offsetPerThread = makeComposedFoldedAffineApply( + b, loc, i + j * m, {offset, threadId, tileSizePerThread}); + // Dynamic upper-bound depending on the threadId. + OpFoldResult residualTileSize = makeComposedFoldedAffineApply( + b, loc, i + j * m - n, + {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size}); + if (!isConstantIntValue(residualTileSize, 0)) { + OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply( + b, loc, -i + m, {offsetPerThread, size}); + tileSizePerThread = + buildMin(b, loc, {sizeMinusOffsetPerThread, tileSizePerThread}); + } + + tiledOffsets.push_back(offsetPerThread); + // TODO: if tileSizePerThread <= 0 early exit. + if (!omitTileOffsetBoundsCheck && + !canOmitTileOffsetInBoundsCheck(tileSizePerThread, + nonZeroNumThreads[threadIdIdx], size)) + tileSizePerThread = + buildMax(b, loc, {b.getIndexAttr(0), tileSizePerThread}); + + tiledSizes.push_back(tileSizePerThread); + ++threadIdIdx; + } +} + +template +static FailureOr +tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef tileSizes, + const LinalgTilingOptions &options) { + OpBuilder::InsertionGuard g(b); + + auto nLoops = op.getNumLoops(); + // Initial tile sizes may be too big, only take the first nLoops. + tileSizes = tileSizes.take_front(nLoops); + + if (llvm::all_of(tileSizes, [](OpFoldResult ofr) { + return getConstantIntValue(ofr) == static_cast(0); + })) { + TiledLinalgOp tiledOp; + tiledOp.op = cast(b.clone(*op.getOperation())); + tiledOp.tensorResults.assign(tiledOp.op->result_begin(), + tiledOp.op->result_end()); + return tiledOp; + } + + // 1. Build the tiled loop ranges. + SmallVector allShapeSizes = + op.createFlatListOfOperandDims(b, op.getLoc()); + AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); + if (!shapeSizesToLoopsMap) + return failure(); + + auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges( + b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); + + SmallVector iteratorTypes; + for (const auto &attr : enumerate(op.getIteratorTypesArray())) { + if (loopIndexToRangeIndex.count(attr.index())) + iteratorTypes.push_back(attr.value()); + } + // If interchangeVector is empty, use the identity. Build the permutation map + // otherwise. + auto invPermutationMap = + AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext()); + if (!options.interchangeVector.empty()) { + // Based on the pruned iterations (due to zero tile size), recompute the + // interchange vector. + SmallVector interchangeVector; + interchangeVector.reserve(options.interchangeVector.size()); + for (auto pos : options.interchangeVector) { + auto it = loopIndexToRangeIndex.find(pos); + if (it == loopIndexToRangeIndex.end()) + continue; + interchangeVector.push_back(it->second); + } + // Interchange vector is guaranteed to be a permutation, + // `inversePermutation` must succeed. + invPermutationMap = inversePermutation( + AffineMap::getPermutationMap(interchangeVector, b.getContext())); + assert(invPermutationMap); + SmallVector permutation(interchangeVector.begin(), + interchangeVector.end()); + applyPermutationToVector(loopRanges, permutation); + applyPermutationToVector(iteratorTypes, permutation); + } + + // Handle distribution. Create a vector of the same size of loops that are to + // be tiled. + SmallVector procInfo; + if (options.distribution) { + procInfo.resize( + iteratorTypes.size(), + linalg::ProcInfo{nullptr, nullptr, linalg::DistributionMethod::None}); + // Collect loop ranges of tiled loops, loops that are parallel. + SmallVector parallelLoopRanges; + for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) { + if (!isParallelIterator(iteratorType.value())) + break; + parallelLoopRanges.push_back(loopRanges[iteratorType.index()]); + } + auto returnedProcInfo = + options.distribution->procInfo(b, op.getLoc(), parallelLoopRanges); + unsigned procIdIdx = 0; + // Update the distribution information for the loops. + for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) { + if (!isParallelIterator(iteratorType.value())) + break; + procInfo[iteratorType.index()] = returnedProcInfo[procIdIdx++]; + } + } + + // 2. Create the tiled loops. + LinalgOp res = op; + SmallVector ivs, tensorResults; + auto tiledLoopBodyBuilder = + [&](OpBuilder &builder, Location loc, ValueRange localIvs, + ValueRange operandValuesToUse) -> scf::ValueVector { + ivs.assign(localIvs.begin(), localIvs.end()); + + // When an `interchangeVector` is present, it has been applied to the + // loop ranges and the iterator types. Apply its inverse to the + // resulting loop `ivs` to match the op definition. + SmallVector interchangedIvs; + if (!options.interchangeVector.empty()) { + for (AffineExpr result : invPermutationMap.getResults()) + interchangedIvs.push_back( + ivs[cast(result).getPosition()]); + } else { + interchangedIvs.assign(ivs.begin(), ivs.end()); + } + + // Tile the `operandValuesToUse` that either match the `op` operands + // themselves or the tile loop arguments forwarding them. + assert(operandValuesToUse.size() == + static_cast(op->getNumOperands()) && + "expect the number of operands and inputs and outputs to match"); + SmallVector valuesToTile = operandValuesToUse; + SmallVector sizeBounds = + makeComposedFoldedMultiResultAffineApply(b, loc, shapeSizesToLoopsMap, + allShapeSizes); + SmallVector tiledOperands = makeTiledShapes( + b, loc, op, valuesToTile, getAsOpFoldResult(interchangedIvs), tileSizes, + sizeBounds, + /*omitPartialTileCheck=*/false); + + SmallVector resultTensorTypes = + getTensorOutputTypes(op, tiledOperands); + res = clone(b, op, resultTensorTypes, tiledOperands); + tensorResults = + insertSlicesBack(builder, loc, op, tiledOperands, res->getResults()); + return scf::ValueVector(tensorResults.begin(), tensorResults.end()); + }; + GenerateLoopNest::doit(b, op.getLoc(), loopRanges, op, iteratorTypes, + tiledLoopBodyBuilder, procInfo); + + // 3. Transform IndexOp results w.r.t. the tiling. + linalg::transformIndexOps(b, res, ivs, loopIndexToRangeIndex); + + // 4. Gather the newly created loops and return them with the new op. + SmallVector loops; + loops.reserve(ivs.size()); + for (auto iv : ivs) { + if (isa(iv)) { + loops.push_back(cast(iv).getOwner()->getParentOp()); + assert(loops.back() && "no owner found for induction variable!"); + } else { + // TODO: Instead of doing this, try to recover the ops used instead of the + // loop. + loops.push_back(nullptr); + } + } + + // 5. Get the tensor results from the outermost loop if available. Otherwise + // use the previously captured `tensorResults`. + Operation *outermostLoop = nullptr; + for (Operation *loop : loops) + if ((outermostLoop = loop)) + break; + + return TiledLinalgOp{ + res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; +} + +FailureOr tileReductionUsingForall( + RewriterBase &b, PartialReductionOpInterface op, + ArrayRef threadNums, ArrayRef tileSizes, + ArrayRef newParallelDims, std::optional mapping) { + Location loc = op.getLoc(); + OpBuilder::InsertionGuard g(b); + + // Ops implementing PartialReductionOpInterface are expected to implement + // TilingInterface. + // TODO: proper core mechanism to tie interfaces together. + auto tilingInterfaceOp = cast(op.getOperation()); + + // Ops implementing PartialReductionOpInterface are not necessarily expected + // to implement TilingInterface.. This cast is unsafe atm. + // TODO: proper core mechanism to tie interfaces together. + // TODO: this function requires a pair of interfaces .. + auto destinationStyleOp = + dyn_cast(op.getOperation()); + if (!destinationStyleOp) + return b.notifyMatchFailure(op, "not a destination style op"); + + // Actually this only work for Linalg ops atm. + auto linalgOp = dyn_cast(op.getOperation()); + if (!linalgOp) + return b.notifyMatchFailure(op, "not a linalg op"); + + SmallVector iterationDomain = tilingInterfaceOp.getIterationDomain(b); + if (op->getNumResults() != 1) + return b.notifyMatchFailure( + op, "don't support ops with multiple results for now"); + + SmallVector iterators = + tilingInterfaceOp.getLoopIteratorTypes(); + SmallVector redDims; + for (auto [idx, iteratorType] : + llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { + if (iteratorType == utils::IteratorType::reduction) + redDims.push_back(idx); + } + + SmallVector numThreads(threadNums.begin(), threadNums.end()); + if (numThreads.empty()) { + SmallVector loopRanges = tilingInterfaceOp.getIterationDomain(b); + unsigned nLoops = loopRanges.size(); + numThreads.reserve(nLoops); + AffineExpr s0, s1; + bindSymbols(b.getContext(), s0, s1); + AffineExpr divExpr = s0.ceilDiv(s1); + for (const auto &it : llvm::zip(tileSizes, loopRanges)) { + OpFoldResult numTiles = std::get<0>(it); + if (!isConstantIntValue(numTiles, 0)) + numTiles = makeComposedFoldedAffineApply( + b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)}); + numThreads.push_back(numTiles); + } + } + + if (!tileSizes.empty() && tileSizes.size() != numThreads.size()) + return b.notifyMatchFailure(op, "if tile sizes are present it must have as " + "many elements as number of threads"); + + if ((unsigned)redDims.front() >= numThreads.size()) + return b.notifyMatchFailure( + op, "reduction dimension must be mapped to threads"); + SmallVector constantNewParallelDims; + for (auto dim : newParallelDims) { + if (getConstantIntValue(dim) == std::nullopt) + return b.notifyMatchFailure( + op, "Expected new parallel dims to be constant integers."); + constantNewParallelDims.push_back(*getConstantIntValue(dim)); + } + if (newParallelDims.empty()) + constantNewParallelDims = redDims; + if (constantNewParallelDims.size() != redDims.size()) + return b.notifyMatchFailure( + op, "reduction dimension must be mapped to new parallel dims"); + // 1. Create the inital tensor value. + FailureOr> maybeInitTensors = + LinalgOpPartialReductionInterface:: + generateInitialTensorForPartialReduction( + op, b, loc, numThreads, redDims, constantNewParallelDims); + if (failed(maybeInitTensors)) + return b.notifyMatchFailure( + op, "Failed to create inital tensors for partial reduction"); + SmallVector &initTensors = maybeInitTensors.value(); + + // Gather destination tensors. + SmallVector dest; + if (failed(tensor::getOrCreateDestinations(b, loc, op, dest))) + return b.notifyMatchFailure(op, "failed to get destination tensors"); + + Operation *tiledOp = nullptr; + SmallVector nonZeroNumThreads = + llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { + return !isConstantIntValue(ofr, 0); + })); + SmallVector materializedNonZeroNumThreads = + getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads); + // 2. Create the ForallOp with an empty region. + scf::ForallOp forallOp = b.create( + loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors, + mapping); + // 3. Calculate the tile offsets and sizes for the subsequent loop that will + // be nested under `forallOp`. + SmallVector tiledOffsets, tiledSizes; + std::optional> nominalTileSizes = std::nullopt; + if (!tileSizes.empty() && threadNums.empty()) { + nominalTileSizes = tileSizes; + } + calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, iterationDomain, + /*omitTileOffsetBoundsCheck =*/false, + /*nominalTileSizes=*/nominalTileSizes, + tiledOffsets, tiledSizes); + // 4. Clone the tileable op and update its destination operands to use the + // output bbArgs of the ForallOp. + SmallVector tilingResults; + ArrayRef destBbArgs = forallOp.getRegionIterArgs(); + { + // 4.a. RAII guard, inserting within forallOp, before terminator. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(forallOp.getTerminator()); + + SmallVector tiledDpsInitOperands; + for (Value initOperand : destinationStyleOp.getDpsInits()) { + auto *it = llvm::find(dest, initOperand); + assert(it != dest.end() && "dest operand not found in dest"); + unsigned destNum = std::distance(dest.begin(), it); + SmallVector strides(numThreads.size(), b.getIndexAttr(1)); + SmallVector outOffsets(numThreads.size(), + b.getIndexAttr(0)); + SmallVector sizes = tiledSizes; + + auto currentReductionIdx = 0; + for (const auto &iteratorType : llvm::enumerate(tiledSizes)) { + if (llvm::find(constantNewParallelDims, iteratorType.index()) != + constantNewParallelDims.end()) { + sizes[iteratorType.index()] = b.getIndexAttr(1); + currentReductionIdx++; + } else { + if (llvm::find(redDims, iteratorType.index() - currentReductionIdx) != + redDims.end()) { + currentReductionIdx--; + } + sizes[iteratorType.index()] = + tiledSizes[iteratorType.index() - currentReductionIdx]; + } + } + auto nonZeroDimIdx = 0; + for (const auto &iteratorType : llvm::enumerate(numThreads)) { + if (!isConstantIntValue(iteratorType.value(), 0)) { + outOffsets[constantNewParallelDims[nonZeroDimIdx]] = + forallOp.getInductionVars()[nonZeroDimIdx]; + nonZeroDimIdx++; + } + } + // TODO: use SubsetExtractOpInterface once it is available. + tiledDpsInitOperands.push_back(b.create( + loc, cast(initOperand.getType()), + destBbArgs[destNum], outOffsets, sizes, strides)); + } + + // 4.b. Clone the op and update init operands. + // We cannot use a IRMapping here because it can replace + // different OpOperands with the same value. + Operation *clonedOp = b.clone(*op.getOperation()); + b.modifyOpInPlace(clonedOp, [&]() { + for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal( + cast(clonedOp).getDpsInitsMutable(), + tiledDpsInitOperands)) { + initOperandPtr.set(tiledInitValue); + } + }); + // 5. Tile the cloned op and delete the clone. + if (tileSizes.empty() || threadNums.empty()) { + FailureOr tilingResult = + cast(clonedOp).getTiledImplementation( + b, tiledOffsets, tiledSizes); + if (failed(tilingResult)) + return clonedOp->emitError("Failed to tile op: "); + if (tilingResult->tiledOps.size() != 1) { + return clonedOp->emitError("expected a single produced tiled op, got ") + << tilingResult->tiledOps.size(); + } + tiledOp = tilingResult->tiledOps.front(); + tilingResults = tilingResult->tiledValues; + } else { + LinalgTilingOptions options; + FailureOr maybeTiled = tileLinalgOpImpl( + b, cast(clonedOp), tileSizes, options); + if (failed(maybeTiled)) + return b.notifyMatchFailure(op, "failed tileLinalgOpImpl"); + + SmallVector ids = forallOp.getInductionVars(); + mapLoopToProcessorIds(cast(maybeTiled->loops.back()), ids, + materializedNonZeroNumThreads); + if (maybeTiled->loops.size() != 1) { + return clonedOp->emitError("expected a single produced loop"); + } + tiledOp = maybeTiled->op; + tilingResults = maybeTiled->loops.front()->getResults(); + } + + b.eraseOp(clonedOp); + } + + // 6. Insert the partial reductions back into a new tensor. + for (auto [index, result, bbArg] : llvm::zip( + llvm::seq(0, dest.size()), tilingResults, destBbArgs)) { + // 6.a. Partial subset information is inserted just before the terminator. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(forallOp.getTerminator()); + + SmallVector resultOffsets, resultSizes; + if (failed(tilingInterfaceOp.getResultTilePosition( + b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes))) + return op->emitOpError("output offsets couldn't be calculated"); + SmallVector resultOffsetsRank, resultSizesRank; + int64_t offIdx = 0; + int64_t sizeIdx = 0; + int64_t nonZeroDimIdx = 0; + for (int64_t i = 0, e = numThreads.size(); i < e; ++i) { + if (llvm::find(constantNewParallelDims, i) != + constantNewParallelDims.end()) { + resultOffsetsRank.push_back(forallOp.getInductionVars()[nonZeroDimIdx]); + resultSizesRank.push_back(b.getIndexAttr(1)); + nonZeroDimIdx++; + continue; + } + if (!isConstantIntValue(numThreads[i], 0)) { + nonZeroDimIdx++; + } + resultOffsetsRank.push_back(resultOffsets[offIdx++]); + resultSizesRank.push_back(resultSizes[sizeIdx++]); + } + SmallVector strides(resultSizesRank.size(), + b.getIndexAttr(1)); + + // 6.b. Parallel insertions are inserted at the end of the combining + // terminator. + b.setInsertionPointToEnd(forallOp.getTerminator().getBody()); + b.create( + loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides); + } + // 7. Merge the partial reductions. + b.setInsertionPointAfter(forallOp); + Operation *mergeOp = op.mergeReductions(b, loc, forallOp->getResults(), + constantNewParallelDims); + b.replaceOp(op, mergeOp->getResults()); + // 8. Return. + ForallReductionTilingResult results; + results.initialValues = initTensors; + results.loops = forallOp; + results.parallelTiledOp = tiledOp; + results.mergeOp = mergeOp; + return results; +} + +template +FailureOr static tileLinalgOpImpl( + RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + + if (!options.tileSizeComputationFunction) + return failure(); + + // Enforce the convention that "tiling by zero" skips tiling a particular + // dimension. This convention is significantly simpler to handle instead of + // adjusting affine maps to account for missing dimensions. + auto nLoops = op.getNumLoops(); + SmallVector tileSizeVector = + getAsOpFoldResult(options.tileSizeComputationFunction(b, op)); + if (tileSizeVector.size() < nLoops) { + tileSizeVector.append(nLoops - tileSizeVector.size(), b.getIndexAttr(0)); + } + + return tileLinalgOpImpl(b, op, tileSizeVector, options); +} + +FailureOr +tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, + ArrayRef numThreads, + ArrayRef tileSizes, + std::optional mapping) { + Location loc = op.getLoc(); + OpBuilder::InsertionGuard g(b); + + // Ops implementing PartialReductionOpInterface are expected to implement + // TilingInterface. + // TODO: proper core mechanism to tie interfaces together. + auto tilingInterfaceOp = cast(op.getOperation()); + + // Ops implementing PartialReductionOpInterface are not necessarily expected + // to implement TilingInterface.. This cast is unsafe atm. + // TODO: proper core mechanism to tie interfaces together. + // TODO: this function requires a pair of interfaces .. + auto destinationStyleOp = + dyn_cast(op.getOperation()); + if (!destinationStyleOp) + return b.notifyMatchFailure(op, "not a destination style op"); + + // Actually this only work for Linalg ops atm. + auto linalgOp = dyn_cast(op.getOperation()); + if (!linalgOp) + return b.notifyMatchFailure(op, "not a linalg op"); + + SmallVector iterationDomain = tilingInterfaceOp.getIterationDomain(b); + if (op->getNumResults() != 1) + return b.notifyMatchFailure( + op, "don't support ops with multiple results for now"); + + SmallVector iterators = + tilingInterfaceOp.getLoopIteratorTypes(); + SmallVector redDims; + for (auto [idx, iteratorType] : + llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { + if (iteratorType == utils::IteratorType::reduction) + redDims.push_back(idx); + } + bool hasReductionThreads = false; + for (auto dim : redDims) { + if (!isConstantIntValue(numThreads[dim], 0) && + !isConstantIntValue(numThreads[dim], 1)) { + hasReductionThreads = true; + break; + } + } + + if (!tileSizes.empty() && tileSizes.size() != numThreads.size()) + return b.notifyMatchFailure(op, "if tile sizes are present it must have as " + "many elements as number of threads"); + + if (redDims.front() >= numThreads.size()) + return b.notifyMatchFailure( + op, "reduction dimension must be mapped to threads"); + + // 1. Create the inital tensor value. + FailureOr> maybeInitTensors; + SmallVector initTensors; + if (hasReductionThreads) { + maybeInitTensors = LinalgOpPartialReductionInterface:: + generateInitialTensorForPartialReduction( + op, b, loc, numThreads, redDims, constantNewParallelDims); + if (failed(maybeInitTensors)) + return b.notifyMatchFailure( + op, "Failed to create inital tensors for partial reduction"); + initTensors = maybeInitTensors.value(); + } + + // Gather destination tensors. + SmallVector dest; + if (failed(tensor::getOrCreateDestinations(b, loc, op, dest))) + return b.notifyMatchFailure(op, "failed to get destination tensors"); + + Operation *tiledOp = nullptr; + + SmallVector nonZeroNumThreads = + llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { + return !isConstantIntValue(ofr, 0); + })); + SmallVector materializedNonZeroNumThreads = + getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads); + + // 2. Create the ForallOp with an empty region. + scf::ForallOp forallOp = b.create( + loc, getAsOpFoldResult(materializedNonZeroNumThreads), + hasReductionThreads ? initTensors : dest, mapping); + // 3. Calculate the tile offsets and sizes for the subsequent loop that will + // be nested under `forallOp`. + SmallVector tiledOffsets, tiledSizes; + calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, iterationDomain, + /*omitTileOffsetBoundsCheck =*/false, + /*nominalTileSizes=*/tileSizes, tiledOffsets, + tiledSizes); + + // 4. Clone the tileable op and update its destination operands to use the + // output bbArgs of the ForallOp. + SmallVector tilingResults; + ArrayRef destBbArgs = forallOp.getRegionIterArgs(); + { + // 4.a. RAII guard, inserting within forallOp, before terminator. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(forallOp.getTerminator()); + + SmallVector tiledDpsInitOperands; + for (Value initOperand : destinationStyleOp.getDpsInits()) { + if (hasReductionThreads) { + auto *it = llvm::find(dest, initOperand); + assert(it != dest.end() && "dest operand not found in dest"); + unsigned destNum = std::distance(dest.begin(), it); + SmallVector strides(numThreads.size(), b.getIndexAttr(1)); + SmallVector outOffsets(numThreads.size(), + b.getIndexAttr(0)); + SmallVector sizes; + for (auto s : + cast(destBbArgs[destNum].getType()).getShape()) { + sizes.emplace_back(getAsIndexOpFoldResult(b.getContext(), (int)s)); + } + for (auto dim : redDims) { + sizes[dim] = b.getIndexAttr(1); + } + + auto nonZeroDimIdx = 0; + for (auto dim = 0; dim < numThreads.size(); dim++) { + if (!isConstantIntValue(numThreads[dim], 0)) { + if (llvm::find(redDims, dim) != redDims.end()) + outOffsets[dim] = forallOp.getInductionVars()[nonZeroDimIdx]; + nonZeroDimIdx++; + } + } + // TODO: use SubsetExtractOpInterface once it is available. + tiledDpsInitOperands.push_back(b.create( + loc, cast(initOperand.getType()), + destBbArgs[destNum], outOffsets, sizes, strides)); + } else { + tiledDpsInitOperands.push_back(initOperand); + } + } + + // 4.b. Clone the op and update init operands. + // We cannot use a IRMapping here because it can replace + // different OpOperands with the same value. + Operation *clonedOp = b.clone(*op.getOperation()); + b.modifyOpInPlace(clonedOp, [&]() { + for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal( + cast(clonedOp).getDpsInitsMutable(), + tiledDpsInitOperands)) { + initOperandPtr.set(tiledInitValue); + } + }); + + // 5. Tile the cloned op and delete the clone. + FailureOr tilingResult = + cast(clonedOp).getTiledImplementation(b, tiledOffsets, + tiledSizes); + if (failed(tilingResult)) + return clonedOp->emitError("Failed to tile op: "); + if (tilingResult->tiledOps.size() != 1) { + return clonedOp->emitError("expected a single produced tiled op, got ") + << tilingResult->tiledOps.size(); + } + tiledOp = tilingResult->tiledOps.front(); + tilingResults = tilingResult->tiledValues; + + b.eraseOp(clonedOp); + } + + // 6. Insert the partial reductions back into a new tensor. + for (auto [index, result, bbArg] : llvm::zip( + llvm::seq(0, dest.size()), tilingResults, destBbArgs)) { + // 6.a. Partial subset information is inserted just before the terminator. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(forallOp.getTerminator()); + + SmallVector resultOffsets, resultSizes; + if (failed(tilingInterfaceOp.getResultTilePosition( + b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes))) + return op->emitOpError("output offsets couldn't be calculated"); + SmallVector resultOffsetsRank, resultSizesRank; + int64_t offIdx = 0; + int64_t sizeIdx = 0; + int64_t nonZeroDimIdx = 0; + for (int64_t i = 0; i < numThreads.size(); ++i) { + if (llvm::find(redDims, i) != redDims.end()) { + if (hasReductionThreads) { + resultOffsetsRank.push_back( + forallOp.getInductionVars()[nonZeroDimIdx]); + resultSizesRank.push_back(b.getIndexAttr(1)); + } + nonZeroDimIdx++; + continue; + } + if (!isConstantIntValue(numThreads[i], 0)) { + nonZeroDimIdx++; + } + resultOffsetsRank.push_back(resultOffsets[offIdx++]); + resultSizesRank.push_back(resultSizes[sizeIdx++]); + } + SmallVector strides(resultSizesRank.size(), + b.getIndexAttr(1)); + + // 6.b. Parallel insertions are inserted at the end of the combining + // terminator. + b.setInsertionPointToEnd(forallOp.getTerminator().getBody()); + b.create( + loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides); + } + + // 7. Merge the partial reductions. + Operation *mergeOp = nullptr; + b.setInsertionPointAfter(forallOp); + if (hasReductionThreads) { + Operation *mergeOp = + op.mergeReductions(b, loc, forallOp->getResults(), redDims); + b.replaceOp(op, mergeOp->getResults()); + } else { + b.replaceOp(op, forallOp->getResults()); + } + + // 8. Return. + ForallReductionTilingResult results; + results.initialValues = initTensors; + results.loops = forallOp; + results.parallelTiledOp = tiledOp; + results.mergeOp = mergeOp; + return results; +} + +} // namespace linalgX +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/Tiling.hpp b/lib/gc/Transforms/Tiling.hpp new file mode 100644 index 000000000..e46720f93 --- /dev/null +++ b/lib/gc/Transforms/Tiling.hpp @@ -0,0 +1,56 @@ +//===- Tilig.hpp - Tiling ops using TilingInterface --*- C++ -*-===// +// +// This file is only temporarily used to extend upstream or upcoming utility in +// TilingInterface, which finally aims for upstream. +// +//===----------------------------------------------------------------------===// + +#ifndef TEMPORARY_TILEUSINGINTERFACE_X_H +#define TEMPORARY_TILEUSINGINTERFACE_X_H + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/CommandLine.h" +#include +#include +namespace mlir { +namespace linalgX { + +FailureOr tileReductionUsingForall( + RewriterBase &b, PartialReductionOpInterface op, + ArrayRef threadNums, ArrayRef tileSizes, + ArrayRef newParallelDims, std::optional mapping); + +FailureOr +tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, + ArrayRef numThreads, + ArrayRef tileSizes, + std::optional mapping); + +} // namespace linalgX +} // namespace mlir + +#endif \ No newline at end of file From 90c2b4b9dfaf7b938ed01fbd9c4545001ee2bb71 Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Wed, 22 May 2024 23:08:51 -0700 Subject: [PATCH 03/21] add easy builder support --- include/gc/Dialect/Arith/Utils/EasyBuild.h | 443 +++++++++++++++++++++ include/gc/IR/EasyBuild.h | 108 +++++ include/gc/IR/EasyBuildSCF.h | 186 +++++++++ 3 files changed, 737 insertions(+) create mode 100644 include/gc/Dialect/Arith/Utils/EasyBuild.h create mode 100644 include/gc/IR/EasyBuild.h create mode 100644 include/gc/IR/EasyBuildSCF.h diff --git a/include/gc/Dialect/Arith/Utils/EasyBuild.h b/include/gc/Dialect/Arith/Utils/EasyBuild.h new file mode 100644 index 000000000..2a45ffcee --- /dev/null +++ b/include/gc/Dialect/Arith/Utils/EasyBuild.h @@ -0,0 +1,443 @@ +//===- EasyBuild.h - Easy Arith IR Builder utilities ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines the easy-build utilities for arith dialects. It +// provides the utility functions, classes and operators to make it easir to +// program arith dialect operations in C++ +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITH_UTILS_EASYBUILD_H +#define MLIR_DIALECT_ARITH_UTILS_EASYBUILD_H +#include "gc/IR/EasyBuild.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Builders.h" +#include +#include +#include + +namespace mlir { +namespace easybuild { + +namespace impl { + +template struct ToFloatType {}; + +template <> struct ToFloatType<4> { + using type = Float32Type; +}; +template <> struct ToFloatType<8> { + using type = Float64Type; +}; + +inline Type getElementType(Value v) { + auto type = v.getType(); + if (type.isa() || type.isa()) { + type = type.cast().getElementType(); + } + return type; +} + +} // namespace impl + +struct EBUnsigned; + +struct EBArithValue : public EBValue { + template + static T toIndex(const impl::StatePtr &state, uint64_t v); + + template + static auto wrapOrFail(const impl::StatePtr &state, T &&v); + + template static auto wrap(const impl::StatePtr &state, T &&v) { + auto ret = wrapOrFail(state, std::forward(v)); + if (failed(ret)) { + llvm_unreachable("Bad wrap"); + } + return *ret; + } + +protected: + using EBValue::EBValue; +}; + +struct EBUnsigned : public EBArithValue { + static FailureOr wrapOrFail(const impl::StatePtr &state, + Value v) { + auto type = impl::getElementType(v); + if (type.isUnsignedInteger() || type.isSignlessInteger() || + type.isIndex()) { + return EBUnsigned{state, v}; + } + return failure(); + } + static FailureOr wrapOrFail(const impl::StatePtr &state, + const OpFoldResult &v) { + if (v.is()) { + return wrapOrFail(state, v.get()); + } + auto attr = v.get(); + if (auto val = attr.dyn_cast()) { + if (val.getType().isIndex()) + return EBUnsigned{state, state->builder.create( + state->loc, val.getInt())}; + else + return EBUnsigned{state, state->builder.create( + state->loc, val.getInt(), val.getType())}; + } + return failure(); + } + friend struct EBArithValue; + friend struct OperatorHandlers; + +protected: + using EBArithValue::EBArithValue; +}; + +struct EBSigned : EBArithValue { + static FailureOr wrapOrFail(const impl::StatePtr &state, Value v) { + auto type = impl::getElementType(v); + if (type.isSignedInteger() || type.isSignlessInteger()) { + return EBSigned{state, v}; + } + return failure(); + } + static FailureOr wrapOrFail(const impl::StatePtr &state, + const OpFoldResult &v) { + if (v.is()) { + return wrapOrFail(state, v.get()); + } + auto attr = v.get(); + if (auto val = attr.dyn_cast()) + return EBSigned{state, state->builder.create( + state->loc, val.getInt(), val.getType())}; + return failure(); + } + friend struct EBArithValue; + friend struct OperatorHandlers; + +protected: + using EBArithValue::EBArithValue; +}; + +struct EBFloatPoint : EBArithValue { + static FailureOr wrapOrFail(const impl::StatePtr &state, + Value v) { + auto type = impl::getElementType(v); + if (type.isa()) { + return EBFloatPoint{state, v}; + } + return failure(); + } + static FailureOr wrapOrFail(const impl::StatePtr &state, + const OpFoldResult &v) { + if (v.is()) { + return wrapOrFail(state, v.get()); + } + auto attr = v.get(); + if (auto val = attr.dyn_cast()) + return EBFloatPoint{state, state->builder.create( + state->loc, val.getValue(), + val.getType().cast())}; + return failure(); + } + friend struct EBArithValue; + friend struct OperatorHandlers; + +protected: + using EBArithValue::EBArithValue; +}; + +template +inline T EBArithValue::toIndex(const impl::StatePtr &state, uint64_t v) { + return EBUnsigned{ + state, state->builder.create(state->loc, v)}; +} + +template +inline auto EBArithValue::wrapOrFail(const impl::StatePtr &state, T &&v) { + using DT = std::decay_t; + static_assert(std::is_arithmetic_v
, "Expecting arithmetic types"); + if constexpr (std::is_same_v) { + if (state->u64AsIndex) { + return FailureOr{toIndex(state, v)}; + } + } + + if constexpr (std::is_same_v) { + return FailureOr{ + EBUnsigned{state, state->builder.create( + state->loc, static_cast(v), 1)}}; + } else if constexpr (std::is_integral_v
) { + if constexpr (!std::is_signed_v
) { + return FailureOr{EBUnsigned{ + state, state->builder.create( + state->loc, static_cast(v), sizeof(T) * 8)}}; + } else { + return FailureOr{EBSigned{ + state, state->builder.create( + state->loc, static_cast(v), sizeof(T) * 8)}}; + } + } else { + using DType = typename impl::ToFloatType::type; + return FailureOr{ + EBFloatPoint{state, state->builder.create( + state->loc, APFloat{v}, + DType::get(state->builder.getContext()))}}; + } +} + +struct OperatorHandlers { + template + static V handleBinary(const V &a, const V &b) { + assert(a.builder == b.builder); + return {a.builder, + a.builder->builder.template create(a.builder->loc, a.v, b.v)}; + } + + template + static V handleBinaryConst(const V &a, const T2 &b) { + return handleBinary(a, EBArithValue::wrap(a.builder, b)); + } + + template + static V handleBinaryConst(const T2 &a, const V &b) { + return handleBinary(EBArithValue::wrap(b.builder, a), b); + } + + template + static EBUnsigned handleCmp(const V &a, const V &b, Pred predicate) { + assert(a.builder == b.builder); + return {a.builder, a.builder->builder.template create( + a.builder->loc, predicate, a.v, b.v)}; + } + + template + static EBUnsigned handleCmpConst(const V &a, const T2 &b, Pred predicate) { + return handleCmp(a, EBArithValue::wrap(a.builder, b), predicate); + } + + template + static EBUnsigned handleCmpConst(const T2 &a, const V &b, Pred predicate) { + return handleCmp(EBArithValue::wrap(b.builder, a), b, predicate); + } + + template + static T create(const impl::StatePtr &state, Args &&...v) { + return {state, + state->builder.create(state->loc, std::forward(v)...)}; + } +}; + +#define DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, OPCLASS, TYPE) \ + inline TYPE operator OP(const TYPE &a, const TYPE &b) { \ + return OperatorHandlers::handleBinary(a, b); \ + } \ + template inline TYPE operator OP(const TYPE &a, T b) { \ + return OperatorHandlers::handleBinaryConst(a, b); \ + } \ + template inline TYPE operator OP(T a, const TYPE &b) { \ + return OperatorHandlers::handleBinaryConst(a, b); \ + } + +#define DEF_EASYBUILD_BINARY_OPERATOR(OP, SIGNED, UNSIGNED, FLOAT) \ + DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, SIGNED, EBSigned) \ + DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, UNSIGNED, EBUnsigned) \ + DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, FLOAT, EBFloatPoint) + +DEF_EASYBUILD_BINARY_OPERATOR(+, arith::AddIOp, arith::AddIOp, arith::AddFOp) +DEF_EASYBUILD_BINARY_OPERATOR(-, arith::SubIOp, arith::SubIOp, arith::SubFOp) +DEF_EASYBUILD_BINARY_OPERATOR(*, arith::MulIOp, arith::MulIOp, arith::MulFOp) +DEF_EASYBUILD_BINARY_OPERATOR(/, arith::DivSIOp, arith::DivUIOp, arith::DivFOp) +DEF_EASYBUILD_BINARY_OPERATOR(%, arith::RemSIOp, arith::RemUIOp, arith::RemFOp) + +#undef DEF_EASYBUILD_BINARY_OPERATOR +#define DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(OP, SIGNED, UNSIGNED) \ + DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, SIGNED, EBSigned) \ + DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, UNSIGNED, EBUnsigned) + +DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(>>, arith::ShRSIOp, arith::ShRUIOp) +DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(<<, arith::ShLIOp, arith::ShLIOp) +DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(&, arith::AndIOp, arith::AndIOp) +DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(|, arith::OrIOp, arith::OrIOp) +DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(^, arith::XOrIOp, arith::XOrIOp) + +#undef DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT +#undef DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE + +inline EBFloatPoint operator-(const EBFloatPoint &a) { + return OperatorHandlers::create(a.builder, a.v); +} + +#define DEF_EASYBUILD_CMP_OPERATOR(OP, OPCLASS, TYPE, PRED) \ + EBUnsigned operator OP(const TYPE &a, const TYPE &b) { \ + return OperatorHandlers::handleCmp(a, b, PRED); \ + } \ + template EBUnsigned operator OP(const TYPE &a, T b) { \ + return OperatorHandlers::handleCmpConst(a, b, PRED); \ + } \ + template EBUnsigned operator OP(T a, const TYPE &b) { \ + return OperatorHandlers::handleCmpConst(a, b, PRED); \ + } + +DEF_EASYBUILD_CMP_OPERATOR(<, arith::CmpIOp, EBUnsigned, + arith::CmpIPredicate::ult) +DEF_EASYBUILD_CMP_OPERATOR(<=, arith::CmpIOp, EBUnsigned, + arith::CmpIPredicate::ule) +DEF_EASYBUILD_CMP_OPERATOR(>, arith::CmpIOp, EBUnsigned, + arith::CmpIPredicate::ugt) +DEF_EASYBUILD_CMP_OPERATOR(>=, arith::CmpIOp, EBUnsigned, + arith::CmpIPredicate::uge) +DEF_EASYBUILD_CMP_OPERATOR(==, arith::CmpIOp, EBUnsigned, + arith::CmpIPredicate::eq) +DEF_EASYBUILD_CMP_OPERATOR(!=, arith::CmpIOp, EBUnsigned, + arith::CmpIPredicate::ne) + +DEF_EASYBUILD_CMP_OPERATOR(<, arith::CmpIOp, EBSigned, + arith::CmpIPredicate::slt) +DEF_EASYBUILD_CMP_OPERATOR(<=, arith::CmpIOp, EBSigned, + arith::CmpIPredicate::sle) +DEF_EASYBUILD_CMP_OPERATOR(>, arith::CmpIOp, EBSigned, + arith::CmpIPredicate::sgt) +DEF_EASYBUILD_CMP_OPERATOR(>=, arith::CmpIOp, EBSigned, + arith::CmpIPredicate::sge) +DEF_EASYBUILD_CMP_OPERATOR(==, arith::CmpIOp, EBSigned, + arith::CmpIPredicate::eq) +DEF_EASYBUILD_CMP_OPERATOR(!=, arith::CmpIOp, EBSigned, + arith::CmpIPredicate::ne) + +DEF_EASYBUILD_CMP_OPERATOR(<, arith::CmpFOp, EBFloatPoint, + arith::CmpFPredicate::OLT) +DEF_EASYBUILD_CMP_OPERATOR(<=, arith::CmpFOp, EBFloatPoint, + arith::CmpFPredicate::OLE) +DEF_EASYBUILD_CMP_OPERATOR(>, arith::CmpFOp, EBFloatPoint, + arith::CmpFPredicate::OGT) +DEF_EASYBUILD_CMP_OPERATOR(>=, arith::CmpFOp, EBFloatPoint, + arith::CmpFPredicate::OGE) +DEF_EASYBUILD_CMP_OPERATOR(==, arith::CmpFOp, EBFloatPoint, + arith::CmpFPredicate::OEQ) +DEF_EASYBUILD_CMP_OPERATOR(!=, arith::CmpFOp, EBFloatPoint, + arith::CmpFPredicate::ONE) + +#undef DEF_EASYBUILD_CMP_OPERATOR + +namespace arithops { +inline EBFloatPoint castIntToFP(Type type, const EBSigned &v) { + return OperatorHandlers::create(v.builder, + type, v); +} + +inline EBFloatPoint castIntToFP(Type type, const EBUnsigned &v) { + return OperatorHandlers::create(v.builder, + type, v); +} + +template inline T castFPToInt(const EBFloatPoint &v) { + if constexpr (std::is_same_v) { + return OperatorHandlers::create(v.builder, v); + } else { + static_assert(std::is_same_v, + "Expecting EBUnsigned or EBSigned"); + return OperatorHandlers::create(v.builder, v); + } +} + +inline EBSigned ceildiv(const EBSigned &a, const EBSigned &b) { + return OperatorHandlers::create(a.builder, a, + b); +} + +inline EBUnsigned ceildiv(const EBUnsigned &a, const EBUnsigned &b) { + return OperatorHandlers::create(a.builder, a, + b); +} + +inline EBSigned floordiv(const EBSigned &a, const EBSigned &b) { + return OperatorHandlers::create(a.builder, a, + b); +} + +inline EBSigned extend(Type type, const EBSigned &a) { + return OperatorHandlers::create(a.builder, type, a); +} + +inline EBUnsigned extend(Type type, const EBUnsigned &a) { + return OperatorHandlers::create(a.builder, type, + a); +} + +inline EBFloatPoint extend(Type type, const EBFloatPoint &a) { + return OperatorHandlers::create(a.builder, type, + a); +} + +inline EBSigned trunc(Type type, const EBSigned &a) { + return OperatorHandlers::create(a.builder, type, + a); +} + +inline EBFloatPoint trunc(Type type, const EBFloatPoint &a) { + return OperatorHandlers::create(a.builder, + type, a); +} + +template +inline T select(const EBUnsigned &pred, const T &trueValue, + const T &falseValue) { + static_assert(std::is_base_of_v, + "Expecting T to be a subclass of EBArithValue"); + return OperatorHandlers::create(pred.builder, pred, + trueValue, falseValue); +} + +template +inline TyTo bitcast(Type type, const TyFrom &v) { + return OperatorHandlers::create(v.builder, type, v); +} + +inline EBSigned min(const EBSigned &a, const EBSigned &b) { + return OperatorHandlers::create(a.builder, a, b); +} + +inline EBSigned max(const EBSigned &a, const EBSigned &b) { + return OperatorHandlers::create(a.builder, a, b); +} + +inline EBUnsigned min(const EBUnsigned &a, const EBUnsigned &b) { + return OperatorHandlers::create(a.builder, a, b); +} + +inline EBUnsigned max(const EBUnsigned &a, const EBUnsigned &b) { + return OperatorHandlers::create(a.builder, a, b); +} + +inline EBFloatPoint minnum(const EBFloatPoint &a, const EBFloatPoint &b) { + return OperatorHandlers::create(a.builder, a, + b); +} + +inline EBFloatPoint maxnum(const EBFloatPoint &a, const EBFloatPoint &b) { + return OperatorHandlers::create(a.builder, a, + b); +} + +inline EBFloatPoint minimum(const EBFloatPoint &a, const EBFloatPoint &b) { + return OperatorHandlers::create(a.builder, a, + b); +} + +inline EBFloatPoint maximum(const EBFloatPoint &a, const EBFloatPoint &b) { + return OperatorHandlers::create(a.builder, a, + b); +} + +} // namespace arithops + +} // namespace easybuild +} // namespace mlir +#endif diff --git a/include/gc/IR/EasyBuild.h b/include/gc/IR/EasyBuild.h new file mode 100644 index 000000000..55fa5a06c --- /dev/null +++ b/include/gc/IR/EasyBuild.h @@ -0,0 +1,108 @@ +//===- EasyBuild.h - Easy IR Builder utilities ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines the easy-build utilities core data structures for +// building IR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_EASYBUILD_H +#define MLIR_IR_EASYBUILD_H +#include "mlir/IR/Builders.h" +#include +#include +#include + +namespace mlir { +namespace easybuild { + +namespace impl { +struct EasyBuildState { + OpBuilder &builder; + Location loc; + bool u64AsIndex; + EasyBuildState(OpBuilder &builder, Location loc, bool u64AsIndex) + : builder{builder}, loc{loc}, u64AsIndex{u64AsIndex} {} +}; + +using StatePtr = std::shared_ptr; + +} // namespace impl + +struct EBValue { + std::shared_ptr builder; + Value v; + EBValue() = default; + EBValue(const impl::StatePtr &builder, Value v) : builder{builder}, v{v} {} + Value get() const { return v; } + operator Value() const { return v; } + + static FailureOr wrapOrFail(const impl::StatePtr &state, Value v) { + return EBValue{state, v}; + } +}; + +struct EBArithValue; + +struct EasyBuilder { + std::shared_ptr builder; + EasyBuilder(OpBuilder &builder, Location loc, bool u64AsIndex = false) + : builder{ + std::make_shared(builder, loc, u64AsIndex)} {} + EasyBuilder(const std::shared_ptr &builder) + : builder{builder} {} + void setLoc(const Location &l) { builder->loc = l; } + + template auto wrapOrFail(V &&v) { + return W::wrapOrFail(builder, std::forward(v)); + } + + Operation *getLastOperaion() { + return &*(--builder->builder.getInsertionPoint()); + } + + template auto wrap(V &&v) { + auto ret = wrapOrFail(std::forward(v)); + if (failed(ret)) { + llvm_unreachable("wrap failed!"); + } + return *ret; + } + + template auto operator()(V &&v) { + if constexpr (std::is_convertible_v) { + return EBValue{builder, std::forward(v)}; + } else { + return wrap(std::forward(v)); + } + } + + template auto toIndex(uint64_t v) const { + return W::toIndex(builder, v); + } + + template + auto F(Args &&...v) { + if constexpr (std::is_same_v) { + builder->builder.create(builder->loc, std::forward(v)...); + } else { + return wrap( + builder->builder.create(builder->loc, std::forward(v)...)); + } + } + + template + auto yield(Args &&...v) { + builder->builder.create(builder->loc, + ValueRange{std::forward(v)...}); + } +}; + +} // namespace easybuild +} // namespace mlir +#endif \ No newline at end of file diff --git a/include/gc/IR/EasyBuildSCF.h b/include/gc/IR/EasyBuildSCF.h new file mode 100644 index 000000000..0bd2ac980 --- /dev/null +++ b/include/gc/IR/EasyBuildSCF.h @@ -0,0 +1,186 @@ +//===- EasyBuildSCF.h - Easy IR Builder for general control flow *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines the helper classes, functions and macros to help to +// build general structured control flow. Developers can use the utilities in +// this header to easily compose control flow IR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_EASYBUILDSCF_H +#define MLIR_IR_EASYBUILDSCF_H +#include "gc/IR/EasyBuild.h" +#include "mlir/Interfaces/LoopLikeInterface.h" + +namespace mlir { +namespace scf { +class IfOp; +} + +namespace easybuild { +namespace impl { + +struct ForRangeSimulatorImpl { + StatePtr s; + LoopLikeOpInterface op; + ForRangeSimulatorImpl(const StatePtr &s, LoopLikeOpInterface op) + : s{s}, op{op} { + s->builder.setInsertionPointToStart(&op.getLoopRegions().front()->front()); + } + ~ForRangeSimulatorImpl() { s->builder.setInsertionPointAfter(op); } +}; + +template +using NthTypeOf = typename std::tuple_element>::type; + +template struct ForVarBinding { + ForRangeSimulatorImpl *impl; + template auto get() { + using TOut = NthTypeOf; + if (auto wrapped = TOut::wrapOrFail( + impl->s, impl->op.getLoopRegions().front()->front().getArgument(I)); + succeeded(wrapped)) { + return *wrapped; + } + llvm_unreachable("Bad cast for the loop iterator"); + } +}; +} // namespace impl +} // namespace easybuild +} // namespace mlir + +namespace std { +template +struct tuple_size> + : std::integral_constant {}; + +template +struct tuple_element> { + using type = mlir::easybuild::impl::NthTypeOf; +}; +} // namespace std + +namespace mlir { +namespace easybuild { + +namespace impl { + +template struct ForRangeSimulator : ForRangeSimulatorImpl { + using ForRangeSimulatorImpl::ForRangeSimulatorImpl; + struct ForRangeIterator { + ForRangeSimulatorImpl *ptr; + bool consumed; + auto operator*() const { return ForVarBinding{ptr}; } + + ForRangeIterator &operator++() { + consumed = true; + return *this; + } + + bool operator!=(ForRangeIterator &other) const { + return consumed != other.consumed; + } + + ForRangeIterator(ForRangeSimulator *ptr) : ptr{ptr}, consumed{false} {} + ForRangeIterator() : ptr{nullptr}, consumed{true} {} + }; + + ForRangeIterator begin() { return ForRangeIterator(this); } + + ForRangeIterator end() { return ForRangeIterator(); } +}; +} // namespace impl + +template +auto forRangeIn(const impl::StatePtr &s, LoopLikeOpInterface op) { + return impl::ForRangeSimulator{s, op}; +} + +template +auto forRangeIn(const EasyBuilder &s, LoopLikeOpInterface op) { + return impl::ForRangeSimulator{s.builder, op}; +} + +#define EB_for for + +namespace impl { +struct IfSimulator; +struct IfIterator { + IfSimulator *ptr; + int index; + int operator*() const; + + IfIterator &operator++() { + index++; + return *this; + } + + bool operator!=(IfIterator &other) const { return index != other.index; } + + IfIterator(IfSimulator *ptr) : ptr{ptr}, index{0} {} + IfIterator(int numRegions) : ptr{nullptr}, index{numRegions} {} +}; + +struct IfSimulator { + StatePtr s; + Operation *op; + IfIterator begin() { return IfIterator(this); } + IfIterator end() { + int nonEmptyRegions = 0; + for (auto ® : op->getRegions()) { + if (reg.begin() != reg.end()) { + nonEmptyRegions++; + } + } + return IfIterator(nonEmptyRegions); + } + ~IfSimulator() { s->builder.setInsertionPointAfter(op); } +}; +inline int IfIterator::operator*() const { + auto &blocks = ptr->op->getRegion(index); + ptr->s->builder.setInsertionPointToStart(&blocks.back()); + return index; +} + +} // namespace impl + +impl::IfSimulator makeIfRange(const EasyBuilder &s, Operation *op) { + return impl::IfSimulator{s.builder, op}; +} + +template +impl::IfSimulator makeScfIfLikeRange(EBValue cond, TypeRange resultTypes) { + auto &s = cond.builder; + auto op = s->builder.create(s->loc, resultTypes, cond, true); + return impl::IfSimulator{s, op}; +} + +template +impl::IfSimulator makeScfIfLikeRange(EBValue cond, bool hasElse = true) { + auto &s = cond.builder; + auto op = s->builder.create(s->loc, TypeRange{}, cond, hasElse); + return impl::IfSimulator{s, op}; +} + +#define EB_if(BUILDER, ...) \ + for (auto &&eb_mlir_if_scope__ : \ + ::mlir::easybuild::makeIfRange(BUILDER, __VA_ARGS__)) \ + if (eb_mlir_if_scope__ == 0) + +// EB_scf_if(COND) +// EB_scf_if(COND, HAS_ELSE) +// EB_scf_if(COND, RESULT_TYPES) +#define EB_scf_if(...) \ + for (auto &&eb_mlir_if_scope__ : \ + ::mlir::easybuild::makeScfIfLikeRange(__VA_ARGS__)) \ + if (eb_mlir_if_scope__ == 0) +#define EB_else else + +} // namespace easybuild +} // namespace mlir +#endif \ No newline at end of file From c0c574993d19e2ab329354ed38de1325ae87597c Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Wed, 22 May 2024 23:08:29 -0700 Subject: [PATCH 04/21] Init C buffer with easy builder --- lib/gc/Transforms/CMakeLists.txt | 1 + .../Transforms/DeepTileContractionNamedOp.cpp | 108 ++++++++++-------- lib/gc/Transforms/Tiling.cpp | 6 +- 3 files changed, 63 insertions(+), 52 deletions(-) diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 21f522224..4da3ef5f8 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -16,6 +16,7 @@ gc_add_mlir_library(GcPasses TilingUsingInterfaceX.cpp VerifyTargetDescription.cpp DeepTileContractionNamedOp.cpp + Tiling.cpp DEPENDS GraphCompilerPassIncGen diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index 334ac1902..afef12480 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -1,16 +1,15 @@ -//===----------------------------------------------------------------------===// -//===- DeepTileContractionNamedOp.cpp - the Fusion for any tilable MLIR -// operation --*- C++ -//-*-=// -//-*-===// -// +//===-- DeepTileContractionNamedOp.cpp - DESC -------------------*- C++ -*-===// +// // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// +// //===----------------------------------------------------------------------===// #include "./Tiling.hpp" +#include "gc/Dialect/Arith/Utils/EasyBuild.h" +#include "gc/IR/EasyBuild.h" +#include "gc/IR/EasyBuildSCF.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -179,6 +178,7 @@ struct OuterLoopGenerationResult { SmallVector tiledOps; /// The `scf.for` operations that iterate over the tiles. SmallVector loops; + SmallVector reductionLoops; /// Values to use as replacements for the untiled op. Is the same size as the /// number of results of the untiled op. SmallVector replacements; @@ -192,6 +192,8 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, auto nestedTileSizes = option.nestedTileSizes; auto loopType = option.loopType; auto loopDim = option.loopDim; + SmallVector iteratorTypes = + linalgOp.getIteratorTypesArray(); if (loopType.size() != loopDim.size() || loopDim.size() != nestedTileSizes.size()) { @@ -228,6 +230,13 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, return failure(); b.replaceOp(currentOp, tilingResult->replacements); currentOp = dyn_cast(tilingResult->tiledOps.back()); + + for (auto [dim, loop] : llvm::zip(currentDim, tilingResult->loops)) { + if (iteratorTypes[dim] == mlir::utils::IteratorType::reduction) { + result.reductionLoops.push_back(loop); + } + result.loops.push_back(loop); + } } else if (type == OuterLoopGenerationOption::LoopType::ForallOp) { SmallVector tileSizes( currentOp.getNumLoops(), getAsIndexOpFoldResult(b.getContext(), 0)); @@ -262,7 +271,6 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(currentOp); - // TODO: add split reduction support here if (auto partialInterface = dyn_cast(currentOp.getOperation())) { auto tilingResult = linalgX::tileAllUsingForall( @@ -395,8 +403,9 @@ getOprandDimType(linalg::LinalgOp &linalgOp) { } /* -forall([PM, PN]: [MThreads, NThreads) { - for(PK : KThreads) { +matmul(A, B) -> C +----------------> +forall([PM, PN, PK]: [MThreads, NThreads, KThreads]) { CSlice = [KThreads, PM * MOuterBlock: (PM + 1) * MOuterBlock, PN * NOuterBlock: (PN + 1) * NOuterBlock] ASlice = A[PM * MOuterBlock: (PM + 1) * MOuterBlock, PK * KOuterBlock * (PK @@ -404,7 +413,6 @@ forall([PM, PN]: [MThreads, NThreads) { BSlice = B[PK * KOuterBlock * (PK + 1) * KOuterBlock, PN * NOuterBlock: (PN + 1) * NOuterBlock] CSlice2 = CSlice[PK, PM * MOuterBlock: (PM + 1) * MOuterBlock, PN * NOuterBlock: (PN + 1) * NOuterBlock] - MNumBlock = MOuterBlock / MBlock NNumBlock = NOuterBlock / NBlock KNumBlock = KOuterBlock / KBlovk @@ -426,9 +434,8 @@ iin_block_: (in + 1) * iin_block_] (init with 0 when ok == 0) A=ASlice3, B=BSlice3, C=CSlice4, onlyUpdate=(ok!=0)); } } - } - C = final_reduce(CSlice) } +C = final_reduce(CSlice) */ struct deepTileMatmul : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; @@ -506,14 +513,15 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { } struct innerBodyGenerationOption { - bool hasFillOp = false; - Value fillValue; + Operation *fillOp; + SmallVector KLoopHandles; }; - LogicalResult - innerBodyGeneration(RewriterBase &rewriter, linalg::LinalgOp originOp, - linalg::LinalgOp currentOp, - const innerBodyGenerationOption &option) const { + LogicalResult innerBodyGeneration(RewriterBase &rewriter, + linalg::LinalgOp originOp, + linalg::LinalgOp currentOp, + innerBodyGenerationOption &option) const { + mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc()}; auto operandDimTypes = getOprandDimType(originOp); MatmulConfig cfg = getDefaultMatmulConfig(originOp); auto AShape = originOp.getShape(originOp.getDpsInputOperand(0)); @@ -655,19 +663,34 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { rewriter.replaceOp(currentOp, matmul.getOperation()->getResult(0)); currentOp = matmul; - if (option.hasFillOp) { - // TODO: support partial K in sinsngle threads, control flow may need - // easy builder support + if (auto fillOp = llvm::dyn_cast_or_null(option.fillOp)) { + auto fillValue = fillOp.getDpsInputs()[0]; + rewriter.replaceOp(fillOp, fillOp.getDpsInits()[0]); + rewriter.setInsertionPointAfter(currentOp); - auto fillOp = rewriter.create( - currentOp->getLoc(), option.fillValue, currentOp.getDpsInits()[0]); - IRMapping mapping; - mapping.map(currentOp.getDpsInits()[0], fillOp.getResult(0)); - auto res = rewriter.clone(*(currentOp.getOperation()), mapping); - rewriter.replaceOp(currentOp, res); - currentOp = dyn_cast(res); + auto cond = eb(true); + for (auto loop : option.KLoopHandles) { + auto induceVar = eb.wrap( + loop.getLoopRegions().front()->front().getArgument(0)); + auto currentCond = induceVar == eb.toIndex(0); + cond = cond & currentCond; + } + EB_scf_if(cond, {currentOp.getDpsInits()[0].getType()}) { + auto fillOp = rewriter.create( + currentOp->getLoc(), fillValue, currentOp.getDpsInits()[0]); + IRMapping mapping; + mapping.map(currentOp.getDpsInits()[0], fillOp.getResult(0)); + auto res = rewriter.clone(*(currentOp.getOperation()), mapping); + eb.yield(res->getResult(0)); + } + EB_else { + auto res = rewriter.clone(*(currentOp.getOperation())); + eb.yield(res->getResult(0)); + } + auto ifOp = eb.getLastOperaion(); + rewriter.replaceOp(currentOp, ifOp); + ifOp->getParentOfType().dump(); } - currentOp.getOperation()->getParentOfType().dump(); return success(); } @@ -680,34 +703,21 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { if (linalgOp.getOperation()->getParentOfType()) return failure(); - // Step 1. Match and remove the init/fill operation - // Fuse the fill op manually before fusion support this case(fuse it into - // if-else block) - bool hasFillOp = false; - Value fillValue; - SmallVector KLoopHandle; - if (auto op = dyn_cast( - linalgOp.getDpsInits()[0].getDefiningOp())) { - hasFillOp = true; - fillValue = op.getDpsInputs()[0]; - rewriter.replaceOp(op, op.getDpsInits()[0]); - } + Operation *fillOp = linalgOp.getDpsInits()[0].getDefiningOp(); - // Step 2. The processes of outer Loop Generation + // Step 1. generate the outer loop // 2.0 Get the iteration infomation first MatmulConfig cfg = getDefaultMatmulConfig(linalgOp); - // TODO: move the reduction dim to the front. (M, N, threads) -> - // (threads, M, N) auto outerLoopResult = outerLoopGeneration(rewriter, linalgOp, cfg); if (failed(outerLoopResult)) { return failure(); } linalgOp = dyn_cast(outerLoopResult->tiledOps.back()); - // Step 3 inner loop generation, convert the linalg.generic to brgemm - if (failed(innerBodyGeneration( - rewriter, matmulOp, linalgOp, - innerBodyGenerationOption{hasFillOp, fillValue}))) { + // Step 2 generate inner loop body, convert the linalg.generic to brgemm + auto option = + innerBodyGenerationOption{fillOp, outerLoopResult->reductionLoops}; + if (failed(innerBodyGeneration(rewriter, matmulOp, linalgOp, option))) { return failure(); } return success(); diff --git a/lib/gc/Transforms/Tiling.cpp b/lib/gc/Transforms/Tiling.cpp index b9de4a777..8404d229f 100644 --- a/lib/gc/Transforms/Tiling.cpp +++ b/lib/gc/Transforms/Tiling.cpp @@ -855,7 +855,7 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, return b.notifyMatchFailure(op, "if tile sizes are present it must have as " "many elements as number of threads"); - if (redDims.front() >= numThreads.size()) + if ((unsigned)redDims.front() >= numThreads.size()) return b.notifyMatchFailure( op, "reduction dimension must be mapped to threads"); @@ -926,7 +926,7 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, } auto nonZeroDimIdx = 0; - for (auto dim = 0; dim < numThreads.size(); dim++) { + for (auto dim = 0UL; dim < numThreads.size(); dim++) { if (!isConstantIntValue(numThreads[dim], 0)) { if (llvm::find(redDims, dim) != redDims.end()) outOffsets[dim] = forallOp.getInductionVars()[nonZeroDimIdx]; @@ -985,7 +985,7 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, int64_t offIdx = 0; int64_t sizeIdx = 0; int64_t nonZeroDimIdx = 0; - for (int64_t i = 0; i < numThreads.size(); ++i) { + for (auto i = 0UL; i < numThreads.size(); ++i) { if (llvm::find(redDims, i) != redDims.end()) { if (hasReductionThreads) { resultOffsetsRank.push_back( From c4b777c653cf35258830b70370acc94eda962621 Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Sun, 26 May 2024 19:12:35 -0700 Subject: [PATCH 05/21] support partial reduction --- include/gc/Dialect/Arith/Utils/EasyBuild.h | 19 +- include/gc/IR/EasyBuild.h | 10 +- include/gc/IR/EasyBuildSCF.h | 5 +- .../Transforms/DeepTileContractionNamedOp.cpp | 18 +- lib/gc/Transforms/Tiling.cpp | 189 +++++++++++------- lib/gc/Transforms/Tiling.hpp | 9 +- 6 files changed, 143 insertions(+), 107 deletions(-) diff --git a/include/gc/Dialect/Arith/Utils/EasyBuild.h b/include/gc/Dialect/Arith/Utils/EasyBuild.h index 2a45ffcee..74f664184 100644 --- a/include/gc/Dialect/Arith/Utils/EasyBuild.h +++ b/include/gc/Dialect/Arith/Utils/EasyBuild.h @@ -1,17 +1,10 @@ -//===- EasyBuild.h - Easy Arith IR Builder utilities ------------*- C++ -*-===// +//===-- EasyBuild.h - DESC --------------------------------------*- C++ -*-===// // -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// This header file defines the easy-build utilities for arith dialects. It -// provides the utility functions, classes and operators to make it easir to -// program arith dialect operations in C++ -// -//===----------------------------------------------------------------------===// - #ifndef MLIR_DIALECT_ARITH_UTILS_EASYBUILD_H #define MLIR_DIALECT_ARITH_UTILS_EASYBUILD_H #include "gc/IR/EasyBuild.h" @@ -28,12 +21,8 @@ namespace impl { template struct ToFloatType {}; -template <> struct ToFloatType<4> { - using type = Float32Type; -}; -template <> struct ToFloatType<8> { - using type = Float64Type; -}; +template <> struct ToFloatType<4> { using type = Float32Type; }; +template <> struct ToFloatType<8> { using type = Float64Type; }; inline Type getElementType(Value v) { auto type = v.getType(); diff --git a/include/gc/IR/EasyBuild.h b/include/gc/IR/EasyBuild.h index 55fa5a06c..4b6e72225 100644 --- a/include/gc/IR/EasyBuild.h +++ b/include/gc/IR/EasyBuild.h @@ -1,16 +1,10 @@ -//===- EasyBuild.h - Easy IR Builder utilities ------------------*- C++ -*-===// +//===-- EasyBuild.h - DESC --------------------------------------*- C++ -*-===// // -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// This header file defines the easy-build utilities core data structures for -// building IR. -// -//===----------------------------------------------------------------------===// - #ifndef MLIR_IR_EASYBUILD_H #define MLIR_IR_EASYBUILD_H #include "mlir/IR/Builders.h" diff --git a/include/gc/IR/EasyBuildSCF.h b/include/gc/IR/EasyBuildSCF.h index 0bd2ac980..3d7ce9d77 100644 --- a/include/gc/IR/EasyBuildSCF.h +++ b/include/gc/IR/EasyBuildSCF.h @@ -1,10 +1,11 @@ -//===- EasyBuildSCF.h - Easy IR Builder for general control flow *- C++ -*-===// +//===-- EasyBuildSCF.h - DESC -----------------------------------*- C++ -*-===// // -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// + // // This header file defines the helper classes, functions and macros to help to // build general structured control flow. Developers can use the utilities in diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index afef12480..dec219ebe 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -1,9 +1,9 @@ //===-- DeepTileContractionNamedOp.cpp - DESC -------------------*- C++ -*-===// -// +// // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// +// //===----------------------------------------------------------------------===// #include "./Tiling.hpp" @@ -273,9 +273,19 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, b.setInsertionPoint(currentOp); if (auto partialInterface = dyn_cast(currentOp.getOperation())) { + for (auto [idx, tile] : llvm::enumerate(tileSizes)) { + if (isConstantIntValue(tile, 0)) { + tileSizes[idx] = loopRanges[idx].size; + } + } + + SmallVector newParallelDims; + for (auto i = 0UL; i < reductionDims.size(); i++) { + newParallelDims.push_back(getAsIndexOpFoldResult(b.getContext(), i)); + } auto tilingResult = linalgX::tileAllUsingForall( - b, cast(currentOp.getOperation()), - numThreads, tileSizes, std::nullopt); + b, cast(currentOp.getOperation()), {}, + tileSizes, newParallelDims, std::nullopt); if (failed(tilingResult)) return failure(); currentOp = dyn_cast(tilingResult->parallelTiledOp); diff --git a/lib/gc/Transforms/Tiling.cpp b/lib/gc/Transforms/Tiling.cpp index 8404d229f..4c36e6661 100644 --- a/lib/gc/Transforms/Tiling.cpp +++ b/lib/gc/Transforms/Tiling.cpp @@ -181,36 +181,15 @@ struct LinalgOpPartialReductionInterface { ValueRange partialReduce, ArrayRef reductionDims) { auto linalgOp = cast(op); - - DenseSet reductionDimsSet(reductionDims.begin(), reductionDims.end()); - - // Then create a new reduction that only reduce the newly added dimensions - // from the previous op. - int64_t intermRank = cast(partialReduce[0].getType()).getRank(); - AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); - SmallVector reductionIteratorTypes; - SmallVector exprs; - - for (int64_t i : llvm::seq(0, intermRank)) { - if (reductionDimsSet.contains(i)) { - reductionIteratorTypes.push_back(utils::IteratorType::reduction); - } else { - exprs.push_back(b.getAffineDimExpr(i)); - reductionIteratorTypes.push_back(utils::IteratorType::parallel); - } - } - - AffineMap outputMap = - AffineMap::get(intermRank, 0, exprs, op->getContext()); - SmallVector reductionMaps = {inputMap, outputMap}; - + SmallVector reductionDimsInt64(reductionDims.begin(), + reductionDims.end()); SmallVector combinerOps; matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps); Operation *reductionOp = combinerOps[0]; - auto reduction = b.create( - loc, op->getResultTypes(), ValueRange({partialReduce[0]}), - linalgOp.getDpsInits(), reductionMaps, reductionIteratorTypes, + auto reduction = b.create( + loc, ValueRange({partialReduce[0]}), + ValueRange({linalgOp.getDpsInits()[0]}), reductionDimsInt64, [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { Operation *clonedReductionOp = b.clone(*reductionOp); clonedReductionOp->setOperand(0, inputs[0]); @@ -768,8 +747,9 @@ FailureOr tileReductionUsingForall( } // 7. Merge the partial reductions. b.setInsertionPointAfter(forallOp); - Operation *mergeOp = op.mergeReductions(b, loc, forallOp->getResults(), - constantNewParallelDims); + Operation *mergeOp = + linalgX::LinalgOpPartialReductionInterface::mergeReductions( + op, b, loc, forallOp->getResults(), constantNewParallelDims); b.replaceOp(op, mergeOp->getResults()); // 8. Return. ForallReductionTilingResult results; @@ -802,11 +782,10 @@ FailureOr static tileLinalgOpImpl( return tileLinalgOpImpl(b, op, tileSizeVector, options); } -FailureOr -tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, - ArrayRef numThreads, - ArrayRef tileSizes, - std::optional mapping) { +FailureOr tileAllUsingForall( + RewriterBase &b, PartialReductionOpInterface op, + ArrayRef threadNums, ArrayRef tileSizes, + ArrayRef newParallelDims, std::optional mapping) { Location loc = op.getLoc(); OpBuilder::InsertionGuard g(b); @@ -842,6 +821,24 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, if (iteratorType == utils::IteratorType::reduction) redDims.push_back(idx); } + + SmallVector numThreads(threadNums.begin(), threadNums.end()); + if (numThreads.empty()) { + SmallVector loopRanges = tilingInterfaceOp.getIterationDomain(b); + unsigned nLoops = loopRanges.size(); + numThreads.reserve(nLoops); + AffineExpr s0, s1; + bindSymbols(b.getContext(), s0, s1); + AffineExpr divExpr = s0.ceilDiv(s1); + for (const auto &it : llvm::zip(tileSizes, loopRanges)) { + OpFoldResult numTiles = std::get<0>(it); + if (!isConstantIntValue(numTiles, 0)) + numTiles = makeComposedFoldedAffineApply( + b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)}); + numThreads.push_back(numTiles); + } + } + bool hasReductionThreads = false; for (auto dim : redDims) { if (!isConstantIntValue(numThreads[dim], 0) && @@ -858,7 +855,18 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, if ((unsigned)redDims.front() >= numThreads.size()) return b.notifyMatchFailure( op, "reduction dimension must be mapped to threads"); - + SmallVector constantNewParallelDims; + for (auto dim : newParallelDims) { + if (getConstantIntValue(dim) == std::nullopt) + return b.notifyMatchFailure( + op, "Expected new parallel dims to be constant integers."); + constantNewParallelDims.push_back(*getConstantIntValue(dim)); + } + if (newParallelDims.empty()) + constantNewParallelDims = redDims; + if (constantNewParallelDims.size() != redDims.size()) + return b.notifyMatchFailure( + op, "reduction dimension must be mapped to new parallel dims"); // 1. Create the inital tensor value. FailureOr> maybeInitTensors; SmallVector initTensors; @@ -876,7 +884,6 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, SmallVector dest; if (failed(tensor::getOrCreateDestinations(b, loc, op, dest))) return b.notifyMatchFailure(op, "failed to get destination tensors"); - Operation *tiledOp = nullptr; SmallVector nonZeroNumThreads = @@ -885,7 +892,6 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, })); SmallVector materializedNonZeroNumThreads = getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads); - // 2. Create the ForallOp with an empty region. scf::ForallOp forallOp = b.create( loc, getAsOpFoldResult(materializedNonZeroNumThreads), @@ -893,11 +899,14 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, // 3. Calculate the tile offsets and sizes for the subsequent loop that will // be nested under `forallOp`. SmallVector tiledOffsets, tiledSizes; + std::optional> nominalTileSizes = std::nullopt; + if (!tileSizes.empty() && threadNums.empty()) { + nominalTileSizes = tileSizes; + } calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, iterationDomain, /*omitTileOffsetBoundsCheck =*/false, - /*nominalTileSizes=*/tileSizes, tiledOffsets, - tiledSizes); - + /*nominalTileSizes=*/nominalTileSizes, + tiledOffsets, tiledSizes); // 4. Clone the tileable op and update its destination operands to use the // output bbArgs of the ForallOp. SmallVector tilingResults; @@ -916,20 +925,26 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, SmallVector strides(numThreads.size(), b.getIndexAttr(1)); SmallVector outOffsets(numThreads.size(), b.getIndexAttr(0)); - SmallVector sizes; - for (auto s : - cast(destBbArgs[destNum].getType()).getShape()) { - sizes.emplace_back(getAsIndexOpFoldResult(b.getContext(), (int)s)); - } - for (auto dim : redDims) { - sizes[dim] = b.getIndexAttr(1); + SmallVector sizes = tiledSizes; + for (const auto &iteratorType : llvm::enumerate( + cast(destBbArgs[destNum].getType()) + .getShape())) { + sizes[iteratorType.index()] = + getAsIndexOpFoldResult(b.getContext(), iteratorType.value()); + if (llvm::find(constantNewParallelDims, iteratorType.index()) != + constantNewParallelDims.end()) { + sizes[iteratorType.index()] = b.getIndexAttr(1); + } } auto nonZeroDimIdx = 0; - for (auto dim = 0UL; dim < numThreads.size(); dim++) { - if (!isConstantIntValue(numThreads[dim], 0)) { - if (llvm::find(redDims, dim) != redDims.end()) - outOffsets[dim] = forallOp.getInductionVars()[nonZeroDimIdx]; + auto currentReductionIdx = 0; + for (const auto &iteratorType : llvm::enumerate(numThreads)) { + if (!isConstantIntValue(iteratorType.value(), 0)) { + if (llvm::find(redDims, iteratorType.index()) != redDims.end()) { + outOffsets[constantNewParallelDims[currentReductionIdx++]] = + forallOp.getInductionVars()[nonZeroDimIdx]; + } nonZeroDimIdx++; } } @@ -938,7 +953,10 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, loc, cast(initOperand.getType()), destBbArgs[destNum], outOffsets, sizes, strides)); } else { - tiledDpsInitOperands.push_back(initOperand); + auto *it = llvm::find(dest, initOperand); + assert(it != dest.end() && "dest operand not found in dest"); + unsigned destNum = std::distance(dest.begin(), it); + tiledDpsInitOperands.push_back(destBbArgs[destNum]); } } @@ -953,19 +971,35 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, initOperandPtr.set(tiledInitValue); } }); - // 5. Tile the cloned op and delete the clone. - FailureOr tilingResult = - cast(clonedOp).getTiledImplementation(b, tiledOffsets, - tiledSizes); - if (failed(tilingResult)) - return clonedOp->emitError("Failed to tile op: "); - if (tilingResult->tiledOps.size() != 1) { - return clonedOp->emitError("expected a single produced tiled op, got ") - << tilingResult->tiledOps.size(); + if (tileSizes.empty() || threadNums.empty()) { + FailureOr tilingResult = + cast(clonedOp).getTiledImplementation( + b, tiledOffsets, tiledSizes); + if (failed(tilingResult)) + return clonedOp->emitError("Failed to tile op: "); + if (tilingResult->tiledOps.size() != 1) { + return clonedOp->emitError("expected a single produced tiled op, got ") + << tilingResult->tiledOps.size(); + } + tiledOp = tilingResult->tiledOps.front(); + tilingResults = tilingResult->tiledValues; + } else { + LinalgTilingOptions options; + FailureOr maybeTiled = tileLinalgOpImpl( + b, cast(clonedOp), tileSizes, options); + if (failed(maybeTiled)) + return b.notifyMatchFailure(op, "failed tileLinalgOpImpl"); + + SmallVector ids = forallOp.getInductionVars(); + mapLoopToProcessorIds(cast(maybeTiled->loops.back()), ids, + materializedNonZeroNumThreads); + if (maybeTiled->loops.size() != 1) { + return clonedOp->emitError("expected a single produced loop"); + } + tiledOp = maybeTiled->op; + tilingResults = maybeTiled->loops.front()->getResults(); } - tiledOp = tilingResult->tiledOps.front(); - tilingResults = tilingResult->tiledValues; b.eraseOp(clonedOp); } @@ -983,23 +1017,33 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, return op->emitOpError("output offsets couldn't be calculated"); SmallVector resultOffsetsRank, resultSizesRank; int64_t offIdx = 0; - int64_t sizeIdx = 0; int64_t nonZeroDimIdx = 0; + SmallVector reductionInductionVars; for (auto i = 0UL; i < numThreads.size(); ++i) { - if (llvm::find(redDims, i) != redDims.end()) { + if (llvm::find(constantNewParallelDims, i) != + constantNewParallelDims.end()) { if (hasReductionThreads) { - resultOffsetsRank.push_back( - forallOp.getInductionVars()[nonZeroDimIdx]); + resultOffsetsRank.push_back(b.getIndexAttr(1)); resultSizesRank.push_back(b.getIndexAttr(1)); } - nonZeroDimIdx++; - continue; + } else { + resultOffsetsRank.push_back(resultOffsets[offIdx]); + resultSizesRank.push_back(resultSizes[offIdx++]); + } + if (llvm::find(redDims, i) != redDims.end()) { + reductionInductionVars.push_back( + forallOp.getInductionVars()[nonZeroDimIdx]); } if (!isConstantIntValue(numThreads[i], 0)) { nonZeroDimIdx++; } - resultOffsetsRank.push_back(resultOffsets[offIdx++]); - resultSizesRank.push_back(resultSizes[sizeIdx++]); + } + if (hasReductionThreads) { + for (auto [parallelDims, redVar] : + llvm::zip(constantNewParallelDims, reductionInductionVars)) { + resultOffsetsRank[parallelDims] = redVar; + resultSizesRank[parallelDims] = b.getIndexAttr(1); + } } SmallVector strides(resultSizesRank.size(), b.getIndexAttr(1)); @@ -1010,18 +1054,17 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, b.create( loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides); } - // 7. Merge the partial reductions. Operation *mergeOp = nullptr; b.setInsertionPointAfter(forallOp); if (hasReductionThreads) { Operation *mergeOp = - op.mergeReductions(b, loc, forallOp->getResults(), redDims); + linalgX::LinalgOpPartialReductionInterface::mergeReductions( + op, b, loc, forallOp->getResults(), constantNewParallelDims); b.replaceOp(op, mergeOp->getResults()); } else { b.replaceOp(op, forallOp->getResults()); } - // 8. Return. ForallReductionTilingResult results; results.initialValues = initTensors; diff --git a/lib/gc/Transforms/Tiling.hpp b/lib/gc/Transforms/Tiling.hpp index e46720f93..7c4188096 100644 --- a/lib/gc/Transforms/Tiling.hpp +++ b/lib/gc/Transforms/Tiling.hpp @@ -44,11 +44,10 @@ FailureOr tileReductionUsingForall( ArrayRef threadNums, ArrayRef tileSizes, ArrayRef newParallelDims, std::optional mapping); -FailureOr -tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op, - ArrayRef numThreads, - ArrayRef tileSizes, - std::optional mapping); +FailureOr tileAllUsingForall( + RewriterBase &b, PartialReductionOpInterface op, + ArrayRef numThreads, ArrayRef tileSizes, + ArrayRef newParallelDims, std::optional mapping); } // namespace linalgX } // namespace mlir From b98100b5dfc7de416066f3311671f39fe2a1870b Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Thu, 30 May 2024 20:00:23 -0700 Subject: [PATCH 06/21] support bf16 cast fuse --- .../Transforms/DeepTileContractionNamedOp.cpp | 332 +++++++++++++----- 1 file changed, 235 insertions(+), 97 deletions(-) diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index dec219ebe..f9b0ea1b8 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -59,6 +59,48 @@ struct MatmulConfig { template inline T divAndCeil(T a, T b) { return (a - 1) / b + 1; } +enum DimType { Batch, M, N, K }; + +static FailureOr>> +getOprandDimType(linalg::LinalgOp &linalgOp) { + if (isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K}, + SmallVector{DimType::K, DimType::N}, + SmallVector{DimType::M, DimType::N}}; + } else if (isa(linalgOp)) { + auto iteratorTypes = linalgOp.getIteratorTypesArray(); + if (iteratorTypes.size() == 7UL) { + // 4Dx5D, brgemm vnni + return SmallVector>{ + SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, + DimType::K}, + SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; + } else if (iteratorTypes.size() == 6UL) { + // 4Dx4D + return SmallVector>{ + SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K, DimType::K, DimType::N}, + SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; + } + } else { + return failure(); + } + return failure(); +} + +[[maybe_unused]] static SmallVector +extractDimTypeIdx(ArrayRef tyList, DimType ty) { + SmallVector idxList; + for (auto [idx, type] : llvm::enumerate(tyList)) { + if (type == ty) { + idxList.push_back(idx); + } + } + return idxList; +} + MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) { // TODO: build a more complex heuristic to determine the best tiling auto M = linalgOp.getShape(linalgOp.getDpsInputOperand(0))[0]; @@ -86,7 +128,6 @@ MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) { cfg.MBlock = divAndCeil((int)MNumBlock, cfg.MThreads) * cfg.innerMostMBlock; cfg.NBlock = divAndCeil((int)NNumBlock, cfg.NThreads) * cfg.innerMostNBlock; cfg.KBlock = divAndCeil((int)KNumBlock, cfg.KThreads) * cfg.innerMostKBlock; - cfg.innerMostMBlock = 32; cfg.innerMostNBlock = 32; cfg.innerMostKBlock = 32; @@ -94,7 +135,7 @@ MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) { cfg.NBlock = 64; cfg.KBlock = 64; cfg.MThreads = 2; - cfg.NThreads = 1; + cfg.NThreads = 2; cfg.KThreads = 1; return cfg; } @@ -169,6 +210,7 @@ struct OuterLoopGenerationOption { SmallVector> nestedTileSizes; SmallVector loopType; SmallVector> loopDim; + bool hasFillOp; }; struct OuterLoopGenerationResult { @@ -179,11 +221,108 @@ struct OuterLoopGenerationResult { /// The `scf.for` operations that iterate over the tiles. SmallVector loops; SmallVector reductionLoops; - /// Values to use as replacements for the untiled op. Is the same size as the - /// number of results of the untiled op. - SmallVector replacements; }; +static void buildLinalgRegion(Operation *op) { + SmallVector argTypes; + SmallVector argLocs; + for (const Value &opOperand : op->getOperands()) { + argTypes.push_back(getElementTypeOrSelf(opOperand.getType())); + argLocs.push_back(opOperand.getLoc()); + } + ImplicitLocOpBuilder b(op->getLoc(), op->getContext()); + Region ®ion = op->getRegion(0); + Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); + b.setInsertionPointToStart(body); + auto *dialect = static_cast(op->getDialect()); + linalg::LinalgDialect::RegionBuilderFunType fun = + dialect->getRegionBuilder("linalg.matmul"); + fun(b, *body, op->getAttrs()); +} + +struct DtypeLegalizeResult { + Operation *linalgOp = nullptr; + Operation *castOp = nullptr; +}; + +// Split a low precision matmul(bf16xbf16->bf16) to a combination +// matmul(bf16xbf16->f32) + cast(f32->bf16) +static FailureOr +matmulDtypeLegalize(RewriterBase &rewriter, Operation *op, + bool needCopyInit = true) { + + auto linalgOp = dyn_cast(op); + DtypeLegalizeResult result; + if (!linalgOp) + return failure(); + + auto dataType = + dyn_cast(linalgOp.getDpsInputs()[0].getType()) + .getElementType(); + auto resultType = + dyn_cast(linalgOp.getDpsInits()[0].getType()) + .getElementType(); + + if ((dataType.isBF16() || dataType.isF16()) && dataType == resultType) { + rewriter.setInsertionPoint(linalgOp); + IRMapping mapping; + auto initOp = linalgOp.getDpsInits()[0].getDefiningOp(); + auto initValue = initOp->getResult(0); + auto initType = cast(initValue.getType()); + auto tensorShape = initType.getShape(); + SmallVector mixedShape; + for (auto i = 0UL; i < tensorShape.size(); i++) { + if (initType.isDynamicDim(i)) { + Value val = + rewriter.create(linalgOp.getLoc(), initValue, i); + mixedShape.push_back(val); + } else { + mixedShape.push_back( + getAsIndexOpFoldResult(rewriter.getContext(), tensorShape[i])); + } + } + Operation *currentOp; + + currentOp = rewriter.create( + linalgOp.getLoc(), mixedShape, Float32Type::get(op->getContext())); + if (needCopyInit) { + currentOp = rewriter.create( + linalgOp.getLoc(), initOp->getResult(0), currentOp->getResult(0)); + } + SmallVector newOperands = linalgOp->getOperands(); + newOperands.back() = currentOp->getResult(0); + OperationState state(linalgOp->getLoc(), linalgOp->getName(), newOperands, + currentOp->getResult(0).getType(), + linalgOp->getAttrs()); + state.addRegion(); + currentOp = rewriter.create(state); + buildLinalgRegion(currentOp); + + auto castOp = rewriter.create( + linalgOp.getLoc(), currentOp->getResult(0), initOp->getResult(0)); + result.linalgOp = currentOp; + result.castOp = castOp; + } + + return result; +} + +static Operation *findParentFillOp(Value val) { + SmallVector skipOpList = {"tensor.pack", "tensor.pad"}; + auto currentOp = val.getDefiningOp(); + while (currentOp && + llvm::find(skipOpList, currentOp->getName().getStringRef()) != + skipOpList.end() && + !isa(currentOp)) { + currentOp = currentOp->getResult(0).getDefiningOp(); + } + if (isa(currentOp)) { + return currentOp; + } + + return nullptr; +} + static FailureOr generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, const OuterLoopGenerationOption &option) { @@ -205,39 +344,43 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, if (linalgOp.hasPureBufferSemantics()) return b.notifyMatchFailure( linalgOp, "currentOp should not has pure buffer semantics"); - linalg::LinalgOp currentOp = linalgOp; - for (auto iteratorType : llvm::enumerate(loopType)) { - auto [i, type] = iteratorType; + for (auto loopTypeIter : llvm::enumerate(loopType)) { + auto [i, loopType] = loopTypeIter; auto currentDim = loopDim[i]; auto currentTileSize = nestedTileSizes[i]; - if (type == OuterLoopGenerationOption::LoopType::ForOp) { - scf::SCFTilingOptions tileOption; - SmallVector TileSizes( - currentOp.getNumLoops(), getAsIndexOpFoldResult(b.getContext(), 0)); - + if (loopType == OuterLoopGenerationOption::LoopType::ForOp) { for (auto [d, tile] : llvm::zip(currentDim, currentTileSize)) { + scf::SCFTilingOptions tileOption; + SmallVector TileSizes( + currentOp.getNumLoops(), getAsIndexOpFoldResult(b.getContext(), 0)); TileSizes[d] = getAsIndexOpFoldResult(b.getContext(), tile); - } - tileOption.setTileSizes(TileSizes); - tileOption.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); - - OpBuilder::InsertionGuard guard(b); - b.setInsertionPoint(currentOp); - auto tilingResult = scf::tileUsingSCF( - b, cast(currentOp.getOperation()), tileOption); - if (failed(tilingResult)) - return failure(); - b.replaceOp(currentOp, tilingResult->replacements); - currentOp = dyn_cast(tilingResult->tiledOps.back()); - - for (auto [dim, loop] : llvm::zip(currentDim, tilingResult->loops)) { - if (iteratorTypes[dim] == mlir::utils::IteratorType::reduction) { - result.reductionLoops.push_back(loop); + tileOption.setTileSizes(TileSizes); + tileOption.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(currentOp); + // TODO: refactor here to use a callback function + if (iteratorTypes[d] == mlir::utils::IteratorType::reduction && + tile != 0) { + auto result = matmulDtypeLegalize(b, currentOp.getOperation(), + !option.hasFillOp); + if (result->castOp && result->linalgOp) { + b.replaceOp(currentOp, result->castOp); + currentOp = dyn_cast(result->linalgOp); + } } - result.loops.push_back(loop); + auto tilingResult = scf::tileUsingSCF( + b, cast(currentOp.getOperation()), tileOption); + if (failed(tilingResult)) + return failure(); + b.replaceOp(currentOp, tilingResult->replacements); + currentOp = dyn_cast(tilingResult->tiledOps.back()); + if (iteratorTypes[d] == mlir::utils::IteratorType::reduction) { + result.reductionLoops.push_back(tilingResult->loops.back()); + } + result.loops.push_back(tilingResult->loops.back()); } - } else if (type == OuterLoopGenerationOption::LoopType::ForallOp) { + } else if (loopType == OuterLoopGenerationOption::LoopType::ForallOp) { SmallVector tileSizes( currentOp.getNumLoops(), getAsIndexOpFoldResult(b.getContext(), 0)); SmallVector threads( @@ -251,24 +394,8 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, else tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), tile); } - - SmallVector numThreads; SmallVector loopRanges = cast(currentOp.getOperation()).getIterationDomain(b); - unsigned nLoops = loopRanges.size(); - numThreads.reserve(nLoops); - AffineExpr s0, s1; - bindSymbols(b.getContext(), s0, s1); - AffineExpr divExpr = s0.ceilDiv(s1); - for (const auto &it : llvm::zip(tileSizes, loopRanges)) { - OpFoldResult numTiles = std::get<0>(it); - if (!isConstantIntValue(numTiles, 0)) - numTiles = mlir::affine::makeComposedFoldedAffineApply( - b, currentOp.getLoc(), divExpr, - {std::get<1>(it).size, std::get<0>(it)}); - numThreads.push_back(numTiles); - } - OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(currentOp); if (auto partialInterface = @@ -289,6 +416,13 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, if (failed(tilingResult)) return failure(); currentOp = dyn_cast(tilingResult->parallelTiledOp); + if (option.hasFillOp && tilingResult->mergeOp) { + auto fillOp = findParentFillOp(tilingResult->loops.getDpsInits()[0]); + if (fillOp) { + b.replaceOp(fillOp, dyn_cast(*fillOp) + .getDpsInits()[0]); + } + } } else if (auto tilingInterface = cast(currentOp.getOperation())) { auto tilingResult = linalg::tileToForallOpUsingTileSizes( @@ -304,6 +438,22 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, return result; } +[[maybe_unused]] static LogicalResult +indexRolling(RewriterBase &b, Block *insertBlock, Location loc, Value v, + Value rollingIdx, Value maximumRange, Value step) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(insertBlock); + mlir::easybuild::EasyBuilder eb{b, loc}; + auto vWraped = eb.wrap(v); + auto rollingIdxWraped = eb.wrap(rollingIdx); + auto stepWraped = eb.wrap(step); + auto maximumRangeWraped = eb.wrap(step); + auto newV = (vWraped + rollingIdxWraped) * stepWraped % + (maximumRangeWraped / stepWraped * stepWraped); + v.replaceAllUsesWith(newV); + return failure(); +} + static void getMatmulParallelDims(linalg::LinalgOp linalgOp, unsigned operandIdx, SmallVectorImpl &dims) { @@ -380,38 +530,6 @@ static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter, return success(); } -enum DimType { Batch, M, N, K }; - -static FailureOr>> -getOprandDimType(linalg::LinalgOp &linalgOp) { - // TODO: add more support for other linalg named matmul - if (isa(linalgOp)) { - return SmallVector>{ - SmallVector{DimType::M, DimType::K}, - SmallVector{DimType::K, DimType::N}, - SmallVector{DimType::M, DimType::N}}; - } else if (isa(linalgOp)) { - auto iteratorTypes = linalgOp.getIteratorTypesArray(); - if (iteratorTypes.size() == 7UL) { - // 4Dx5D, brgemm vnni - return SmallVector>{ - SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, - SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, - DimType::K}, - SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; - } else if (iteratorTypes.size() == 6UL) { - // 4Dx4D - return SmallVector>{ - SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, - SmallVector{DimType::N, DimType::K, DimType::K, DimType::N}, - SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; - } - } else { - return failure(); - } - return failure(); -} - /* matmul(A, B) -> C ----------------> @@ -452,7 +570,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { FailureOr outerLoopGeneration(RewriterBase &rewriter, linalg::LinalgOp linalgOp, - MatmulConfig cfg) const { + MatmulConfig cfg, bool hasFillOp) const { SmallVector KDimPos, MDimPos, NDimPos; linalgOp.getReductionDims(KDimPos); getMatmulParallelDims(linalgOp, 0, MDimPos); @@ -489,6 +607,14 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1 : cfg.NBlock; // Outer + if (cfg.KThreads > 1) { + auto result = + matmulDtypeLegalize(rewriter, linalgOp.getOperation(), !hasFillOp); + if (result->castOp && result->linalgOp) { + rewriter.replaceOp(linalgOp, result->castOp); + linalgOp = dyn_cast(result->linalgOp); + } + } option.nestedTileSizes.emplace_back(SmallVector{ MParallelBlockSize, NParallelBlockSize, KParallelBlockSize}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForallOp); @@ -519,6 +645,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { option.loopDim.emplace_back(SmallVector{(int)dim}); } } + option.hasFillOp = hasFillOp; return generateOuterLoop(rewriter, linalgOp, option); } @@ -649,8 +776,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { resultType.getElementType()), currentOp.getDpsInits()[0]); - // Create the brgemm op - // TODO: use brgemm_vnni to replace generic when it is applicable + // Create the brgemm op and replace the origin linalg op linalg::LinalgOp matmul; if (BInnermostDims.size() == 4 || BInnermostDims.size() == 2) { matmul = rewriter.create( @@ -673,6 +799,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { rewriter.replaceOp(currentOp, matmul.getOperation()->getResult(0)); currentOp = matmul; + // Fuse the fill op to the innermost body if (auto fillOp = llvm::dyn_cast_or_null(option.fillOp)) { auto fillValue = fillOp.getDpsInputs()[0]; rewriter.replaceOp(fillOp, fillOp.getDpsInits()[0]); @@ -699,37 +826,49 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { } auto ifOp = eb.getLastOperaion(); rewriter.replaceOp(currentOp, ifOp); - ifOp->getParentOfType().dump(); } return success(); } - LogicalResult matchAndRewrite(linalg::LinalgOp matmulOp, + LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, PatternRewriter &rewriter) const override { - if (matmulOp.hasPureBufferSemantics()) + if (linalgOp.hasPureBufferSemantics()) return failure(); - linalg::LinalgOp linalgOp; - linalgOp = dyn_cast(matmulOp.getOperation()); - if (linalgOp.getOperation()->getParentOfType()) + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(linalgOp); + if (linalgOp.getOperation()->getParentOfType() || + !linalgOp || linalgOp.getNumDpsInputs() != 2) return failure(); - - Operation *fillOp = linalgOp.getDpsInits()[0].getDefiningOp(); - + Operation *fillOp = findParentFillOp(linalgOp.getDpsInits()[0]); + linalg::LinalgOp originOp = + dyn_cast(*rewriter.clone(*(linalgOp.getOperation()))); // Step 1. generate the outer loop - // 2.0 Get the iteration infomation first MatmulConfig cfg = getDefaultMatmulConfig(linalgOp); - auto outerLoopResult = outerLoopGeneration(rewriter, linalgOp, cfg); + auto outerLoopResult = outerLoopGeneration(rewriter, linalgOp, cfg, + isa(fillOp)); if (failed(outerLoopResult)) { return failure(); } linalgOp = dyn_cast(outerLoopResult->tiledOps.back()); - - // Step 2 generate inner loop body, convert the linalg.generic to brgemm + // Step 2 index rolling + // if (failed(indexRolling(rewriter, linalgOp.getLoc(), + // outerLoopResult->reductionLoops[0].getInductionVar(), + // linalgOp.getLoopRanges()[0].size, cfg.MBlock)) + // || + // failed(indexRolling(rewriter, linalgOp.getLoc(), + // linalgOp.getDpsInputOperand(1), + // linalgOp.getLoopRanges()[1].size, cfg.KBlock))) + // { + // return failure(); + // } + + // Step 3 generate inner loop body, convert the linalg.generic to brgemm auto option = innerBodyGenerationOption{fillOp, outerLoopResult->reductionLoops}; - if (failed(innerBodyGeneration(rewriter, matmulOp, linalgOp, option))) { + if (failed(innerBodyGeneration(rewriter, originOp, linalgOp, option))) { return failure(); } + rewriter.eraseOp(originOp); return success(); } }; @@ -754,7 +893,6 @@ struct DeepTileContractionNamedOp dialect->getCanonicalizationPatterns(patterns); for (RegisteredOperationName op : ctx.getRegisteredOperations()) op.getCanonicalizationPatterns(patterns, &ctx); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); From 66a986b9fd1461ba8efd2c3e2507439b44857f47 Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Mon, 3 Jun 2024 22:48:19 -0700 Subject: [PATCH 07/21] replace generic op with named op --- .../Transforms/DeepTileContractionNamedOp.cpp | 63 +++++++++++------- .../deepTileContractionNamedOp.mlir | 64 +++++++++++-------- 2 files changed, 78 insertions(+), 49 deletions(-) diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index f9b0ea1b8..2d461e068 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -8,6 +8,7 @@ #include "./Tiling.hpp" #include "gc/Dialect/Arith/Utils/EasyBuild.h" +#include "gc/Dialect/Linalgx/LinalgxOps.h" #include "gc/IR/EasyBuild.h" #include "gc/IR/EasyBuildSCF.h" #include "mlir/AsmParser/AsmParser.h" @@ -68,24 +69,23 @@ getOprandDimType(linalg::LinalgOp &linalgOp) { SmallVector{DimType::M, DimType::K}, SmallVector{DimType::K, DimType::N}, SmallVector{DimType::M, DimType::N}}; - } else if (isa(linalgOp)) { - auto iteratorTypes = linalgOp.getIteratorTypesArray(); - if (iteratorTypes.size() == 7UL) { - // 4Dx5D, brgemm vnni - return SmallVector>{ - SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, - SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, - DimType::K}, - SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; - } else if (iteratorTypes.size() == 6UL) { - // 4Dx4D - return SmallVector>{ - SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, - SmallVector{DimType::N, DimType::K, DimType::K, DimType::N}, - SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; - } - } else { - return failure(); + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, + DimType::K}, + SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, + DimType::K}, + SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::Batch, DimType::M, DimType::K}, + SmallVector{DimType::Batch, DimType::K, DimType::N}, + SmallVector{DimType::Batch, DimType::M, DimType::N}}; } return failure(); } @@ -136,7 +136,7 @@ MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) { cfg.KBlock = 64; cfg.MThreads = 2; cfg.NThreads = 2; - cfg.KThreads = 1; + cfg.KThreads = 2; return cfg; } @@ -784,8 +784,9 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { ValueRange{dataOprand, weightOprand}, resultOprand); } else { IRMapping mapping; - matmul = dyn_cast( - *rewriter.clone(*(currentOp.getOperation()))); + matmul = rewriter.create( + resultOprand.getLoc(), resultOprand.getType(), + ValueRange{dataOprand, weightOprand}, resultOprand); } Value result = matmul.getOperation()->getResult(0); @@ -830,18 +831,32 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { return success(); } + bool checkLinalgMatmulType(linalg::LinalgOp linalgOp) const { + return llvm::isa(linalgOp) || + llvm::isa(linalgOp) || + llvm::isa(linalgOp) || + llvm::isa(linalgOp) || + llvm::isa(linalgOp); + } + LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, PatternRewriter &rewriter) const override { + if (!checkLinalgMatmulType(linalgOp)) + return failure(); if (linalgOp.hasPureBufferSemantics()) return failure(); - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(linalgOp); + if (linalgOp.getOperation()->getParentOfType() || !linalgOp || linalgOp.getNumDpsInputs() != 2) return failure(); - Operation *fillOp = findParentFillOp(linalgOp.getDpsInits()[0]); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(linalgOp); linalg::LinalgOp originOp = dyn_cast(*rewriter.clone(*(linalgOp.getOperation()))); + linalgOp = *linalg::generalizeNamedOp(rewriter, linalgOp); + Operation *fillOp = findParentFillOp(linalgOp.getDpsInits()[0]); + // Step 1. generate the outer loop MatmulConfig cfg = getDefaultMatmulConfig(linalgOp); auto outerLoopResult = outerLoopGeneration(rewriter, linalgOp, cfg, diff --git a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir index 209145f9e..d221557fc 100644 --- a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir +++ b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir @@ -1,26 +1,21 @@ // RUN: gc-opt --split-input-file --deep-tile-contraction-named-op %s -// ----- +// // ----- -/// CHECK-LABEL: @blocked_matmul_f32 -func.func @blocked_matmul_f32(%arg0: tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> { - %cst = arith.constant dense<1.000000e+00> : tensor<128x128x32x32xf32> - %cst_0 = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<128x128x32x32xf32> - %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> - %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %cst : tensor<128x128x32x32xf32>, tensor<128x128x32x32xf32>) outs(%1 : tensor<128x128x32x32xf32>) { - ^bb0(%in: f32, %in_1: f32, %out: f32): - %3 = arith.mulf %in, %in_1 : f32 - %4 = arith.addf %out, %3 : f32 - linalg.yield %4 : f32 - } -> tensor<128x128x32x32xf32> - return %2 : tensor<128x128x32x32xf32> -} +// /// CHECK-LABEL: @matmul_4Dx4D_f32 +// func.func @matmul_4Dx4D_f32(%arg0: tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> { +// %cst = arith.constant dense<1.000000e+00> : tensor<128x128x32x32x1xf32> +// %cst_0 = arith.constant 0.000000e+00 : f32 +// %0 = tensor.empty() : tensor<128x128x32x32xf32> +// %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> +// %2 = linalgx.mm4d_vnni ins(%arg0, %cst : tensor<128x128x32x32xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> +// return %2 : tensor<128x128x32x32xf32> +// } // ----- -/// CHECK-LABEL: @plain_matmul_f32 -func.func @plain_matmul_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> { +/// CHECK-LABEL: @matmul_2Dx2D_f32 +func.func @matmul_2Dx2D_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> { %cst = arith.constant dense<1.000000e+00> : tensor<4096x4096xf32> %cst_0 = arith.constant 0.000000e+00 : f32 %0 = tensor.empty() : tensor<4096x4096xf32> @@ -29,20 +24,39 @@ func.func @plain_matmul_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf3 return %2 : tensor<4096x4096xf32> } +// // ----- + +// /// CHECK-LABEL: @matmul_2Dx4D_f32 +// func.func @matmul_4Dx4D_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> { +// %cst = arith.constant dense<1.000000e+00> : tensor<128x128x32x32x1xf32> +// %cst_0 = arith.constant 0.000000e+00 : f32 +// %0 = tensor.empty() : tensor<4096x4096xf32> +// %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> +// %2 = linalgx.mm2d_vnni ins(%arg0, %cst : tensor<4096x4096xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> +// return %2 : tensor<4096x4096xf32> +// } + // ----- -/// CHECK-LABEL: @blocked_matmul_bf16 -func.func @blocked_matmul_bf16(%arg0: tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> { +/// CHECK-LABEL: @matmul_4Dx4D_bf16 +func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> { %cst = arith.constant dense<1.000000e+00> : tensor<128x128x16x32x2xbf16> %cst_0 = arith.constant 0.000000e+00 : bf16 %0 = tensor.empty() : tensor<128x128x32x32xbf16> %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> - %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)>], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %cst : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>) { - ^bb0(%in: bf16, %in_1: bf16, %out: bf16): - %3 = arith.mulf %in, %in_1 : bf16 - %4 = arith.addf %out, %3 : bf16 - linalg.yield %4 : bf16 - } -> tensor<128x128x32x32xbf16> + %2 = linalgx.mm4d_vnni ins(%arg0, %cst : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> return %2 : tensor<128x128x32x32xbf16> } +// // ----- + +// /// CHECK-LABEL: @matmul_2Dx4D_bf16 +// func.func @matmul_4Dx4D_bf16(%arg0: tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> { +// %cst = arith.constant dense<1.000000e+00> : tensor<128x128x16x32x2xbf16> +// %cst_0 = arith.constant 0.000000e+00 : bf16 +// %0 = tensor.empty() : tensor<4096x4096xbf16> +// %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> +// %2 = linalgx.mm2d_vnni ins(%arg0, %cst : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> +// return %2 : tensor<4096x4096xbf16> +// } + From 56624bb12da0e04ba243c5ab493083ba5c60d624 Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Tue, 4 Jun 2024 19:28:46 -0700 Subject: [PATCH 08/21] support 2Dx4D/5D case --- .../Transforms/DeepTileContractionNamedOp.cpp | 201 +++++++++++------- lib/gc/Transforms/Tiling.cpp | 24 +-- .../deepTileContractionNamedOp.mlir | 45 ++-- 3 files changed, 159 insertions(+), 111 deletions(-) diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index 2d461e068..f73f213fa 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -136,13 +136,14 @@ MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) { cfg.KBlock = 64; cfg.MThreads = 2; cfg.NThreads = 2; - cfg.KThreads = 2; + cfg.KThreads = 1; return cfg; } -static Value tensorViewRankedTensor(RewriterBase &rewriter, - RankedTensorType outTensorType, - Value value) { +static Value +tensorViewRankedTensor(RewriterBase &rewriter, RankedTensorType outTensorType, + Value value, + ArrayRef permutation = SmallVector{}) { // TODO: add support for plain layout transpose Value result, currentValue = value; auto loc = currentValue.getLoc(); @@ -175,33 +176,57 @@ static Value tensorViewRankedTensor(RewriterBase &rewriter, if (outShape.size() < inShape.size()) { SmallVector reassocIndices; - ReassociationIndices firstEntry; - for (auto i = 0UL; i < inShape.size() - outShape.size() + 1; i++) { - firstEntry.push_back(i); - } - reassocIndices.push_back(firstEntry); - for (auto i = inShape.size() - outShape.size() + 1UL; i < inShape.size(); - i++) { - reassocIndices.push_back({(int)i}); + uint64_t outIdx = 0UL, inIdx = 0UL; + while (inIdx < inShape.size() && outIdx < outShape.size()) { + ReassociationIndices firstEntry; + auto remaining = outShape[outIdx++]; + if (remaining == 1) { + firstEntry.push_back(inIdx++); + reassocIndices.push_back(firstEntry); + continue; + } + while (remaining > 1) { + remaining /= inShape[inIdx]; + firstEntry.push_back(inIdx++); + } + reassocIndices.push_back(firstEntry); } result = rewriter.create( loc, outTensorType, currentValue, reassocIndices); } else if (outShape.size() > inShape.size()) { SmallVector reassocIndices; - ReassociationIndices firstEntry; - for (auto i = 0UL; i < outShape.size() - inShape.size() + 1; i++) { - firstEntry.push_back((int)i); - } - reassocIndices.push_back(firstEntry); - for (auto i = outShape.size() - inShape.size() + 1UL; i < outShape.size(); - i++) { - reassocIndices.push_back({(int)i}); + uint64_t outIdx = 0UL, inIdx = 0UL; + while (outIdx < outShape.size() && inIdx < inShape.size()) { + ReassociationIndices firstEntry; + auto remaining = inShape[inIdx++]; + if (remaining == 1) { + firstEntry.push_back(outIdx++); + reassocIndices.push_back(firstEntry); + continue; + } + while (remaining > 1) { + remaining /= outShape[outIdx]; + firstEntry.push_back(outIdx++); + } + reassocIndices.push_back(firstEntry); } result = rewriter.create( loc, outTensorType, currentValue, reassocIndices); } else { result = rewriter.create(loc, outTensorType, currentValue); } + + if (!permutation.empty()) { + SmallVector transposeShape; + for (auto idx : permutation) { + transposeShape.push_back(outShape[idx]); + } + auto initOp = rewriter.create(loc, transposeShape, + tensorElementType); + auto transposeOp = rewriter.create( + loc, result, initOp->getResult(0), permutation); + result = transposeOp->getResult(0); + } return result; } @@ -345,6 +370,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, return b.notifyMatchFailure( linalgOp, "currentOp should not has pure buffer semantics"); linalg::LinalgOp currentOp = linalgOp; + for (auto loopTypeIter : llvm::enumerate(loopType)) { auto [i, loopType] = loopTypeIter; auto currentDim = loopDim[i]; @@ -486,6 +512,8 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter, bool isExtract, SmallVector size, int shrinDimNum = 0) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); if (auto extractSlice = dyn_cast(op)) { SmallVector mixedOffsets = extractSlice.getMixedOffsets(); SmallVector mixedSizes = extractSlice.getMixedSizes(); @@ -514,6 +542,8 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter, static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter, Operation *op, Value source, SmallVector size) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); if (auto insertSlice = dyn_cast(op)) { SmallVector mixedOffsets = insertSlice.getMixedOffsets(); SmallVector mixedSizes = insertSlice.getMixedSizes(); @@ -575,7 +605,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { linalgOp.getReductionDims(KDimPos); getMatmulParallelDims(linalgOp, 0, MDimPos); getMatmulParallelDims(linalgOp, 1, NDimPos); - bool useBlockedLayout = KDimPos.size() > 1; OuterLoopGenerationOption option; auto iteratorTypes = linalgOp.getIteratorTypesArray(); @@ -583,27 +612,27 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { auto MFirstDim = (int)getOprandDim(linalgOp, MDimPos[0], 0); auto NFirstDim = (int)getOprandDim(linalgOp, NDimPos[0], 1); auto KParallelBlockSize = - useBlockedLayout + KDimPos.size() > 1 ? divAndCeil(KFirstDim, cfg.KThreads) : divAndCeil(divAndCeil(KFirstDim, cfg.KBlock), cfg.KThreads) * cfg.KBlock; auto MParallelBlockSize = - useBlockedLayout + MDimPos.size() > 1 ? divAndCeil(MFirstDim, cfg.MThreads) : divAndCeil(divAndCeil(MFirstDim, cfg.MBlock), cfg.MThreads) * cfg.MBlock; auto NParallelBlockSize = - useBlockedLayout + NDimPos.size() > 1 ? divAndCeil(NFirstDim, cfg.NThreads) : divAndCeil(divAndCeil(NFirstDim, cfg.NBlock), cfg.NThreads) * cfg.NBlock; - auto KOuterBlockSize = useBlockedLayout + auto KOuterBlockSize = KDimPos.size() > 1 ? (cfg.KBlock - 1) / cfg.innerMostKBlock + 1 : cfg.KBlock; - auto MOuterBlockSize = useBlockedLayout + auto MOuterBlockSize = MDimPos.size() > 1 ? (cfg.MBlock - 1) / cfg.innerMostMBlock + 1 : cfg.MBlock; - auto NOuterBlockSize = useBlockedLayout + auto NOuterBlockSize = NDimPos.size() > 1 ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1 : cfg.NBlock; // Outer @@ -631,11 +660,23 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { option.loopDim.emplace_back(SmallVector{dim}); } // Inner - if (!useBlockedLayout) { + if (KDimPos.size() == 1) { option.nestedTileSizes.emplace_back(SmallVector{cfg.KBlock}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); option.loopDim.emplace_back(SmallVector{(int)KDimPos.back()}); } + if (MDimPos.size() == 1) { + option.nestedTileSizes.emplace_back( + SmallVector{cfg.innerMostMBlock}); + option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); + option.loopDim.emplace_back(SmallVector{(int)MDimPos.back()}); + } + if (NDimPos.size() == 1) { + option.nestedTileSizes.emplace_back( + SmallVector{cfg.innerMostNBlock}); + option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); + option.loopDim.emplace_back(SmallVector{(int)NDimPos.back()}); + } for (auto dim = 0UL; dim < linalgOp.getNumLoops(); dim++) { if (dim != MDimPos.back() && dim != NDimPos.back() && iteratorTypes[dim] != mlir::utils::IteratorType::reduction) { @@ -658,17 +699,24 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { linalg::LinalgOp originOp, linalg::LinalgOp currentOp, innerBodyGenerationOption &option) const { + mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc()}; auto operandDimTypes = getOprandDimType(originOp); MatmulConfig cfg = getDefaultMatmulConfig(originOp); auto AShape = originOp.getShape(originOp.getDpsInputOperand(0)); auto BShape = originOp.getShape(originOp.getDpsInputOperand(1)); auto CShape = originOp.getShape(originOp.getDpsInitOperand(0)); - bool useBlockedLayout = BShape.size() > 2; + + auto MDimNum = std::count_if((*operandDimTypes)[0].begin(), + (*operandDimTypes)[0].end(), + [](DimType d) { return d == DimType::M; }); + auto NDimNum = std::count_if((*operandDimTypes)[1].begin(), + (*operandDimTypes)[1].end(), + [](DimType d) { return d == DimType::N; }); // TODO: support plain in/block out format SmallVector AInnermostDims, BInnermostDims, CInnermostDims; - if (useBlockedLayout) { - bool firstM = true, firstK = true, firstN = true; + bool firstM = true, firstK = true, firstN = true; + if (MDimNum > 1) { for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[0])) { if (iter == DimType::M && firstM) { AInnermostDims.push_back(1); @@ -682,21 +730,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { AInnermostDims.push_back(AShape[idx]); } } - firstN = true; - firstK = true; - for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[1])) { - if (iter == DimType::N && firstN) { - BInnermostDims.push_back(1); - firstN = false; - } else if (iter == DimType::Batch) { - BInnermostDims.push_back(1); - } else if (iter == DimType::K && firstK) { - BInnermostDims.push_back(cfg.KBlock / cfg.innerMostKBlock); - firstK = false; - } else { - BInnermostDims.push_back(BShape[idx]); - } - } firstM = true; firstN = true; for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[2])) { @@ -716,74 +749,94 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { AInnermostDims = SmallVector{cfg.innerMostMBlock, cfg.KBlock / cfg.innerMostKBlock * cfg.innerMostKBlock}; + CInnermostDims = + SmallVector{cfg.innerMostMBlock, cfg.innerMostNBlock}; + } + if (NDimNum > 1) { + firstN = true; + firstK = true; + for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[1])) { + if (iter == DimType::N && firstN) { + BInnermostDims.push_back(1); + firstN = false; + } else if (iter == DimType::Batch) { + BInnermostDims.push_back(1); + } else if (iter == DimType::K && firstK) { + BInnermostDims.push_back(cfg.KBlock / cfg.innerMostKBlock); + firstK = false; + } else { + BInnermostDims.push_back(BShape[idx]); + } + } + } else { BInnermostDims = SmallVector{cfg.KBlock / cfg.innerMostKBlock * cfg.innerMostKBlock, cfg.innerMostNBlock}; - CInnermostDims = - SmallVector{cfg.innerMostMBlock, cfg.innerMostNBlock}; } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(currentOp); auto dataType = - dyn_cast(currentOp.getDpsInputs()[0].getType()); + dyn_cast(currentOp.getDpsInputs()[0].getType()) + .getElementType(); auto weightType = - dyn_cast(currentOp.getDpsInputs()[1].getType()); + dyn_cast(currentOp.getDpsInputs()[1].getType()) + .getElementType(); auto resultType = - dyn_cast(currentOp.getDpsInits()[0].getType()); - // use shrink layout when it is able to be converted to brgemm - bool useShrinkedLayout = (BInnermostDims.size() == 4); + dyn_cast(currentOp.getDpsInits()[0].getType()) + .getElementType(); // update the extractSlice to static size, replace it with // useBlockedLayout when if (failed(setStaticSizeForExtractSliceOp( rewriter, currentOp.getDpsInits()[0].getDefiningOp(), true, - CInnermostDims, useShrinkedLayout ? 2 : 0)) || + CInnermostDims, MDimNum > 1 ? 2 : 0)) || failed(setStaticSizeForExtractSliceOp( rewriter, currentOp.getDpsInputs()[1].getDefiningOp(), true, - BInnermostDims, useShrinkedLayout)) || + BInnermostDims, NDimNum > 1)) || failed(setStaticSizeForExtractSliceOp( rewriter, currentOp.getDpsInputs()[0].getDefiningOp(), true, - AInnermostDims, useShrinkedLayout))) { + AInnermostDims, MDimNum > 1))) { return failure(); } - // View the tensor to brgemm required format Value dataOprand = tensorViewRankedTensor( rewriter, mlir::RankedTensorType::get( - useBlockedLayout - ? SmallVector(AInnermostDims.begin() + 1, - AInnermostDims.end()) - : SmallVector{1, AInnermostDims[0], AInnermostDims[1]}, - dataType.getElementType()), - currentOp.getDpsInputs()[0]); + MDimNum > 1 ? SmallVector(AInnermostDims.begin() + 1, + AInnermostDims.end()) + : SmallVector{cfg.innerMostMBlock, + cfg.KBlock / cfg.innerMostKBlock, + cfg.innerMostKBlock}, + dataType), + currentOp.getDpsInputs()[0], + MDimNum == 1 ? SmallVector{1, 0, 2} : SmallVector{}); Value weightOprand = tensorViewRankedTensor( rewriter, mlir::RankedTensorType::get( - useBlockedLayout - ? SmallVector(BInnermostDims.begin() + 1, - BInnermostDims.end()) - : SmallVector{1, BInnermostDims[0], BInnermostDims[1]}, - weightType.getElementType()), + NDimNum > 1 ? SmallVector(BInnermostDims.begin() + 1, + BInnermostDims.end()) + : SmallVector{cfg.KBlock / cfg.innerMostKBlock, + cfg.innerMostKBlock, + cfg.innerMostNBlock}, + weightType), currentOp.getDpsInputs()[1]); Value resultOprand = tensorViewRankedTensor( rewriter, mlir::RankedTensorType::get( - SmallVector(CInnermostDims.begin() + - (useBlockedLayout ? 2 : 0), + SmallVector(CInnermostDims.begin() + (MDimNum > 1 ? 2 : 0), CInnermostDims.end()), - resultType.getElementType()), + resultType), currentOp.getDpsInits()[0]); - // Create the brgemm op and replace the origin linalg op linalg::LinalgOp matmul; - if (BInnermostDims.size() == 4 || BInnermostDims.size() == 2) { + if (dyn_cast(weightOprand.getType()) + .getShape() + .size() == 3) { matmul = rewriter.create( resultOprand.getLoc(), resultOprand.getType(), ValueRange{dataOprand, weightOprand}, resultOprand); } else { - IRMapping mapping; matmul = rewriter.create( resultOprand.getLoc(), resultOprand.getType(), ValueRange{dataOprand, weightOprand}, resultOprand); diff --git a/lib/gc/Transforms/Tiling.cpp b/lib/gc/Transforms/Tiling.cpp index 4c36e6661..4462472bb 100644 --- a/lib/gc/Transforms/Tiling.cpp +++ b/lib/gc/Transforms/Tiling.cpp @@ -922,13 +922,14 @@ FailureOr tileAllUsingForall( auto *it = llvm::find(dest, initOperand); assert(it != dest.end() && "dest operand not found in dest"); unsigned destNum = std::distance(dest.begin(), it); - SmallVector strides(numThreads.size(), b.getIndexAttr(1)); - SmallVector outOffsets(numThreads.size(), + auto dest = destBbArgs[destNum]; + auto destShape = cast(dest.getType()).getShape(); + SmallVector strides(destShape.size(), b.getIndexAttr(1)); + SmallVector outOffsets(destShape.size(), b.getIndexAttr(0)); - SmallVector sizes = tiledSizes; + SmallVector sizes(destShape.size(), b.getIndexAttr(0)); for (const auto &iteratorType : llvm::enumerate( - cast(destBbArgs[destNum].getType()) - .getShape())) { + cast(dest.getType()).getShape())) { sizes[iteratorType.index()] = getAsIndexOpFoldResult(b.getContext(), iteratorType.value()); if (llvm::find(constantNewParallelDims, iteratorType.index()) != @@ -950,8 +951,8 @@ FailureOr tileAllUsingForall( } // TODO: use SubsetExtractOpInterface once it is available. tiledDpsInitOperands.push_back(b.create( - loc, cast(initOperand.getType()), - destBbArgs[destNum], outOffsets, sizes, strides)); + loc, cast(initOperand.getType()), dest, + outOffsets, sizes, strides)); } else { auto *it = llvm::find(dest, initOperand); assert(it != dest.end() && "dest operand not found in dest"); @@ -1016,7 +1017,7 @@ FailureOr tileAllUsingForall( b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes))) return op->emitOpError("output offsets couldn't be calculated"); SmallVector resultOffsetsRank, resultSizesRank; - int64_t offIdx = 0; + uint64_t offIdx = 0; int64_t nonZeroDimIdx = 0; SmallVector reductionInductionVars; for (auto i = 0UL; i < numThreads.size(); ++i) { @@ -1026,7 +1027,7 @@ FailureOr tileAllUsingForall( resultOffsetsRank.push_back(b.getIndexAttr(1)); resultSizesRank.push_back(b.getIndexAttr(1)); } - } else { + } else if (offIdx < resultOffsets.size()) { resultOffsetsRank.push_back(resultOffsets[offIdx]); resultSizesRank.push_back(resultSizes[offIdx++]); } @@ -1058,9 +1059,8 @@ FailureOr tileAllUsingForall( Operation *mergeOp = nullptr; b.setInsertionPointAfter(forallOp); if (hasReductionThreads) { - Operation *mergeOp = - linalgX::LinalgOpPartialReductionInterface::mergeReductions( - op, b, loc, forallOp->getResults(), constantNewParallelDims); + mergeOp = linalgX::LinalgOpPartialReductionInterface::mergeReductions( + op, b, loc, forallOp->getResults(), constantNewParallelDims); b.replaceOp(op, mergeOp->getResults()); } else { b.replaceOp(op, forallOp->getResults()); diff --git a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir index d221557fc..bfb9a52e9 100644 --- a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir +++ b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir @@ -1,62 +1,57 @@ // RUN: gc-opt --split-input-file --deep-tile-contraction-named-op %s -// // ----- +// ----- // /// CHECK-LABEL: @matmul_4Dx4D_f32 -// func.func @matmul_4Dx4D_f32(%arg0: tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> { -// %cst = arith.constant dense<1.000000e+00> : tensor<128x128x32x32x1xf32> +// func.func @matmul_4Dx4D_f32(%arg0: tensor<128x128x32x32xf32>, %arg1 : tensor<128x128x32x32x1xf32>) -> tensor<128x128x32x32xf32> { // %cst_0 = arith.constant 0.000000e+00 : f32 // %0 = tensor.empty() : tensor<128x128x32x32xf32> // %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> -// %2 = linalgx.mm4d_vnni ins(%arg0, %cst : tensor<128x128x32x32xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> +// %2 = linalgx.mm4d_vnni ins(%arg0, %arg1 : tensor<128x128x32x32xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> // return %2 : tensor<128x128x32x32xf32> // } // ----- /// CHECK-LABEL: @matmul_2Dx2D_f32 -func.func @matmul_2Dx2D_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> { - %cst = arith.constant dense<1.000000e+00> : tensor<4096x4096xf32> +func.func @matmul_2Dx2D_f32(%arg0: tensor<4096x4096xf32>, %arg1: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> { %cst_0 = arith.constant 0.000000e+00 : f32 %0 = tensor.empty() : tensor<4096x4096xf32> %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> - %2 = linalg.matmul ins(%arg0, %cst : tensor<4096x4096xf32>, tensor<4096x4096xf32>) outs(%1 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<4096x4096xf32>, tensor<4096x4096xf32>) outs(%1 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> return %2 : tensor<4096x4096xf32> } -// // ----- +// ----- // /// CHECK-LABEL: @matmul_2Dx4D_f32 -// func.func @matmul_4Dx4D_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> { -// %cst = arith.constant dense<1.000000e+00> : tensor<128x128x32x32x1xf32> +// func.func @matmul_4Dx4D_f32(%arg0: tensor<4096x4096xf32>, %arg1: tensor<128x128x32x32x1xf32>) -> tensor<4096x4096xf32> { // %cst_0 = arith.constant 0.000000e+00 : f32 // %0 = tensor.empty() : tensor<4096x4096xf32> // %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> -// %2 = linalgx.mm2d_vnni ins(%arg0, %cst : tensor<4096x4096xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> +// %2 = linalgx.mm2d_vnni ins(%arg0, %arg1 : tensor<4096x4096xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> // return %2 : tensor<4096x4096xf32> // } // ----- -/// CHECK-LABEL: @matmul_4Dx4D_bf16 -func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> { - %cst = arith.constant dense<1.000000e+00> : tensor<128x128x16x32x2xbf16> +// /// CHECK-LABEL: @matmul_4Dx4D_bf16 +func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<128x128x32x32xbf16> { %cst_0 = arith.constant 0.000000e+00 : bf16 %0 = tensor.empty() : tensor<128x128x32x32xbf16> %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> - %2 = linalgx.mm4d_vnni ins(%arg0, %cst : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> + %2 = linalgx.mm4d_vnni ins(%arg0, %arg1 : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> return %2 : tensor<128x128x32x32xbf16> } -// // ----- +// ----- -// /// CHECK-LABEL: @matmul_2Dx4D_bf16 -// func.func @matmul_4Dx4D_bf16(%arg0: tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> { -// %cst = arith.constant dense<1.000000e+00> : tensor<128x128x16x32x2xbf16> -// %cst_0 = arith.constant 0.000000e+00 : bf16 -// %0 = tensor.empty() : tensor<4096x4096xbf16> -// %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> -// %2 = linalgx.mm2d_vnni ins(%arg0, %cst : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> -// return %2 : tensor<4096x4096xbf16> -// } +/// CHECK-LABEL: @matmul_2Dx4D_bf16 +func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<4096x4096xbf16> { + %cst_0 = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<4096x4096xbf16> + %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> + %2 = linalgx.mm2d_vnni ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> + return %2 : tensor<4096x4096xbf16> +} From b950edbb98d2e4364aa1e2277f901b13c32e7dbb Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Wed, 12 Jun 2024 19:59:51 -0700 Subject: [PATCH 09/21] support fusing cast to the innermost loop --- .../Transforms/DeepTileContractionNamedOp.cpp | 476 +++++++++++------- 1 file changed, 299 insertions(+), 177 deletions(-) diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index f73f213fa..eabc434ca 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" @@ -30,7 +31,6 @@ #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Parser/Parser.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include #include "gc/Transforms/Passes.h" @@ -230,39 +230,44 @@ tensorViewRankedTensor(RewriterBase &rewriter, RankedTensorType outTensorType, return result; } -struct OuterLoopGenerationOption { - enum LoopType { ForOp, ForallOp }; - SmallVector> nestedTileSizes; - SmallVector loopType; - SmallVector> loopDim; - bool hasFillOp; -}; - -struct OuterLoopGenerationResult { - /// Tiled operations that are generated during tiling. The order does not - /// matter except the last op. The replacements are expected to be the results - /// of the last op. - SmallVector tiledOps; - /// The `scf.for` operations that iterate over the tiles. - SmallVector loops; - SmallVector reductionLoops; -}; +bool isDummyLoop(LoopLikeOpInterface loop) { + std::optional tripCount = mlir::constantTripCount( + *loop.getSingleLowerBound(), *loop.getSingleUpperBound(), + *loop.getSingleStep()); + if (tripCount) { + return *tripCount == 1; + } + return false; +} -static void buildLinalgRegion(Operation *op) { +static void buildLinalgRegion(Operation *op, bool createTemporaryOp = false) { SmallVector argTypes; SmallVector argLocs; for (const Value &opOperand : op->getOperands()) { argTypes.push_back(getElementTypeOrSelf(opOperand.getType())); argLocs.push_back(opOperand.getLoc()); } + auto initSize = op->getResults().size(); ImplicitLocOpBuilder b(op->getLoc(), op->getContext()); Region ®ion = op->getRegion(0); Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); b.setInsertionPointToStart(body); - auto *dialect = static_cast(op->getDialect()); - linalg::LinalgDialect::RegionBuilderFunType fun = - dialect->getRegionBuilder("linalg.matmul"); - fun(b, *body, op->getAttrs()); + if (createTemporaryOp) { + auto argNum = body->getNumArguments(); + SmallVector vals; + for (auto i = initSize; i > 0; i--) { + vals.push_back(body->getArgument(argNum - i)); + } + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToEnd(body); + Location loc = b.getUnknownLoc(); + b.create(loc, ValueRange(vals)); + } else { + auto *dialect = static_cast(op->getDialect()); + linalg::LinalgDialect::RegionBuilderFunType fun = + dialect->getRegionBuilder("linalg.matmul"); + fun(b, *body, op->getAttrs()); + } } struct DtypeLegalizeResult { @@ -270,25 +275,29 @@ struct DtypeLegalizeResult { Operation *castOp = nullptr; }; +bool needToLegalizeDtype(linalg::LinalgOp linalgOp) { + auto dataType = + dyn_cast(linalgOp.getDpsInputs()[0].getType()) + .getElementType(); + auto resultType = + dyn_cast(linalgOp.getDpsInits()[0].getType()) + .getElementType(); + return (dataType.isBF16() || dataType.isF16()) && dataType == resultType; +} + // Split a low precision matmul(bf16xbf16->bf16) to a combination // matmul(bf16xbf16->f32) + cast(f32->bf16) +// if needFurtherFuse=true, a middle temporary linalgOp(bf16xbf16->(f32,bf16)) +// will be created static FailureOr matmulDtypeLegalize(RewriterBase &rewriter, Operation *op, - bool needCopyInit = true) { - + bool needCopyInit = true, bool needFurtherFuse = false) { auto linalgOp = dyn_cast(op); DtypeLegalizeResult result; if (!linalgOp) return failure(); - auto dataType = - dyn_cast(linalgOp.getDpsInputs()[0].getType()) - .getElementType(); - auto resultType = - dyn_cast(linalgOp.getDpsInits()[0].getType()) - .getElementType(); - - if ((dataType.isBF16() || dataType.isF16()) && dataType == resultType) { + if (needToLegalizeDtype(linalgOp)) { rewriter.setInsertionPoint(linalgOp); IRMapping mapping; auto initOp = linalgOp.getDpsInits()[0].getDefiningOp(); @@ -315,14 +324,30 @@ matmulDtypeLegalize(RewriterBase &rewriter, Operation *op, linalgOp.getLoc(), initOp->getResult(0), currentOp->getResult(0)); } SmallVector newOperands = linalgOp->getOperands(); + auto oldInit = newOperands.back(); newOperands.back() = currentOp->getResult(0); + + auto indexingMaps = linalgOp.getIndexingMapsArray(); + indexingMaps.push_back(indexingMaps.back()); + SmallVector attrs(linalgOp->getAttrs()); + SmallVector types = {currentOp->getResult(0).getType()}; + if (needFurtherFuse) { + auto segmentSize = rewriter.getNamedAttr( + "operandSegmentSizes", rewriter.getDenseI32ArrayAttr({2, 2})); + for (auto &attr : attrs) { + if (attr.getName() == "indexing_maps") + attr.setValue(rewriter.getAffineMapArrayAttr(indexingMaps)); + if (attr.getName() == "operandSegmentSizes") + attr.setValue(segmentSize.getValue()); + } + types.push_back(oldInit.getType()); + newOperands.push_back(oldInit); + } OperationState state(linalgOp->getLoc(), linalgOp->getName(), newOperands, - currentOp->getResult(0).getType(), - linalgOp->getAttrs()); + types, attrs); state.addRegion(); currentOp = rewriter.create(state); - buildLinalgRegion(currentOp); - + buildLinalgRegion(currentOp, needFurtherFuse); auto castOp = rewriter.create( linalgOp.getLoc(), currentOp->getResult(0), initOp->getResult(0)); result.linalgOp = currentOp; @@ -348,6 +373,129 @@ static Operation *findParentFillOp(Value val) { return nullptr; } +[[maybe_unused]] static LogicalResult +indexRolling(RewriterBase &b, Block *insertBlock, Location loc, Value v, + Value rollingIdx, Value maximumRange, Value step) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(insertBlock); + mlir::easybuild::EasyBuilder eb{b, loc}; + auto vWraped = eb.wrap(v); + auto rollingIdxWraped = eb.wrap(rollingIdx); + auto stepWraped = eb.wrap(step); + auto maximumRangeWraped = eb.wrap(step); + auto newV = (vWraped + rollingIdxWraped) * stepWraped % + (maximumRangeWraped / stepWraped * stepWraped); + v.replaceAllUsesWith(newV); + return failure(); +} + +static void getMatmulParallelDims(linalg::LinalgOp linalgOp, + unsigned operandIdx, + SmallVectorImpl &dims) { + AffineMap map = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(operandIdx)); + SmallVector iteratorTypes = + linalgOp.getIteratorTypesArray(); + + ArrayRef results = map.getResults(); + for (auto dim : results) { + auto dimExpr = dyn_cast(dim); + if (dimExpr && iteratorTypes[dimExpr.getPosition()] == + mlir::utils::IteratorType::parallel) { + dims.push_back(dimExpr.getPosition()); + } + } +} + +static unsigned getOprandDim(linalg::LinalgOp &linalgOp, unsigned iteratorPos, + unsigned operandIdx) { + Value Operand; + unsigned dimPos; + [[maybe_unused]] auto result = + linalgOp.mapIterationSpaceDimToOperandDim(iteratorPos, Operand, dimPos); + return linalgOp.getShape(linalgOp.getDpsInputOperand(operandIdx))[dimPos]; +} + +static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter, + Operation *op, + bool isExtract, + SmallVector size, + int shrinDimNum = 0) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + if (auto extractSlice = dyn_cast(op)) { + SmallVector mixedOffsets = extractSlice.getMixedOffsets(); + SmallVector mixedSizes = extractSlice.getMixedSizes(); + SmallVector mixedStrides = extractSlice.getMixedStrides(); + for (auto i = 0UL; i < mixedSizes.size(); i++) { + mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), size[i]); + } + if (shrinDimNum > 0) { + rewriter.replaceOpWithNewOp( + extractSlice, + mlir::RankedTensorType::get( + SmallVector(size.begin() + shrinDimNum, size.end()), + extractSlice.getResult().getType().getElementType()), + extractSlice.getSource(), mixedOffsets, mixedSizes, mixedStrides); + } else { + rewriter.replaceOpWithNewOp( + extractSlice, extractSlice.getSource(), mixedOffsets, mixedSizes, + mixedStrides); + } + } else { + return failure(); + } + return mlir::success(); +} + +static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter, + Operation *op, Value source, + SmallVector size) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + if (auto insertSlice = dyn_cast(op)) { + SmallVector mixedOffsets = insertSlice.getMixedOffsets(); + SmallVector mixedSizes = insertSlice.getMixedSizes(); + SmallVector mixedStrides = insertSlice.getMixedStrides(); + for (auto i = 0UL; i < mixedSizes.size(); i++) { + mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), size[i]); + } + rewriter.replaceOpWithNewOp( + insertSlice, source, insertSlice.getDest(), mixedOffsets, mixedSizes, + mixedStrides); + } else { + return failure(); + } + return success(); +} + +using InnermostFullResultCallBackFn = std::function( + RewriterBase &rewriter, Location loc, linalg::LinalgOp linalgop)>; + +using FinalReduceCallBackFn = std::function( + RewriterBase &rewriter, Location loc, + linalg::ForallReductionTilingResult result)>; + +struct OuterLoopGenerationOption { + enum LoopType { ForOp, ForallOp }; + SmallVector> nestedTileSizes; + SmallVector loopType; + SmallVector> loopDim; + SmallVector innermostFullResultCallBacks; + SmallVector finalReduceCallBacks; + bool isPartialResult = false; +}; + +struct OuterLoopGenerationResult { + /// Tiled operations that are generated during tiling. The order does not + /// matter except the last op. The replacements are expected to be the results + /// of the last op. + SmallVector tiledOps; + /// The `scf.for` operations that iterate over the tiles. + SmallVector loops; + SmallVector reductionLoops; +}; + static FailureOr generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, const OuterLoopGenerationOption &option) { @@ -371,6 +519,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, linalgOp, "currentOp should not has pure buffer semantics"); linalg::LinalgOp currentOp = linalgOp; + bool hasFullResult = !option.isPartialResult; for (auto loopTypeIter : llvm::enumerate(loopType)) { auto [i, loopType] = loopTypeIter; auto currentDim = loopDim[i]; @@ -385,26 +534,29 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, tileOption.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(currentOp); - // TODO: refactor here to use a callback function if (iteratorTypes[d] == mlir::utils::IteratorType::reduction && - tile != 0) { - auto result = matmulDtypeLegalize(b, currentOp.getOperation(), - !option.hasFillOp); - if (result->castOp && result->linalgOp) { - b.replaceOp(currentOp, result->castOp); - currentOp = dyn_cast(result->linalgOp); + tile != 0 && hasFullResult) { + for (const auto &fn : option.innermostFullResultCallBacks) { + auto result = fn(b, currentOp.getLoc(), currentOp); + if (succeeded(result)) { + currentOp = *result; + } } + hasFullResult = false; } auto tilingResult = scf::tileUsingSCF( b, cast(currentOp.getOperation()), tileOption); if (failed(tilingResult)) return failure(); - b.replaceOp(currentOp, tilingResult->replacements); - currentOp = dyn_cast(tilingResult->tiledOps.back()); - if (iteratorTypes[d] == mlir::utils::IteratorType::reduction) { - result.reductionLoops.push_back(tilingResult->loops.back()); + + if (!isDummyLoop(tilingResult->loops.back())) { + b.replaceOp(currentOp, tilingResult->replacements); + currentOp = dyn_cast(tilingResult->tiledOps.back()); + if (iteratorTypes[d] == mlir::utils::IteratorType::reduction) { + result.reductionLoops.push_back(tilingResult->loops.back()); + } + result.loops.push_back(tilingResult->loops.back()); } - result.loops.push_back(tilingResult->loops.back()); } } else if (loopType == OuterLoopGenerationOption::LoopType::ForallOp) { SmallVector tileSizes( @@ -442,11 +594,12 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, if (failed(tilingResult)) return failure(); currentOp = dyn_cast(tilingResult->parallelTiledOp); - if (option.hasFillOp && tilingResult->mergeOp) { - auto fillOp = findParentFillOp(tilingResult->loops.getDpsInits()[0]); - if (fillOp) { - b.replaceOp(fillOp, dyn_cast(*fillOp) - .getDpsInits()[0]); + if (tilingResult->mergeOp) { + for (const auto &fn : option.finalReduceCallBacks) { + auto result = fn(b, currentOp.getLoc(), *tilingResult); + if (succeeded(result)) { + currentOp = *result; + } } } } else if (auto tilingInterface = @@ -464,102 +617,6 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, return result; } -[[maybe_unused]] static LogicalResult -indexRolling(RewriterBase &b, Block *insertBlock, Location loc, Value v, - Value rollingIdx, Value maximumRange, Value step) { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(insertBlock); - mlir::easybuild::EasyBuilder eb{b, loc}; - auto vWraped = eb.wrap(v); - auto rollingIdxWraped = eb.wrap(rollingIdx); - auto stepWraped = eb.wrap(step); - auto maximumRangeWraped = eb.wrap(step); - auto newV = (vWraped + rollingIdxWraped) * stepWraped % - (maximumRangeWraped / stepWraped * stepWraped); - v.replaceAllUsesWith(newV); - return failure(); -} - -static void getMatmulParallelDims(linalg::LinalgOp linalgOp, - unsigned operandIdx, - SmallVectorImpl &dims) { - AffineMap map = - linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(operandIdx)); - SmallVector iteratorTypes = - linalgOp.getIteratorTypesArray(); - - ArrayRef results = map.getResults(); - for (auto dim : results) { - auto dimExpr = dyn_cast(dim); - if (dimExpr && iteratorTypes[dimExpr.getPosition()] == - mlir::utils::IteratorType::parallel) { - dims.push_back(dimExpr.getPosition()); - } - } -} - -static unsigned getOprandDim(linalg::LinalgOp &linalgOp, unsigned iteratorPos, - unsigned operandIdx) { - Value Operand; - unsigned dimPos; - [[maybe_unused]] auto result = - linalgOp.mapIterationSpaceDimToOperandDim(iteratorPos, Operand, dimPos); - return linalgOp.getShape(linalgOp.getDpsInputOperand(operandIdx))[dimPos]; -} - -static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter, - Operation *op, - bool isExtract, - SmallVector size, - int shrinDimNum = 0) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(op); - if (auto extractSlice = dyn_cast(op)) { - SmallVector mixedOffsets = extractSlice.getMixedOffsets(); - SmallVector mixedSizes = extractSlice.getMixedSizes(); - SmallVector mixedStrides = extractSlice.getMixedStrides(); - for (auto i = 0UL; i < mixedSizes.size(); i++) { - mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), size[i]); - } - if (shrinDimNum > 0) { - rewriter.replaceOpWithNewOp( - extractSlice, - mlir::RankedTensorType::get( - SmallVector(size.begin() + shrinDimNum, size.end()), - extractSlice.getResult().getType().getElementType()), - extractSlice.getSource(), mixedOffsets, mixedSizes, mixedStrides); - } else { - rewriter.replaceOpWithNewOp( - extractSlice, extractSlice.getSource(), mixedOffsets, mixedSizes, - mixedStrides); - } - } else { - return failure(); - } - return mlir::success(); -} - -static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter, - Operation *op, Value source, - SmallVector size) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(op); - if (auto insertSlice = dyn_cast(op)) { - SmallVector mixedOffsets = insertSlice.getMixedOffsets(); - SmallVector mixedSizes = insertSlice.getMixedSizes(); - SmallVector mixedStrides = insertSlice.getMixedStrides(); - for (auto i = 0UL; i < mixedSizes.size(); i++) { - mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), size[i]); - } - rewriter.replaceOpWithNewOp( - insertSlice, source, insertSlice.getDest(), mixedOffsets, mixedSizes, - mixedStrides); - } else { - return failure(); - } - return success(); -} - /* matmul(A, B) -> C ----------------> @@ -636,14 +693,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1 : cfg.NBlock; // Outer - if (cfg.KThreads > 1) { - auto result = - matmulDtypeLegalize(rewriter, linalgOp.getOperation(), !hasFillOp); - if (result->castOp && result->linalgOp) { - rewriter.replaceOp(linalgOp, result->castOp); - linalgOp = dyn_cast(result->linalgOp); - } - } option.nestedTileSizes.emplace_back(SmallVector{ MParallelBlockSize, NParallelBlockSize, KParallelBlockSize}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForallOp); @@ -686,12 +735,45 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { option.loopDim.emplace_back(SmallVector{(int)dim}); } } - option.hasFillOp = hasFillOp; + + auto lowPrecisionCast = + [&](RewriterBase &rewriter, Location loc, + linalg::LinalgOp linalgop) -> FailureOr { + auto legalizedResult = matmulDtypeLegalize( + rewriter, linalgop.getOperation(), !hasFillOp, true); + if (legalizedResult->castOp && legalizedResult->linalgOp) { + auto linalgOp = legalizedResult->linalgOp; + rewriter.replaceOp(linalgop, + linalgOp->getResult(linalgOp->getNumResults() - 1)); + return dyn_cast(linalgOp); + } + return failure(); + }; + option.innermostFullResultCallBacks.push_back(lowPrecisionCast); + + if (hasFillOp) { + auto removeReduncantFill = + [&](RewriterBase &rewriter, Location loc, + const linalg::ForallReductionTilingResult &result) + -> FailureOr { + auto initValue = result.initialValues; + if (initValue.size() == 1 && + isa(initValue[0].getDefiningOp())) { + rewriter.replaceOp(initValue[0].getDefiningOp(), + dyn_cast( + initValue[0].getDefiningOp()) + .getDpsInits()[0]); + } + return dyn_cast(result.parallelTiledOp); + }; + option.finalReduceCallBacks.push_back(removeReduncantFill); + } return generateOuterLoop(rewriter, linalgOp, option); } struct innerBodyGenerationOption { Operation *fillOp; + bool needLowPrecisionCast; SmallVector KLoopHandles; }; @@ -796,7 +878,11 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { BInnermostDims, NDimNum > 1)) || failed(setStaticSizeForExtractSliceOp( rewriter, currentOp.getDpsInputs()[0].getDefiningOp(), true, - AInnermostDims, MDimNum > 1))) { + AInnermostDims, MDimNum > 1)) || + (currentOp.getDpsInits().size() > 1 && + failed(setStaticSizeForExtractSliceOp( + rewriter, currentOp.getDpsInits()[1].getDefiningOp(), true, + CInnermostDims, MDimNum > 1 ? 2 : 0)))) { return failure(); } // View the tensor to brgemm required format @@ -850,13 +936,47 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { return failure(); } } - rewriter.replaceOp(currentOp, matmul.getOperation()->getResult(0)); - currentOp = matmul; + if (option.needLowPrecisionCast) { + rewriter.setInsertionPointAfter(currentOp); + auto cond = eb(true); + for (auto loop : option.KLoopHandles) { + auto induceVar = + eb.wrap(*loop.getSingleInductionVar()); + auto upBound = + eb.wrap(*loop.getSingleUpperBound()); + auto step = eb.wrap(*loop.getSingleStep()); + auto currentCond = (induceVar + step) > upBound; + cond = cond & currentCond; + } + EB_scf_if(cond, {currentOp.getDpsInits().back().getType()}) { + auto castOp = rewriter.create( + matmul.getLoc(), matmul->getResult(0), + currentOp.getDpsInits().back()); + eb.yield(castOp->getResult(0)); + } + EB_else { eb.yield(currentOp.getDpsInits().back()); } + auto ifOp = eb.getLastOperaion(); + // set static size for the insertSliceOp of copyOp + for (Operation *user : currentOp->getResult(1).getUsers()) { + if (failed(setStaticSizeForInsertSliceOp( + rewriter, user, ifOp->getResult(0), CInnermostDims))) { + return failure(); + } + } + rewriter.replaceOp(currentOp, {matmul->getResult(0), ifOp->getResult(0)}); + } else { + rewriter.replaceOp(currentOp, matmul->getResult(0)); + } + currentOp = matmul; // Fuse the fill op to the innermost body if (auto fillOp = llvm::dyn_cast_or_null(option.fillOp)) { auto fillValue = fillOp.getDpsInputs()[0]; - rewriter.replaceOp(fillOp, fillOp.getDpsInits()[0]); + if (cfg.KThreads <= 1) { + // if use k slicing, the fill op is still need to be kept for the reduce + // init + rewriter.replaceOp(fillOp, fillOp.getDpsInits()[0]); + } rewriter.setInsertionPointAfter(currentOp); auto cond = eb(true); @@ -910,29 +1030,31 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { linalgOp = *linalg::generalizeNamedOp(rewriter, linalgOp); Operation *fillOp = findParentFillOp(linalgOp.getDpsInits()[0]); - // Step 1. generate the outer loop + // Step 1. Split matmul(bf16xbf16->bf16) to matmul(bf16xbf16->f32) + + // cast(f32->bf16) if K slicing is needed MatmulConfig cfg = getDefaultMatmulConfig(linalgOp); + bool needLowPrecisionCast = needToLegalizeDtype(linalgOp); + if (cfg.KThreads > 1) { + auto result = matmulDtypeLegalize(rewriter, linalgOp.getOperation()); + if (result->castOp && result->linalgOp) { + rewriter.replaceOp(linalgOp, result->castOp); + linalgOp = dyn_cast(result->linalgOp); + } + needLowPrecisionCast = false; + } + + // Step 2. Outer loop generation auto outerLoopResult = outerLoopGeneration(rewriter, linalgOp, cfg, isa(fillOp)); if (failed(outerLoopResult)) { return failure(); } linalgOp = dyn_cast(outerLoopResult->tiledOps.back()); - // Step 2 index rolling - // if (failed(indexRolling(rewriter, linalgOp.getLoc(), - // outerLoopResult->reductionLoops[0].getInductionVar(), - // linalgOp.getLoopRanges()[0].size, cfg.MBlock)) - // || - // failed(indexRolling(rewriter, linalgOp.getLoc(), - // linalgOp.getDpsInputOperand(1), - // linalgOp.getLoopRanges()[1].size, cfg.KBlock))) - // { - // return failure(); - // } // Step 3 generate inner loop body, convert the linalg.generic to brgemm - auto option = - innerBodyGenerationOption{fillOp, outerLoopResult->reductionLoops}; + auto option = innerBodyGenerationOption{fillOp, needLowPrecisionCast, + outerLoopResult->reductionLoops}; + if (failed(innerBodyGeneration(rewriter, originOp, linalgOp, option))) { return failure(); } From dc9a1a4fb13187e9e5a784b60416f5cc8e834e2d Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Wed, 19 Jun 2024 19:41:56 -0700 Subject: [PATCH 10/21] enhance config --- include/gc/Analysis/MatmulConfigAnalysis.h | 123 ++++++ lib/gc/Analysis/CMakeLists.txt | 3 +- lib/gc/Analysis/MatmulConfigAnalysis.cpp | 387 ++++++++++++++++++ .../Transforms/DeepTileContractionNamedOp.cpp | 162 ++------ 4 files changed, 547 insertions(+), 128 deletions(-) create mode 100644 include/gc/Analysis/MatmulConfigAnalysis.h create mode 100644 lib/gc/Analysis/MatmulConfigAnalysis.cpp diff --git a/include/gc/Analysis/MatmulConfigAnalysis.h b/include/gc/Analysis/MatmulConfigAnalysis.h new file mode 100644 index 000000000..cbc259609 --- /dev/null +++ b/include/gc/Analysis/MatmulConfigAnalysis.h @@ -0,0 +1,123 @@ +//===-- MatmulConfigAnalysis.h - DESC ---------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H +#define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H + +#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" +#include +#include +#include + +namespace mlir { +namespace gc { + +using namespace mlir; + +struct SystemDesc { + // get runtime OMP_NUM_THREADS + uint32_t getNumThreads() { + char *numThreads = getenv("OMP_NUM_THREADS"); + if (numThreads) { + return std::stoi(numThreads); + } + return 1; + } + // get cache size by cacheLevel + size_t getCacheSize(uint8_t cacheLevel) { + if (cacheLevel == 1) { + char *cacheSize = getenv("L1_CACHE_SIZE"); + if (cacheSize) { + return std::stoi(cacheSize); + } + } else if (cacheLevel == 2) { + char *cacheSize = getenv("L2_CACHE_SIZE"); + if (cacheSize) { + return std::stoi(cacheSize); + } + } else if (cacheLevel == 3) { + char *cacheSize = getenv("L3_CACHE_SIZE"); + if (cacheSize) { + return std::stoi(cacheSize); + } + } + return 0; + } + + SmallVector getContractionOperationMaxVectorLength() { + return {512UL, 512UL}; + } +}; + +struct MatmulConfig { + uint32_t MBlock, NBlock, KBlock; + uint32_t MThreads, NThreads, KThreads; + uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock; + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, + const MatmulConfig &config); +}; + +enum DimType { Batch, M, N, K }; + +[[maybe_unused]] static SmallVector +extractDimTypeIdx(ArrayRef tyList, DimType ty) { + SmallVector idxList; + for (auto [idx, type] : llvm::enumerate(tyList)) { + if (type == ty) { + idxList.push_back(idx); + } + } + return idxList; +} + +static FailureOr>> +getOprandDimType(linalg::LinalgOp &linalgOp) { + if (isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K}, + SmallVector{DimType::K, DimType::N}, + SmallVector{DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, + DimType::K}, + SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, + DimType::K}, + SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::Batch, DimType::M, DimType::K}, + SmallVector{DimType::Batch, DimType::K, DimType::N}, + SmallVector{DimType::Batch, DimType::M, DimType::N}}; + } + return failure(); +} + +struct MatmulConfigAnalysis { +public: + explicit MatmulConfigAnalysis(Operation *root); + MatmulConfig getConfig() { return config; } + +private: + MatmulConfig config; +}; + +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/lib/gc/Analysis/CMakeLists.txt b/lib/gc/Analysis/CMakeLists.txt index c1c34ea50..51163823a 100644 --- a/lib/gc/Analysis/CMakeLists.txt +++ b/lib/gc/Analysis/CMakeLists.txt @@ -4,6 +4,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS gc_add_mlir_library(GcAnalysis TargetDescriptionAnalysis.cpp + MatmulConfigAnalysis.cpp DEPENDS GraphCompilerPassIncGen @@ -12,4 +13,4 @@ gc_add_mlir_library(GcAnalysis ${mlir_dialect_libs} ${MLIR_LINK_COMPONENTS} GcInterface - ) \ No newline at end of file + ) diff --git a/lib/gc/Analysis/MatmulConfigAnalysis.cpp b/lib/gc/Analysis/MatmulConfigAnalysis.cpp new file mode 100644 index 000000000..de2067566 --- /dev/null +++ b/lib/gc/Analysis/MatmulConfigAnalysis.cpp @@ -0,0 +1,387 @@ +//===-- MatmulConfigAnalysis.cpp - DESC -------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "gc/Analysis/MatmulConfigAnalysis.h" + +namespace mlir { +namespace gc { + +#define DEBUG_TYPE "matmul-config-analysis" + +#define MAX_THREADS (1024U * 1024U) + +llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, + const MatmulConfig &config) { + + ss << "MBlock: " << config.MBlock << ", NBlock: " << config.NBlock + << ", KBlock: " << config.KBlock << ", MThreads: " << config.MThreads + << ", NThreads: " << config.NThreads << ", KThreads: " << config.KThreads + << ", innerMostMBlock: " << config.innerMostMBlock + << ", innerMostNBlock: " << config.innerMostNBlock + << ", innerMostKBlock: " << config.innerMostKBlock; + return ss; +} + +std::vector getCandidate(uint32_t num, uint32_t floor, + uint32_t ceil) { + std::vector candidates; + for (uint32_t i = 1; i <= num; i++) { + if (num % i == 0 && i <= ceil && i >= floor) { + candidates.push_back(i); + } + } + auto candidate = 1U; + while (candidate < num && candidate <= ceil && candidate >= floor) { + candidates.push_back(candidate); + candidate *= 2; + } + auto last = std::unique(candidates.begin(), candidates.end()); + candidates.erase(last, candidates.end()); + return candidates; +} + +bool isValidConfig(const MatmulConfig &config, SystemDesc &sysDesc, + ArrayRef shape) { + if (config.innerMostMBlock == 0 || config.innerMostNBlock == 0 || + config.innerMostKBlock == 0) { + return false; + } + if (config.MBlock % config.innerMostMBlock != 0 || + config.NBlock % config.innerMostNBlock != 0 || + config.KBlock % config.innerMostKBlock != 0) { + return false; + } + auto threads = sysDesc.getNumThreads(); + if (config.MThreads * config.NThreads * config.KThreads != threads) { + return false; + } + + if (shape[0] % config.innerMostMBlock != 0 || + shape[1] % config.innerMostNBlock != 0 || + shape[2] % config.innerMostKBlock != 0) { + return false; + } + + return true; +} + +double threadUtilizationCost(linalg::LinalgOp &linalgOp, + ArrayRef shape, + const MatmulConfig &config, SystemDesc &sysDesc) { + auto threads = sysDesc.getNumThreads(); + auto actualThreads = + (float)(config.MThreads * config.NThreads * config.KThreads); + return threads >= actualThreads ? threads / actualThreads + : actualThreads / threads; +} + +double hardwareEfficiencyCost(linalg::LinalgOp &linalgOp, + ArrayRef shape, + const MatmulConfig &config, SystemDesc &sysDesc) { + auto dtypeSize = DataLayout().getTypeSizeInBits( + ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); + auto vectorLength = sysDesc.getContractionOperationMaxVectorLength(); + auto mMaxVectorLength = vectorLength[0] / dtypeSize; + auto kMaxVectorLength = + (vectorLength.size() > 1 ? vectorLength[1] : vectorLength[0]) / dtypeSize; + auto cost = (mMaxVectorLength - config.innerMostMBlock % mMaxVectorLength) % + mMaxVectorLength * 1.0 / config.innerMostMBlock + + (kMaxVectorLength - config.innerMostKBlock % kMaxVectorLength) % + kMaxVectorLength * 1.0 / config.innerMostKBlock + + (mMaxVectorLength - config.innerMostNBlock % mMaxVectorLength) % + mMaxVectorLength * 1.0 / config.innerMostNBlock; + return cost; +} + +double workloadBalancedCost(linalg::LinalgOp &linalgOp, + ArrayRef shape, + const MatmulConfig &config, SystemDesc &sysDesc) { + return 1; +} + +double memoryConsumptionOnThreadCost(linalg::LinalgOp &linalgOp, + ArrayRef shape, + const MatmulConfig &config, + SystemDesc &sysDesc) { + auto M = shape[0], N = shape[1], K = shape[2]; + auto dtypeSize = DataLayout().getTypeSizeInBits( + ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); + auto penalty = 2.0 * (dtypeSize / 8); + auto memoryConsumptionPerThread = + M * K * 1.0 / config.MThreads / config.KThreads + + K * N * 1.0 / config.KThreads / config.NThreads + + M * N * ((config.KThreads - 1) * penalty + 1.0) / config.MThreads / + config.NThreads; + return memoryConsumptionPerThread; +} + +double computationIntensityOnL1Cache(linalg::LinalgOp &linalgOp, + ArrayRef shape, + const MatmulConfig &config, + SystemDesc &sysDesc) { + auto L1Cache = sysDesc.getCacheSize(2); + auto dtypeSize = DataLayout().getTypeSizeInBits( + ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); + auto outOfCachePenalty = 1024; + double FLOPS = + 2.0 * config.innerMostMBlock * config.innerMostNBlock * config.KBlock; + double memoryConsumption = config.innerMostMBlock * config.innerMostNBlock + + config.innerMostNBlock * config.KBlock + + config.innerMostMBlock * config.KBlock; + double computationIntensity = FLOPS / memoryConsumption; + if (memoryConsumption * (dtypeSize / 8) > L1Cache) { + computationIntensity /= outOfCachePenalty; + } + return 1 / computationIntensity; +} + +using CostModelFn = + std::function shape, + MatmulConfig cfg, SystemDesc &sysDesc)>; + +std::vector +filterConfigByCostModel(std::vector configs, + linalg::LinalgOp &linalgOp, ArrayRef shape, + SystemDesc &sysDesc, const CostModelFn &costModel, + float eliminationRatio = 0.5, float threshold = -1) { + std::vector result; + std::vector costs; + std::vector idx; + for (auto [i, config] : llvm::enumerate(configs)) { + costs.push_back(costModel(linalgOp, shape, config, sysDesc)); + idx.push_back(i); + } + std::stable_sort(idx.begin(), idx.end(), [&costs](size_t i1, size_t i2) { + return costs[i1] < costs[i2]; + }); + auto thresholdCost = costs[idx[(size_t)(eliminationRatio * configs.size())]]; + thresholdCost = + threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost; + for (size_t i = 0; i < configs.size(); i++) { + if (costs[idx[i]] <= thresholdCost) { + result.push_back(configs[idx[i]]); + } + } + llvm::outs() << "thresholdCost is: " << thresholdCost + << "\nbest with cost: " << costs[idx[0]] << "\n" + << configs[idx[0]] + << "\n worst with cost: " << costs[idx[configs.size() - 1]] + << "\n" + << configs[idx[configs.size() - 1]] << "\n"; + return !result.empty() ? result : configs; +} + +std::vector +prepareConfigCandidates(Operation *root, SystemDesc &sysDesc, + ArrayRef shape, + ArrayRef givenInnermostBlock) { + std::vector configs; + auto threads = sysDesc.getNumThreads(); + auto MThreadsCandidates = getCandidate((uint32_t)threads, 1U, MAX_THREADS); + auto NThreadsCandidates = getCandidate((uint32_t)threads, 1U, MAX_THREADS); + auto KThreadsCandidates = getCandidate((uint32_t)threads, 1U, MAX_THREADS); + auto MBlockCandidates = + getCandidate((uint32_t)shape[0], 1U, (uint32_t)shape[0]); + auto NBlockCandidates = getCandidate((uint32_t)shape[1], 1U, shape[1]); + auto KBlockCandidates = getCandidate((uint32_t)shape[2], 1U, shape[2]); + auto innerMostMBlockCandidates = + getCandidate((uint32_t)shape[0], 1U, (uint32_t)shape[0]); + auto innerMostNBlockCandidates = + getCandidate((uint32_t)shape[1], 1U, (uint32_t)shape[1]); + auto innerMostKBlockCandidates = + getCandidate((uint32_t)shape[2], 1U, (uint32_t)shape[2]); + if (givenInnermostBlock.size() == 3) { + innerMostMBlockCandidates = + givenInnermostBlock[0] != 0 + ? std::vector{givenInnermostBlock[0]} + : innerMostMBlockCandidates; + innerMostNBlockCandidates = + givenInnermostBlock[1] != 0 + ? std::vector{givenInnermostBlock[1]} + : innerMostNBlockCandidates; + innerMostKBlockCandidates = + givenInnermostBlock[2] != 0 + ? std::vector{givenInnermostBlock[2]} + : innerMostKBlockCandidates; + } + llvm::outs() << "MThreadsCandidates size: " << MThreadsCandidates.size() + << "\n"; + llvm::outs() << "NThreadsCandidates size: " << NThreadsCandidates.size() + << "\n"; + llvm::outs() << "KThreadsCandidates size: " << KThreadsCandidates.size() + << "\n"; + llvm::outs() << "MBlockCandidates size: " << MBlockCandidates.size() << "\n"; + llvm::outs() << "NBlockCandidates size: " << NBlockCandidates.size() << "\n"; + llvm::outs() << "KBlockCandidates size: " << KBlockCandidates.size() << "\n"; + llvm::outs() << "innerMostMBlockCandidates size: " + << innerMostMBlockCandidates.size() << "\n"; + llvm::outs() << "innerMostNBlockCandidates size: " + << innerMostNBlockCandidates.size() << "\n"; + llvm::outs() << "innerMostKBlockCandidates size: " + << innerMostKBlockCandidates.size() << "\n"; + for (auto MThreads : MThreadsCandidates) { + for (auto NThreads : NThreadsCandidates) { + for (auto KThreads : KThreadsCandidates) { + for (auto MBlock : MBlockCandidates) { + for (auto NBlock : NBlockCandidates) { + for (auto KBlock : KBlockCandidates) { + for (auto innerMostMBlock : innerMostMBlockCandidates) { + for (auto innerMostNBlock : innerMostNBlockCandidates) { + for (auto innerMostKBlock : innerMostKBlockCandidates) { + MatmulConfig config{ + MBlock, NBlock, KBlock, + MThreads, NThreads, KThreads, + innerMostMBlock, innerMostNBlock, innerMostKBlock}; + + if (isValidConfig(config, sysDesc, shape)) { + configs.push_back(config); + } + } + } + } + } + } + } + } + } + } + return configs; +} + +/* +thread utilization +computation intensity +cache locality +memory requirements +computation unit efficiency +padding/pack cost +workload balance +communication +previous matmul +*/ +MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { + SystemDesc sysDesc; + if (auto linalgOp = dyn_cast(root)) { + // TODO: build a more complex heuristic to determine the best tiling + auto oprandDimType = *getOprandDimType(linalgOp); + // get the origin M,N,K size + auto MDimTypeIdx = extractDimTypeIdx(oprandDimType[0], DimType::M); + auto KDimTypeIdx = extractDimTypeIdx(oprandDimType[1], DimType::K); + auto NDimTypeIdx = extractDimTypeIdx(oprandDimType[1], DimType::N); + uint32_t M = 1U, N = 1U, K = 1U; + for (auto [s, dimType] : + llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(0)), + oprandDimType[0])) { + if (dimType == DimType::M) { + M *= s; + } + } + for (auto [s, dimType] : + llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(1)), + oprandDimType[1])) { + if (dimType == DimType::N) { + N *= s; + } else if (dimType == DimType::K) { + K *= s; + } + } + + // innermost Block, if the layout is blockied layout, the innermost block + // will derived from the layout directly + auto defaultBlock = 32; + config.innerMostMBlock = M % defaultBlock == 0 ? defaultBlock : M; + config.innerMostNBlock = N % defaultBlock == 0 ? defaultBlock : N; + config.innerMostKBlock = K % defaultBlock == 0 ? defaultBlock : K; + SmallVector givenInnermostBlock; + if (MDimTypeIdx.size() > 1) { + config.innerMostMBlock = 1; + for (auto i = 1UL; i < MDimTypeIdx.size(); i++) { + config.innerMostMBlock *= + linalgOp.getShape(linalgOp.getDpsInputOperand(0))[MDimTypeIdx[i]]; + } + givenInnermostBlock.push_back(config.innerMostMBlock); + } else { + givenInnermostBlock.push_back(0); + } + if (NDimTypeIdx.size() > 1) { + config.innerMostNBlock = 1; + for (auto i = 1UL; i < NDimTypeIdx.size(); i++) { + config.innerMostNBlock *= + linalgOp.getShape(linalgOp.getDpsInputOperand(1))[NDimTypeIdx[i]]; + } + givenInnermostBlock.push_back(config.innerMostNBlock); + } else { + givenInnermostBlock.push_back(0); + } + if (KDimTypeIdx.size() > 1) { + config.innerMostKBlock = 1; + for (auto i = 1UL; i < KDimTypeIdx.size(); i++) { + config.innerMostKBlock *= + linalgOp.getShape(linalgOp.getDpsInputOperand(1))[KDimTypeIdx[i]]; + } + givenInnermostBlock.push_back(config.innerMostKBlock); + } else { + givenInnermostBlock.push_back(0); + } + + // Number of block + auto MNumBlock = M / config.innerMostMBlock; + auto NNumBlock = N / config.innerMostNBlock; + auto KNumBlock = K / config.innerMostKBlock; + + // Threads + config.MThreads = 32; + config.NThreads = 1; + config.KThreads = 1; + + // Block + config.MBlock = (int)llvm::divideCeil(MNumBlock, config.MThreads) * + config.innerMostMBlock; + config.NBlock = (int)llvm::divideCeil(NNumBlock, config.NThreads) * + config.innerMostNBlock; + config.KBlock = (int)llvm::divideCeil(KNumBlock, config.KThreads) * + config.innerMostKBlock; + config.MBlock = 128; + config.NBlock = 128; + config.KBlock = 128; + config.MThreads = 2; + config.NThreads = 2; + config.KThreads = 1; + + llvm::outs() << "M: " << M << ", N: " << N << ", K: " << K << "\n"; + + SmallVector> costModelList = { + {threadUtilizationCost, "threadUtilizationCost"}, + {hardwareEfficiencyCost, "hardwareEfficiencyCost"}, + {workloadBalancedCost, "workloadBalancedCost"}, + {memoryConsumptionOnThreadCost, "memoryConsumptionOnThreadCost"}, + {computationIntensityOnL1Cache, "computationIntensityOnL1Cache"}}; + + auto configCandidates = + prepareConfigCandidates(root, sysDesc, {M, N, K}, givenInnermostBlock); + + for (auto [fn, name] : costModelList) { + llvm::outs() << name << "\n\n"; + configCandidates = filterConfigByCostModel(configCandidates, linalgOp, + {M, N, K}, sysDesc, fn, 0.5); + llvm::outs() << "ConfigCandidates size: " << configCandidates.size() + << "\n"; + } + + if (!configCandidates.empty()) { + config = configCandidates[0]; + } + + llvm::outs() << "Final config\nNumThreads: " << sysDesc.getNumThreads() + << ", MatmulConfig: " << config << "\n"; + } +} +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index eabc434ca..5b7f34416 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "./Tiling.hpp" +#include "gc/Analysis/MatmulConfigAnalysis.h" #include "gc/Dialect/Arith/Utils/EasyBuild.h" #include "gc/Dialect/Linalgx/LinalgxOps.h" #include "gc/IR/EasyBuild.h" @@ -45,101 +46,6 @@ namespace gc { namespace { -struct SystemDesc { - // get runtime OMP_NUM_THREADS - uint32_t getNumThreads(); - // get cache size by cacheLevel - size_t getCacheSize(uint8_t cacheLevel); -}; - -struct MatmulConfig { - int MBlock, NBlock, KBlock; - int MThreads, NThreads, KThreads; - int innerMostMBlock, innerMostNBlock, innerMostKBlock; -}; - -template inline T divAndCeil(T a, T b) { return (a - 1) / b + 1; } - -enum DimType { Batch, M, N, K }; - -static FailureOr>> -getOprandDimType(linalg::LinalgOp &linalgOp) { - if (isa(linalgOp)) { - return SmallVector>{ - SmallVector{DimType::M, DimType::K}, - SmallVector{DimType::K, DimType::N}, - SmallVector{DimType::M, DimType::N}}; - } else if (llvm::isa(linalgOp)) { - return SmallVector>{ - SmallVector{DimType::M, DimType::K}, - SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, - DimType::K}, - SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; - } else if (llvm::isa(linalgOp)) { - return SmallVector>{ - SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, - SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, - DimType::K}, - SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; - } else if (llvm::isa(linalgOp)) { - return SmallVector>{ - SmallVector{DimType::Batch, DimType::M, DimType::K}, - SmallVector{DimType::Batch, DimType::K, DimType::N}, - SmallVector{DimType::Batch, DimType::M, DimType::N}}; - } - return failure(); -} - -[[maybe_unused]] static SmallVector -extractDimTypeIdx(ArrayRef tyList, DimType ty) { - SmallVector idxList; - for (auto [idx, type] : llvm::enumerate(tyList)) { - if (type == ty) { - idxList.push_back(idx); - } - } - return idxList; -} - -MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) { - // TODO: build a more complex heuristic to determine the best tiling - auto M = linalgOp.getShape(linalgOp.getDpsInputOperand(0))[0]; - auto N = linalgOp.getShape(linalgOp.getDpsInputOperand(1))[1]; - auto K = linalgOp.getShape(linalgOp.getDpsInputOperand(1))[0]; - MatmulConfig cfg; - - // innermost Block - auto defaultBlock = 32; - cfg.innerMostMBlock = M % defaultBlock == 0 ? defaultBlock : M; - cfg.innerMostNBlock = N % defaultBlock == 0 ? defaultBlock : N; - cfg.innerMostKBlock = K % defaultBlock == 0 ? defaultBlock : K; - - // Number of block - auto MNumBlock = M / cfg.innerMostMBlock; - auto NNumBlock = N / cfg.innerMostNBlock; - auto KNumBlock = K / cfg.innerMostKBlock; - - // Threads - cfg.MThreads = 32; - cfg.NThreads = 1; - cfg.KThreads = 1; - - // Block - cfg.MBlock = divAndCeil((int)MNumBlock, cfg.MThreads) * cfg.innerMostMBlock; - cfg.NBlock = divAndCeil((int)NNumBlock, cfg.NThreads) * cfg.innerMostNBlock; - cfg.KBlock = divAndCeil((int)KNumBlock, cfg.KThreads) * cfg.innerMostKBlock; - cfg.innerMostMBlock = 32; - cfg.innerMostNBlock = 32; - cfg.innerMostKBlock = 32; - cfg.MBlock = 64; - cfg.NBlock = 64; - cfg.KBlock = 64; - cfg.MThreads = 2; - cfg.NThreads = 2; - cfg.KThreads = 1; - return cfg; -} - static Value tensorViewRankedTensor(RewriterBase &rewriter, RankedTensorType outTensorType, Value value, @@ -478,9 +384,9 @@ using FinalReduceCallBackFn = std::function( struct OuterLoopGenerationOption { enum LoopType { ForOp, ForallOp }; - SmallVector> nestedTileSizes; + SmallVector> nestedTileSizes; SmallVector loopType; - SmallVector> loopDim; + SmallVector> loopDim; SmallVector innermostFullResultCallBacks; SmallVector finalReduceCallBacks; bool isPartialResult = false; @@ -657,7 +563,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { FailureOr outerLoopGeneration(RewriterBase &rewriter, linalg::LinalgOp linalgOp, - MatmulConfig cfg, bool hasFillOp) const { + gc::MatmulConfig cfg, bool hasFillOp) const { SmallVector KDimPos, MDimPos, NDimPos; linalgOp.getReductionDims(KDimPos); getMatmulParallelDims(linalgOp, 0, MDimPos); @@ -665,23 +571,26 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { OuterLoopGenerationOption option; auto iteratorTypes = linalgOp.getIteratorTypesArray(); - auto KFirstDim = (int)getOprandDim(linalgOp, KDimPos[0], 1); - auto MFirstDim = (int)getOprandDim(linalgOp, MDimPos[0], 0); - auto NFirstDim = (int)getOprandDim(linalgOp, NDimPos[0], 1); + auto KFirstDim = getOprandDim(linalgOp, KDimPos[0], 1); + auto MFirstDim = getOprandDim(linalgOp, MDimPos[0], 0); + auto NFirstDim = getOprandDim(linalgOp, NDimPos[0], 1); auto KParallelBlockSize = KDimPos.size() > 1 - ? divAndCeil(KFirstDim, cfg.KThreads) - : divAndCeil(divAndCeil(KFirstDim, cfg.KBlock), cfg.KThreads) * + ? llvm::divideCeil(KFirstDim, cfg.KThreads) + : llvm::divideCeil(llvm::divideCeil(KFirstDim, cfg.KBlock), + cfg.KThreads) * cfg.KBlock; auto MParallelBlockSize = MDimPos.size() > 1 - ? divAndCeil(MFirstDim, cfg.MThreads) - : divAndCeil(divAndCeil(MFirstDim, cfg.MBlock), cfg.MThreads) * + ? llvm::divideCeil(MFirstDim, cfg.MThreads) + : llvm::divideCeil(llvm::divideCeil(MFirstDim, cfg.MBlock), + cfg.MThreads) * cfg.MBlock; auto NParallelBlockSize = NDimPos.size() > 1 - ? divAndCeil(NFirstDim, cfg.NThreads) - : divAndCeil(divAndCeil(NFirstDim, cfg.NBlock), cfg.NThreads) * + ? llvm::divideCeil(NFirstDim, cfg.NThreads) + : llvm::divideCeil(llvm::divideCeil(NFirstDim, cfg.NBlock), + cfg.NThreads) * cfg.NBlock; auto KOuterBlockSize = KDimPos.size() > 1 ? (cfg.KBlock - 1) / cfg.innerMostKBlock + 1 @@ -693,46 +602,45 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1 : cfg.NBlock; // Outer - option.nestedTileSizes.emplace_back(SmallVector{ + option.nestedTileSizes.emplace_back(SmallVector{ MParallelBlockSize, NParallelBlockSize, KParallelBlockSize}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForallOp); option.loopDim.emplace_back( - SmallVector{(int)MDimPos[0], (int)NDimPos[0], (int)KDimPos[0]}); + SmallVector{MDimPos[0], NDimPos[0], KDimPos[0]}); // Middle for (auto [tile, dim] : - llvm::zip(SmallVector{MOuterBlockSize, NOuterBlockSize, - KOuterBlockSize}, - SmallVector{(int)MDimPos[0], (int)NDimPos[0], - (int)KDimPos[0]})) { - option.nestedTileSizes.emplace_back(SmallVector{tile}); + llvm::zip(SmallVector{MOuterBlockSize, NOuterBlockSize, + KOuterBlockSize}, + SmallVector{MDimPos[0], NDimPos[0], KDimPos[0]})) { + option.nestedTileSizes.emplace_back(SmallVector{tile}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); - option.loopDim.emplace_back(SmallVector{dim}); + option.loopDim.emplace_back(SmallVector{dim}); } // Inner if (KDimPos.size() == 1) { - option.nestedTileSizes.emplace_back(SmallVector{cfg.KBlock}); + option.nestedTileSizes.emplace_back(SmallVector{cfg.KBlock}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); - option.loopDim.emplace_back(SmallVector{(int)KDimPos.back()}); + option.loopDim.emplace_back(SmallVector{KDimPos.back()}); } if (MDimPos.size() == 1) { option.nestedTileSizes.emplace_back( - SmallVector{cfg.innerMostMBlock}); + SmallVector{cfg.innerMostMBlock}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); - option.loopDim.emplace_back(SmallVector{(int)MDimPos.back()}); + option.loopDim.emplace_back(SmallVector{MDimPos.back()}); } if (NDimPos.size() == 1) { option.nestedTileSizes.emplace_back( - SmallVector{cfg.innerMostNBlock}); + SmallVector{cfg.innerMostNBlock}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); - option.loopDim.emplace_back(SmallVector{(int)NDimPos.back()}); + option.loopDim.emplace_back(SmallVector{NDimPos.back()}); } for (auto dim = 0UL; dim < linalgOp.getNumLoops(); dim++) { if (dim != MDimPos.back() && dim != NDimPos.back() && iteratorTypes[dim] != mlir::utils::IteratorType::reduction) { - option.nestedTileSizes.emplace_back(SmallVector{1}); + option.nestedTileSizes.emplace_back(SmallVector{1}); option.loopType.emplace_back( OuterLoopGenerationOption::LoopType::ForOp); - option.loopDim.emplace_back(SmallVector{(int)dim}); + option.loopDim.emplace_back(SmallVector{dim}); } } @@ -784,7 +692,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc()}; auto operandDimTypes = getOprandDimType(originOp); - MatmulConfig cfg = getDefaultMatmulConfig(originOp); + auto cfg = MatmulConfigAnalysis(originOp.getOperation()).getConfig(); auto AShape = originOp.getShape(originOp.getDpsInputOperand(0)); auto BShape = originOp.getShape(originOp.getDpsInputOperand(1)); auto CShape = originOp.getShape(originOp.getDpsInitOperand(0)); @@ -946,7 +854,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { auto upBound = eb.wrap(*loop.getSingleUpperBound()); auto step = eb.wrap(*loop.getSingleStep()); - auto currentCond = (induceVar + step) > upBound; + auto currentCond = (induceVar + step) >= upBound; cond = cond & currentCond; } EB_scf_if(cond, {currentOp.getDpsInits().back().getType()}) { @@ -1027,12 +935,12 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { rewriter.setInsertionPoint(linalgOp); linalg::LinalgOp originOp = dyn_cast(*rewriter.clone(*(linalgOp.getOperation()))); - linalgOp = *linalg::generalizeNamedOp(rewriter, linalgOp); Operation *fillOp = findParentFillOp(linalgOp.getDpsInits()[0]); // Step 1. Split matmul(bf16xbf16->bf16) to matmul(bf16xbf16->f32) + // cast(f32->bf16) if K slicing is needed - MatmulConfig cfg = getDefaultMatmulConfig(linalgOp); + auto cfg = MatmulConfigAnalysis(originOp.getOperation()).getConfig(); + linalgOp = *linalg::generalizeNamedOp(rewriter, linalgOp); bool needLowPrecisionCast = needToLegalizeDtype(linalgOp); if (cfg.KThreads > 1) { auto result = matmulDtypeLegalize(rewriter, linalgOp.getOperation()); From 9c9ff10c91836cb50cf0dc5df2207fbf5e744b57 Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Wed, 26 Jun 2024 18:33:07 -0700 Subject: [PATCH 11/21] rebase to the latest llvm --- include/gc/Dialect/Arith/Utils/EasyBuild.h | 3 ++- lib/gc/Transforms/DeepTileContractionNamedOp.cpp | 10 ++++++---- lib/gc/Transforms/Tiling.cpp | 10 +++++----- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/include/gc/Dialect/Arith/Utils/EasyBuild.h b/include/gc/Dialect/Arith/Utils/EasyBuild.h index 74f664184..f7656370d 100644 --- a/include/gc/Dialect/Arith/Utils/EasyBuild.h +++ b/include/gc/Dialect/Arith/Utils/EasyBuild.h @@ -361,8 +361,9 @@ inline EBUnsigned extend(Type type, const EBUnsigned &a) { } inline EBFloatPoint extend(Type type, const EBFloatPoint &a) { + arith::FastMathFlagsAttr fastMathAttr; return OperatorHandlers::create(a.builder, type, - a); + a, fastMathAttr); } inline EBSigned trunc(Type type, const EBSigned &a) { diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index 5b7f34416..818c20e59 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -497,10 +497,12 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, auto tilingResult = linalgX::tileAllUsingForall( b, cast(currentOp.getOperation()), {}, tileSizes, newParallelDims, std::nullopt); - if (failed(tilingResult)) + if (failed(tilingResult) && + tilingResult->parallelTiledOps.size() == 1UL) return failure(); - currentOp = dyn_cast(tilingResult->parallelTiledOp); - if (tilingResult->mergeOp) { + currentOp = + dyn_cast(tilingResult->parallelTiledOps.back()); + if (!tilingResult->mergeOps.empty()) { for (const auto &fn : option.finalReduceCallBacks) { auto result = fn(b, currentOp.getLoc(), *tilingResult); if (succeeded(result)) { @@ -672,7 +674,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { initValue[0].getDefiningOp()) .getDpsInits()[0]); } - return dyn_cast(result.parallelTiledOp); + return dyn_cast(result.parallelTiledOps.back()); }; option.finalReduceCallBacks.push_back(removeReduncantFill); } diff --git a/lib/gc/Transforms/Tiling.cpp b/lib/gc/Transforms/Tiling.cpp index 4462472bb..ea4d73722 100644 --- a/lib/gc/Transforms/Tiling.cpp +++ b/lib/gc/Transforms/Tiling.cpp @@ -284,7 +284,7 @@ static void calculateTileOffsetsAndSizes( OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(forallOp.getBody(0)); - ValueRange threadIds = forallOp.getInductionVars(); + SmallVector threadIds = forallOp.getInductionVars(); SmallVector nonZeroNumThreads = llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); @@ -755,8 +755,8 @@ FailureOr tileReductionUsingForall( ForallReductionTilingResult results; results.initialValues = initTensors; results.loops = forallOp; - results.parallelTiledOp = tiledOp; - results.mergeOp = mergeOp; + results.parallelTiledOps = {tiledOp}; + results.mergeOps = {mergeOp}; return results; } @@ -1069,8 +1069,8 @@ FailureOr tileAllUsingForall( ForallReductionTilingResult results; results.initialValues = initTensors; results.loops = forallOp; - results.parallelTiledOp = tiledOp; - results.mergeOp = mergeOp; + results.parallelTiledOps = SmallVector{tiledOp}; + results.mergeOps = SmallVector{mergeOp}; return results; } From 162466e216522838e09110a7583ba858b86cb94d Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Wed, 26 Jun 2024 22:30:03 -0700 Subject: [PATCH 12/21] fix deepTileMatmul --- .../Transforms/DeepTileContractionNamedOp.cpp | 70 ++++++++----------- 1 file changed, 29 insertions(+), 41 deletions(-) diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index 818c20e59..3cb54c191 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -272,7 +272,7 @@ static Operation *findParentFillOp(Value val) { !isa(currentOp)) { currentOp = currentOp->getResult(0).getDefiningOp(); } - if (isa(currentOp)) { + if (currentOp && isa(currentOp)) { return currentOp; } @@ -322,11 +322,10 @@ static unsigned getOprandDim(linalg::LinalgOp &linalgOp, unsigned iteratorPos, return linalgOp.getShape(linalgOp.getDpsInputOperand(operandIdx))[dimPos]; } -static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter, - Operation *op, - bool isExtract, - SmallVector size, - int shrinDimNum = 0) { +static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter, + Operation *op, bool isExtract, + SmallVector size, + int shrinDimNum = 0) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); if (auto extractSlice = dyn_cast(op)) { @@ -348,15 +347,12 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter, extractSlice, extractSlice.getSource(), mixedOffsets, mixedSizes, mixedStrides); } - } else { - return failure(); } - return mlir::success(); } -static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter, - Operation *op, Value source, - SmallVector size) { +static void setStaticSizeForInsertSliceOp(RewriterBase &rewriter, Operation *op, + Value source, + SmallVector size) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); if (auto insertSlice = dyn_cast(op)) { @@ -369,10 +365,7 @@ static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter, rewriter.replaceOpWithNewOp( insertSlice, source, insertSlice.getDest(), mixedOffsets, mixedSizes, mixedStrides); - } else { - return failure(); } - return success(); } using InnermostFullResultCallBackFn = std::function( @@ -691,7 +684,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { linalg::LinalgOp originOp, linalg::LinalgOp currentOp, innerBodyGenerationOption &option) const { - mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc()}; auto operandDimTypes = getOprandDimType(originOp); auto cfg = MatmulConfigAnalysis(originOp.getOperation()).getConfig(); @@ -744,6 +736,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { CInnermostDims = SmallVector{cfg.innerMostMBlock, cfg.innerMostNBlock}; } + if (NDimNum > 1) { firstN = true; firstK = true; @@ -780,21 +773,17 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { // update the extractSlice to static size, replace it with // useBlockedLayout when - if (failed(setStaticSizeForExtractSliceOp( - rewriter, currentOp.getDpsInits()[0].getDefiningOp(), true, - CInnermostDims, MDimNum > 1 ? 2 : 0)) || - failed(setStaticSizeForExtractSliceOp( - rewriter, currentOp.getDpsInputs()[1].getDefiningOp(), true, - BInnermostDims, NDimNum > 1)) || - failed(setStaticSizeForExtractSliceOp( - rewriter, currentOp.getDpsInputs()[0].getDefiningOp(), true, - AInnermostDims, MDimNum > 1)) || - (currentOp.getDpsInits().size() > 1 && - failed(setStaticSizeForExtractSliceOp( - rewriter, currentOp.getDpsInits()[1].getDefiningOp(), true, - CInnermostDims, MDimNum > 1 ? 2 : 0)))) { - return failure(); + setStaticSizeForExtractSliceOp(rewriter, + currentOp.getDpsInputs()[1].getDefiningOp(), + true, BInnermostDims, NDimNum > 1); + setStaticSizeForExtractSliceOp(rewriter, + currentOp.getDpsInputs()[0].getDefiningOp(), + true, AInnermostDims, MDimNum > 1); + for (auto init : currentOp.getDpsInits()) { + setStaticSizeForExtractSliceOp(rewriter, init.getDefiningOp(), true, + CInnermostDims, MDimNum > 1 ? 2 : 0); } + // View the tensor to brgemm required format Value dataOprand = tensorViewRankedTensor( rewriter, @@ -841,10 +830,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { // Insert the result back to the original tensor for (Operation *user : currentOp->getResult(0).getUsers()) { - if (failed(setStaticSizeForInsertSliceOp(rewriter, user, result, - CInnermostDims))) { - return failure(); - } + setStaticSizeForInsertSliceOp(rewriter, user, result, CInnermostDims); } if (option.needLowPrecisionCast) { @@ -869,10 +855,8 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { auto ifOp = eb.getLastOperaion(); // set static size for the insertSliceOp of copyOp for (Operation *user : currentOp->getResult(1).getUsers()) { - if (failed(setStaticSizeForInsertSliceOp( - rewriter, user, ifOp->getResult(0), CInnermostDims))) { - return failure(); - } + setStaticSizeForInsertSliceOp(rewriter, user, ifOp->getResult(0), + CInnermostDims); } rewriter.replaceOp(currentOp, {matmul->getResult(0), ifOp->getResult(0)}); } else { @@ -885,7 +869,11 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { if (cfg.KThreads <= 1) { // if use k slicing, the fill op is still need to be kept for the reduce // init - rewriter.replaceOp(fillOp, fillOp.getDpsInits()[0]); + rewriter.replaceUsesWithIf(fillOp.getResult(0), fillOp.getDpsInits()[0], + [&](OpOperand &operand) { + return isa( + operand.getOwner()); + }); } rewriter.setInsertionPointAfter(currentOp); @@ -954,8 +942,8 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { } // Step 2. Outer loop generation - auto outerLoopResult = outerLoopGeneration(rewriter, linalgOp, cfg, - isa(fillOp)); + auto outerLoopResult = outerLoopGeneration( + rewriter, linalgOp, cfg, fillOp && isa(fillOp)); if (failed(outerLoopResult)) { return failure(); } From b138713a1f3e08fc6ac9094f8230aa1e9307b031 Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Sun, 30 Jun 2024 18:27:21 -0700 Subject: [PATCH 13/21] tune config --- lib/gc/Analysis/MatmulConfigAnalysis.cpp | 286 ++++++++++++++--------- 1 file changed, 171 insertions(+), 115 deletions(-) diff --git a/lib/gc/Analysis/MatmulConfigAnalysis.cpp b/lib/gc/Analysis/MatmulConfigAnalysis.cpp index de2067566..d0147ee2e 100644 --- a/lib/gc/Analysis/MatmulConfigAnalysis.cpp +++ b/lib/gc/Analysis/MatmulConfigAnalysis.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include #include #include "gc/Analysis/MatmulConfigAnalysis.h" @@ -15,8 +16,6 @@ namespace gc { #define DEBUG_TYPE "matmul-config-analysis" -#define MAX_THREADS (1024U * 1024U) - llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, const MatmulConfig &config) { @@ -29,19 +28,36 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, return ss; } -std::vector getCandidate(uint32_t num, uint32_t floor, - uint32_t ceil) { +template +llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, std::vector arry) { + ss << "["; + for (auto [idx, a] : llvm::enumerate(arry)) { + if (idx != 0) { + ss << ", "; + } + ss << a; + } + ss << "]"; + return ss; +} + +std::vector +getCandidate(uint32_t num, uint32_t floor, + uint32_t ceil = std::numeric_limits::max()) { + // factor std::vector candidates; for (uint32_t i = 1; i <= num; i++) { if (num % i == 0 && i <= ceil && i >= floor) { candidates.push_back(i); } } + // the pow of 2 auto candidate = 1U; while (candidate < num && candidate <= ceil && candidate >= floor) { candidates.push_back(candidate); candidate *= 2; } + std::sort(candidates.begin(), candidates.end()); auto last = std::unique(candidates.begin(), candidates.end()); candidates.erase(last, candidates.end()); return candidates; @@ -53,15 +69,6 @@ bool isValidConfig(const MatmulConfig &config, SystemDesc &sysDesc, config.innerMostKBlock == 0) { return false; } - if (config.MBlock % config.innerMostMBlock != 0 || - config.NBlock % config.innerMostNBlock != 0 || - config.KBlock % config.innerMostKBlock != 0) { - return false; - } - auto threads = sysDesc.getNumThreads(); - if (config.MThreads * config.NThreads * config.KThreads != threads) { - return false; - } if (shape[0] % config.innerMostMBlock != 0 || shape[1] % config.innerMostNBlock != 0 || @@ -72,14 +79,13 @@ bool isValidConfig(const MatmulConfig &config, SystemDesc &sysDesc, return true; } -double threadUtilizationCost(linalg::LinalgOp &linalgOp, - ArrayRef shape, - const MatmulConfig &config, SystemDesc &sysDesc) { - auto threads = sysDesc.getNumThreads(); - auto actualThreads = - (float)(config.MThreads * config.NThreads * config.KThreads); - return threads >= actualThreads ? threads / actualThreads - : actualThreads / threads; +bool validateThreads(ArrayRef threads, SystemDesc &sysDesc) { + auto numThreads = sysDesc.getNumThreads(); + auto actualThreads = 1U; + for (auto t : threads) { + actualThreads *= t; + } + return actualThreads == numThreads; } double hardwareEfficiencyCost(linalg::LinalgOp &linalgOp, @@ -103,9 +109,21 @@ double hardwareEfficiencyCost(linalg::LinalgOp &linalgOp, double workloadBalancedCost(linalg::LinalgOp &linalgOp, ArrayRef shape, const MatmulConfig &config, SystemDesc &sysDesc) { - return 1; + auto M = shape[0], N = shape[1], K = shape[2]; + auto MTaskNum = llvm::divideCeil(M, config.MBlock); + auto NTaskNum = llvm::divideCeil(N, config.NBlock); + auto KTaskNum = llvm::divideCeil(K, config.KBlock); + auto cost = (MTaskNum % config.MThreads) * 1.0 / MTaskNum + + (NTaskNum % config.NThreads) * 1.0 / NTaskNum + + (KTaskNum % config.KThreads) * 1.0 / KTaskNum; + if (MTaskNum < config.MThreads || NTaskNum < config.NThreads || + KTaskNum < config.KThreads) { + auto threadNotFulllyUtilizedPenalty = 10.0; + cost *= threadNotFulllyUtilizedPenalty; + } + return cost; } - +constexpr unsigned bitPerByte = 8; double memoryConsumptionOnThreadCost(linalg::LinalgOp &linalgOp, ArrayRef shape, const MatmulConfig &config, @@ -113,30 +131,34 @@ double memoryConsumptionOnThreadCost(linalg::LinalgOp &linalgOp, auto M = shape[0], N = shape[1], K = shape[2]; auto dtypeSize = DataLayout().getTypeSizeInBits( ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); - auto penalty = 2.0 * (dtypeSize / 8); + // if use K split, there will be one more final reduce and break the post + // fusion + + auto KSplitPenalty = 8.0 * (dtypeSize / bitPerByte); auto memoryConsumptionPerThread = M * K * 1.0 / config.MThreads / config.KThreads + K * N * 1.0 / config.KThreads / config.NThreads + - M * N * ((config.KThreads - 1) * penalty + 1.0) / config.MThreads / + M * N * ((config.KThreads - 1) * KSplitPenalty + 1.0) / config.MThreads / config.NThreads; return memoryConsumptionPerThread; } -double computationIntensityOnL1Cache(linalg::LinalgOp &linalgOp, +double computationIntensityOnL2Cache(linalg::LinalgOp &linalgOp, ArrayRef shape, const MatmulConfig &config, SystemDesc &sysDesc) { - auto L1Cache = sysDesc.getCacheSize(2); + double simulationPenalty = 0.7; + auto L2Cache = sysDesc.getCacheSize(2); auto dtypeSize = DataLayout().getTypeSizeInBits( ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); auto outOfCachePenalty = 1024; - double FLOPS = - 2.0 * config.innerMostMBlock * config.innerMostNBlock * config.KBlock; - double memoryConsumption = config.innerMostMBlock * config.innerMostNBlock + - config.innerMostNBlock * config.KBlock + - config.innerMostMBlock * config.KBlock; + double FLOPS = 2.0 * config.MBlock * config.NBlock * config.KBlock; + double memoryConsumption = config.MBlock * config.NBlock + + config.NBlock * config.KBlock + + config.MBlock * config.KBlock; double computationIntensity = FLOPS / memoryConsumption; - if (memoryConsumption * (dtypeSize / 8) > L1Cache) { + if (memoryConsumption * (dtypeSize / bitPerByte) > + L2Cache * simulationPenalty) { computationIntensity /= outOfCachePenalty; } return 1 / computationIntensity; @@ -149,7 +171,7 @@ using CostModelFn = std::vector filterConfigByCostModel(std::vector configs, linalg::LinalgOp &linalgOp, ArrayRef shape, - SystemDesc &sysDesc, const CostModelFn &costModel, + SystemDesc &sysDesc, CostModelFn costModel, float eliminationRatio = 0.5, float threshold = -1) { std::vector result; std::vector costs; @@ -169,13 +191,13 @@ filterConfigByCostModel(std::vector configs, result.push_back(configs[idx[i]]); } } - llvm::outs() << "thresholdCost is: " << thresholdCost + llvm::errs() << "thresholdCost is: " << thresholdCost << "\nbest with cost: " << costs[idx[0]] << "\n" << configs[idx[0]] << "\n worst with cost: " << costs[idx[configs.size() - 1]] << "\n" << configs[idx[configs.size() - 1]] << "\n"; - return !result.empty() ? result : configs; + return result.size() > 0 ? result : configs; } std::vector @@ -184,19 +206,25 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc, ArrayRef givenInnermostBlock) { std::vector configs; auto threads = sysDesc.getNumThreads(); - auto MThreadsCandidates = getCandidate((uint32_t)threads, 1U, MAX_THREADS); - auto NThreadsCandidates = getCandidate((uint32_t)threads, 1U, MAX_THREADS); - auto KThreadsCandidates = getCandidate((uint32_t)threads, 1U, MAX_THREADS); - auto MBlockCandidates = - getCandidate((uint32_t)shape[0], 1U, (uint32_t)shape[0]); - auto NBlockCandidates = getCandidate((uint32_t)shape[1], 1U, shape[1]); - auto KBlockCandidates = getCandidate((uint32_t)shape[2], 1U, shape[2]); - auto innerMostMBlockCandidates = - getCandidate((uint32_t)shape[0], 1U, (uint32_t)shape[0]); - auto innerMostNBlockCandidates = - getCandidate((uint32_t)shape[1], 1U, (uint32_t)shape[1]); - auto innerMostKBlockCandidates = - getCandidate((uint32_t)shape[2], 1U, (uint32_t)shape[2]); + auto MThreadsCandidates = getCandidate((uint32_t)threads, 1U); + auto NThreadsCandidates = getCandidate((uint32_t)threads, 1U); + auto KThreadsCandidates = getCandidate((uint32_t)threads, 1U); + auto noSmallBlockNeedThreshold = 8 * 8U; + auto MBlockCandidates = getCandidate( + (uint32_t)shape[0], shape[0] > noSmallBlockNeedThreshold ? 8U : 1U, + (uint32_t)shape[0]); + auto NBlockCandidates = + getCandidate((uint32_t)shape[1], + shape[1] > noSmallBlockNeedThreshold ? 8U : 1U, shape[1]); + auto KBlockCandidates = + getCandidate((uint32_t)shape[2], + shape[2] > noSmallBlockNeedThreshold ? 8U : 1U, shape[2]); + auto innerMostMBlockCandidates = getCandidate( + (uint32_t)shape[0], shape[0] > noSmallBlockNeedThreshold ? 8U : 1U, 256U); + auto innerMostNBlockCandidates = getCandidate( + (uint32_t)shape[1], shape[1] > noSmallBlockNeedThreshold ? 8U : 1U, 256U); + auto innerMostKBlockCandidates = getCandidate( + (uint32_t)shape[2], shape[2] > noSmallBlockNeedThreshold ? 8U : 1U, 256U); if (givenInnermostBlock.size() == 3) { innerMostMBlockCandidates = givenInnermostBlock[0] != 0 @@ -211,38 +239,56 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc, ? std::vector{givenInnermostBlock[2]} : innerMostKBlockCandidates; } - llvm::outs() << "MThreadsCandidates size: " << MThreadsCandidates.size() + llvm::errs() << "MThreadsCandidates size: " << MThreadsCandidates.size() + << MThreadsCandidates << "\n"; + llvm::errs() << "NThreadsCandidates size: " << NThreadsCandidates.size() + << NThreadsCandidates << "\n"; + llvm::errs() << "KThreadsCandidates size: " << KThreadsCandidates.size() + << KThreadsCandidates << "\n"; + llvm::errs() << "MBlockCandidates size: " << MBlockCandidates.size() + << MBlockCandidates << "\n"; + llvm::errs() << "NBlockCandidates size: " << NBlockCandidates.size() + << NBlockCandidates << "\n"; + llvm::errs() << "KBlockCandidates size: " << KBlockCandidates.size() + << KBlockCandidates << "\n"; + llvm::errs() << "innerMostMBlockCandidates size: " + << innerMostMBlockCandidates.size() << innerMostMBlockCandidates << "\n"; - llvm::outs() << "NThreadsCandidates size: " << NThreadsCandidates.size() + llvm::errs() << "innerMostNBlockCandidates size: " + << innerMostNBlockCandidates.size() << innerMostNBlockCandidates << "\n"; - llvm::outs() << "KThreadsCandidates size: " << KThreadsCandidates.size() + llvm::errs() << "innerMostKBlockCandidates size: " + << innerMostKBlockCandidates.size() << innerMostKBlockCandidates << "\n"; - llvm::outs() << "MBlockCandidates size: " << MBlockCandidates.size() << "\n"; - llvm::outs() << "NBlockCandidates size: " << NBlockCandidates.size() << "\n"; - llvm::outs() << "KBlockCandidates size: " << KBlockCandidates.size() << "\n"; - llvm::outs() << "innerMostMBlockCandidates size: " - << innerMostMBlockCandidates.size() << "\n"; - llvm::outs() << "innerMostNBlockCandidates size: " - << innerMostNBlockCandidates.size() << "\n"; - llvm::outs() << "innerMostKBlockCandidates size: " - << innerMostKBlockCandidates.size() << "\n"; for (auto MThreads : MThreadsCandidates) { for (auto NThreads : NThreadsCandidates) { for (auto KThreads : KThreadsCandidates) { + if (!validateThreads({MThreads, NThreads, KThreads}, sysDesc)) { + continue; + } for (auto MBlock : MBlockCandidates) { - for (auto NBlock : NBlockCandidates) { - for (auto KBlock : KBlockCandidates) { - for (auto innerMostMBlock : innerMostMBlockCandidates) { - for (auto innerMostNBlock : innerMostNBlockCandidates) { + for (auto innerMostMBlock : innerMostMBlockCandidates) { + if (MBlock % innerMostMBlock != 0 || + shape[0] % innerMostMBlock != 0) { + continue; + } + for (auto NBlock : NBlockCandidates) { + for (auto innerMostNBlock : innerMostNBlockCandidates) { + if (NBlock % innerMostNBlock != 0 || + shape[1] % innerMostNBlock != 0) { + continue; + } + for (auto KBlock : KBlockCandidates) { for (auto innerMostKBlock : innerMostKBlockCandidates) { + if (KBlock % innerMostKBlock != 0 || + shape[2] % innerMostKBlock != 0) { + continue; + } MatmulConfig config{ MBlock, NBlock, KBlock, MThreads, NThreads, KThreads, innerMostMBlock, innerMostNBlock, innerMostKBlock}; - - if (isValidConfig(config, sysDesc, shape)) { - configs.push_back(config); - } + configs.push_back(config); } } } @@ -252,9 +298,38 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc, } } } + llvm::errs() << "Finish generating candidates. ConfigCandidates size: " + << configs.size() << "\n"; return configs; } +bool readConfigFromAttrs(MatmulConfig &config, ArrayRef attrs) { + bool hasPredefinedConfig = false; + for (auto attr : attrs) { + if (attr.getName() == "KBlock") { + config.KBlock = cast(attr.getValue()).getInt(); + hasPredefinedConfig = true; + } else if (attr.getName() == "KThreads") { + config.KThreads = cast(attr.getValue()).getInt(); + } else if (attr.getName() == "NBlock") { + config.NBlock = cast(attr.getValue()).getInt(); + } else if (attr.getName() == "NThreads") { + config.NThreads = cast(attr.getValue()).getInt(); + } else if (attr.getName() == "MBlock") { + config.MBlock = cast(attr.getValue()).getInt(); + } else if (attr.getName() == "MThreads") { + config.MThreads = cast(attr.getValue()).getInt(); + } else if (attr.getName() == "innerMostMBlock") { + config.innerMostMBlock = cast(attr.getValue()).getInt(); + } else if (attr.getName() == "innerMostNBlock") { + config.innerMostNBlock = cast(attr.getValue()).getInt(); + } else if (attr.getName() == "innerMostKBlock") { + config.innerMostKBlock = cast(attr.getValue()).getInt(); + } + } + return hasPredefinedConfig; +} + /* thread utilization computation intensity @@ -269,7 +344,6 @@ previous matmul MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { SystemDesc sysDesc; if (auto linalgOp = dyn_cast(root)) { - // TODO: build a more complex heuristic to determine the best tiling auto oprandDimType = *getOprandDimType(linalgOp); // get the origin M,N,K size auto MDimTypeIdx = extractDimTypeIdx(oprandDimType[0], DimType::M); @@ -292,7 +366,6 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { K *= s; } } - // innermost Block, if the layout is blockied layout, the innermost block // will derived from the layout directly auto defaultBlock = 32; @@ -331,56 +404,39 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { givenInnermostBlock.push_back(0); } - // Number of block - auto MNumBlock = M / config.innerMostMBlock; - auto NNumBlock = N / config.innerMostNBlock; - auto KNumBlock = K / config.innerMostKBlock; - - // Threads - config.MThreads = 32; - config.NThreads = 1; - config.KThreads = 1; - - // Block - config.MBlock = (int)llvm::divideCeil(MNumBlock, config.MThreads) * - config.innerMostMBlock; - config.NBlock = (int)llvm::divideCeil(NNumBlock, config.NThreads) * - config.innerMostNBlock; - config.KBlock = (int)llvm::divideCeil(KNumBlock, config.KThreads) * - config.innerMostKBlock; - config.MBlock = 128; - config.NBlock = 128; - config.KBlock = 128; - config.MThreads = 2; - config.NThreads = 2; - config.KThreads = 1; + llvm::errs() << "M: " << M << ", N: " << N << ", K: " << K << "\n"; - llvm::outs() << "M: " << M << ", N: " << N << ", K: " << K << "\n"; + SmallVector> costModelList = { + {workloadBalancedCost, "workloadBalancedCost", 1}, + {hardwareEfficiencyCost, "hardwareEfficiencyCost", -1}, + {computationIntensityOnL2Cache, "computationIntensityOnL2Cache", -1}, + {memoryConsumptionOnThreadCost, "memoryConsumptionOnThreadCost", -1}}; - SmallVector> costModelList = { - {threadUtilizationCost, "threadUtilizationCost"}, - {hardwareEfficiencyCost, "hardwareEfficiencyCost"}, - {workloadBalancedCost, "workloadBalancedCost"}, - {memoryConsumptionOnThreadCost, "memoryConsumptionOnThreadCost"}, - {computationIntensityOnL1Cache, "computationIntensityOnL1Cache"}}; + SmallVector attrs(linalgOp->getAttrs()); + bool hasPredefinedConfig = readConfigFromAttrs(config, attrs); - auto configCandidates = - prepareConfigCandidates(root, sysDesc, {M, N, K}, givenInnermostBlock); - - for (auto [fn, name] : costModelList) { - llvm::outs() << name << "\n\n"; - configCandidates = filterConfigByCostModel(configCandidates, linalgOp, - {M, N, K}, sysDesc, fn, 0.5); - llvm::outs() << "ConfigCandidates size: " << configCandidates.size() - << "\n"; - } - - if (!configCandidates.empty()) { - config = configCandidates[0]; + if (!hasPredefinedConfig) { + llvm::errs() << "No predefined config\n"; + auto configCandidates = prepareConfigCandidates(root, sysDesc, {M, N, K}, + givenInnermostBlock); + for (auto [fn, name, threshold] : costModelList) { + llvm::errs() << "\n" << name << "\n"; + configCandidates = filterConfigByCostModel( + configCandidates, linalgOp, {M, N, K}, sysDesc, fn, 0.5, threshold); + llvm::errs() << "ConfigCandidates size: " << configCandidates.size() + << "\n"; + } + if (configCandidates.size() > 0) { + config = configCandidates[0]; + } } - llvm::outs() << "Final config\nNumThreads: " << sysDesc.getNumThreads() + llvm::errs() << "Final config\nNumThreads: " << sysDesc.getNumThreads() << ", MatmulConfig: " << config << "\n"; + for (auto [fn, name, threshold] : costModelList) { + auto cost = fn(linalgOp, {M, N, K}, config, sysDesc); + llvm::errs() << name << ": " << cost << "\n"; + } } } } // namespace gc From 8c7d155bd23fad0f0beda0d91da9cc6bf01b75d3 Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Thu, 11 Jul 2024 23:54:15 -0700 Subject: [PATCH 14/21] add merge forall pass --- include/gc/Transforms/Passes.td | 13 + lib/gc/Transforms/CMakeLists.txt | 2 + .../Transforms/DeepTileContractionNamedOp.cpp | 38 +- lib/gc/Transforms/MergeNestedForall.cpp | 100 +++++ lib/gc/Transforms/Pipeline.cpp | 28 +- lib/gc/Transforms/SinkOpIntoInnerLoop.cpp | 51 +++ lib/gc/Transforms/Tiling.cpp | 360 ++---------------- 7 files changed, 241 insertions(+), 351 deletions(-) create mode 100644 lib/gc/Transforms/MergeNestedForall.cpp create mode 100644 lib/gc/Transforms/SinkOpIntoInnerLoop.cpp diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 2933b65ba..d5330851b 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -93,4 +93,17 @@ def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> { ]; } +def SinkOpIntoInnerLoop : Pass<"sink-op-into-inner-loop"> { + let summary = "Sink operations into inner loops"; + let description = [{The pass tries to sink operations into inner loops as deep as possible to maximize the chance for outer loop optimization. + }]; + let dependentDialects = []; +} + +def MergeNestedForall : Pass<"merge-nested-forall"> { + let summary = "Merge nested scf.forall operations"; + let description = [{The pass tries to merge nested forall operations.}]; + let dependentDialects = ["scf::SCFDialect"]; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 4da3ef5f8..3673103a5 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -17,6 +17,8 @@ gc_add_mlir_library(GcPasses VerifyTargetDescription.cpp DeepTileContractionNamedOp.cpp Tiling.cpp + SinkOpIntoInnerLoop.cpp + MergeNestedForall.cpp DEPENDS GraphCompilerPassIncGen diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index 3cb54c191..084ce539d 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -464,30 +464,36 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, currentOp.getNumLoops(), getAsIndexOpFoldResult(b.getContext(), 0)); SmallVector reductionDims; currentOp.getReductionDims(reductionDims); + bool tileOnReduction = false; for (auto [d, tile] : llvm::zip(currentDim, currentTileSize)) { + if (llvm::find(reductionDims, d) != reductionDims.end()) { + tileOnReduction = true; + } if (llvm::find(reductionDims, d) != reductionDims.end() && - !dyn_cast(currentOp.getOperation())) + !dyn_cast(currentOp.getOperation())) { tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), 0); - else + tileOnReduction = false; + } else tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), tile); } SmallVector loopRanges = cast(currentOp.getOperation()).getIterationDomain(b); OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(currentOp); - if (auto partialInterface = - dyn_cast(currentOp.getOperation())) { + if (tileOnReduction) { + auto partialInterface = + dyn_cast(currentOp.getOperation()); for (auto [idx, tile] : llvm::enumerate(tileSizes)) { - if (isConstantIntValue(tile, 0)) { + if (isConstantIntValue(tile, 0) && + llvm::find(reductionDims, d) != reductionDims.end()) { tileSizes[idx] = loopRanges[idx].size; } } - SmallVector newParallelDims; for (auto i = 0UL; i < reductionDims.size(); i++) { newParallelDims.push_back(getAsIndexOpFoldResult(b.getContext(), i)); } - auto tilingResult = linalgX::tileAllUsingForall( + auto tilingResult = linalgX::tileReductionUsingForall( b, cast(currentOp.getOperation()), {}, tileSizes, newParallelDims, std::nullopt); if (failed(tilingResult) && @@ -503,8 +509,8 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, } } } - } else if (auto tilingInterface = - cast(currentOp.getOperation())) { + } else { + auto tilingInterface = cast(currentOp.getOperation()); auto tilingResult = linalg::tileToForallOpUsingTileSizes( b, tilingInterface, tileSizes, std::nullopt); if (failed(tilingResult)) @@ -597,11 +603,15 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1 : cfg.NBlock; // Outer - option.nestedTileSizes.emplace_back(SmallVector{ - MParallelBlockSize, NParallelBlockSize, KParallelBlockSize}); - option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForallOp); - option.loopDim.emplace_back( - SmallVector{MDimPos[0], NDimPos[0], KDimPos[0]}); + for (auto [tile, dim] : + llvm::zip(SmallVector{KParallelBlockSize, MParallelBlockSize, + NParallelBlockSize}, + SmallVector{KDimPos[0], MDimPos[0], NDimPos[0]})) { + option.nestedTileSizes.emplace_back(SmallVector{tile}); + option.loopType.emplace_back( + OuterLoopGenerationOption::LoopType::ForallOp); + option.loopDim.emplace_back(SmallVector{dim}); + } // Middle for (auto [tile, dim] : llvm::zip(SmallVector{MOuterBlockSize, NOuterBlockSize, diff --git a/lib/gc/Transforms/MergeNestedForall.cpp b/lib/gc/Transforms/MergeNestedForall.cpp new file mode 100644 index 000000000..cd0442c4a --- /dev/null +++ b/lib/gc/Transforms/MergeNestedForall.cpp @@ -0,0 +1,100 @@ +//===-- MergeNestedForall.cpp - DESC -------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/ControlFlowSinkUtils.h" + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_MERGENESTEDFORALL +#include "gc/Transforms/Passes.h.inc" + +namespace { + +struct MergeNestedForallLoops : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForallOp op, + PatternRewriter &rewriter) const override { + Block &outerBody = *op.getBody(); + if (!llvm::hasSingleElement(outerBody.without_terminator())) + return failure(); + + auto innerOp = dyn_cast(outerBody.front()); + if (!innerOp) + return failure(); + + for (auto val : outerBody.getArguments()) + if (llvm::is_contained(innerOp.getDynamicLowerBound(), val) || + llvm::is_contained(innerOp.getDynamicUpperBound(), val) || + llvm::is_contained(innerOp.getDynamicStep(), val)) + return failure(); + + // Reductions are not supported yet. + if (!op.getInits().empty() || !innerOp.getInits().empty()) + return failure(); + + auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/, + ValueRange iterVals) { + Block &innerBody = *innerOp.getBody(); + assert(iterVals.size() == + (outerBody.getNumArguments() + innerBody.getNumArguments())); + IRMapping mapping; + mapping.map(outerBody.getArguments(), + iterVals.take_front(outerBody.getNumArguments())); + mapping.map(innerBody.getArguments(), + iterVals.take_back(innerBody.getNumArguments())); + for (Operation &op : innerBody) + builder.clone(op, mapping); + }; + + auto concatValues = [](const auto &first, const auto &second) { + SmallVector ret; + ret.reserve(first.size() + second.size()); + ret.assign(first.begin(), first.end()); + ret.append(second.begin(), second.end()); + return ret; + }; + + auto newLowerBounds = + concatValues(op.getMixedLowerBound(), innerOp.getMixedLowerBound()); + auto newUpperBounds = + concatValues(op.getMixedUpperBound(), innerOp.getMixedUpperBound()); + auto newSteps = concatValues(op.getMixedStep(), innerOp.getMixedStep()); + rewriter.replaceOpWithNewOp( + op, newLowerBounds, newUpperBounds, newSteps, ValueRange{}, + std::nullopt, bodyBuilder); + return success(); + } +}; + +struct MergeNestedForall + : public impl::MergeNestedForallBase { +public: + void runOnOperation() final { + auto &ctx = getContext(); + RewritePatternSet patterns(&ctx); + + patterns.add(patterns.getContext()); + + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 459e77fa8..510a186b5 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -34,6 +34,17 @@ namespace mlir::gc { +void populateCleanUpPasses(mlir::PassManager &pm) { + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(createLoopInvariantCodeMotionPass()); + pm.addPass(createControlFlowSinkPass()); + pm.addPass(createCSEPass()); + pm.addPass(createSCCPPass()); + pm.addPass(createMem2Reg()); + pm.addPass(createTopologicalSortPass()); +} + // linalg + linalgX + tensor void populateFrontendPasses(mlir::OpPassManager &pm) { #ifdef GC_HAS_ONEDNN_DIALECT @@ -46,14 +57,17 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // todo: padding propagation pass // todo: layout propagation pass // todo: tensor constant propagation pass - // todo: linalg.matmul lowering to (scf.loop + linalg.brgemm) pass - // Fine-grain fusion pass + // linalg.matmul lowering to (scf.loop + linalg.brgemm) pass pm.addNestedPass(createIterativeTilingAndFusion()); + // Fine-grain fusion pass + pm.addNestedPass(createDeepTileContractionNamedOp()); + // todo: fine-grain fusion pass // todo: lower linalg to arith/math on virtual vector pass // REMOVE this pass after the above passes are added. Currently we add this // pass to make the pipeline work properly pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); + populateCleanUpPasses(pm); } // scf + arith + math + vector + tensor + linalg.brgemm @@ -72,6 +86,7 @@ void populateVectorPasses(mlir::OpPassManager &pm) { // oneDNN graph spec pm.addNestedPass(arith::createArithExpandOpsPass()); // todo: lower to physical vector pass, device dependent pass + populateCleanUpPasses(pm); } // scf + arith + math + vector + memref + linalg.brgemm @@ -91,6 +106,7 @@ void populateBufferizationPasses(mlir::OpPassManager &pm) { pm.addNestedPass(bufferization::createBufferLoopHoistingPass()); pm.addNestedPass(bufferization::createBufferDeallocationPass()); pm.addPass(createBufferizationToMemRefPass()); + populateCleanUpPasses(pm); } // scf + arith + math + vector + memref + func/microkernel @@ -107,6 +123,12 @@ void populateMicroKernelPasses(mlir::OpPassManager &pm) { void populateCPURuntimePasses(mlir::OpPassManager &pm) { // todo: flatten nested parallel pass to support coarse-grain usion // remove this pass after we add FlattenNestedParallel + pm.addPass(createSinkOpIntoInnerLoop()); + pm.addPass(createMergeNestedForall()); + populateCleanUpPasses(pm); + pm.addPass(createForallToParallelLoopPass()); + pm.addPass(createParallelLoopFusionPass()); + pm.addPass(createLoopInvariantCodeMotionPass()); pm.addPass(createConvertSCFToOpenMPPass()); } @@ -149,7 +171,7 @@ void populateCPUPipeline(mlir::OpPassManager &pm) { pm.addNestedPass(createConvertLinalgToParallelLoopsPass()); populateMicroKernelPasses(pm); populateCPURuntimePasses(pm); - // // back-end, llvm dialect + // back-end, llvm dialect populateLLVMPasses(pm); } diff --git a/lib/gc/Transforms/SinkOpIntoInnerLoop.cpp b/lib/gc/Transforms/SinkOpIntoInnerLoop.cpp new file mode 100644 index 000000000..426b1e258 --- /dev/null +++ b/lib/gc/Transforms/SinkOpIntoInnerLoop.cpp @@ -0,0 +1,51 @@ +//===-- SinkOpIntoInnerLoop.cpp - DESC -------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" + +#include "mlir/IR/Dominance.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/ControlFlowSinkUtils.h" + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_SINKOPINTOINNERLOOP +#include "gc/Transforms/Passes.h.inc" + +namespace { + +struct SinkOpIntoInnerLoop + : public impl::SinkOpIntoInnerLoopBase { +public: + void runOnOperation() final { + auto &domInfo = getAnalysis(); + getOperation()->walk([&](LoopLikeOpInterface loop) { + SmallVector regionsToSink; + // Get the regions are that known to be executed at most once. + for (auto &it : loop->getRegions()) { + regionsToSink.push_back(&it); + } + // Sink side-effect free operations. + controlFlowSink( + regionsToSink, domInfo, + [](Operation *op, Region *) { return isMemoryEffectFree(op); }, + [](Operation *op, Region *region) { + // Move the operation to the beginning of the region's entry block. + // This guarantees the preservation of SSA dominance of all of the + // operation's uses are in the region. + op->moveBefore(®ion->front(), region->front().begin()); + }); + }); + } +}; + +} // namespace +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/Tiling.cpp b/lib/gc/Transforms/Tiling.cpp index ea4d73722..cd01067c7 100644 --- a/lib/gc/Transforms/Tiling.cpp +++ b/lib/gc/Transforms/Tiling.cpp @@ -629,336 +629,36 @@ FailureOr tileReductionUsingForall( auto *it = llvm::find(dest, initOperand); assert(it != dest.end() && "dest operand not found in dest"); unsigned destNum = std::distance(dest.begin(), it); - SmallVector strides(numThreads.size(), b.getIndexAttr(1)); - SmallVector outOffsets(numThreads.size(), - b.getIndexAttr(0)); - SmallVector sizes = tiledSizes; - - auto currentReductionIdx = 0; - for (const auto &iteratorType : llvm::enumerate(tiledSizes)) { + auto dest = destBbArgs[destNum]; + auto destShape = cast(dest.getType()).getShape(); + SmallVector strides(destShape.size(), b.getIndexAttr(1)); + SmallVector outOffsets(destShape.size(), b.getIndexAttr(0)); + SmallVector sizes(destShape.size(), b.getIndexAttr(0)); + for (const auto &iteratorType : + llvm::enumerate(cast(dest.getType()).getShape())) { + sizes[iteratorType.index()] = + getAsIndexOpFoldResult(b.getContext(), iteratorType.value()); if (llvm::find(constantNewParallelDims, iteratorType.index()) != constantNewParallelDims.end()) { sizes[iteratorType.index()] = b.getIndexAttr(1); - currentReductionIdx++; - } else { - if (llvm::find(redDims, iteratorType.index() - currentReductionIdx) != - redDims.end()) { - currentReductionIdx--; - } - sizes[iteratorType.index()] = - tiledSizes[iteratorType.index() - currentReductionIdx]; } } + auto nonZeroDimIdx = 0; + auto currentReductionIdx = 0; for (const auto &iteratorType : llvm::enumerate(numThreads)) { if (!isConstantIntValue(iteratorType.value(), 0)) { - outOffsets[constantNewParallelDims[nonZeroDimIdx]] = - forallOp.getInductionVars()[nonZeroDimIdx]; + if (llvm::find(redDims, iteratorType.index()) != redDims.end()) { + outOffsets[constantNewParallelDims[currentReductionIdx++]] = + forallOp.getInductionVars()[nonZeroDimIdx]; + } nonZeroDimIdx++; } } // TODO: use SubsetExtractOpInterface once it is available. tiledDpsInitOperands.push_back(b.create( - loc, cast(initOperand.getType()), - destBbArgs[destNum], outOffsets, sizes, strides)); - } - - // 4.b. Clone the op and update init operands. - // We cannot use a IRMapping here because it can replace - // different OpOperands with the same value. - Operation *clonedOp = b.clone(*op.getOperation()); - b.modifyOpInPlace(clonedOp, [&]() { - for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal( - cast(clonedOp).getDpsInitsMutable(), - tiledDpsInitOperands)) { - initOperandPtr.set(tiledInitValue); - } - }); - // 5. Tile the cloned op and delete the clone. - if (tileSizes.empty() || threadNums.empty()) { - FailureOr tilingResult = - cast(clonedOp).getTiledImplementation( - b, tiledOffsets, tiledSizes); - if (failed(tilingResult)) - return clonedOp->emitError("Failed to tile op: "); - if (tilingResult->tiledOps.size() != 1) { - return clonedOp->emitError("expected a single produced tiled op, got ") - << tilingResult->tiledOps.size(); - } - tiledOp = tilingResult->tiledOps.front(); - tilingResults = tilingResult->tiledValues; - } else { - LinalgTilingOptions options; - FailureOr maybeTiled = tileLinalgOpImpl( - b, cast(clonedOp), tileSizes, options); - if (failed(maybeTiled)) - return b.notifyMatchFailure(op, "failed tileLinalgOpImpl"); - - SmallVector ids = forallOp.getInductionVars(); - mapLoopToProcessorIds(cast(maybeTiled->loops.back()), ids, - materializedNonZeroNumThreads); - if (maybeTiled->loops.size() != 1) { - return clonedOp->emitError("expected a single produced loop"); - } - tiledOp = maybeTiled->op; - tilingResults = maybeTiled->loops.front()->getResults(); - } - - b.eraseOp(clonedOp); - } - - // 6. Insert the partial reductions back into a new tensor. - for (auto [index, result, bbArg] : llvm::zip( - llvm::seq(0, dest.size()), tilingResults, destBbArgs)) { - // 6.a. Partial subset information is inserted just before the terminator. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(forallOp.getTerminator()); - - SmallVector resultOffsets, resultSizes; - if (failed(tilingInterfaceOp.getResultTilePosition( - b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes))) - return op->emitOpError("output offsets couldn't be calculated"); - SmallVector resultOffsetsRank, resultSizesRank; - int64_t offIdx = 0; - int64_t sizeIdx = 0; - int64_t nonZeroDimIdx = 0; - for (int64_t i = 0, e = numThreads.size(); i < e; ++i) { - if (llvm::find(constantNewParallelDims, i) != - constantNewParallelDims.end()) { - resultOffsetsRank.push_back(forallOp.getInductionVars()[nonZeroDimIdx]); - resultSizesRank.push_back(b.getIndexAttr(1)); - nonZeroDimIdx++; - continue; - } - if (!isConstantIntValue(numThreads[i], 0)) { - nonZeroDimIdx++; - } - resultOffsetsRank.push_back(resultOffsets[offIdx++]); - resultSizesRank.push_back(resultSizes[sizeIdx++]); - } - SmallVector strides(resultSizesRank.size(), - b.getIndexAttr(1)); - - // 6.b. Parallel insertions are inserted at the end of the combining - // terminator. - b.setInsertionPointToEnd(forallOp.getTerminator().getBody()); - b.create( - loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides); - } - // 7. Merge the partial reductions. - b.setInsertionPointAfter(forallOp); - Operation *mergeOp = - linalgX::LinalgOpPartialReductionInterface::mergeReductions( - op, b, loc, forallOp->getResults(), constantNewParallelDims); - b.replaceOp(op, mergeOp->getResults()); - // 8. Return. - ForallReductionTilingResult results; - results.initialValues = initTensors; - results.loops = forallOp; - results.parallelTiledOps = {tiledOp}; - results.mergeOps = {mergeOp}; - return results; -} - -template -FailureOr static tileLinalgOpImpl( - RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) { - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - - if (!options.tileSizeComputationFunction) - return failure(); - - // Enforce the convention that "tiling by zero" skips tiling a particular - // dimension. This convention is significantly simpler to handle instead of - // adjusting affine maps to account for missing dimensions. - auto nLoops = op.getNumLoops(); - SmallVector tileSizeVector = - getAsOpFoldResult(options.tileSizeComputationFunction(b, op)); - if (tileSizeVector.size() < nLoops) { - tileSizeVector.append(nLoops - tileSizeVector.size(), b.getIndexAttr(0)); - } - - return tileLinalgOpImpl(b, op, tileSizeVector, options); -} - -FailureOr tileAllUsingForall( - RewriterBase &b, PartialReductionOpInterface op, - ArrayRef threadNums, ArrayRef tileSizes, - ArrayRef newParallelDims, std::optional mapping) { - Location loc = op.getLoc(); - OpBuilder::InsertionGuard g(b); - - // Ops implementing PartialReductionOpInterface are expected to implement - // TilingInterface. - // TODO: proper core mechanism to tie interfaces together. - auto tilingInterfaceOp = cast(op.getOperation()); - - // Ops implementing PartialReductionOpInterface are not necessarily expected - // to implement TilingInterface.. This cast is unsafe atm. - // TODO: proper core mechanism to tie interfaces together. - // TODO: this function requires a pair of interfaces .. - auto destinationStyleOp = - dyn_cast(op.getOperation()); - if (!destinationStyleOp) - return b.notifyMatchFailure(op, "not a destination style op"); - - // Actually this only work for Linalg ops atm. - auto linalgOp = dyn_cast(op.getOperation()); - if (!linalgOp) - return b.notifyMatchFailure(op, "not a linalg op"); - - SmallVector iterationDomain = tilingInterfaceOp.getIterationDomain(b); - if (op->getNumResults() != 1) - return b.notifyMatchFailure( - op, "don't support ops with multiple results for now"); - - SmallVector iterators = - tilingInterfaceOp.getLoopIteratorTypes(); - SmallVector redDims; - for (auto [idx, iteratorType] : - llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { - if (iteratorType == utils::IteratorType::reduction) - redDims.push_back(idx); - } - - SmallVector numThreads(threadNums.begin(), threadNums.end()); - if (numThreads.empty()) { - SmallVector loopRanges = tilingInterfaceOp.getIterationDomain(b); - unsigned nLoops = loopRanges.size(); - numThreads.reserve(nLoops); - AffineExpr s0, s1; - bindSymbols(b.getContext(), s0, s1); - AffineExpr divExpr = s0.ceilDiv(s1); - for (const auto &it : llvm::zip(tileSizes, loopRanges)) { - OpFoldResult numTiles = std::get<0>(it); - if (!isConstantIntValue(numTiles, 0)) - numTiles = makeComposedFoldedAffineApply( - b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)}); - numThreads.push_back(numTiles); - } - } - - bool hasReductionThreads = false; - for (auto dim : redDims) { - if (!isConstantIntValue(numThreads[dim], 0) && - !isConstantIntValue(numThreads[dim], 1)) { - hasReductionThreads = true; - break; - } - } - - if (!tileSizes.empty() && tileSizes.size() != numThreads.size()) - return b.notifyMatchFailure(op, "if tile sizes are present it must have as " - "many elements as number of threads"); - - if ((unsigned)redDims.front() >= numThreads.size()) - return b.notifyMatchFailure( - op, "reduction dimension must be mapped to threads"); - SmallVector constantNewParallelDims; - for (auto dim : newParallelDims) { - if (getConstantIntValue(dim) == std::nullopt) - return b.notifyMatchFailure( - op, "Expected new parallel dims to be constant integers."); - constantNewParallelDims.push_back(*getConstantIntValue(dim)); - } - if (newParallelDims.empty()) - constantNewParallelDims = redDims; - if (constantNewParallelDims.size() != redDims.size()) - return b.notifyMatchFailure( - op, "reduction dimension must be mapped to new parallel dims"); - // 1. Create the inital tensor value. - FailureOr> maybeInitTensors; - SmallVector initTensors; - if (hasReductionThreads) { - maybeInitTensors = LinalgOpPartialReductionInterface:: - generateInitialTensorForPartialReduction( - op, b, loc, numThreads, redDims, constantNewParallelDims); - if (failed(maybeInitTensors)) - return b.notifyMatchFailure( - op, "Failed to create inital tensors for partial reduction"); - initTensors = maybeInitTensors.value(); - } - - // Gather destination tensors. - SmallVector dest; - if (failed(tensor::getOrCreateDestinations(b, loc, op, dest))) - return b.notifyMatchFailure(op, "failed to get destination tensors"); - Operation *tiledOp = nullptr; - - SmallVector nonZeroNumThreads = - llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { - return !isConstantIntValue(ofr, 0); - })); - SmallVector materializedNonZeroNumThreads = - getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads); - // 2. Create the ForallOp with an empty region. - scf::ForallOp forallOp = b.create( - loc, getAsOpFoldResult(materializedNonZeroNumThreads), - hasReductionThreads ? initTensors : dest, mapping); - // 3. Calculate the tile offsets and sizes for the subsequent loop that will - // be nested under `forallOp`. - SmallVector tiledOffsets, tiledSizes; - std::optional> nominalTileSizes = std::nullopt; - if (!tileSizes.empty() && threadNums.empty()) { - nominalTileSizes = tileSizes; - } - calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, iterationDomain, - /*omitTileOffsetBoundsCheck =*/false, - /*nominalTileSizes=*/nominalTileSizes, - tiledOffsets, tiledSizes); - // 4. Clone the tileable op and update its destination operands to use the - // output bbArgs of the ForallOp. - SmallVector tilingResults; - ArrayRef destBbArgs = forallOp.getRegionIterArgs(); - { - // 4.a. RAII guard, inserting within forallOp, before terminator. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(forallOp.getTerminator()); - - SmallVector tiledDpsInitOperands; - for (Value initOperand : destinationStyleOp.getDpsInits()) { - if (hasReductionThreads) { - auto *it = llvm::find(dest, initOperand); - assert(it != dest.end() && "dest operand not found in dest"); - unsigned destNum = std::distance(dest.begin(), it); - auto dest = destBbArgs[destNum]; - auto destShape = cast(dest.getType()).getShape(); - SmallVector strides(destShape.size(), b.getIndexAttr(1)); - SmallVector outOffsets(destShape.size(), - b.getIndexAttr(0)); - SmallVector sizes(destShape.size(), b.getIndexAttr(0)); - for (const auto &iteratorType : llvm::enumerate( - cast(dest.getType()).getShape())) { - sizes[iteratorType.index()] = - getAsIndexOpFoldResult(b.getContext(), iteratorType.value()); - if (llvm::find(constantNewParallelDims, iteratorType.index()) != - constantNewParallelDims.end()) { - sizes[iteratorType.index()] = b.getIndexAttr(1); - } - } - - auto nonZeroDimIdx = 0; - auto currentReductionIdx = 0; - for (const auto &iteratorType : llvm::enumerate(numThreads)) { - if (!isConstantIntValue(iteratorType.value(), 0)) { - if (llvm::find(redDims, iteratorType.index()) != redDims.end()) { - outOffsets[constantNewParallelDims[currentReductionIdx++]] = - forallOp.getInductionVars()[nonZeroDimIdx]; - } - nonZeroDimIdx++; - } - } - // TODO: use SubsetExtractOpInterface once it is available. - tiledDpsInitOperands.push_back(b.create( - loc, cast(initOperand.getType()), dest, - outOffsets, sizes, strides)); - } else { - auto *it = llvm::find(dest, initOperand); - assert(it != dest.end() && "dest operand not found in dest"); - unsigned destNum = std::distance(dest.begin(), it); - tiledDpsInitOperands.push_back(destBbArgs[destNum]); - } + loc, cast(initOperand.getType()), dest, outOffsets, + sizes, strides)); } // 4.b. Clone the op and update init operands. @@ -1023,10 +723,8 @@ FailureOr tileAllUsingForall( for (auto i = 0UL; i < numThreads.size(); ++i) { if (llvm::find(constantNewParallelDims, i) != constantNewParallelDims.end()) { - if (hasReductionThreads) { - resultOffsetsRank.push_back(b.getIndexAttr(1)); - resultSizesRank.push_back(b.getIndexAttr(1)); - } + resultOffsetsRank.push_back(b.getIndexAttr(1)); + resultSizesRank.push_back(b.getIndexAttr(1)); } else if (offIdx < resultOffsets.size()) { resultOffsetsRank.push_back(resultOffsets[offIdx]); resultSizesRank.push_back(resultSizes[offIdx++]); @@ -1039,12 +737,10 @@ FailureOr tileAllUsingForall( nonZeroDimIdx++; } } - if (hasReductionThreads) { - for (auto [parallelDims, redVar] : - llvm::zip(constantNewParallelDims, reductionInductionVars)) { - resultOffsetsRank[parallelDims] = redVar; - resultSizesRank[parallelDims] = b.getIndexAttr(1); - } + for (auto [parallelDims, redVar] : + llvm::zip(constantNewParallelDims, reductionInductionVars)) { + resultOffsetsRank[parallelDims] = redVar; + resultSizesRank[parallelDims] = b.getIndexAttr(1); } SmallVector strides(resultSizesRank.size(), b.getIndexAttr(1)); @@ -1058,13 +754,9 @@ FailureOr tileAllUsingForall( // 7. Merge the partial reductions. Operation *mergeOp = nullptr; b.setInsertionPointAfter(forallOp); - if (hasReductionThreads) { - mergeOp = linalgX::LinalgOpPartialReductionInterface::mergeReductions( - op, b, loc, forallOp->getResults(), constantNewParallelDims); - b.replaceOp(op, mergeOp->getResults()); - } else { - b.replaceOp(op, forallOp->getResults()); - } + mergeOp = linalgX::LinalgOpPartialReductionInterface::mergeReductions( + op, b, loc, forallOp->getResults(), constantNewParallelDims); + b.replaceOp(op, mergeOp->getResults()); // 8. Return. ForallReductionTilingResult results; results.initialValues = initTensors; From df1c683b2af2555fc4da75835fcc60d082627233 Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Sun, 21 Jul 2024 19:00:49 -0700 Subject: [PATCH 15/21] polish code --- docs/deep_tile_matmul_design.md | 594 ++++++++++++++++++ include/gc/Analysis/MatmulConfigAnalysis.h | 87 ++- include/gc/Dialect/Arith/Utils/EasyBuild.h | 433 ------------- include/gc/IR/EasyBuild.h | 102 --- include/gc/IR/EasyBuildSCF.h | 187 ------ lib/gc/Analysis/MatmulConfigAnalysis.cpp | 373 ++++++----- lib/gc/Transforms/CMakeLists.txt | 2 +- .../Transforms/DeepTileContractionNamedOp.cpp | 477 +++++++------- lib/gc/Transforms/MergeNestedForall.cpp | 14 +- lib/gc/Transforms/Pipeline.cpp | 14 +- lib/gc/Transforms/SinkOpIntoInnerLoop.cpp | 2 +- lib/gc/Transforms/Tiling.hpp | 55 -- .../Transforms/{Tiling.cpp => TilingUtil.cpp} | 26 +- lib/gc/Transforms/TilingUtil.hpp | 26 + .../deepTileContractionNamedOp.mlir | 107 +++- .../test/gc/Transforms/mergeNestedForall.mlir | 93 +++ .../gc/Transforms/sinkOpIntoInnerLoop.mlir | 46 ++ 17 files changed, 1349 insertions(+), 1289 deletions(-) create mode 100644 docs/deep_tile_matmul_design.md delete mode 100644 include/gc/Dialect/Arith/Utils/EasyBuild.h delete mode 100644 include/gc/IR/EasyBuild.h delete mode 100644 include/gc/IR/EasyBuildSCF.h delete mode 100644 lib/gc/Transforms/Tiling.hpp rename lib/gc/Transforms/{Tiling.cpp => TilingUtil.cpp} (96%) create mode 100644 lib/gc/Transforms/TilingUtil.hpp create mode 100644 test/mlir/test/gc/Transforms/mergeNestedForall.mlir create mode 100644 test/mlir/test/gc/Transforms/sinkOpIntoInnerLoop.mlir diff --git a/docs/deep_tile_matmul_design.md b/docs/deep_tile_matmul_design.md new file mode 100644 index 000000000..4e48d82de --- /dev/null +++ b/docs/deep_tile_matmul_design.md @@ -0,0 +1,594 @@ +# DOC: deep-Tiled Matmul + +## Introduction + +Tiling and parallelization are important for the performance of a computation intensitive workload (matmul, convolution, and e.t.c). Modern hardware is often equipped with multiple cores and multiple levels of cache, each with different characteristics in terms of size, latency, and bandwidth. To achieve good performance, it is important to utilize the parallelism of the underlying hardware and minimize the number of cache misses to improve the performance of the generated code. The goal of this document is to provide a design overview of the deep-tiled matmul in the graph compiler and its current situation in the community. + +## Current Situation in the MLIR Community + +According to the last section, tiling and parallelization are two important optimization techniques used in compilers to improve the performance of the generated code(matmul, convolution, and e.t.c). The code template could allow some complex optimization(some nontrivial memory copy/reuse to maximize the hardware efficiency), which is hard to write a unified pass in the compiler. + +In the upstream MLIR, there is already some support for tiling and parallelization optimization. The `Linalg` dialect provides a tiling interface to support tiling optimization. Besides, for better representing the concept of schedule, it also introduces the `Transform` dialect to declare the `schedule` in an IR form(vertical to the `payload`). + +This section will introduce the current situation in the MLIR community about the tiling interface, `Transform` dialect, hardware abstration layer and what is missing in the current upstream MLIR. + +### Tiling Interface And the Related Pass + +The MLIR provides the tiling interface as follows to support some simple tiling optimization. + +The tiling interface is a set of methods that an operation can implement to provide information about its iteration space and how it can be tiled. The tiling interface is used by the tiling pass to generate a tiled implementation of the operation. It could easily transform the operation like: + +```MLIR +%0 = linalg.generic ins(%in) outs(%out) {indexing_maps = [affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + : tensor -> tensor +``` + +into: + +```MLIR +%1 = scf.for %iv = %c0 to %dim_0 step %c4 iter_args(%arg3 = %out) -> (tensor) { + %2 = tensor.extract_slice %in[%iv] [%c4] [1] : tensor to tensor + %3 = tensor.extract_slice %out[%iv] [%c4] [1] : tensor to tensor + %4 = linalg.generic ins(%2) outs(%3) ["parallel"] : tensor -> tensor + %5 = tensor.insert_slice %4, %arg3 : tensor + scf.yield %5 +} +``` + +The tiling interface further provides several functions like `tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)` to support tile an op inherited the tiling interface, where the SCFTilingOption contains the loop type(scf::For or scf::Forall), interchange vector, mapping vector, and tile size. Through this function, the user could easily generate a tiled implementation of the operation on the parallel axis. + +```c++ +class TilingInterface : public ::mlir::OpInterface { +public: + using ::mlir::OpInterface::OpInterface; + template + struct Trait : public detail::TilingInterfaceTrait {}; + /// Returns a list of iterator types that describe the number of loops. + SmallVector getLoopIteratorTypes(); + /// Returns a list of ranges that describe the loop bounds and + /// step for the loops of the operation. + SmallVector getIterationDomain(OpBuilder & b); + /// Method to generate the tiled implementation of an operation. + /// + /// The iteration space of the operation is returned by + /// `getIterationDomain`. The caller provides the information of the + /// tile within this iteration space whose implementation the + /// caller needs. + /// - `offsets` provides the offset of the tile in the coordinate system + /// of the original iteration space, i.e., if an iteration space + /// dimension had non-zero offset, it must be included in the offset + /// provided here (as opposed to zero-based offset "relative" to the + /// iteration space). + /// - `sizes` provides the size of the tile. + /// + /// The method returns the operation that is the tiled + /// implementation. + FailureOr getTiledImplementation(OpBuilder & b, ArrayRef offsets, ArrayRef sizes); + /// Method to return the position of the result tile computed by the tiled operation. + /// + /// Specifies what tile of the result of the original tensor is computed + /// by the tiled implementation. Expects the same `offsets` and `sizes` as + /// used to obtain the tiled implementation of the operation. + LogicalResult getResultTilePosition(OpBuilder & b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes, SmallVector & resultOffsets, SmallVector & resultSizes); + /// Method to generate the code that produces a tile of the result. + /// + /// Generates the IR that computes the tile of a result of the + /// operation. The `offsets` and `sizes` describe the tile of + /// the output required. This is different from + /// `getTiledImplementation` which generates the tiled + /// implementation of the operation given a tile of the + /// iteration space. This method generates a tiled + /// implementation of the operation based on the tile of the + /// result required. This method enables fusion by using tile + /// and fuse. The method returns failure if the operation can't be + /// tiled to generate the result tile. In practical terms this + /// implies it cannot be tiled and fused with its consumers. + /// + /// - `offsets` provides the offset of the tile in the coordinate system + /// of the original iteration space, i.e., if an iteration space + /// dimension had non-zero offset, it must be included in the offset + /// provided here (as opposed to zero-based offset "relative" to the + /// iteration space). + /// - `sizes` provides the size of the tile. + FailureOr generateResultTileValue(OpBuilder & b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes); + /// Generates the scalar implementation of the operation. + /// + /// Given the list `ivs` that represent points in the iteration space + /// (as specified by `getIterationDomain()`) returns the scalar operations + /// that represent the computation at that point in the iteration space. + /// This method is typically used as the "exit path", i.e. once all + /// transformations are done, this method can be used to lower to scalar + /// code that can then be lowered to LLVM or SPIR-V dialects. + LogicalResult generateScalarImplementation(OpBuilder & b, Location loc, ValueRange ivs); +}; + +struct SCFTilingOptions { + /// Computation function that returns the tile sizes for each operation. + /// Delayed construction of constant tile sizes should occur to interoperate + /// with folding. + SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr; + + /// The interchange vector to reorder the tiled loops. + SmallVector interchangeVector = {}; + + /// Specify which loop construct to use for tile and fuse. + enum class LoopType { ForOp, ForallOp }; + LoopType loopType = LoopType::ForOp; + + /// Specify mapping of loops to devices. This is only respected when the loop + /// constructs support such a mapping (like `scf.forall`). Will be ignored + /// when using loop constructs that dont support such a mapping (like + /// `scf.for`) + SmallVector mappingVector = {}; +}; +FailureOr tileUsingSCF(RewriterBase &rewriter, + TilingInterface op, + const SCFTilingOptions &options); + +/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying +/// tiling by `numThreads`. +/// If non-empty, the `mapping` is added as an attribute to the +/// resulting `scf.forall`. +/// Zero tile sizes indicate that the dimension is not tiled, and can be +/// thought of as tiling by the full size of data. It is the user's +/// responsibility to ensure that `numThreads` is a valid tiling specification +/// (i.e. that only tiles parallel dimensions, e.g. in the Linalg case). +struct ForallTilingResult { + Operation *tileOp; + Operation *tiledOp; +}; +FailureOr tileToForallOp(RewriterBase &builder, + TilingInterface op, + ArrayRef numThreads, + std::optional mapping); + +/// Same as `tileToForallOp`, but calculate the number of threads +/// required using the given tileSizes. +FailureOr +tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op, + ArrayRef tileSizes, + std::optional mapping); +``` + +The above tiling interface only supports the tiling on the parallel axis. But in a workload like matmul, it is often required to do a tiling on the reduction axis for better performance considering the size of available memory/cache, computation intensity, cache communication, etc. + +```MLIR +%red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg9: f32): + %1 = arith.addf %arg7, %arg9 : f32 + linalg.yield %1 : f32 + } -> tensor +``` + +into: + +```MLIR +%0 = tensor.empty(%dim_1) : tensor +%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor +%2 = scf.for %arg2 = %c0 to %dim_0 step %c5 iter_args(%arg3 = %1) -> (tensor) { + %extracted_slice = tensor.extract_slice %1[0, 0] [%dim, 5] [1, 1] : tensor to tensor + %extracted_slice_2 = tensor.extract_slice %arg0[0, %arg2] [%dim, 5] [1, 1] : tensor to tensor + %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%extracted_slice_2 : tensor) + outs(%extracted_slice : tensor) { + ^bb0(%in: f32, %out: f32): + %5 = arith.addf %in, %out : f32 + linalg.yield %5 : f32 + } -> tensor + %dim_3 = tensor.dim %1, %c0 : tensor + %inserted_slice = tensor.insert_slice %4 into %arg3[0, 0] [%dim_3, 5] [1, 1] : tensor into tensor + scf.yield %inserted_slice : tensor +} +%3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%2 : tensor) + outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: f32): + %4 = arith.addf %in, %out : f32 + linalg.yield %4 : f32 + } -> tensor +``` + +To support this kind of tiling, the MLIR also provide a `PartialReductionOpInterface` based on TilingInterface. The `PartialReductionOpInterface` is an interface with a set of methods that provide information about its partial reduction and how it can be tiled. Based on the `PartialReductionOpInterface`, it further provides a function `tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef tileSize)` to support tile an op inherited the `PartialReductionOpInterface`, where the `tileSize` is the tile size for the reduction axis. + +```c++ +class PartialReductionOpInterface : public ::mlir::OpInterface { +public: + using ::mlir::OpInterface::OpInterface; + template + struct Trait : public detail::PartialReductionOpInterfaceTrait {}; + /// Method to generate a tensor initalized with the identity value of the + /// operation reduction. The tensor shape is equal to operation result + /// shape with new dimension for each non zero tile size. + FailureOr generateInitialTensorForPartialReduction(OpBuilder & b, Location loc, ArrayRef sizes, ArrayRef reductionDim); + /// Method to generate a tiled version of the operation where the tiled + /// reduction dimension are converted to parallel dimensions with a size + /// less or equal to the tile size. This is meant to be used with + /// `mergeReductions` method which will combine the partial reductions. + Operation*tileToPartialReduction(OpBuilder & b, Location loc, ValueRange init, ArrayRef offsets, ArrayRef sizes, ArrayRef reductionDims); + /// Method to merge partial reductions for an operation that has been + /// tiled along the reduction dimensions. This will only apply the + /// reduction the operation. + Operation*mergeReductions(OpBuilder & b, Location loc, ValueRange partialReduce, ArrayRef reductionDim); + +/// Method to tile a reduction and generate a parallel op within a serial loop. +/// Each of the partial reductions are calculated in parallel. Then after the +/// loop all the partial reduction are merged into a final reduction. +/// For example for the following sequence +/// +/// ```mlir +/// %0 = linalg.generic %in ["parallel", "reduction"] +/// : tensor<7x9xf32> -> tensor<7xf32> +/// ``` +/// +/// into: +/// +/// ```mlir +/// %0 = linalg.fill ... : tensor<7x4xf32> +/// %1 = scf.for ... iter_args(%arg0 = %0) +/// %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32> +/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32> +/// %4 = linalg.generic %2, %3 ["parallel", "parallel"] +/// : tensor<7x?xf32> -> tensor<7x?xf32> +/// %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32> +/// } +/// %6 = linalg.generic %1 ["parallel", "reduction"] +/// : tensor<7x4xf32> -> tensor<7xf32> +/// ``` +FailureOr +tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, + ArrayRef tileSize); + +/// Method to tile a reduction to parallel iterations computing partial +/// reductions. After the loop all the partial reduction are merged into a final +/// reduction. For example for the following sequence +/// +/// ```mlir +/// %0 = linalg.generic %in ["parallel", "reduction"] +/// : tensor<7x9xf32> -> tensor<7xf32> +/// ``` +/// +/// into: +/// +/// ```mlir +/// %0 = linalg.fill ... : tensor<7x4xf32> +/// %1 = scf.forall (%iv) in (%c4) shared_outs(%arg0 = %0) +/// -> (tensor<7x4xf32>) { +/// %2 = tensor.extract_slice %arg3 : tensor<7x4xf32> to tensor<7xf32> +/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32> +/// %4 = linalg.generic %2, %3 ["parallel", "reduction"] +/// : tensor<7x?xf32> -> tensor<7xf32> +/// %5 = tensor.insert_slice %3, %arg0[0, %iv] : tensor<7x4xf32> +/// } +/// %6 = linalg.generic %1 ["parallel", "reduction"] +/// : tensor<7x +FailureOr +tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, + ArrayRef numThreads, + ArrayRef tileSizes = {}, + std::optional mapping = std::nullopt); +}; +``` + +### Hardware Abstraction Layer(HAL) + +To achieve the best performance, a good schedule requires the hardware information as a reference. Hardware information like cache size, thread number, etc. is often needed to generate the best schedule. Hardware Abstraction Layer(HAL) is a layer of software that provides a hardware-independent interface to the underlying hardware. The mainstream dl compiler or performance library has a way to get the hardware information to guide the schedule like [IREE](https://iree.dev/developers/design-docs/design-roadmap/#hal-hardware-abstraction-layer-and-multi-architecture-executables), [TVM](https://tvm.apache.org/docs/arch/device_target_interactions.html#tvm-target-specific-overview), [onednn](https://github.com/oneapi-src/oneDNN), etc. However, the MLIR doesn't have such a hardware abstraction layer(HAL) to provide the hardware information. + +## Deep-Tiled Matmul Introduction + +This section will introduce the concept of the deep-tiled matmul optimization(nested matmul/managed_matmul in graph compiler v1) and how it could improve the performance. + +Deep-tiled matmul originally is a [matmul code template](https://github.com/oneapi-src/oneDNN/blob/main/src/graph/backend/graph_compiler/core/src/ops/templates/managed_matmul_core.cpp) in the [onednn graph compiler v1](https://arxiv.org/ftp/arxiv/papers/2301/2301.01333.pdf) with well-tuned default parameters to deliver good performance in the e2e model. The basic idea of the deep-tiled matmul is to partition the iteration space of the matmul into 9 loops as the pseudocode shown below. The outermost 3 loops(`Mthreads, NThreads, KThreads`) are used to partition the iteration space of the matmul according to the number of threads, which is used to balance the workload distribution among the threads and minimize the cache synchronization/communication overhead. The middle 3 loops(`MBlock, NBlock, KBlock`) are used to partition the iteration space of the matmul and control the loop order according to the L2 cache size in the CPU, which is used to improve the data locality of the generated code. The innermost 3 loops(`innermostMBlock, innermostNBlock, innermostKBlock`) are used to partition the iteration space of the matmul and control the loop order according to the L1 cache size in CPU, which could further improve the data locality of the generated code. At this level, the matmul will be converted to the micro-kernel call [*brgemm*](https://arxiv.org/pdf/2104.05755.pdf) which is a highly optimized vectorized kernel(appling the optimiztion like unroll, operation interleave, prefetch, nt load/store, particularly tuned memory accessing pattern, carefully handcrafted register allocation). Though the tiling strategy above is based on the CPU model, it could be easily extended to the concept of the other hardware like GPU, FPGA, etc.(`global/shared memory`, `L1/2 cache size`, `execution model(threads, warp, block, grid, etc)`, etc) + +```c++ +parameter M, N, K, MBlock, NBlock, KBlock, MThreads, NThreads, KThreads, innermostMBlock, innermostNBlock, innermostKBlock +tensor A, B, C +tempC = create_tensor for C -> tensor([KThreads, M, N]) +parallel_for([PM, PN, PK]: [MThreads, NThreads, KThreads]) { + ASlice = extract_slice from A -> tensor([MOuterBlock, KOuterBlock]) + BSlice = extract_slice from B -> tensor([KOuterBlock, NOuterBlock]) + CSlice = extract_slice from C -> tensor([MOuterBlock, NOuterBlock]) + MNumBlock = MOuterBlock / MBlock + NNumBlock = NOuterBlock / NBlock + KNumBlock = KOuterBlock / KBlovk + for([om, on, ok]: [MNumBlock, NNumBlock, KNumBlock]) { + ASlice2 = extract_slice from ASlice -> tensor([MBlock, KBlock]) + BSlice2 = extract_slice from BSlice -> tensor([KBlock, NBlock]) + CSlice2 = extract_slice from CSlice -> tensor([1, MBlock, NBlock]) + MNumInnerBlock = MBlock / innermostMBlock + NNumInnerBlock = NBlock / innermostNBlock + KNumInnerBlock = KBlock / innermostKBlock + for([im, in]: [MNumInnerBlock, NNumInnerBlock]) { + ASlice3 = extract_slice from ASlice2 -> tensor([innermostMBlock, KBlock]) + BSlice3 = extract_slice from BSlice2 -> tensor([KBlock, innermostNBlock]) + CSlice3 = extract_slice from CSlice2 -> tensor([innermostMBlock, innermostNBlock]) + if(ok == 0) { + init CSlice3 with 0 (could use init_brgemm when it is avaliable) + } + brgemm(bs=KNumInnerBlock, M=innermostMBlock, N=innermostNBlock, K=innermostKBlock, +A=ASlice3, B=BSlice3, C=CSlice4, onlyUpdate=(ok!=0)); + } + } +} +C = final_reduce(tempC) -> [M, N] +``` + +## Proposal + +This section will present a proposal based on the [Option 4](#option-4---outer-loop-based-on-tiling-interface--inner-loop-through-a-predefined-template-with-ir-builder) above to implement the deep-tiled matmul in the graph compiler v2. According to the discussion above, option 4 could deliver high performance and maximally reuse the current existing work MLIR, which minimizes the difficulty of acceptance by the community. In the meantime, future optimizations like `loop reorder`, and `axis split` could be easily added by changing the parameter. So this is the recommended way in this document and the detail will be introduced in the following. + +### Position + +> The transformation control infrastructure provided by this dialect is positioned roughly between rewrite patterns and passes. A transformation that is executed by a transform operation is likely to be sufficiently complex to require at least a set of patterns to be implemented. It is also expected to be more focused than a pass: a pass typically applies identical transformations everywhere in the IR, a transform dialect-controlled transformation would apply to a small subset of operations selected, e.g., by a pattern-matching operation or generated by a previous transformation. It is discouraged, although technically possible, to run a pass pipeline as part of the transform op implementation. *From [MLIR documentation](https://mlir.llvm.org/docs/Dialects/Transform/)* + +As MLIR mentions in the documentation, the scope order from largest to smallest is `pass > Transform dialect > rewrite patterns`. The deep-tiled matmul only applies to the operation `matmul` and `batch_matmul`. So it is better to implement it as a rewrite pattern. To better meet the upstream's need, it could be warped into an operation of the `Transform` dialect so that it could become a part of the `Transform` schedule. + +In the graph compiler v2, this could be further warped in a pass `deepTilingRewriteForContractionOperation`, which could also contain other deep-tiling rewrite patterns in the future(`paddedConvolution`, `reduceLoweringConvolution`, `depthwiseConvolution`, etc). This pass is expected to be executed after the `padding/layout propagation`-related pass and before the `fusion` '-related pass. `Layout` related pass could convert the input/output tensor to the required blocked layout to achieve better performance. And fusion-related pass may depend on the tiled matmul's `insert_slice/extract_slice` as the anchor to do fusion. + +```MLIR +... +layout propogation related pass(pack/unpack, pad, propogation, etc) + +deepTilingRewriteForContractionOperation(deep-tiled matmul, deep-tiled padded conv, conv1x1, depthwise conv, etc) + +fusion related pass +... +``` + +Besides, this rewrite pattern is expected to be a part of the linalg dialect. This is similar to the existing rewrite `ConvertConv2DToImg2Col` in MLIR. In graph compiler v2, it could be a part of `linalgX` before upstream. + +### Outer Loop Generation + +For outer loop generation, we will generate the loop step by step according to the parameters/config(`outermost loop for multicore -> loop for L2 cache -> loop for L1 cache`). This part would be implemented based on the tiling interface and its related utility function, which could maximally reuse the existing work in the MLIR and decrease the difficulty of the maintenance. Besides, function like `tileToForallOp` provides an `interchange` parameter which makes it easy to change the loop order according to the workload characteristics. This way could be also easily reused by other operations like `convolution`, `depthwise convolution`, etc because they have a similar structure in this part. + +The expected implementation in pseudocode code is as follows + +```c++ +// generate outer loop with MThreads, NThreads +linalg::tileToForallOp(rewriter, cast(matmul), {MThreads, NThreads}); +// generate outer reduction loop with KThreads +linalg::tileReductionUsingForall(rewriter, cast(matmul), KThreads, tileSizes); +// generate the middle three loops(MBlock, NBlock, KBlock) +scf::tileUsingSCF(rewriter, cast(matmul),tileOption); +// generate the inner loops(innerMostMBlock, innerMostNBlock, innerMostKBlock) +scf::tileUsingSCF(rewriter, cast(matmul),tileOption); +``` + +As mentioned in the [Current Situation in the MLIR Community](#current-situation-in-the-mlir-community), there are still some missing things in the current MLIR like the lack of balance211 for not perfectly divisible cases, inefficient partial K threads position for cpu, etc. These should be further enhanced in future work. + +### Inner Loop Body Generation + +Compared to outer loop generation, the inner loop body generation is sometimes op-specific. For example, the `squeeze stride` optimization for convolution doesn't make any sense for `matmul`. Besides, this part is possibly more complex than the outer-loop(may have tail processing, non-trivial memory copy/init) and hard to unify a pass to do it. So it is better to implement it as a predefined template through IR builder which could make the code more flexible. We could also add easy builder/util support to make it more readable. + +Below is the expected pseudocode of the inner loop body for the deep-tiled matmul in the graph compiler v2. + +```c++ +A = tensor.extract_slice +B = tensor.extract_slice +C = tensor.extract_slice +D3 = scf.if(ok == 0) { + D1 = init_brgemm(A,B,C) tensor<...>, tensor<...>, tensor<...> -> tensor<...> +} else { + D2 = brgemm(A,B,C) tensor<...>, tensor<...>, tensor<...> -> tensor<...> +} -> tensor<...> +tensor.insert_slice D3 +``` + +The inner loop body will convert the `matmul` to the `batch_reduce_gemm`, which will be finally converted to the microkernel [`brgemm`](https://github.com/oneapi-src/oneDNN/pull/1852) call. + +### Config/Schedule + +```c++ +struct MatmulConfig { + int MThreads, NThreads, KThreads; + int MBlock, NBlock, KBlock; + int innerMostMBlock, innerMostNBlock, innerMostKBlock; + int loopOrder; +}; +``` + +The above is the expected config for the deep-tiled matmul. The `MThreads, NThreads, KThreads` is used to partition the iteration space of the matmul according to the number of threads. The `MBlock, NBlock, KBlock` is used to partition the iteration space of the matmul and control the loop order according to the L2 cache size in the CPU. The `innerMostMBlock, innerMostNBlock, innerMostKBlock` is used to partition the iteration space of the matmul and control the loop order according to the L1 cache size in the CPU. The `loopOrder` is used to control the loop order/iterate order according to the workload characteristics. + +A default heuristic config corresponding to these items will be tuned for the performance. + +1. For `MThreads, NThreads, KThreads`, we should rely on the available threads, required memory for the input/output/temp buffer, and the L2/L3 cache size to build a cost model that maximizes the workload balance, threads utilization and minimize the cache synchronization. But the threads on the K axis should be set carefully as it may hurt performance in most cases (performance gain on large K but small M, N). +2. For `MBlock, NBlock, KBlock`, the L2 cache size and the required memory for every core are needed to build a cost model so that the L2 cache misses would be minimized. +3. For `innerMostMBlock, innerMostNBlock, innerMostKBlock`, we need to know the L1 cache size, the size of available registers and vector/matrix-vector (amx-like) length to decide the innermost block size so that the hardware efficiency can be maximized. Besides, if we convert the brgemm to an external library function call, the cost of the function call is also needed to be considered. In the case that M/N/K is not divisible by vector length, we usually will choose a factor of the M/N/K as the innermost block size or do the packing/unpacking to make it divisible in advance(a tradeoff between reducing memory copy and maximize hardware efficiency). +4. The `loopOrder` is mainly related to the workload characteristics(data, weight, output size), the cache size and where the actual data/weight is located at L1/L2/L3/memory. This will have an impact on the visit order of the memory and finally impact the cache data locality. + +The description above shows what should be considered from the horizontal view(`[M/N/K]threads`, `[M/N/K]block`, `innermost[M/N/K]Block`, `loop order`) in the config. However, in the vertical view(`MThreads, MBlock, innermostMBlock`, `N...`, `K...`), they will have some interdependence that will also impact the performance, and the order to decide them will matter. The breakdown of how to decide is as follows. + +1. Firstly, we need to decide the `innerMostBlock[M/N/K]` which will impact the maximum hardware efficiency we can achieve, especially for the machine with a specialized matrix computation unit(amx-like). For example, if the physical matrix vector size is 16x64 and we choose the innermost block size as 8x32, then the theoretical efficiency will be a quarter of the maximum. Even for the vector instruction set like `avx512, avx2, etc`, the `innermostBlock` still matters because they still require the `innermostBlock` to align with the vector length(64/32/...). So the priority of the `innermostBlock` is the highest. +2. After the `innermostBlock` is decided, the input and output matrix will be divided into `[M/N/K]NumBlock` blocks with block size `[M/N/K]innermostBlock`. Then we will decide what `[M/N/K]Threads` should use to distribute these blocks so that the best workload balance, compute intensity and cache utilization can be achieved. +3. After step 2, the number of innermost blocks for every thread has been decided. Then we will decide the `[M/N/K]Block` to further partition the iteration space of the matmul so that the L2 cache misses in a single core would be minimized. This should be the multiples of the `innermost[M/N/K]Block`. +4. After above steps, all tile size is decided and we have enough infomation about where the data is located(L1/L2/L3 and their size). The `loopOrder` could be decided to maximize the data locality/data reuse. What it decides is the order of these loops(`pmpnpkomonokiminik`, `pnpmpkokonominimik`, etc where `p` is the outermost parallel loop, `o` is the middle outer loop, `i` is the innermost loop). + +**Note**: In the graph compiler v1, we also consider the impact of the previous matmul as this will decide where the output of the previous matmul is located (3rd core's l2 cache or 4th core's). This could be also further enhanced in the future. + +The heuristic default config will be implemented as an [analysis pass](https://mlir.llvm.org/docs/PassManagement/#analysis-management). In this way, the heuristic is maximally isolated from the real IR transformation and easier to be accepted by the upstream community(who want to separate the heuristics from passes as much as possible). By the way, other passes like layout/padding propagation could also know which tile size is preferable by the matmul and will not have a dependence cycle among these passes. + +All choices above need to be under the guidance of HAL. But the HAL support(multi-level cache size, machine kind, available threads, register vector length) is not fully ready in the MLIR now. So there is a risk here to tune a good performance for general. + +### Expected IR Change + +Below is a matmul example(`M=256, K=128, N=512`) of the expected IR change after applying the deep-tiled matmul rewrite pattern(with config `MThreads=2, NThreads=2, KThreads=1, MBlock=128, NBlock=256, KBlock=128, innerMostMBlock=32, innerMostKBlock=32, loopOrder=0`). + +```MLIR +%0 = linalg.matmul ins(%cst_0, %cst_1 : tensor<256x128xf32>, tensor<128x512xf32>) outs(%cst_2 : tensor<256x512xf32>) -> tensor<256x512xf32> +``` + +into: + +```MLIR +%0 = scf.forall (%arg0, %arg1) in (2, 2) shared_outs(%arg2 = %cst_3) -> (tensor<256x512xf32>) { + %1 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %2 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + %3 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %4 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + %5 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %6 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + %extracted_slice = tensor.extract_slice %cst_1[%3, 0] [128, 128] [1, 1] : tensor<256x128xf32> to tensor<128x128xf32> + %extracted_slice_5 = tensor.extract_slice %cst_2[0, %4] [128, 256] [1, 1] : tensor<128x512xf32> to tensor<128x256xf32> + %extracted_slice_6 = tensor.extract_slice %arg2[%5, %6] [128, 256] [1, 1] : tensor<256x512xf32> to tensor<128x256xf32> + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c128_7 = arith.constant 128 : index + %7 = scf.for %arg3 = %c0 to %c128 step %c128_7 iter_args(%arg4 = %extracted_slice_6) -> (tensor<128x256xf32>) { + %c0_8 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c256_9 = arith.constant 256 : index + %10 = scf.for %arg5 = %c0_8 to %c256 step %c256_9 iter_args(%arg6 = %arg4) -> (tensor<128x256xf32>) { + %c0_10 = arith.constant 0 : index + %c128_11 = arith.constant 128 : index + %c128_12 = arith.constant 128 : index + %11 = scf.for %arg7 = %c0_10 to %c128_11 step %c128_12 iter_args(%arg8 = %arg6) -> (tensor<128x256xf32>) { + %extracted_slice_13 = tensor.extract_slice %extracted_slice[%arg3, %arg7] [128, 128] [1, 1] : tensor<128x128xf32> to tensor<128x128xf32> + %extracted_slice_14 = tensor.extract_slice %extracted_slice_5[%arg7, %arg5] [128, 256] [1, 1] : tensor<128x256xf32> to tensor<128x256xf32> + %extracted_slice_15 = tensor.extract_slice %arg8[%arg3, %arg5] [128, 256] [1, 1] : tensor<128x256xf32> to tensor<128x256xf32> + %c0_16 = arith.constant 0 : index + %c128_17 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %12 = scf.for %arg9 = %c0_16 to %c128_17 step %c32 iter_args(%arg10 = %extracted_slice_15) -> (tensor<128x256xf32>) { + %c0_18 = arith.constant 0 : index + %c256_19 = arith.constant 256 : index + %c32_20 = arith.constant 32 : index + %13 = scf.for %arg11 = %c0_18 to %c256_19 step %c32_20 iter_args(%arg12 = %arg10) -> (tensor<128x256xf32>) { + %c0_21 = arith.constant 0 : index + %c128_22 = arith.constant 128 : index + %c128_23 = arith.constant 128 : index + %14 = scf.for %arg13 = %c0_21 to %c128_22 step %c128_23 iter_args(%arg14 = %arg12) -> (tensor<128x256xf32>) { + %extracted_slice_24 = tensor.extract_slice %extracted_slice_13[%arg9, %arg13] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32> + %extracted_slice_25 = tensor.extract_slice %extracted_slice_14[%arg13, %arg11] [128, 32] [1, 1] : tensor<128x256xf32> to tensor<128x32xf32> + %extracted_slice_26 = tensor.extract_slice %arg14[%arg9, %arg11] [32, 32] [1, 1] : tensor<128x256xf32> to tensor<32x32xf32> + %expanded = tensor.expand_shape %extracted_slice_24 [[0, 1], [2]] : tensor<32x128xf32> into tensor<1x32x128xf32> + %expanded_27 = tensor.expand_shape %extracted_slice_25 [[0, 1], [2]] : tensor<128x32xf32> into tensor<1x128x32xf32> + %15 = linalg.batch_reduce_matmul ins(%expanded, %expanded_27 : tensor<1x32x128xf32>, tensor<1x128x32xf32>) outs(%extracted_slice_26 : tensor<32x32xf32>) -> tensor<32x32xf32> + %inserted_slice_28 = tensor.insert_slice %15 into %arg14[%arg9, %arg11] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<128x256xf32> + scf.yield %inserted_slice_28 : tensor<128x256xf32> + } + scf.yield %14 : tensor<128x256xf32> + } + scf.yield %13 : tensor<128x256xf32> + } + %inserted_slice = tensor.insert_slice %12 into %arg8[%arg3, %arg5] [128, 256] [1, 1] : tensor<128x256xf32> into tensor<128x256xf32> + scf.yield %inserted_slice : tensor<128x256xf32> + } + scf.yield %11 : tensor<128x256xf32> + } + scf.yield %10 : tensor<128x256xf32> + } + %8 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %9 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + scf.forall.in_parallel { + tensor.parallel_insert_slice %7 into %arg2[%8, %9] [128, 256] [1, 1] : tensor<128x256xf32> into tensor<256x512xf32> + } +} +``` + +When the `KThreads=2`, there will be partial reduction in the loop + +```MLIR +%0 = scf.forall (%arg0, %arg1) in (2, 2) shared_outs(%arg2 = %cst_3) -> (tensor<256x512xf32>) { + %1 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %2 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + %3 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %4 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + %5 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %6 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + %extracted_slice = tensor.extract_slice %cst_1[%3, 0] [128, 128] [1, 1] : tensor<256x128xf32> to tensor<128x128xf32> + %extracted_slice_5 = tensor.extract_slice %cst_2[0, %4] [128, 256] [1, 1] : tensor<128x512xf32> to tensor<128x256xf32> + %extracted_slice_6 = tensor.extract_slice %arg2[%5, %6] [128, 256] [1, 1] : tensor<256x512xf32> to tensor<128x256xf32> + %c0 = arith.constant 0 : index + %c0_7 = arith.constant 0 : index + %c2_8 = arith.constant 2 : index + %7 = tensor.empty() : tensor<128x256x2xf32> + %cst_9 = arith.constant 0.000000e+00 : f32 + %8 = linalg.fill ins(%cst_9 : f32) outs(%7 : tensor<128x256x2xf32>) -> tensor<128x256x2xf32> + %c2_10 = arith.constant 2 : index + %9 = scf.forall (%arg3) in (2) shared_outs(%arg4 = %8) -> (tensor<128x256x2xf32>) { + %13 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3) + %extracted_slice_11 = tensor.extract_slice %arg4[0, 0, %arg3] [128, 256, 1] [1, 1, 1] : tensor<128x256x2xf32> to tensor<128x256xf32> + %c0_12 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c128_13 = arith.constant 128 : index + %14 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%arg3, %c128_13] + %15 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%14, %c0_12] + %16 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%c2_10, %c128_13] + %17 = scf.for %arg5 = %15 to %c128 step %16 iter_args(%arg6 = %extracted_slice_11) -> (tensor<128x256xf32>) { + %extracted_slice_14 = tensor.extract_slice %extracted_slice[0, %arg5] [128, 128] [1, 1] : tensor<128x128xf32> to tensor<128x128xf32> + %extracted_slice_15 = tensor.extract_slice %extracted_slice_5[%arg5, 0] [128, 256] [1, 1] : tensor<128x256xf32> to tensor<128x256xf32> + %extracted_slice_16 = tensor.extract_slice %arg6[0, 0] [128, 256] [1, 1] : tensor<128x256xf32> to tensor<128x256xf32> + %c0_17 = arith.constant 0 : index + %c128_18 = arith.constant 128 : index + %c128_19 = arith.constant 128 : index + %18 = scf.for %arg7 = %c0_17 to %c128_18 step %c128_19 iter_args(%arg8 = %extracted_slice_16) -> (tensor<128x256xf32>) { + %c0_20 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c256_21 = arith.constant 256 : index + %19 = scf.for %arg9 = %c0_20 to %c256 step %c256_21 iter_args(%arg10 = %arg8) -> (tensor<128x256xf32>) { + %c0_22 = arith.constant 0 : index + %c128_23 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %20 = scf.for %arg11 = %c0_22 to %c128_23 step %c64 iter_args(%arg12 = %arg10) -> (tensor<128x256xf32>) { + %extracted_slice_24 = tensor.extract_slice %extracted_slice_14[%arg7, %arg11] [128, 64] [1, 1] : tensor<128x128xf32> to tensor<128x64xf32> + %extracted_slice_25 = tensor.extract_slice %extracted_slice_15[%arg11, %arg9] [64, 256] [1, 1] : tensor<128x256xf32> to tensor<64x256xf32> + %extracted_slice_26 = tensor.extract_slice %arg12[%arg7, %arg9] [128, 256] [1, 1] : tensor<128x256xf32> to tensor<128x256xf32> + %c0_27 = arith.constant 0 : index + %c128_28 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %21 = scf.for %arg13 = %c0_27 to %c128_28 step %c32 iter_args(%arg14 = %extracted_slice_26) -> (tensor<128x256xf32>) { + %c0_30 = arith.constant 0 : index + %c256_31 = arith.constant 256 : index + %c32_32 = arith.constant 32 : index + %22 = scf.for %arg15 = %c0_30 to %c256_31 step %c32_32 iter_args(%arg16 = %arg14) -> (tensor<128x256xf32>) { + %c0_33 = arith.constant 0 : index + %c64_34 = arith.constant 64 : index + %c64_35 = arith.constant 64 : index + %23 = scf.for %arg17 = %c0_33 to %c64_34 step %c64_35 iter_args(%arg18 = %arg16) -> (tensor<128x256xf32>) { + %extracted_slice_36 = tensor.extract_slice %extracted_slice_24[%arg13, %arg17] [32, 64] [1, 1] : tensor<128x64xf32> to tensor<32x64xf32> + %extracted_slice_37 = tensor.extract_slice %extracted_slice_25[%arg17, %arg15] [64, 32] [1, 1] : tensor<64x256xf32> to tensor<64x32xf32> + %extracted_slice_38 = tensor.extract_slice %arg18[%arg13, %arg15] [32, 32] [1, 1] : tensor<128x256xf32> to tensor<32x32xf32> + %expanded = tensor.expand_shape %extracted_slice_36 [[0, 1], [2]] : tensor<32x64xf32> into tensor<1x32x64xf32> + %expanded_39 = tensor.expand_shape %extracted_slice_37 [[0, 1], [2]] : tensor<64x32xf32> into tensor<1x64x32xf32> + %24 = linalg.batch_reduce_matmul ins(%expanded, %expanded_39 : tensor<1x32x64xf32>, tensor<1x64x32xf32>) outs(%extracted_slice_38 : tensor<32x32xf32>) -> tensor<32x32xf32> + %inserted_slice_40 = tensor.insert_slice %24 into %arg18[%arg13, %arg15] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<128x256xf32> + scf.yield %inserted_slice_40 : tensor<128x256xf32> + } + scf.yield %23 : tensor<128x256xf32> + } + scf.yield %22 : tensor<128x256xf32> + } + %inserted_slice_29 = tensor.insert_slice %21 into %arg12[%arg7, %arg9] [128, 256] [1, 1] : tensor<128x256xf32> into tensor<128x256xf32> + scf.yield %inserted_slice_29 : tensor<128x256xf32> + } + scf.yield %20 : tensor<128x256xf32> + } + scf.yield %19 : tensor<128x256xf32> + } + %inserted_slice = tensor.insert_slice %18 into %arg6[0, 0] [128, 256] [1, 1] : tensor<128x256xf32> into tensor<128x256xf32> + scf.yield %inserted_slice : tensor<128x256xf32> + } + scf.forall.in_parallel { + tensor.parallel_insert_slice %17 into %arg4[0, 0, %arg3] [128, 256, 1] [1, 1, 1] : tensor<128x256xf32> into tensor<128x256x2xf32> + } + } + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%9 : tensor<128x256x2xf32>) outs(%extracted_slice_6 : tensor<128x256xf32>) { + ^bb0(%in: f32, %out: f32): + %13 = arith.addf %in, %out : f32 + linalg.yield %13 : f32 + } -> tensor<128x256xf32> + %11 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg0) + %12 = affine.apply affine_map<(d0) -> (d0 * 256)>(%arg1) + scf.forall.in_parallel { + tensor.parallel_insert_slice %10 into %arg2[%11, %12] [128, 256] [1, 1] : tensor<128x256xf32> into tensor<256x512xf32> + } +} +``` \ No newline at end of file diff --git a/include/gc/Analysis/MatmulConfigAnalysis.h b/include/gc/Analysis/MatmulConfigAnalysis.h index cbc259609..d991bec86 100644 --- a/include/gc/Analysis/MatmulConfigAnalysis.h +++ b/include/gc/Analysis/MatmulConfigAnalysis.h @@ -1,4 +1,4 @@ -//===-- MatmulConfigAnalysis.h - DESC ---------------------------*- C++ -*-===// +//===-- MatmulConfigAnalysis.h - the analysis for matmul config -*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -11,66 +11,70 @@ #include "gc/Dialect/Linalgx/LinalgxOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/DenseMap.h" -#include -#include -#include +#include namespace mlir { namespace gc { using namespace mlir; +// A mock for the taget information +// TODO: replace it with upstream hardware description model struct SystemDesc { + + static int getPositiveIntFromStr(char *str, int defaultValue = 1) { + if (!str || strlen(str) == 0 || str[0] > '9' || str[0] < '0') { + return defaultValue; + } + auto val = std::stoi(str); + return val > 0 ? val : defaultValue; + } + // get runtime OMP_NUM_THREADS uint32_t getNumThreads() { char *numThreads = getenv("OMP_NUM_THREADS"); - if (numThreads) { - return std::stoi(numThreads); - } - return 1; + return getPositiveIntFromStr(numThreads, 1); } // get cache size by cacheLevel size_t getCacheSize(uint8_t cacheLevel) { if (cacheLevel == 1) { char *cacheSize = getenv("L1_CACHE_SIZE"); - if (cacheSize) { - return std::stoi(cacheSize); - } + return getPositiveIntFromStr(cacheSize, 0); } else if (cacheLevel == 2) { char *cacheSize = getenv("L2_CACHE_SIZE"); - if (cacheSize) { - return std::stoi(cacheSize); - } + return getPositiveIntFromStr(cacheSize, 0); } else if (cacheLevel == 3) { char *cacheSize = getenv("L3_CACHE_SIZE"); - if (cacheSize) { - return std::stoi(cacheSize); - } + return getPositiveIntFromStr(cacheSize, 0); } return 0; } - SmallVector getContractionOperationMaxVectorLength() { - return {512UL, 512UL}; + // get the maximum vector length in bits + size_t getMaxVectorLength() { + char *maxVectorLanes = getenv("MAX_VECTOR_LENGTH"); + return getPositiveIntFromStr(maxVectorLanes, 512); } }; +// The configuration for matmul tiling +// TODO: support batch matmul struct MatmulConfig { - uint32_t MBlock, NBlock, KBlock; + // The number of threads distributed to M, N, K uint32_t MThreads, NThreads, KThreads; + // The innermost block size for M, N, K which will be directly converted to + // brgemm. uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock; - friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, - const MatmulConfig &config); + // The outer block size for M, N, K which will be used to decide the loop tile + // size in single thread + uint32_t MBlock, NBlock, KBlock; }; enum DimType { Batch, M, N, K }; -[[maybe_unused]] static SmallVector -extractDimTypeIdx(ArrayRef tyList, DimType ty) { +// Extract the index of the given DimType in the DimType list +inline SmallVector extractDimTypeIdx(ArrayRef tyList, + DimType ty) { SmallVector idxList; for (auto [idx, type] : llvm::enumerate(tyList)) { if (type == ty) { @@ -80,9 +84,11 @@ extractDimTypeIdx(ArrayRef tyList, DimType ty) { return idxList; } -static FailureOr>> +// Get the operand dim type for every operand for the given linalg op +inline FailureOr>> getOprandDimType(linalg::LinalgOp &linalgOp) { - if (isa(linalgOp)) { + // TODO: replace the linalgx op with generic op + if (llvm::isa(linalgOp)) { return SmallVector>{ SmallVector{DimType::M, DimType::K}, SmallVector{DimType::K, DimType::N}, @@ -104,10 +110,31 @@ getOprandDimType(linalg::LinalgOp &linalgOp) { SmallVector{DimType::Batch, DimType::M, DimType::K}, SmallVector{DimType::Batch, DimType::K, DimType::N}, SmallVector{DimType::Batch, DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::K, DimType::M}, + SmallVector{DimType::K, DimType::N}, + SmallVector{DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K}, + SmallVector{DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::Batch, DimType::K, DimType::M}, + SmallVector{DimType::Batch, DimType::K, DimType::N}, + SmallVector{DimType::Batch, DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::Batch, DimType::M, DimType::K}, + SmallVector{DimType::Batch, DimType::N, DimType::K}, + SmallVector{DimType::Batch, DimType::M, DimType::N}}; } return failure(); } +// The analysis to extract the matmul configuration from the given linalg op struct MatmulConfigAnalysis { public: explicit MatmulConfigAnalysis(Operation *root); diff --git a/include/gc/Dialect/Arith/Utils/EasyBuild.h b/include/gc/Dialect/Arith/Utils/EasyBuild.h deleted file mode 100644 index f7656370d..000000000 --- a/include/gc/Dialect/Arith/Utils/EasyBuild.h +++ /dev/null @@ -1,433 +0,0 @@ -//===-- EasyBuild.h - DESC --------------------------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_ARITH_UTILS_EASYBUILD_H -#define MLIR_DIALECT_ARITH_UTILS_EASYBUILD_H -#include "gc/IR/EasyBuild.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/Builders.h" -#include -#include -#include - -namespace mlir { -namespace easybuild { - -namespace impl { - -template struct ToFloatType {}; - -template <> struct ToFloatType<4> { using type = Float32Type; }; -template <> struct ToFloatType<8> { using type = Float64Type; }; - -inline Type getElementType(Value v) { - auto type = v.getType(); - if (type.isa() || type.isa()) { - type = type.cast().getElementType(); - } - return type; -} - -} // namespace impl - -struct EBUnsigned; - -struct EBArithValue : public EBValue { - template - static T toIndex(const impl::StatePtr &state, uint64_t v); - - template - static auto wrapOrFail(const impl::StatePtr &state, T &&v); - - template static auto wrap(const impl::StatePtr &state, T &&v) { - auto ret = wrapOrFail(state, std::forward(v)); - if (failed(ret)) { - llvm_unreachable("Bad wrap"); - } - return *ret; - } - -protected: - using EBValue::EBValue; -}; - -struct EBUnsigned : public EBArithValue { - static FailureOr wrapOrFail(const impl::StatePtr &state, - Value v) { - auto type = impl::getElementType(v); - if (type.isUnsignedInteger() || type.isSignlessInteger() || - type.isIndex()) { - return EBUnsigned{state, v}; - } - return failure(); - } - static FailureOr wrapOrFail(const impl::StatePtr &state, - const OpFoldResult &v) { - if (v.is()) { - return wrapOrFail(state, v.get()); - } - auto attr = v.get(); - if (auto val = attr.dyn_cast()) { - if (val.getType().isIndex()) - return EBUnsigned{state, state->builder.create( - state->loc, val.getInt())}; - else - return EBUnsigned{state, state->builder.create( - state->loc, val.getInt(), val.getType())}; - } - return failure(); - } - friend struct EBArithValue; - friend struct OperatorHandlers; - -protected: - using EBArithValue::EBArithValue; -}; - -struct EBSigned : EBArithValue { - static FailureOr wrapOrFail(const impl::StatePtr &state, Value v) { - auto type = impl::getElementType(v); - if (type.isSignedInteger() || type.isSignlessInteger()) { - return EBSigned{state, v}; - } - return failure(); - } - static FailureOr wrapOrFail(const impl::StatePtr &state, - const OpFoldResult &v) { - if (v.is()) { - return wrapOrFail(state, v.get()); - } - auto attr = v.get(); - if (auto val = attr.dyn_cast()) - return EBSigned{state, state->builder.create( - state->loc, val.getInt(), val.getType())}; - return failure(); - } - friend struct EBArithValue; - friend struct OperatorHandlers; - -protected: - using EBArithValue::EBArithValue; -}; - -struct EBFloatPoint : EBArithValue { - static FailureOr wrapOrFail(const impl::StatePtr &state, - Value v) { - auto type = impl::getElementType(v); - if (type.isa()) { - return EBFloatPoint{state, v}; - } - return failure(); - } - static FailureOr wrapOrFail(const impl::StatePtr &state, - const OpFoldResult &v) { - if (v.is()) { - return wrapOrFail(state, v.get()); - } - auto attr = v.get(); - if (auto val = attr.dyn_cast()) - return EBFloatPoint{state, state->builder.create( - state->loc, val.getValue(), - val.getType().cast())}; - return failure(); - } - friend struct EBArithValue; - friend struct OperatorHandlers; - -protected: - using EBArithValue::EBArithValue; -}; - -template -inline T EBArithValue::toIndex(const impl::StatePtr &state, uint64_t v) { - return EBUnsigned{ - state, state->builder.create(state->loc, v)}; -} - -template -inline auto EBArithValue::wrapOrFail(const impl::StatePtr &state, T &&v) { - using DT = std::decay_t; - static_assert(std::is_arithmetic_v
, "Expecting arithmetic types"); - if constexpr (std::is_same_v) { - if (state->u64AsIndex) { - return FailureOr{toIndex(state, v)}; - } - } - - if constexpr (std::is_same_v) { - return FailureOr{ - EBUnsigned{state, state->builder.create( - state->loc, static_cast(v), 1)}}; - } else if constexpr (std::is_integral_v
) { - if constexpr (!std::is_signed_v
) { - return FailureOr{EBUnsigned{ - state, state->builder.create( - state->loc, static_cast(v), sizeof(T) * 8)}}; - } else { - return FailureOr{EBSigned{ - state, state->builder.create( - state->loc, static_cast(v), sizeof(T) * 8)}}; - } - } else { - using DType = typename impl::ToFloatType::type; - return FailureOr{ - EBFloatPoint{state, state->builder.create( - state->loc, APFloat{v}, - DType::get(state->builder.getContext()))}}; - } -} - -struct OperatorHandlers { - template - static V handleBinary(const V &a, const V &b) { - assert(a.builder == b.builder); - return {a.builder, - a.builder->builder.template create(a.builder->loc, a.v, b.v)}; - } - - template - static V handleBinaryConst(const V &a, const T2 &b) { - return handleBinary(a, EBArithValue::wrap(a.builder, b)); - } - - template - static V handleBinaryConst(const T2 &a, const V &b) { - return handleBinary(EBArithValue::wrap(b.builder, a), b); - } - - template - static EBUnsigned handleCmp(const V &a, const V &b, Pred predicate) { - assert(a.builder == b.builder); - return {a.builder, a.builder->builder.template create( - a.builder->loc, predicate, a.v, b.v)}; - } - - template - static EBUnsigned handleCmpConst(const V &a, const T2 &b, Pred predicate) { - return handleCmp(a, EBArithValue::wrap(a.builder, b), predicate); - } - - template - static EBUnsigned handleCmpConst(const T2 &a, const V &b, Pred predicate) { - return handleCmp(EBArithValue::wrap(b.builder, a), b, predicate); - } - - template - static T create(const impl::StatePtr &state, Args &&...v) { - return {state, - state->builder.create(state->loc, std::forward(v)...)}; - } -}; - -#define DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, OPCLASS, TYPE) \ - inline TYPE operator OP(const TYPE &a, const TYPE &b) { \ - return OperatorHandlers::handleBinary(a, b); \ - } \ - template inline TYPE operator OP(const TYPE &a, T b) { \ - return OperatorHandlers::handleBinaryConst(a, b); \ - } \ - template inline TYPE operator OP(T a, const TYPE &b) { \ - return OperatorHandlers::handleBinaryConst(a, b); \ - } - -#define DEF_EASYBUILD_BINARY_OPERATOR(OP, SIGNED, UNSIGNED, FLOAT) \ - DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, SIGNED, EBSigned) \ - DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, UNSIGNED, EBUnsigned) \ - DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, FLOAT, EBFloatPoint) - -DEF_EASYBUILD_BINARY_OPERATOR(+, arith::AddIOp, arith::AddIOp, arith::AddFOp) -DEF_EASYBUILD_BINARY_OPERATOR(-, arith::SubIOp, arith::SubIOp, arith::SubFOp) -DEF_EASYBUILD_BINARY_OPERATOR(*, arith::MulIOp, arith::MulIOp, arith::MulFOp) -DEF_EASYBUILD_BINARY_OPERATOR(/, arith::DivSIOp, arith::DivUIOp, arith::DivFOp) -DEF_EASYBUILD_BINARY_OPERATOR(%, arith::RemSIOp, arith::RemUIOp, arith::RemFOp) - -#undef DEF_EASYBUILD_BINARY_OPERATOR -#define DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(OP, SIGNED, UNSIGNED) \ - DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, SIGNED, EBSigned) \ - DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE(OP, UNSIGNED, EBUnsigned) - -DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(>>, arith::ShRSIOp, arith::ShRUIOp) -DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(<<, arith::ShLIOp, arith::ShLIOp) -DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(&, arith::AndIOp, arith::AndIOp) -DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(|, arith::OrIOp, arith::OrIOp) -DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT(^, arith::XOrIOp, arith::XOrIOp) - -#undef DEF_EASYBUILD_BINARY_OPERATOR_FOR_INT -#undef DEF_EASYBUILD_BINARY_OPERATOR_FOR_TYPE - -inline EBFloatPoint operator-(const EBFloatPoint &a) { - return OperatorHandlers::create(a.builder, a.v); -} - -#define DEF_EASYBUILD_CMP_OPERATOR(OP, OPCLASS, TYPE, PRED) \ - EBUnsigned operator OP(const TYPE &a, const TYPE &b) { \ - return OperatorHandlers::handleCmp(a, b, PRED); \ - } \ - template EBUnsigned operator OP(const TYPE &a, T b) { \ - return OperatorHandlers::handleCmpConst(a, b, PRED); \ - } \ - template EBUnsigned operator OP(T a, const TYPE &b) { \ - return OperatorHandlers::handleCmpConst(a, b, PRED); \ - } - -DEF_EASYBUILD_CMP_OPERATOR(<, arith::CmpIOp, EBUnsigned, - arith::CmpIPredicate::ult) -DEF_EASYBUILD_CMP_OPERATOR(<=, arith::CmpIOp, EBUnsigned, - arith::CmpIPredicate::ule) -DEF_EASYBUILD_CMP_OPERATOR(>, arith::CmpIOp, EBUnsigned, - arith::CmpIPredicate::ugt) -DEF_EASYBUILD_CMP_OPERATOR(>=, arith::CmpIOp, EBUnsigned, - arith::CmpIPredicate::uge) -DEF_EASYBUILD_CMP_OPERATOR(==, arith::CmpIOp, EBUnsigned, - arith::CmpIPredicate::eq) -DEF_EASYBUILD_CMP_OPERATOR(!=, arith::CmpIOp, EBUnsigned, - arith::CmpIPredicate::ne) - -DEF_EASYBUILD_CMP_OPERATOR(<, arith::CmpIOp, EBSigned, - arith::CmpIPredicate::slt) -DEF_EASYBUILD_CMP_OPERATOR(<=, arith::CmpIOp, EBSigned, - arith::CmpIPredicate::sle) -DEF_EASYBUILD_CMP_OPERATOR(>, arith::CmpIOp, EBSigned, - arith::CmpIPredicate::sgt) -DEF_EASYBUILD_CMP_OPERATOR(>=, arith::CmpIOp, EBSigned, - arith::CmpIPredicate::sge) -DEF_EASYBUILD_CMP_OPERATOR(==, arith::CmpIOp, EBSigned, - arith::CmpIPredicate::eq) -DEF_EASYBUILD_CMP_OPERATOR(!=, arith::CmpIOp, EBSigned, - arith::CmpIPredicate::ne) - -DEF_EASYBUILD_CMP_OPERATOR(<, arith::CmpFOp, EBFloatPoint, - arith::CmpFPredicate::OLT) -DEF_EASYBUILD_CMP_OPERATOR(<=, arith::CmpFOp, EBFloatPoint, - arith::CmpFPredicate::OLE) -DEF_EASYBUILD_CMP_OPERATOR(>, arith::CmpFOp, EBFloatPoint, - arith::CmpFPredicate::OGT) -DEF_EASYBUILD_CMP_OPERATOR(>=, arith::CmpFOp, EBFloatPoint, - arith::CmpFPredicate::OGE) -DEF_EASYBUILD_CMP_OPERATOR(==, arith::CmpFOp, EBFloatPoint, - arith::CmpFPredicate::OEQ) -DEF_EASYBUILD_CMP_OPERATOR(!=, arith::CmpFOp, EBFloatPoint, - arith::CmpFPredicate::ONE) - -#undef DEF_EASYBUILD_CMP_OPERATOR - -namespace arithops { -inline EBFloatPoint castIntToFP(Type type, const EBSigned &v) { - return OperatorHandlers::create(v.builder, - type, v); -} - -inline EBFloatPoint castIntToFP(Type type, const EBUnsigned &v) { - return OperatorHandlers::create(v.builder, - type, v); -} - -template inline T castFPToInt(const EBFloatPoint &v) { - if constexpr (std::is_same_v) { - return OperatorHandlers::create(v.builder, v); - } else { - static_assert(std::is_same_v, - "Expecting EBUnsigned or EBSigned"); - return OperatorHandlers::create(v.builder, v); - } -} - -inline EBSigned ceildiv(const EBSigned &a, const EBSigned &b) { - return OperatorHandlers::create(a.builder, a, - b); -} - -inline EBUnsigned ceildiv(const EBUnsigned &a, const EBUnsigned &b) { - return OperatorHandlers::create(a.builder, a, - b); -} - -inline EBSigned floordiv(const EBSigned &a, const EBSigned &b) { - return OperatorHandlers::create(a.builder, a, - b); -} - -inline EBSigned extend(Type type, const EBSigned &a) { - return OperatorHandlers::create(a.builder, type, a); -} - -inline EBUnsigned extend(Type type, const EBUnsigned &a) { - return OperatorHandlers::create(a.builder, type, - a); -} - -inline EBFloatPoint extend(Type type, const EBFloatPoint &a) { - arith::FastMathFlagsAttr fastMathAttr; - return OperatorHandlers::create(a.builder, type, - a, fastMathAttr); -} - -inline EBSigned trunc(Type type, const EBSigned &a) { - return OperatorHandlers::create(a.builder, type, - a); -} - -inline EBFloatPoint trunc(Type type, const EBFloatPoint &a) { - return OperatorHandlers::create(a.builder, - type, a); -} - -template -inline T select(const EBUnsigned &pred, const T &trueValue, - const T &falseValue) { - static_assert(std::is_base_of_v, - "Expecting T to be a subclass of EBArithValue"); - return OperatorHandlers::create(pred.builder, pred, - trueValue, falseValue); -} - -template -inline TyTo bitcast(Type type, const TyFrom &v) { - return OperatorHandlers::create(v.builder, type, v); -} - -inline EBSigned min(const EBSigned &a, const EBSigned &b) { - return OperatorHandlers::create(a.builder, a, b); -} - -inline EBSigned max(const EBSigned &a, const EBSigned &b) { - return OperatorHandlers::create(a.builder, a, b); -} - -inline EBUnsigned min(const EBUnsigned &a, const EBUnsigned &b) { - return OperatorHandlers::create(a.builder, a, b); -} - -inline EBUnsigned max(const EBUnsigned &a, const EBUnsigned &b) { - return OperatorHandlers::create(a.builder, a, b); -} - -inline EBFloatPoint minnum(const EBFloatPoint &a, const EBFloatPoint &b) { - return OperatorHandlers::create(a.builder, a, - b); -} - -inline EBFloatPoint maxnum(const EBFloatPoint &a, const EBFloatPoint &b) { - return OperatorHandlers::create(a.builder, a, - b); -} - -inline EBFloatPoint minimum(const EBFloatPoint &a, const EBFloatPoint &b) { - return OperatorHandlers::create(a.builder, a, - b); -} - -inline EBFloatPoint maximum(const EBFloatPoint &a, const EBFloatPoint &b) { - return OperatorHandlers::create(a.builder, a, - b); -} - -} // namespace arithops - -} // namespace easybuild -} // namespace mlir -#endif diff --git a/include/gc/IR/EasyBuild.h b/include/gc/IR/EasyBuild.h deleted file mode 100644 index 4b6e72225..000000000 --- a/include/gc/IR/EasyBuild.h +++ /dev/null @@ -1,102 +0,0 @@ -//===-- EasyBuild.h - DESC --------------------------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -#ifndef MLIR_IR_EASYBUILD_H -#define MLIR_IR_EASYBUILD_H -#include "mlir/IR/Builders.h" -#include -#include -#include - -namespace mlir { -namespace easybuild { - -namespace impl { -struct EasyBuildState { - OpBuilder &builder; - Location loc; - bool u64AsIndex; - EasyBuildState(OpBuilder &builder, Location loc, bool u64AsIndex) - : builder{builder}, loc{loc}, u64AsIndex{u64AsIndex} {} -}; - -using StatePtr = std::shared_ptr; - -} // namespace impl - -struct EBValue { - std::shared_ptr builder; - Value v; - EBValue() = default; - EBValue(const impl::StatePtr &builder, Value v) : builder{builder}, v{v} {} - Value get() const { return v; } - operator Value() const { return v; } - - static FailureOr wrapOrFail(const impl::StatePtr &state, Value v) { - return EBValue{state, v}; - } -}; - -struct EBArithValue; - -struct EasyBuilder { - std::shared_ptr builder; - EasyBuilder(OpBuilder &builder, Location loc, bool u64AsIndex = false) - : builder{ - std::make_shared(builder, loc, u64AsIndex)} {} - EasyBuilder(const std::shared_ptr &builder) - : builder{builder} {} - void setLoc(const Location &l) { builder->loc = l; } - - template auto wrapOrFail(V &&v) { - return W::wrapOrFail(builder, std::forward(v)); - } - - Operation *getLastOperaion() { - return &*(--builder->builder.getInsertionPoint()); - } - - template auto wrap(V &&v) { - auto ret = wrapOrFail(std::forward(v)); - if (failed(ret)) { - llvm_unreachable("wrap failed!"); - } - return *ret; - } - - template auto operator()(V &&v) { - if constexpr (std::is_convertible_v) { - return EBValue{builder, std::forward(v)}; - } else { - return wrap(std::forward(v)); - } - } - - template auto toIndex(uint64_t v) const { - return W::toIndex(builder, v); - } - - template - auto F(Args &&...v) { - if constexpr (std::is_same_v) { - builder->builder.create(builder->loc, std::forward(v)...); - } else { - return wrap( - builder->builder.create(builder->loc, std::forward(v)...)); - } - } - - template - auto yield(Args &&...v) { - builder->builder.create(builder->loc, - ValueRange{std::forward(v)...}); - } -}; - -} // namespace easybuild -} // namespace mlir -#endif \ No newline at end of file diff --git a/include/gc/IR/EasyBuildSCF.h b/include/gc/IR/EasyBuildSCF.h deleted file mode 100644 index 3d7ce9d77..000000000 --- a/include/gc/IR/EasyBuildSCF.h +++ /dev/null @@ -1,187 +0,0 @@ -//===-- EasyBuildSCF.h - DESC -----------------------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -// -// This header file defines the helper classes, functions and macros to help to -// build general structured control flow. Developers can use the utilities in -// this header to easily compose control flow IR. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_EASYBUILDSCF_H -#define MLIR_IR_EASYBUILDSCF_H -#include "gc/IR/EasyBuild.h" -#include "mlir/Interfaces/LoopLikeInterface.h" - -namespace mlir { -namespace scf { -class IfOp; -} - -namespace easybuild { -namespace impl { - -struct ForRangeSimulatorImpl { - StatePtr s; - LoopLikeOpInterface op; - ForRangeSimulatorImpl(const StatePtr &s, LoopLikeOpInterface op) - : s{s}, op{op} { - s->builder.setInsertionPointToStart(&op.getLoopRegions().front()->front()); - } - ~ForRangeSimulatorImpl() { s->builder.setInsertionPointAfter(op); } -}; - -template -using NthTypeOf = typename std::tuple_element>::type; - -template struct ForVarBinding { - ForRangeSimulatorImpl *impl; - template auto get() { - using TOut = NthTypeOf; - if (auto wrapped = TOut::wrapOrFail( - impl->s, impl->op.getLoopRegions().front()->front().getArgument(I)); - succeeded(wrapped)) { - return *wrapped; - } - llvm_unreachable("Bad cast for the loop iterator"); - } -}; -} // namespace impl -} // namespace easybuild -} // namespace mlir - -namespace std { -template -struct tuple_size> - : std::integral_constant {}; - -template -struct tuple_element> { - using type = mlir::easybuild::impl::NthTypeOf; -}; -} // namespace std - -namespace mlir { -namespace easybuild { - -namespace impl { - -template struct ForRangeSimulator : ForRangeSimulatorImpl { - using ForRangeSimulatorImpl::ForRangeSimulatorImpl; - struct ForRangeIterator { - ForRangeSimulatorImpl *ptr; - bool consumed; - auto operator*() const { return ForVarBinding{ptr}; } - - ForRangeIterator &operator++() { - consumed = true; - return *this; - } - - bool operator!=(ForRangeIterator &other) const { - return consumed != other.consumed; - } - - ForRangeIterator(ForRangeSimulator *ptr) : ptr{ptr}, consumed{false} {} - ForRangeIterator() : ptr{nullptr}, consumed{true} {} - }; - - ForRangeIterator begin() { return ForRangeIterator(this); } - - ForRangeIterator end() { return ForRangeIterator(); } -}; -} // namespace impl - -template -auto forRangeIn(const impl::StatePtr &s, LoopLikeOpInterface op) { - return impl::ForRangeSimulator{s, op}; -} - -template -auto forRangeIn(const EasyBuilder &s, LoopLikeOpInterface op) { - return impl::ForRangeSimulator{s.builder, op}; -} - -#define EB_for for - -namespace impl { -struct IfSimulator; -struct IfIterator { - IfSimulator *ptr; - int index; - int operator*() const; - - IfIterator &operator++() { - index++; - return *this; - } - - bool operator!=(IfIterator &other) const { return index != other.index; } - - IfIterator(IfSimulator *ptr) : ptr{ptr}, index{0} {} - IfIterator(int numRegions) : ptr{nullptr}, index{numRegions} {} -}; - -struct IfSimulator { - StatePtr s; - Operation *op; - IfIterator begin() { return IfIterator(this); } - IfIterator end() { - int nonEmptyRegions = 0; - for (auto ® : op->getRegions()) { - if (reg.begin() != reg.end()) { - nonEmptyRegions++; - } - } - return IfIterator(nonEmptyRegions); - } - ~IfSimulator() { s->builder.setInsertionPointAfter(op); } -}; -inline int IfIterator::operator*() const { - auto &blocks = ptr->op->getRegion(index); - ptr->s->builder.setInsertionPointToStart(&blocks.back()); - return index; -} - -} // namespace impl - -impl::IfSimulator makeIfRange(const EasyBuilder &s, Operation *op) { - return impl::IfSimulator{s.builder, op}; -} - -template -impl::IfSimulator makeScfIfLikeRange(EBValue cond, TypeRange resultTypes) { - auto &s = cond.builder; - auto op = s->builder.create(s->loc, resultTypes, cond, true); - return impl::IfSimulator{s, op}; -} - -template -impl::IfSimulator makeScfIfLikeRange(EBValue cond, bool hasElse = true) { - auto &s = cond.builder; - auto op = s->builder.create(s->loc, TypeRange{}, cond, hasElse); - return impl::IfSimulator{s, op}; -} - -#define EB_if(BUILDER, ...) \ - for (auto &&eb_mlir_if_scope__ : \ - ::mlir::easybuild::makeIfRange(BUILDER, __VA_ARGS__)) \ - if (eb_mlir_if_scope__ == 0) - -// EB_scf_if(COND) -// EB_scf_if(COND, HAS_ELSE) -// EB_scf_if(COND, RESULT_TYPES) -#define EB_scf_if(...) \ - for (auto &&eb_mlir_if_scope__ : \ - ::mlir::easybuild::makeScfIfLikeRange(__VA_ARGS__)) \ - if (eb_mlir_if_scope__ == 0) -#define EB_else else - -} // namespace easybuild -} // namespace mlir -#endif \ No newline at end of file diff --git a/lib/gc/Analysis/MatmulConfigAnalysis.cpp b/lib/gc/Analysis/MatmulConfigAnalysis.cpp index d0147ee2e..682855cd4 100644 --- a/lib/gc/Analysis/MatmulConfigAnalysis.cpp +++ b/lib/gc/Analysis/MatmulConfigAnalysis.cpp @@ -1,4 +1,4 @@ -//===-- MatmulConfigAnalysis.cpp - DESC -------------------------*- C++ -*-===// +//===-- MatmulConfigAnalysis.cpp - Analysis for matmul config ---*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,10 +6,9 @@ // //===----------------------------------------------------------------------===// -#include -#include - #include "gc/Analysis/MatmulConfigAnalysis.h" +#include +#include namespace mlir { namespace gc { @@ -19,9 +18,9 @@ namespace gc { llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, const MatmulConfig &config) { - ss << "MBlock: " << config.MBlock << ", NBlock: " << config.NBlock - << ", KBlock: " << config.KBlock << ", MThreads: " << config.MThreads - << ", NThreads: " << config.NThreads << ", KThreads: " << config.KThreads + ss << "MThreads: " << config.MThreads << ", NThreads: " << config.NThreads + << ", KThreads: " << config.KThreads << ", MBlock: " << config.MBlock + << ", NBlock: " << config.NBlock << ", KBlock: " << config.KBlock << ", innerMostMBlock: " << config.innerMostMBlock << ", innerMostNBlock: " << config.innerMostNBlock << ", innerMostKBlock: " << config.innerMostKBlock; @@ -29,7 +28,8 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, } template -llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, std::vector arry) { +static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, + std::vector arry) { ss << "["; for (auto [idx, a] : llvm::enumerate(arry)) { if (idx != 0) { @@ -41,101 +41,101 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, std::vector arry) { return ss; } +// generate the candidate for the block size(factor of `num`, pow of 2 which is +// less than `num`) std::vector getCandidate(uint32_t num, uint32_t floor, uint32_t ceil = std::numeric_limits::max()) { // factor std::vector candidates; - for (uint32_t i = 1; i <= num; i++) { - if (num % i == 0 && i <= ceil && i >= floor) { + uint32_t upperbound = std::min(num, ceil); + for (uint32_t i = floor; i <= upperbound; i++) { + if (num % i == 0) { candidates.push_back(i); } } // the pow of 2 - auto candidate = 1U; - while (candidate < num && candidate <= ceil && candidate >= floor) { + uint32_t candidate = 1U; + while (candidate < floor) + candidate *= 2; + while (candidate <= upperbound) { candidates.push_back(candidate); candidate *= 2; } + // remove duplicate candidates std::sort(candidates.begin(), candidates.end()); - auto last = std::unique(candidates.begin(), candidates.end()); - candidates.erase(last, candidates.end()); + candidates.erase(std::unique(candidates.begin(), candidates.end()), + candidates.end()); return candidates; } -bool isValidConfig(const MatmulConfig &config, SystemDesc &sysDesc, - ArrayRef shape) { - if (config.innerMostMBlock == 0 || config.innerMostNBlock == 0 || - config.innerMostKBlock == 0) { - return false; - } - - if (shape[0] % config.innerMostMBlock != 0 || - shape[1] % config.innerMostNBlock != 0 || - shape[2] % config.innerMostKBlock != 0) { - return false; - } - - return true; -} - +// check if the threads are valid bool validateThreads(ArrayRef threads, SystemDesc &sysDesc) { - auto numThreads = sysDesc.getNumThreads(); - auto actualThreads = 1U; - for (auto t : threads) { + uint32_t numThreads = sysDesc.getNumThreads(); + uint32_t actualThreads = 1U; + for (uint32_t t : threads) { actualThreads *= t; } return actualThreads == numThreads; } -double hardwareEfficiencyCost(linalg::LinalgOp &linalgOp, - ArrayRef shape, - const MatmulConfig &config, SystemDesc &sysDesc) { - auto dtypeSize = DataLayout().getTypeSizeInBits( +// calculate the cost of the hardware efficiency(whether the vector register is +// fully utilized) +double vectorRegEfficiencyCost(linalg::LinalgOp &linalgOp, + ArrayRef shape, + const MatmulConfig &config, + SystemDesc &sysDesc) { + size_t dtypeSize = DataLayout().getTypeSizeInBits( ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); - auto vectorLength = sysDesc.getContractionOperationMaxVectorLength(); - auto mMaxVectorLength = vectorLength[0] / dtypeSize; - auto kMaxVectorLength = - (vectorLength.size() > 1 ? vectorLength[1] : vectorLength[0]) / dtypeSize; - auto cost = (mMaxVectorLength - config.innerMostMBlock % mMaxVectorLength) % - mMaxVectorLength * 1.0 / config.innerMostMBlock + - (kMaxVectorLength - config.innerMostKBlock % kMaxVectorLength) % - kMaxVectorLength * 1.0 / config.innerMostKBlock + - (mMaxVectorLength - config.innerMostNBlock % mMaxVectorLength) % - mMaxVectorLength * 1.0 / config.innerMostNBlock; + size_t maxVectorLength = sysDesc.getMaxVectorLength() / dtypeSize; + double cost = (maxVectorLength - config.innerMostMBlock % maxVectorLength) % + maxVectorLength * 1.0 / config.innerMostMBlock + + (maxVectorLength - config.innerMostKBlock % maxVectorLength) % + maxVectorLength * 1.0 / config.innerMostKBlock + + (maxVectorLength - config.innerMostNBlock % maxVectorLength) % + maxVectorLength * 1.0 / config.innerMostNBlock; return cost; } +// calculate the cost of the workload balance double workloadBalancedCost(linalg::LinalgOp &linalgOp, ArrayRef shape, const MatmulConfig &config, SystemDesc &sysDesc) { - auto M = shape[0], N = shape[1], K = shape[2]; - auto MTaskNum = llvm::divideCeil(M, config.MBlock); - auto NTaskNum = llvm::divideCeil(N, config.NBlock); - auto KTaskNum = llvm::divideCeil(K, config.KBlock); - auto cost = (MTaskNum % config.MThreads) * 1.0 / MTaskNum + - (NTaskNum % config.NThreads) * 1.0 / NTaskNum + - (KTaskNum % config.KThreads) * 1.0 / KTaskNum; + if (shape.size() < 3) { + // Has an invalid shape + return 0; + } + uint32_t M = shape[0], N = shape[1], K = shape[2]; + uint32_t MTaskNum = llvm::divideCeil(M, config.MBlock); + uint32_t NTaskNum = llvm::divideCeil(N, config.NBlock); + uint32_t KTaskNum = llvm::divideCeil(K, config.KBlock); + double cost = (MTaskNum % config.MThreads) * 1.0 / MTaskNum + + (NTaskNum % config.NThreads) * 1.0 / NTaskNum + + (KTaskNum % config.KThreads) * 1.0 / KTaskNum; if (MTaskNum < config.MThreads || NTaskNum < config.NThreads || KTaskNum < config.KThreads) { - auto threadNotFulllyUtilizedPenalty = 10.0; + double threadNotFulllyUtilizedPenalty = 10.0; cost *= threadNotFulllyUtilizedPenalty; } return cost; } -constexpr unsigned bitPerByte = 8; + +// calculate the cost of the memory consumption on the thread double memoryConsumptionOnThreadCost(linalg::LinalgOp &linalgOp, ArrayRef shape, const MatmulConfig &config, SystemDesc &sysDesc) { - auto M = shape[0], N = shape[1], K = shape[2]; - auto dtypeSize = DataLayout().getTypeSizeInBits( + if (shape.size() < 3) { + // Has an invalid shape + return 0; + } + uint32_t M = shape[0], N = shape[1], K = shape[2]; + size_t dtypeSize = DataLayout().getTypeSize( ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); // if use K split, there will be one more final reduce and break the post // fusion - - auto KSplitPenalty = 8.0 * (dtypeSize / bitPerByte); - auto memoryConsumptionPerThread = + double KSplitPenalty = 8.0 * dtypeSize; + double memoryConsumptionPerThread = M * K * 1.0 / config.MThreads / config.KThreads + K * N * 1.0 / config.KThreads / config.NThreads + M * N * ((config.KThreads - 1) * KSplitPenalty + 1.0) / config.MThreads / @@ -143,35 +143,36 @@ double memoryConsumptionOnThreadCost(linalg::LinalgOp &linalgOp, return memoryConsumptionPerThread; } +// calculate the cost of the computation intensity on the L2 cache double computationIntensityOnL2Cache(linalg::LinalgOp &linalgOp, ArrayRef shape, const MatmulConfig &config, SystemDesc &sysDesc) { - double simulationPenalty = 0.7; - auto L2Cache = sysDesc.getCacheSize(2); - auto dtypeSize = DataLayout().getTypeSizeInBits( + double fullLoadRatio = 0.7; + uint32_t L2Cache = sysDesc.getCacheSize(2); + size_t dtypeSize = DataLayout().getTypeSize( ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); - auto outOfCachePenalty = 1024; + uint32_t outOfCachePenalty = 1024; double FLOPS = 2.0 * config.MBlock * config.NBlock * config.KBlock; double memoryConsumption = config.MBlock * config.NBlock + config.NBlock * config.KBlock + config.MBlock * config.KBlock; double computationIntensity = FLOPS / memoryConsumption; - if (memoryConsumption * (dtypeSize / bitPerByte) > - L2Cache * simulationPenalty) { + if (memoryConsumption * dtypeSize > L2Cache * fullLoadRatio) { computationIntensity /= outOfCachePenalty; } return 1 / computationIntensity; } using CostModelFn = - std::function shape, - MatmulConfig cfg, SystemDesc &sysDesc)>; + std::function shape, + MatmulConfig cfg, SystemDesc &sysDesc)>; +// filter the config by the cost model std::vector -filterConfigByCostModel(std::vector configs, +filterConfigByCostModel(ArrayRef configs, linalg::LinalgOp &linalgOp, ArrayRef shape, - SystemDesc &sysDesc, CostModelFn costModel, + SystemDesc &sysDesc, const CostModelFn &costModel, float eliminationRatio = 0.5, float threshold = -1) { std::vector result; std::vector costs; @@ -183,7 +184,8 @@ filterConfigByCostModel(std::vector configs, std::stable_sort(idx.begin(), idx.end(), [&costs](size_t i1, size_t i2) { return costs[i1] < costs[i2]; }); - auto thresholdCost = costs[idx[(size_t)(eliminationRatio * configs.size())]]; + double thresholdCost = + costs[idx[(size_t)(eliminationRatio * configs.size())]]; thresholdCost = threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost; for (size_t i = 0; i < configs.size(); i++) { @@ -191,95 +193,78 @@ filterConfigByCostModel(std::vector configs, result.push_back(configs[idx[i]]); } } - llvm::errs() << "thresholdCost is: " << thresholdCost - << "\nbest with cost: " << costs[idx[0]] << "\n" - << configs[idx[0]] - << "\n worst with cost: " << costs[idx[configs.size() - 1]] - << "\n" - << configs[idx[configs.size() - 1]] << "\n"; - return result.size() > 0 ? result : configs; + LLVM_DEBUG(llvm::dbgs() << "thresholdCost is: " << thresholdCost + << "\nbest with cost: " << costs[idx[0]] << "\n" + << configs[idx[0]] << "\n worst with cost: " + << costs[idx[configs.size() - 1]] << "\n" + << configs[idx[configs.size() - 1]] << "\n"); + if (result.empty()) { + result = configs; + } + return result; } +// prepare the config candidates std::vector prepareConfigCandidates(Operation *root, SystemDesc &sysDesc, ArrayRef shape, ArrayRef givenInnermostBlock) { std::vector configs; - auto threads = sysDesc.getNumThreads(); - auto MThreadsCandidates = getCandidate((uint32_t)threads, 1U); - auto NThreadsCandidates = getCandidate((uint32_t)threads, 1U); - auto KThreadsCandidates = getCandidate((uint32_t)threads, 1U); - auto noSmallBlockNeedThreshold = 8 * 8U; - auto MBlockCandidates = getCandidate( - (uint32_t)shape[0], shape[0] > noSmallBlockNeedThreshold ? 8U : 1U, + uint32_t threads = sysDesc.getNumThreads(); + std::vector MThreadsCandidates = + getCandidate((uint32_t)threads, 1U); + std::vector NThreadsCandidates = + getCandidate((uint32_t)threads, 1U); + std::vector KThreadsCandidates = + getCandidate((uint32_t)threads, 1U); + uint32_t noSmallBlockNeedThreshold = 8 * 8U; + std::vector MBlockCandidates = getCandidate( + (uint32_t)shape[0], shape[0] >= noSmallBlockNeedThreshold ? 8U : 1U, (uint32_t)shape[0]); - auto NBlockCandidates = + std::vector NBlockCandidates = getCandidate((uint32_t)shape[1], - shape[1] > noSmallBlockNeedThreshold ? 8U : 1U, shape[1]); - auto KBlockCandidates = + shape[1] >= noSmallBlockNeedThreshold ? 8U : 1U, shape[1]); + std::vector KBlockCandidates = getCandidate((uint32_t)shape[2], - shape[2] > noSmallBlockNeedThreshold ? 8U : 1U, shape[2]); - auto innerMostMBlockCandidates = getCandidate( - (uint32_t)shape[0], shape[0] > noSmallBlockNeedThreshold ? 8U : 1U, 256U); - auto innerMostNBlockCandidates = getCandidate( - (uint32_t)shape[1], shape[1] > noSmallBlockNeedThreshold ? 8U : 1U, 256U); - auto innerMostKBlockCandidates = getCandidate( - (uint32_t)shape[2], shape[2] > noSmallBlockNeedThreshold ? 8U : 1U, 256U); - if (givenInnermostBlock.size() == 3) { - innerMostMBlockCandidates = - givenInnermostBlock[0] != 0 - ? std::vector{givenInnermostBlock[0]} - : innerMostMBlockCandidates; - innerMostNBlockCandidates = - givenInnermostBlock[1] != 0 - ? std::vector{givenInnermostBlock[1]} - : innerMostNBlockCandidates; - innerMostKBlockCandidates = - givenInnermostBlock[2] != 0 - ? std::vector{givenInnermostBlock[2]} - : innerMostKBlockCandidates; - } - llvm::errs() << "MThreadsCandidates size: " << MThreadsCandidates.size() - << MThreadsCandidates << "\n"; - llvm::errs() << "NThreadsCandidates size: " << NThreadsCandidates.size() - << NThreadsCandidates << "\n"; - llvm::errs() << "KThreadsCandidates size: " << KThreadsCandidates.size() - << KThreadsCandidates << "\n"; - llvm::errs() << "MBlockCandidates size: " << MBlockCandidates.size() - << MBlockCandidates << "\n"; - llvm::errs() << "NBlockCandidates size: " << NBlockCandidates.size() - << NBlockCandidates << "\n"; - llvm::errs() << "KBlockCandidates size: " << KBlockCandidates.size() - << KBlockCandidates << "\n"; - llvm::errs() << "innerMostMBlockCandidates size: " - << innerMostMBlockCandidates.size() << innerMostMBlockCandidates - << "\n"; - llvm::errs() << "innerMostNBlockCandidates size: " - << innerMostNBlockCandidates.size() << innerMostNBlockCandidates - << "\n"; - llvm::errs() << "innerMostKBlockCandidates size: " - << innerMostKBlockCandidates.size() << innerMostKBlockCandidates - << "\n"; - for (auto MThreads : MThreadsCandidates) { - for (auto NThreads : NThreadsCandidates) { - for (auto KThreads : KThreadsCandidates) { + shape[2] >= noSmallBlockNeedThreshold ? 8U : 1U, shape[2]); + std::vector innerMostMBlockCandidates = + givenInnermostBlock[0] != 0 && givenInnermostBlock.size() == 3 + ? std::vector{givenInnermostBlock[0]} + : getCandidate((uint32_t)shape[0], + shape[0] >= noSmallBlockNeedThreshold ? 8U : 1U, 256U); + std::vector innerMostNBlockCandidates = + givenInnermostBlock[1] != 0 && givenInnermostBlock.size() == 3 + ? std::vector{givenInnermostBlock[1]} + : getCandidate((uint32_t)shape[1], + shape[1] >= noSmallBlockNeedThreshold ? 8U : 1U, 256U); + std::vector innerMostKBlockCandidates = + givenInnermostBlock[2] != 0 && givenInnermostBlock.size() == 3 + ? std::vector{givenInnermostBlock[2]} + : getCandidate((uint32_t)shape[2], + shape[2] >= noSmallBlockNeedThreshold ? 8U : 1U, 256U); + + // TODO: improve via multi threading or add more constraints to restrict the + // candidate size + for (uint32_t MThreads : MThreadsCandidates) { + for (uint32_t NThreads : NThreadsCandidates) { + for (uint32_t KThreads : KThreadsCandidates) { if (!validateThreads({MThreads, NThreads, KThreads}, sysDesc)) { continue; } - for (auto MBlock : MBlockCandidates) { - for (auto innerMostMBlock : innerMostMBlockCandidates) { + for (uint32_t MBlock : MBlockCandidates) { + for (uint32_t innerMostMBlock : innerMostMBlockCandidates) { if (MBlock % innerMostMBlock != 0 || shape[0] % innerMostMBlock != 0) { continue; } - for (auto NBlock : NBlockCandidates) { - for (auto innerMostNBlock : innerMostNBlockCandidates) { + for (uint32_t NBlock : NBlockCandidates) { + for (uint32_t innerMostNBlock : innerMostNBlockCandidates) { if (NBlock % innerMostNBlock != 0 || shape[1] % innerMostNBlock != 0) { continue; } - for (auto KBlock : KBlockCandidates) { - for (auto innerMostKBlock : innerMostKBlockCandidates) { + for (uint32_t KBlock : KBlockCandidates) { + for (uint32_t innerMostKBlock : innerMostKBlockCandidates) { if (KBlock % innerMostKBlock != 0 || shape[2] % innerMostKBlock != 0) { continue; @@ -298,57 +283,71 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc, } } } - llvm::errs() << "Finish generating candidates. ConfigCandidates size: " - << configs.size() << "\n"; + LLVM_DEBUG( + llvm::dbgs() << "Finish generating candidates. ConfigCandidates size: " + << configs.size() << "\n"); return configs; } +// read the config from the attributes for tuning bool readConfigFromAttrs(MatmulConfig &config, ArrayRef attrs) { - bool hasPredefinedConfig = false; - for (auto attr : attrs) { + size_t cfgItemCnt = 0; + for (auto &attr : attrs) { if (attr.getName() == "KBlock") { config.KBlock = cast(attr.getValue()).getInt(); - hasPredefinedConfig = true; + cfgItemCnt++; } else if (attr.getName() == "KThreads") { config.KThreads = cast(attr.getValue()).getInt(); + cfgItemCnt++; } else if (attr.getName() == "NBlock") { config.NBlock = cast(attr.getValue()).getInt(); + cfgItemCnt++; } else if (attr.getName() == "NThreads") { config.NThreads = cast(attr.getValue()).getInt(); + cfgItemCnt++; } else if (attr.getName() == "MBlock") { config.MBlock = cast(attr.getValue()).getInt(); + cfgItemCnt++; } else if (attr.getName() == "MThreads") { config.MThreads = cast(attr.getValue()).getInt(); + cfgItemCnt++; } else if (attr.getName() == "innerMostMBlock") { config.innerMostMBlock = cast(attr.getValue()).getInt(); + cfgItemCnt++; } else if (attr.getName() == "innerMostNBlock") { config.innerMostNBlock = cast(attr.getValue()).getInt(); + cfgItemCnt++; } else if (attr.getName() == "innerMostKBlock") { config.innerMostKBlock = cast(attr.getValue()).getInt(); + cfgItemCnt++; } } - return hasPredefinedConfig; + return cfgItemCnt == 9; } -/* -thread utilization -computation intensity -cache locality -memory requirements -computation unit efficiency -padding/pack cost -workload balance -communication -previous matmul -*/ +// Analyze the workload and system description to generate the default config +// Factor to consider: +// thread utilization +// computation intensity +// cache locality +// memory requirements +// computation unit efficiency +// padding/pack cost +// workload balance +// communication +// previous matmul MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { - SystemDesc sysDesc; if (auto linalgOp = dyn_cast(root)) { - auto oprandDimType = *getOprandDimType(linalgOp); + SystemDesc sysDesc; + SmallVector> oprandDimType = + *getOprandDimType(linalgOp); // get the origin M,N,K size - auto MDimTypeIdx = extractDimTypeIdx(oprandDimType[0], DimType::M); - auto KDimTypeIdx = extractDimTypeIdx(oprandDimType[1], DimType::K); - auto NDimTypeIdx = extractDimTypeIdx(oprandDimType[1], DimType::N); + SmallVector MDimTypeIdx = + extractDimTypeIdx(oprandDimType[0], DimType::M); + SmallVector KDimTypeIdx = + extractDimTypeIdx(oprandDimType[1], DimType::K); + SmallVector NDimTypeIdx = + extractDimTypeIdx(oprandDimType[1], DimType::N); uint32_t M = 1U, N = 1U, K = 1U; for (auto [s, dimType] : llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(0)), @@ -366,16 +365,17 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { K *= s; } } + // innermost Block, if the layout is blockied layout, the innermost block // will derived from the layout directly - auto defaultBlock = 32; + uint32_t defaultBlock = 32; config.innerMostMBlock = M % defaultBlock == 0 ? defaultBlock : M; config.innerMostNBlock = N % defaultBlock == 0 ? defaultBlock : N; config.innerMostKBlock = K % defaultBlock == 0 ? defaultBlock : K; SmallVector givenInnermostBlock; if (MDimTypeIdx.size() > 1) { config.innerMostMBlock = 1; - for (auto i = 1UL; i < MDimTypeIdx.size(); i++) { + for (size_t i = 1UL; i < MDimTypeIdx.size(); i++) { config.innerMostMBlock *= linalgOp.getShape(linalgOp.getDpsInputOperand(0))[MDimTypeIdx[i]]; } @@ -385,7 +385,7 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { } if (NDimTypeIdx.size() > 1) { config.innerMostNBlock = 1; - for (auto i = 1UL; i < NDimTypeIdx.size(); i++) { + for (size_t i = 1UL; i < NDimTypeIdx.size(); i++) { config.innerMostNBlock *= linalgOp.getShape(linalgOp.getDpsInputOperand(1))[NDimTypeIdx[i]]; } @@ -395,7 +395,7 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { } if (KDimTypeIdx.size() > 1) { config.innerMostKBlock = 1; - for (auto i = 1UL; i < KDimTypeIdx.size(); i++) { + for (size_t i = 1UL; i < KDimTypeIdx.size(); i++) { config.innerMostKBlock *= linalgOp.getShape(linalgOp.getDpsInputOperand(1))[KDimTypeIdx[i]]; } @@ -404,39 +404,38 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { givenInnermostBlock.push_back(0); } - llvm::errs() << "M: " << M << ", N: " << N << ", K: " << K << "\n"; - - SmallVector> costModelList = { - {workloadBalancedCost, "workloadBalancedCost", 1}, - {hardwareEfficiencyCost, "hardwareEfficiencyCost", -1}, - {computationIntensityOnL2Cache, "computationIntensityOnL2Cache", -1}, - {memoryConsumptionOnThreadCost, "memoryConsumptionOnThreadCost", -1}}; + LLVM_DEBUG(llvm::dbgs() + << "M: " << M << ", N: " << N << ", K: " << K << "\n"); + // try to read the config from the attributes SmallVector attrs(linalgOp->getAttrs()); bool hasPredefinedConfig = readConfigFromAttrs(config, attrs); + // if there is a given config, skip the cost model if (!hasPredefinedConfig) { - llvm::errs() << "No predefined config\n"; - auto configCandidates = prepareConfigCandidates(root, sysDesc, {M, N, K}, - givenInnermostBlock); + LLVM_DEBUG(llvm::dbgs() << "No predefined config\n"); + // TODO: Could add a weight or priority for cost model + SmallVector> costModelList = + {{workloadBalancedCost, "workloadBalancedCost", 1}, + {vectorRegEfficiencyCost, "vectorRegEfficiencyCost ", -1}, + {computationIntensityOnL2Cache, "computationIntensityOnL2Cache", -1}, + {memoryConsumptionOnThreadCost, "memoryConsumptionOnThreadCost", + -1}}; + SmallVector shape = {M, N, K}; + std::vector configCandidates = + prepareConfigCandidates(root, sysDesc, shape, givenInnermostBlock); for (auto [fn, name, threshold] : costModelList) { - llvm::errs() << "\n" << name << "\n"; configCandidates = filterConfigByCostModel( - configCandidates, linalgOp, {M, N, K}, sysDesc, fn, 0.5, threshold); - llvm::errs() << "ConfigCandidates size: " << configCandidates.size() - << "\n"; + configCandidates, linalgOp, shape, sysDesc, fn, 0.5, threshold); } - if (configCandidates.size() > 0) { + if (!configCandidates.empty()) { config = configCandidates[0]; } } - llvm::errs() << "Final config\nNumThreads: " << sysDesc.getNumThreads() - << ", MatmulConfig: " << config << "\n"; - for (auto [fn, name, threshold] : costModelList) { - auto cost = fn(linalgOp, {M, N, K}, config, sysDesc); - llvm::errs() << name << ": " << cost << "\n"; - } + LLVM_DEBUG(llvm::dbgs() + << "Final config\nNumThreads: " << sysDesc.getNumThreads() + << ", MatmulConfig: " << config << "\n"); } } } // namespace gc diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 3673103a5..705e257d7 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -16,7 +16,7 @@ gc_add_mlir_library(GcPasses TilingUsingInterfaceX.cpp VerifyTargetDescription.cpp DeepTileContractionNamedOp.cpp - Tiling.cpp + TilingUtil.cpp SinkOpIntoInnerLoop.cpp MergeNestedForall.cpp diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index 084ce539d..d82cc554a 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -1,4 +1,4 @@ -//===-- DeepTileContractionNamedOp.cpp - DESC -------------------*- C++ -*-===// +//===-- DeepTileContractionNamedOp.cpp - tile named op deeply ---*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,31 +6,14 @@ // //===----------------------------------------------------------------------===// -#include "./Tiling.hpp" +#include "./TilingUtil.hpp" #include "gc/Analysis/MatmulConfigAnalysis.h" -#include "gc/Dialect/Arith/Utils/EasyBuild.h" #include "gc/Dialect/Linalgx/LinalgxOps.h" -#include "gc/IR/EasyBuild.h" -#include "gc/IR/EasyBuildSCF.h" -#include "mlir/AsmParser/AsmParser.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/Tensor/Transforms/Transforms.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Region.h" -#include "mlir/IR/Visitors.h" -#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/TilingInterface.h" -#include "mlir/Parser/Parser.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "gc/Transforms/Passes.h" @@ -46,46 +29,31 @@ namespace gc { namespace { +// Util function to tensor view a ranked tensor to another ranked tensor without +// change the data layout static Value tensorViewRankedTensor(RewriterBase &rewriter, RankedTensorType outTensorType, Value value, ArrayRef permutation = SmallVector{}) { - // TODO: add support for plain layout transpose Value result, currentValue = value; - auto loc = currentValue.getLoc(); - auto inTensorType = cast(currentValue.getType()); - auto inShape = inTensorType.getShape(); - auto outShape = outTensorType.getShape(); - auto tensorElementType = inTensorType.getElementType(); - + Location loc = currentValue.getLoc(); + RankedTensorType inTensorType = + cast(currentValue.getType()); + ArrayRef inShape = inTensorType.getShape(); + ArrayRef outShape = outTensorType.getShape(); + mlir::Type tensorElementType = inTensorType.getElementType(); + + // Check if the input and output tensor have the same shape if (inShape == outShape) { return currentValue; } - if (outTensorType.getNumDynamicDims() != inTensorType.getNumDynamicDims()) { - SmallVector alignOutShape(outShape.begin(), outShape.end()); - if (outShape.size() < inShape.size()) { - SmallVector oneVector(inShape.size() - outShape.size(), 1); - alignOutShape.insert(alignOutShape.begin(), oneVector.begin(), - oneVector.end()); - } else { - alignOutShape.erase(alignOutShape.begin(), - alignOutShape.begin() + - (outShape.size() - inShape.size())); - } - auto type = RankedTensorType::get(alignOutShape, tensorElementType); - currentValue = rewriter.create(loc, type, currentValue); - if (type == outTensorType) { - return currentValue; - } - } - if (outShape.size() < inShape.size()) { SmallVector reassocIndices; uint64_t outIdx = 0UL, inIdx = 0UL; while (inIdx < inShape.size() && outIdx < outShape.size()) { ReassociationIndices firstEntry; - auto remaining = outShape[outIdx++]; + int64_t remaining = outShape[outIdx++]; if (remaining == 1) { firstEntry.push_back(inIdx++); reassocIndices.push_back(firstEntry); @@ -104,7 +72,7 @@ tensorViewRankedTensor(RewriterBase &rewriter, RankedTensorType outTensorType, uint64_t outIdx = 0UL, inIdx = 0UL; while (outIdx < outShape.size() && inIdx < inShape.size()) { ReassociationIndices firstEntry; - auto remaining = inShape[inIdx++]; + int64_t remaining = inShape[inIdx++]; if (remaining == 1) { firstEntry.push_back(outIdx++); reassocIndices.push_back(firstEntry); @@ -122,20 +90,22 @@ tensorViewRankedTensor(RewriterBase &rewriter, RankedTensorType outTensorType, result = rewriter.create(loc, outTensorType, currentValue); } + // Transpose the tensor if permutation is not empty if (!permutation.empty()) { SmallVector transposeShape; - for (auto idx : permutation) { + for (int64_t idx : permutation) { transposeShape.push_back(outShape[idx]); } - auto initOp = rewriter.create(loc, transposeShape, - tensorElementType); - auto transposeOp = rewriter.create( + Operation *initOp = rewriter.create(loc, transposeShape, + tensorElementType); + Operation *transposeOp = rewriter.create( loc, result, initOp->getResult(0), permutation); result = transposeOp->getResult(0); } return result; } +// Check if the loop is dummy loop(has only one iteration) bool isDummyLoop(LoopLikeOpInterface loop) { std::optional tripCount = mlir::constantTripCount( *loop.getSingleLowerBound(), *loop.getSingleUpperBound(), @@ -146,6 +116,7 @@ bool isDummyLoop(LoopLikeOpInterface loop) { return false; } +// Build the linalg region for a linalg op static void buildLinalgRegion(Operation *op, bool createTemporaryOp = false) { SmallVector argTypes; SmallVector argLocs; @@ -153,15 +124,15 @@ static void buildLinalgRegion(Operation *op, bool createTemporaryOp = false) { argTypes.push_back(getElementTypeOrSelf(opOperand.getType())); argLocs.push_back(opOperand.getLoc()); } - auto initSize = op->getResults().size(); + size_t initSize = op->getResults().size(); ImplicitLocOpBuilder b(op->getLoc(), op->getContext()); Region ®ion = op->getRegion(0); Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); b.setInsertionPointToStart(body); if (createTemporaryOp) { - auto argNum = body->getNumArguments(); + unsigned argNum = body->getNumArguments(); SmallVector vals; - for (auto i = initSize; i > 0; i--) { + for (size_t i = initSize; i > 0; i--) { vals.push_back(body->getArgument(argNum - i)); } OpBuilder::InsertionGuard g(b); @@ -169,28 +140,30 @@ static void buildLinalgRegion(Operation *op, bool createTemporaryOp = false) { Location loc = b.getUnknownLoc(); b.create(loc, ValueRange(vals)); } else { - auto *dialect = static_cast(op->getDialect()); + linalg::LinalgDialect *dialect = + static_cast(op->getDialect()); linalg::LinalgDialect::RegionBuilderFunType fun = dialect->getRegionBuilder("linalg.matmul"); fun(b, *body, op->getAttrs()); } } -struct DtypeLegalizeResult { - Operation *linalgOp = nullptr; - Operation *castOp = nullptr; -}; - -bool needToLegalizeDtype(linalg::LinalgOp linalgOp) { - auto dataType = +// Check if the linalgOp need to be legalized to f32 accumulation type +static bool needToLegalizeDtype(linalg::LinalgOp linalgOp) { + mlir::Type dataType = dyn_cast(linalgOp.getDpsInputs()[0].getType()) .getElementType(); - auto resultType = + mlir::Type resultType = dyn_cast(linalgOp.getDpsInits()[0].getType()) .getElementType(); return (dataType.isBF16() || dataType.isF16()) && dataType == resultType; } +struct DtypeLegalizeResult { + Operation *linalgOp = nullptr; + Operation *castOp = nullptr; +}; + // Split a low precision matmul(bf16xbf16->bf16) to a combination // matmul(bf16xbf16->f32) + cast(f32->bf16) // if needFurtherFuse=true, a middle temporary linalgOp(bf16xbf16->(f32,bf16)) @@ -198,7 +171,8 @@ bool needToLegalizeDtype(linalg::LinalgOp linalgOp) { static FailureOr matmulDtypeLegalize(RewriterBase &rewriter, Operation *op, bool needCopyInit = true, bool needFurtherFuse = false) { - auto linalgOp = dyn_cast(op); + linalg::LinalgOp linalgOp = dyn_cast(op); + Location loc = linalgOp->getLoc(); DtypeLegalizeResult result; if (!linalgOp) return failure(); @@ -206,15 +180,14 @@ matmulDtypeLegalize(RewriterBase &rewriter, Operation *op, if (needToLegalizeDtype(linalgOp)) { rewriter.setInsertionPoint(linalgOp); IRMapping mapping; - auto initOp = linalgOp.getDpsInits()[0].getDefiningOp(); - auto initValue = initOp->getResult(0); - auto initType = cast(initValue.getType()); - auto tensorShape = initType.getShape(); + Operation *initOp = linalgOp.getDpsInits()[0].getDefiningOp(); + Value initValue = initOp->getResult(0); + ShapedType initType = cast(initValue.getType()); + ArrayRef tensorShape = initType.getShape(); SmallVector mixedShape; - for (auto i = 0UL; i < tensorShape.size(); i++) { + for (size_t i = 0UL; i < tensorShape.size(); i++) { if (initType.isDynamicDim(i)) { - Value val = - rewriter.create(linalgOp.getLoc(), initValue, i); + Value val = rewriter.create(loc, initValue, i); mixedShape.push_back(val); } else { mixedShape.push_back( @@ -224,21 +197,21 @@ matmulDtypeLegalize(RewriterBase &rewriter, Operation *op, Operation *currentOp; currentOp = rewriter.create( - linalgOp.getLoc(), mixedShape, Float32Type::get(op->getContext())); + loc, mixedShape, Float32Type::get(op->getContext())); if (needCopyInit) { - currentOp = rewriter.create( - linalgOp.getLoc(), initOp->getResult(0), currentOp->getResult(0)); + currentOp = rewriter.create(loc, initOp->getResult(0), + currentOp->getResult(0)); } SmallVector newOperands = linalgOp->getOperands(); - auto oldInit = newOperands.back(); + Value oldInit = newOperands.back(); newOperands.back() = currentOp->getResult(0); - auto indexingMaps = linalgOp.getIndexingMapsArray(); + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); indexingMaps.push_back(indexingMaps.back()); SmallVector attrs(linalgOp->getAttrs()); SmallVector types = {currentOp->getResult(0).getType()}; if (needFurtherFuse) { - auto segmentSize = rewriter.getNamedAttr( + NamedAttribute segmentSize = rewriter.getNamedAttr( "operandSegmentSizes", rewriter.getDenseI32ArrayAttr({2, 2})); for (auto &attr : attrs) { if (attr.getName() == "indexing_maps") @@ -249,13 +222,12 @@ matmulDtypeLegalize(RewriterBase &rewriter, Operation *op, types.push_back(oldInit.getType()); newOperands.push_back(oldInit); } - OperationState state(linalgOp->getLoc(), linalgOp->getName(), newOperands, - types, attrs); + OperationState state(loc, linalgOp->getName(), newOperands, types, attrs); state.addRegion(); currentOp = rewriter.create(state); buildLinalgRegion(currentOp, needFurtherFuse); - auto castOp = rewriter.create( - linalgOp.getLoc(), currentOp->getResult(0), initOp->getResult(0)); + linalg::CopyOp castOp = rewriter.create( + loc, currentOp->getResult(0), initOp->getResult(0)); result.linalgOp = currentOp; result.castOp = castOp; } @@ -263,9 +235,10 @@ matmulDtypeLegalize(RewriterBase &rewriter, Operation *op, return result; } +// Find the parent fill op of a value and will penetrate pack/pad ops static Operation *findParentFillOp(Value val) { SmallVector skipOpList = {"tensor.pack", "tensor.pad"}; - auto currentOp = val.getDefiningOp(); + Operation *currentOp = val.getDefiningOp(); while (currentOp && llvm::find(skipOpList, currentOp->getName().getStringRef()) != skipOpList.end() && @@ -279,22 +252,7 @@ static Operation *findParentFillOp(Value val) { return nullptr; } -[[maybe_unused]] static LogicalResult -indexRolling(RewriterBase &b, Block *insertBlock, Location loc, Value v, - Value rollingIdx, Value maximumRange, Value step) { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(insertBlock); - mlir::easybuild::EasyBuilder eb{b, loc}; - auto vWraped = eb.wrap(v); - auto rollingIdxWraped = eb.wrap(rollingIdx); - auto stepWraped = eb.wrap(step); - auto maximumRangeWraped = eb.wrap(step); - auto newV = (vWraped + rollingIdxWraped) * stepWraped % - (maximumRangeWraped / stepWraped * stepWraped); - v.replaceAllUsesWith(newV); - return failure(); -} - +// Get the parallel dims of a matmul op static void getMatmulParallelDims(linalg::LinalgOp linalgOp, unsigned operandIdx, SmallVectorImpl &dims) { @@ -304,8 +262,8 @@ static void getMatmulParallelDims(linalg::LinalgOp linalgOp, linalgOp.getIteratorTypesArray(); ArrayRef results = map.getResults(); - for (auto dim : results) { - auto dimExpr = dyn_cast(dim); + for (AffineExpr dim : results) { + AffineDimExpr dimExpr = dyn_cast(dim); if (dimExpr && iteratorTypes[dimExpr.getPosition()] == mlir::utils::IteratorType::parallel) { dims.push_back(dimExpr.getPosition()); @@ -313,15 +271,8 @@ static void getMatmulParallelDims(linalg::LinalgOp linalgOp, } } -static unsigned getOprandDim(linalg::LinalgOp &linalgOp, unsigned iteratorPos, - unsigned operandIdx) { - Value Operand; - unsigned dimPos; - [[maybe_unused]] auto result = - linalgOp.mapIterationSpaceDimToOperandDim(iteratorPos, Operand, dimPos); - return linalgOp.getShape(linalgOp.getDpsInputOperand(operandIdx))[dimPos]; -} - +// set the dynamic size to static size for ExtractSliceOp according to the tile +// config static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter, Operation *op, bool isExtract, SmallVector size, @@ -332,7 +283,7 @@ static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter, SmallVector mixedOffsets = extractSlice.getMixedOffsets(); SmallVector mixedSizes = extractSlice.getMixedSizes(); SmallVector mixedStrides = extractSlice.getMixedStrides(); - for (auto i = 0UL; i < mixedSizes.size(); i++) { + for (size_t i = 0UL; i < mixedSizes.size(); i++) { mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), size[i]); } if (shrinDimNum > 0) { @@ -350,6 +301,8 @@ static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter, } } +// set the dynamic size to static size for InsertSliceOp according to the tile +// config static void setStaticSizeForInsertSliceOp(RewriterBase &rewriter, Operation *op, Value source, SmallVector size) { @@ -359,7 +312,7 @@ static void setStaticSizeForInsertSliceOp(RewriterBase &rewriter, Operation *op, SmallVector mixedOffsets = insertSlice.getMixedOffsets(); SmallVector mixedSizes = insertSlice.getMixedSizes(); SmallVector mixedStrides = insertSlice.getMixedStrides(); - for (auto i = 0UL; i < mixedSizes.size(); i++) { + for (size_t i = 0UL; i < mixedSizes.size(); i++) { mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), size[i]); } rewriter.replaceOpWithNewOp( @@ -395,14 +348,14 @@ struct OuterLoopGenerationResult { SmallVector reductionLoops; }; +// Generate outer loop for a linalg op static FailureOr generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, const OuterLoopGenerationOption &option) { - // TODO: handle the return value OuterLoopGenerationResult result; - auto nestedTileSizes = option.nestedTileSizes; - auto loopType = option.loopType; - auto loopDim = option.loopDim; + SmallVector> nestedTileSizes = option.nestedTileSizes; + SmallVector loopType = option.loopType; + SmallVector> loopDim = option.loopDim; SmallVector iteratorTypes = linalgOp.getIteratorTypesArray(); @@ -419,10 +372,9 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, linalg::LinalgOp currentOp = linalgOp; bool hasFullResult = !option.isPartialResult; - for (auto loopTypeIter : llvm::enumerate(loopType)) { - auto [i, loopType] = loopTypeIter; - auto currentDim = loopDim[i]; - auto currentTileSize = nestedTileSizes[i]; + for (auto [i, loopType] : llvm::enumerate(loopType)) { + ArrayRef currentDim = loopDim[i]; + ArrayRef currentTileSize = nestedTileSizes[i]; if (loopType == OuterLoopGenerationOption::LoopType::ForOp) { for (auto [d, tile] : llvm::zip(currentDim, currentTileSize)) { scf::SCFTilingOptions tileOption; @@ -436,14 +388,15 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, if (iteratorTypes[d] == mlir::utils::IteratorType::reduction && tile != 0 && hasFullResult) { for (const auto &fn : option.innermostFullResultCallBacks) { - auto result = fn(b, currentOp.getLoc(), currentOp); + FailureOr result = + fn(b, currentOp->getLoc(), currentOp); if (succeeded(result)) { currentOp = *result; } } hasFullResult = false; } - auto tilingResult = scf::tileUsingSCF( + FailureOr tilingResult = scf::tileUsingSCF( b, cast(currentOp.getOperation()), tileOption); if (failed(tilingResult)) return failure(); @@ -463,10 +416,15 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, SmallVector threads( currentOp.getNumLoops(), getAsIndexOpFoldResult(b.getContext(), 0)); SmallVector reductionDims; + SmallVector loopRanges = + cast(currentOp.getOperation()).getIterationDomain(b); currentOp.getReductionDims(reductionDims); bool tileOnReduction = false; for (auto [d, tile] : llvm::zip(currentDim, currentTileSize)) { - if (llvm::find(reductionDims, d) != reductionDims.end()) { + if (llvm::find(reductionDims, d) != reductionDims.end() && tile != 0 && + (!getConstantIntValue(loopRanges[d].size) || + tile != static_cast( + *getConstantIntValue(loopRanges[d].size)))) { tileOnReduction = true; } if (llvm::find(reductionDims, d) != reductionDims.end() && @@ -476,26 +434,24 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, } else tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), tile); } - SmallVector loopRanges = - cast(currentOp.getOperation()).getIterationDomain(b); + OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(currentOp); if (tileOnReduction) { - auto partialInterface = - dyn_cast(currentOp.getOperation()); for (auto [idx, tile] : llvm::enumerate(tileSizes)) { if (isConstantIntValue(tile, 0) && - llvm::find(reductionDims, d) != reductionDims.end()) { + llvm::find(reductionDims, idx) != reductionDims.end()) { tileSizes[idx] = loopRanges[idx].size; } } SmallVector newParallelDims; - for (auto i = 0UL; i < reductionDims.size(); i++) { + for (size_t i = 0UL; i < reductionDims.size(); i++) { newParallelDims.push_back(getAsIndexOpFoldResult(b.getContext(), i)); } - auto tilingResult = linalgX::tileReductionUsingForall( - b, cast(currentOp.getOperation()), {}, - tileSizes, newParallelDims, std::nullopt); + FailureOr tilingResult = + linalgX::tileReductionUsingForall( + b, cast(currentOp.getOperation()), + {}, tileSizes, newParallelDims, std::nullopt); if (failed(tilingResult) && tilingResult->parallelTiledOps.size() == 1UL) return failure(); @@ -503,16 +459,19 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, dyn_cast(tilingResult->parallelTiledOps.back()); if (!tilingResult->mergeOps.empty()) { for (const auto &fn : option.finalReduceCallBacks) { - auto result = fn(b, currentOp.getLoc(), *tilingResult); + FailureOr result = + fn(b, currentOp->getLoc(), *tilingResult); if (succeeded(result)) { currentOp = *result; } } } } else { - auto tilingInterface = cast(currentOp.getOperation()); - auto tilingResult = linalg::tileToForallOpUsingTileSizes( - b, tilingInterface, tileSizes, std::nullopt); + TilingInterface tilingInterface = + cast(currentOp.getOperation()); + FailureOr tilingResult = + linalg::tileToForallOpUsingTileSizes(b, tilingInterface, tileSizes, + std::nullopt); if (failed(tilingResult)) return failure(); b.replaceOp(currentOp, tilingResult->tileOp); @@ -524,6 +483,23 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, return result; } +// Turn a OpFoldResult into a Value +static Value turnOpFoldResultIntoValue(RewriterBase &rewriter, Location loc, + OpFoldResult result) { + if (auto value = dyn_cast(result)) + return value; + if (auto attr = dyn_cast(result)) { + if (auto val = dyn_cast(attr)) { + if (val.getType().isIndex()) + return rewriter.create(loc, val.getInt()); + else + return rewriter.create(loc, val.getInt(), + val.getType()); + } + } + return Value(); +} + /* matmul(A, B) -> C ----------------> @@ -559,50 +535,56 @@ A=ASlice3, B=BSlice3, C=CSlice4, onlyUpdate=(ok!=0)); } C = final_reduce(CSlice) */ -struct deepTileMatmul : public OpInterfaceRewritePattern { +struct DeepTileMatmul : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - FailureOr + static FailureOr outerLoopGeneration(RewriterBase &rewriter, linalg::LinalgOp linalgOp, - gc::MatmulConfig cfg, bool hasFillOp) const { + gc::MatmulConfig cfg, bool hasFillOp) { SmallVector KDimPos, MDimPos, NDimPos; linalgOp.getReductionDims(KDimPos); getMatmulParallelDims(linalgOp, 0, MDimPos); getMatmulParallelDims(linalgOp, 1, NDimPos); - OuterLoopGenerationOption option; - auto iteratorTypes = linalgOp.getIteratorTypesArray(); - auto KFirstDim = getOprandDim(linalgOp, KDimPos[0], 1); - auto MFirstDim = getOprandDim(linalgOp, MDimPos[0], 0); - auto NFirstDim = getOprandDim(linalgOp, NDimPos[0], 1); - auto KParallelBlockSize = + + SmallVector iteratorTypes = + linalgOp.getIteratorTypesArray(); + SmallVector loopRange = + cast(linalgOp.getOperation()) + .getIterationDomain(rewriter); + size_t KFirstDim = *getConstantIntValue(loopRange[KDimPos[0]].size); + size_t MFirstDim = *getConstantIntValue(loopRange[MDimPos[0]].size); + size_t NFirstDim = *getConstantIntValue(loopRange[NDimPos[0]].size); + + size_t KParallelBlockSize = KDimPos.size() > 1 ? llvm::divideCeil(KFirstDim, cfg.KThreads) : llvm::divideCeil(llvm::divideCeil(KFirstDim, cfg.KBlock), cfg.KThreads) * cfg.KBlock; - auto MParallelBlockSize = + size_t MParallelBlockSize = MDimPos.size() > 1 ? llvm::divideCeil(MFirstDim, cfg.MThreads) : llvm::divideCeil(llvm::divideCeil(MFirstDim, cfg.MBlock), cfg.MThreads) * cfg.MBlock; - auto NParallelBlockSize = + size_t NParallelBlockSize = NDimPos.size() > 1 ? llvm::divideCeil(NFirstDim, cfg.NThreads) : llvm::divideCeil(llvm::divideCeil(NFirstDim, cfg.NBlock), cfg.NThreads) * cfg.NBlock; - auto KOuterBlockSize = KDimPos.size() > 1 - ? (cfg.KBlock - 1) / cfg.innerMostKBlock + 1 - : cfg.KBlock; - auto MOuterBlockSize = MDimPos.size() > 1 - ? (cfg.MBlock - 1) / cfg.innerMostMBlock + 1 - : cfg.MBlock; - auto NOuterBlockSize = NDimPos.size() > 1 - ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1 - : cfg.NBlock; - // Outer + size_t KOuterBlockSize = KDimPos.size() > 1 + ? (cfg.KBlock - 1) / cfg.innerMostKBlock + 1 + : cfg.KBlock; + size_t MOuterBlockSize = MDimPos.size() > 1 + ? (cfg.MBlock - 1) / cfg.innerMostMBlock + 1 + : cfg.MBlock; + size_t NOuterBlockSize = NDimPos.size() > 1 + ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1 + : cfg.NBlock; + + // Outer loop tile size for (auto [tile, dim] : llvm::zip(SmallVector{KParallelBlockSize, MParallelBlockSize, NParallelBlockSize}, @@ -612,7 +594,8 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { OuterLoopGenerationOption::LoopType::ForallOp); option.loopDim.emplace_back(SmallVector{dim}); } - // Middle + + // Middle loop tile size for (auto [tile, dim] : llvm::zip(SmallVector{MOuterBlockSize, NOuterBlockSize, KOuterBlockSize}, @@ -621,12 +604,12 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); option.loopDim.emplace_back(SmallVector{dim}); } - // Inner if (KDimPos.size() == 1) { option.nestedTileSizes.emplace_back(SmallVector{cfg.KBlock}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); option.loopDim.emplace_back(SmallVector{KDimPos.back()}); } + // Inner loop tile size if (MDimPos.size() == 1) { option.nestedTileSizes.emplace_back( SmallVector{cfg.innerMostMBlock}); @@ -639,7 +622,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); option.loopDim.emplace_back(SmallVector{NDimPos.back()}); } - for (auto dim = 0UL; dim < linalgOp.getNumLoops(); dim++) { + for (size_t dim = 0UL; dim < linalgOp.getNumLoops(); dim++) { if (dim != MDimPos.back() && dim != NDimPos.back() && iteratorTypes[dim] != mlir::utils::IteratorType::reduction) { option.nestedTileSizes.emplace_back(SmallVector{1}); @@ -649,13 +632,16 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { } } + // cast the low precision matmul to f32 when partial accumulation(result not + // full) is needed auto lowPrecisionCast = [&](RewriterBase &rewriter, Location loc, linalg::LinalgOp linalgop) -> FailureOr { - auto legalizedResult = matmulDtypeLegalize( + FailureOr legalizedResult = matmulDtypeLegalize( rewriter, linalgop.getOperation(), !hasFillOp, true); - if (legalizedResult->castOp && legalizedResult->linalgOp) { - auto linalgOp = legalizedResult->linalgOp; + if (succeeded(legalizedResult) && legalizedResult->castOp && + legalizedResult->linalgOp) { + Operation *linalgOp = legalizedResult->linalgOp; rewriter.replaceOp(linalgop, linalgOp->getResult(linalgOp->getNumResults() - 1)); return dyn_cast(linalgOp); @@ -669,7 +655,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { [&](RewriterBase &rewriter, Location loc, const linalg::ForallReductionTilingResult &result) -> FailureOr { - auto initValue = result.initialValues; + ArrayRef initValue = result.initialValues; if (initValue.size() == 1 && isa(initValue[0].getDefiningOp())) { rewriter.replaceOp(initValue[0].getDefiningOp(), @@ -681,6 +667,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { }; option.finalReduceCallBacks.push_back(removeReduncantFill); } + return generateOuterLoop(rewriter, linalgOp, option); } @@ -694,20 +681,28 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { linalg::LinalgOp originOp, linalg::LinalgOp currentOp, innerBodyGenerationOption &option) const { - mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc()}; - auto operandDimTypes = getOprandDimType(originOp); - auto cfg = MatmulConfigAnalysis(originOp.getOperation()).getConfig(); - auto AShape = originOp.getShape(originOp.getDpsInputOperand(0)); - auto BShape = originOp.getShape(originOp.getDpsInputOperand(1)); - auto CShape = originOp.getShape(originOp.getDpsInitOperand(0)); - - auto MDimNum = std::count_if((*operandDimTypes)[0].begin(), - (*operandDimTypes)[0].end(), - [](DimType d) { return d == DimType::M; }); - auto NDimNum = std::count_if((*operandDimTypes)[1].begin(), - (*operandDimTypes)[1].end(), - [](DimType d) { return d == DimType::N; }); + Location loc = currentOp->getLoc(); + FailureOr>> operandDimTypes = + getOprandDimType(originOp); + MatmulConfig cfg = + MatmulConfigAnalysis(originOp.getOperation()).getConfig(); + ArrayRef AShape = + originOp.getShape(originOp.getDpsInputOperand(0)); + ArrayRef BShape = + originOp.getShape(originOp.getDpsInputOperand(1)); + ArrayRef CShape = originOp.getShape(originOp.getDpsInitOperand(0)); + + if (failed(operandDimTypes)) + return failure(); + + size_t MDimNum = std::count_if((*operandDimTypes)[0].begin(), + (*operandDimTypes)[0].end(), + [](DimType d) { return d == DimType::M; }); + size_t NDimNum = std::count_if((*operandDimTypes)[1].begin(), + (*operandDimTypes)[1].end(), + [](DimType d) { return d == DimType::N; }); // TODO: support plain in/block out format + // Calculate the innermost block size according to the config SmallVector AInnermostDims, BInnermostDims, CInnermostDims; bool firstM = true, firstK = true, firstN = true; if (MDimNum > 1) { @@ -769,15 +764,16 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { cfg.innerMostNBlock}; } + // Get the data/wei/dst data type OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(currentOp); - auto dataType = + mlir::Type dataType = dyn_cast(currentOp.getDpsInputs()[0].getType()) .getElementType(); - auto weightType = + mlir::Type weightType = dyn_cast(currentOp.getDpsInputs()[1].getType()) .getElementType(); - auto resultType = + mlir::Type resultType = dyn_cast(currentOp.getDpsInits()[0].getType()) .getElementType(); @@ -789,7 +785,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { setStaticSizeForExtractSliceOp(rewriter, currentOp.getDpsInputs()[0].getDefiningOp(), true, AInnermostDims, MDimNum > 1); - for (auto init : currentOp.getDpsInits()) { + for (Value init : currentOp.getDpsInits()) { setStaticSizeForExtractSliceOp(rewriter, init.getDefiningOp(), true, CInnermostDims, MDimNum > 1 ? 2 : 0); } @@ -829,12 +825,12 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { .getShape() .size() == 3) { matmul = rewriter.create( - resultOprand.getLoc(), resultOprand.getType(), - ValueRange{dataOprand, weightOprand}, resultOprand); + loc, resultOprand.getType(), ValueRange{dataOprand, weightOprand}, + resultOprand); } else { matmul = rewriter.create( - resultOprand.getLoc(), resultOprand.getType(), - ValueRange{dataOprand, weightOprand}, resultOprand); + loc, resultOprand.getType(), ValueRange{dataOprand, weightOprand}, + resultOprand); } Value result = matmul.getOperation()->getResult(0); @@ -844,25 +840,41 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { } if (option.needLowPrecisionCast) { + // fuse the low precision cast to the innermost body rewriter.setInsertionPointAfter(currentOp); - auto cond = eb(true); - for (auto loop : option.KLoopHandles) { - auto induceVar = - eb.wrap(*loop.getSingleInductionVar()); - auto upBound = - eb.wrap(*loop.getSingleUpperBound()); - auto step = eb.wrap(*loop.getSingleStep()); - auto currentCond = (induceVar + step) >= upBound; - cond = cond & currentCond; + Value cond; + for (LoopLikeOpInterface loop : option.KLoopHandles) { + Value induceVar = turnOpFoldResultIntoValue( + rewriter, loc, *loop.getSingleInductionVar()); + Value upBound = turnOpFoldResultIntoValue(rewriter, loc, + *loop.getSingleUpperBound()); + Value step = + turnOpFoldResultIntoValue(rewriter, loc, *loop.getSingleStep()); + Value currentCond = + rewriter.create(loc, induceVar, step); + currentCond = rewriter.create( + loc, arith::CmpIPredicate::sge, currentCond, upBound); + cond = cond ? rewriter.create(loc, cond, currentCond) + : currentCond; + } + scf::IfOp ifOp = rewriter.create( + loc, TypeRange{currentOp.getDpsInits().back().getType()}, + cond ? cond : rewriter.create(loc, true, 1), + true); + { + OpBuilder::InsertionGuard guard(rewriter); + Region ®ion = ifOp.getThenRegion(); + rewriter.setInsertionPointToStart(®ion.back()); + linalg::CopyOp castOp = rewriter.create( + loc, matmul->getResult(0), currentOp.getDpsInits().back()); + rewriter.create(loc, castOp->getResult(0)); } - EB_scf_if(cond, {currentOp.getDpsInits().back().getType()}) { - auto castOp = rewriter.create( - matmul.getLoc(), matmul->getResult(0), - currentOp.getDpsInits().back()); - eb.yield(castOp->getResult(0)); + { + OpBuilder::InsertionGuard guard(rewriter); + Region ®ion = ifOp.getElseRegion(); + rewriter.setInsertionPointToStart(®ion.back()); + rewriter.create(loc, currentOp.getDpsInits().back()); } - EB_else { eb.yield(currentOp.getDpsInits().back()); } - auto ifOp = eb.getLastOperaion(); // set static size for the insertSliceOp of copyOp for (Operation *user : currentOp->getResult(1).getUsers()) { setStaticSizeForInsertSliceOp(rewriter, user, ifOp->getResult(0), @@ -873,9 +885,10 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { rewriter.replaceOp(currentOp, matmul->getResult(0)); } currentOp = matmul; + // Fuse the fill op to the innermost body if (auto fillOp = llvm::dyn_cast_or_null(option.fillOp)) { - auto fillValue = fillOp.getDpsInputs()[0]; + Value fillValue = fillOp.getDpsInputs()[0]; if (cfg.KThreads <= 1) { // if use k slicing, the fill op is still need to be kept for the reduce // init @@ -887,26 +900,38 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { } rewriter.setInsertionPointAfter(currentOp); - auto cond = eb(true); - for (auto loop : option.KLoopHandles) { - auto induceVar = eb.wrap( - loop.getLoopRegions().front()->front().getArgument(0)); - auto currentCond = induceVar == eb.toIndex(0); - cond = cond & currentCond; + Value cond; + arith::ConstantIndexOp zeroConst = + rewriter.create(loc, 0); + for (LoopLikeOpInterface loop : option.KLoopHandles) { + Value induceVar = loop.getLoopRegions().front()->front().getArgument(0); + Value currentCond = rewriter.create( + loc, arith::CmpIPredicate::eq, induceVar, zeroConst); + cond = cond ? rewriter.create(loc, cond, currentCond) + : currentCond; } - EB_scf_if(cond, {currentOp.getDpsInits()[0].getType()}) { - auto fillOp = rewriter.create( - currentOp->getLoc(), fillValue, currentOp.getDpsInits()[0]); + scf::IfOp ifOp = rewriter.create( + loc, TypeRange{currentOp.getDpsInits()[0].getType()}, + cond ? cond : rewriter.create(loc, true, 1), + true); + { + OpBuilder::InsertionGuard guard(rewriter); + Region ®ion = ifOp.getThenRegion(); + rewriter.setInsertionPointToStart(®ion.back()); + linalg::FillOp fillOp = rewriter.create( + loc, fillValue, currentOp.getDpsInits()[0]); IRMapping mapping; mapping.map(currentOp.getDpsInits()[0], fillOp.getResult(0)); - auto res = rewriter.clone(*(currentOp.getOperation()), mapping); - eb.yield(res->getResult(0)); + Operation *res = rewriter.clone(*(currentOp.getOperation()), mapping); + rewriter.create(loc, res->getResult(0)); } - EB_else { - auto res = rewriter.clone(*(currentOp.getOperation())); - eb.yield(res->getResult(0)); + { + OpBuilder::InsertionGuard guard(rewriter); + Region ®ion = ifOp.getElseRegion(); + rewriter.setInsertionPointToStart(®ion.back()); + Operation *res = rewriter.clone(*(currentOp.getOperation())); + rewriter.create(loc, res->getResult(0)); } - auto ifOp = eb.getLastOperaion(); rewriter.replaceOp(currentOp, ifOp); } return success(); @@ -939,20 +964,24 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { // Step 1. Split matmul(bf16xbf16->bf16) to matmul(bf16xbf16->f32) + // cast(f32->bf16) if K slicing is needed - auto cfg = MatmulConfigAnalysis(originOp.getOperation()).getConfig(); + MatmulConfig cfg = + MatmulConfigAnalysis(originOp.getOperation()).getConfig(); linalgOp = *linalg::generalizeNamedOp(rewriter, linalgOp); bool needLowPrecisionCast = needToLegalizeDtype(linalgOp); if (cfg.KThreads > 1) { - auto result = matmulDtypeLegalize(rewriter, linalgOp.getOperation()); - if (result->castOp && result->linalgOp) { + FailureOr result = + matmulDtypeLegalize(rewriter, linalgOp.getOperation()); + if (succeeded(result) && result->castOp && result->linalgOp) { rewriter.replaceOp(linalgOp, result->castOp); linalgOp = dyn_cast(result->linalgOp); + } else { + return failure(); } needLowPrecisionCast = false; } // Step 2. Outer loop generation - auto outerLoopResult = outerLoopGeneration( + FailureOr outerLoopResult = outerLoopGeneration( rewriter, linalgOp, cfg, fillOp && isa(fillOp)); if (failed(outerLoopResult)) { return failure(); @@ -960,8 +989,8 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { linalgOp = dyn_cast(outerLoopResult->tiledOps.back()); // Step 3 generate inner loop body, convert the linalg.generic to brgemm - auto option = innerBodyGenerationOption{fillOp, needLowPrecisionCast, - outerLoopResult->reductionLoops}; + innerBodyGenerationOption option = innerBodyGenerationOption{ + fillOp, needLowPrecisionCast, outerLoopResult->reductionLoops}; if (failed(innerBodyGeneration(rewriter, originOp, linalgOp, option))) { return failure(); @@ -975,11 +1004,11 @@ struct DeepTileContractionNamedOp : public impl::DeepTileContractionNamedOpBase { public: void runOnOperation() final { - auto &ctx = getContext(); + MLIRContext &ctx = getContext(); IRRewriter rewriter(&ctx); RewritePatternSet patterns(&ctx); - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); linalg::populateLinalgTilingCanonicalizationPatterns(patterns); linalg::ControlDropUnitDims options; options.rankReductionStrategy = @@ -987,7 +1016,7 @@ struct DeepTileContractionNamedOp linalg::populateFoldUnitExtentDimsPatterns(patterns, options); tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); - for (auto *dialect : ctx.getLoadedDialects()) + for (Dialect *dialect : ctx.getLoadedDialects()) dialect->getCanonicalizationPatterns(patterns); for (RegisteredOperationName op : ctx.getRegisteredOperations()) op.getCanonicalizationPatterns(patterns, &ctx); diff --git a/lib/gc/Transforms/MergeNestedForall.cpp b/lib/gc/Transforms/MergeNestedForall.cpp index cd0442c4a..516981c9c 100644 --- a/lib/gc/Transforms/MergeNestedForall.cpp +++ b/lib/gc/Transforms/MergeNestedForall.cpp @@ -1,4 +1,4 @@ -//===-- MergeNestedForall.cpp - DESC -------------------*- C++ -*-===// +//===-- MergeNestedForall.cpp - Merge nested scf.forall op ------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,14 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Transforms/Passes.h" - #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/Dominance.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/LoopLikeInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Transforms/ControlFlowSinkUtils.h" +#include "mlir/Transforms/Passes.h" namespace mlir { namespace gc { @@ -31,7 +25,7 @@ struct MergeNestedForallLoops : public OpRewritePattern { if (!llvm::hasSingleElement(outerBody.without_terminator())) return failure(); - auto innerOp = dyn_cast(outerBody.front()); + scf::ForallOp innerOp = dyn_cast(outerBody.front()); if (!innerOp) return failure(); @@ -97,4 +91,4 @@ struct MergeNestedForall } // namespace } // namespace gc -} // namespace mlir \ No newline at end of file +} // namespace mlir diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 510a186b5..4e4a0dd25 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -34,15 +34,9 @@ namespace mlir::gc { -void populateCleanUpPasses(mlir::PassManager &pm) { +void populateCleanUpPasses(mlir::OpPassManager &pm) { pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - pm.addPass(createLoopInvariantCodeMotionPass()); - pm.addPass(createControlFlowSinkPass()); - pm.addPass(createCSEPass()); - pm.addPass(createSCCPPass()); - pm.addPass(createMem2Reg()); - pm.addPass(createTopologicalSortPass()); } // linalg + linalgX + tensor @@ -67,6 +61,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // REMOVE this pass after the above passes are added. Currently we add this // pass to make the pipeline work properly pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); + pm.addPass(createLoopInvariantCodeMotionPass()); + pm.addPass(createControlFlowSinkPass()); populateCleanUpPasses(pm); } @@ -125,11 +121,13 @@ void populateCPURuntimePasses(mlir::OpPassManager &pm) { // remove this pass after we add FlattenNestedParallel pm.addPass(createSinkOpIntoInnerLoop()); pm.addPass(createMergeNestedForall()); - populateCleanUpPasses(pm); + pm.addPass(createLoopInvariantCodeMotionPass()); + pm.addPass(createControlFlowSinkPass()); pm.addPass(createForallToParallelLoopPass()); pm.addPass(createParallelLoopFusionPass()); pm.addPass(createLoopInvariantCodeMotionPass()); pm.addPass(createConvertSCFToOpenMPPass()); + populateCleanUpPasses(pm); } void populateLoweringToLLVMPasses(mlir::OpPassManager &pm) { diff --git a/lib/gc/Transforms/SinkOpIntoInnerLoop.cpp b/lib/gc/Transforms/SinkOpIntoInnerLoop.cpp index 426b1e258..df04b6590 100644 --- a/lib/gc/Transforms/SinkOpIntoInnerLoop.cpp +++ b/lib/gc/Transforms/SinkOpIntoInnerLoop.cpp @@ -1,4 +1,4 @@ -//===-- SinkOpIntoInnerLoop.cpp - DESC -------------------*- C++ -*-===// +//===-- SinkOpIntoInnerLoop.cpp - sink op to inner if possible --*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/lib/gc/Transforms/Tiling.hpp b/lib/gc/Transforms/Tiling.hpp deleted file mode 100644 index 7c4188096..000000000 --- a/lib/gc/Transforms/Tiling.hpp +++ /dev/null @@ -1,55 +0,0 @@ -//===- Tilig.hpp - Tiling ops using TilingInterface --*- C++ -*-===// -// -// This file is only temporarily used to extend upstream or upcoming utility in -// TilingInterface, which finally aims for upstream. -// -//===----------------------------------------------------------------------===// - -#ifndef TEMPORARY_TILEUSINGINTERFACE_X_H -#define TEMPORARY_TILEUSINGINTERFACE_X_H - -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Affine/LoopUtils.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Interfaces/TilingInterface.h" -#include "mlir/Transforms/FoldUtils.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/CommandLine.h" -#include -#include -namespace mlir { -namespace linalgX { - -FailureOr tileReductionUsingForall( - RewriterBase &b, PartialReductionOpInterface op, - ArrayRef threadNums, ArrayRef tileSizes, - ArrayRef newParallelDims, std::optional mapping); - -FailureOr tileAllUsingForall( - RewriterBase &b, PartialReductionOpInterface op, - ArrayRef numThreads, ArrayRef tileSizes, - ArrayRef newParallelDims, std::optional mapping); - -} // namespace linalgX -} // namespace mlir - -#endif \ No newline at end of file diff --git a/lib/gc/Transforms/Tiling.cpp b/lib/gc/Transforms/TilingUtil.cpp similarity index 96% rename from lib/gc/Transforms/Tiling.cpp rename to lib/gc/Transforms/TilingUtil.cpp index cd01067c7..25d94938c 100644 --- a/lib/gc/Transforms/Tiling.cpp +++ b/lib/gc/Transforms/TilingUtil.cpp @@ -1,40 +1,18 @@ -//===- Tiling.cpp - Implementation of linalg Tiling -----------------------===// +//===-- TilingUtil.cpp - Implementation of linalg Tiling --------*- C++ -*-===// // -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// This file implements the linalg dialect Tiling pass. -// -//===----------------------------------------------------------------------===// #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/TilingInterface.h" -#include "mlir/Transforms/FoldUtils.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/CommandLine.h" #include #include diff --git a/lib/gc/Transforms/TilingUtil.hpp b/lib/gc/Transforms/TilingUtil.hpp new file mode 100644 index 000000000..e05680fa1 --- /dev/null +++ b/lib/gc/Transforms/TilingUtil.hpp @@ -0,0 +1,26 @@ +//===-- TilingUtil.hpp - Tile op using tiling interface ---------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEMPORARY_TILEUSINGINTERFACE_X_H +#define TEMPORARY_TILEUSINGINTERFACE_X_H + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Interfaces/TilingInterface.h" +#include +namespace mlir { +namespace linalgX { + +FailureOr tileReductionUsingForall( + RewriterBase &b, PartialReductionOpInterface op, + ArrayRef threadNums, ArrayRef tileSizes, + ArrayRef newParallelDims, std::optional mapping); +} // namespace linalgX +} // namespace mlir + +#endif \ No newline at end of file diff --git a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir index bfb9a52e9..56e52f852 100644 --- a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir +++ b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir @@ -1,15 +1,4 @@ -// RUN: gc-opt --split-input-file --deep-tile-contraction-named-op %s - -// ----- - -// /// CHECK-LABEL: @matmul_4Dx4D_f32 -// func.func @matmul_4Dx4D_f32(%arg0: tensor<128x128x32x32xf32>, %arg1 : tensor<128x128x32x32x1xf32>) -> tensor<128x128x32x32xf32> { -// %cst_0 = arith.constant 0.000000e+00 : f32 -// %0 = tensor.empty() : tensor<128x128x32x32xf32> -// %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> -// %2 = linalgx.mm4d_vnni ins(%arg0, %arg1 : tensor<128x128x32x32xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> -// return %2 : tensor<128x128x32x32xf32> -// } +// RUN: gc-opt --split-input-file --deep-tile-contraction-named-op %s | FileCheck %s // ----- @@ -18,29 +7,67 @@ func.func @matmul_2Dx2D_f32(%arg0: tensor<4096x4096xf32>, %arg1: tensor<4096x409 %cst_0 = arith.constant 0.000000e+00 : f32 %0 = tensor.empty() : tensor<4096x4096xf32> %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> - %2 = linalg.matmul ins(%arg0, %arg1 : tensor<4096x4096xf32>, tensor<4096x4096xf32>) outs(%1 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> + // CHECK: scf.forall + // CHECK: tensor.extract_slice + // CHECK: scf.forall + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice + // CHECK: linalg.transpose + // CHECK: tensor.expand_shape + // CHECK: scf.if + // CHECK: linalg.fill + // CHECK: linalg.batch_reduce_matmul + // CHECK: else + // CHECK: linalg.batch_reduce_matmul + // CHECK: tensor.insert_slice + %2 = linalg.matmul {MThreads = 4 : i32, NThreads = 2 : i32, KThreads = 1 : i32, MBlock = 256 : i32, NBlock = 256 : i32, KBlock = 256 : i32,innermostMBlock = 32 : i32, innermostNBlock = 32 : i32, innermostKBlock = 32 : i32 } ins(%arg0, %arg1 : tensor<4096x4096xf32>, tensor<4096x4096xf32>) outs(%1 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> return %2 : tensor<4096x4096xf32> } // ----- -// /// CHECK-LABEL: @matmul_2Dx4D_f32 -// func.func @matmul_4Dx4D_f32(%arg0: tensor<4096x4096xf32>, %arg1: tensor<128x128x32x32x1xf32>) -> tensor<4096x4096xf32> { -// %cst_0 = arith.constant 0.000000e+00 : f32 -// %0 = tensor.empty() : tensor<4096x4096xf32> -// %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> -// %2 = linalgx.mm2d_vnni ins(%arg0, %arg1 : tensor<4096x4096xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> -// return %2 : tensor<4096x4096xf32> -// } - -// ----- - -// /// CHECK-LABEL: @matmul_4Dx4D_bf16 +/// CHECK-LABEL: @matmul_4Dx4D_bf16 func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<128x128x32x32xbf16> { %cst_0 = arith.constant 0.000000e+00 : bf16 + // CHECK: tensor.empty() : tensor<128x128x32x32xbf16> %0 = tensor.empty() : tensor<128x128x32x32xbf16> + // CHECK-NOT: linalg.fill %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> - %2 = linalgx.mm4d_vnni ins(%arg0, %arg1 : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> + // CHECK: scf.forall + // CHECK: tensor.extract_slice + // CHECK: scf.forall + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: tensor.empty() : tensor<8x8x32x32xf32> + // CHECK: scf.for + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice + // CHECK: scf.if + // CHECK: linalg.fill + // CHECK: linalgx.batch_reduce_matmul_vnni + // CHECK: else + // CHECK: linalgx.batch_reduce_matmul_vnni + // CHECK: scf.if + // CHECK: linalg.copy + // CHECK: else + %2 = linalgx.mm4d_vnni {MThreads = 16 : i32, NThreads = 2 : i32, KThreads = 1 : i32, MBlock = 256 : i32, NBlock = 256 : i32, KBlock = 256 : i32,innermostMBlock = 32 : i32, innermostNBlock = 32 : i32, innermostKBlock = 32 : i32 } ins(%arg0, %arg1 : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> return %2 : tensor<128x128x32x32xbf16> } @@ -51,7 +78,33 @@ func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x12 %cst_0 = arith.constant 0.000000e+00 : bf16 %0 = tensor.empty() : tensor<4096x4096xbf16> %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> - %2 = linalgx.mm2d_vnni ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> + // CHECK: scf.forall + // CHECK: tensor.extract_slice + // CHECK: scf.forall + // CHECK: tensor.extract_slice + // CHECK: scf.forall + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice + // CHECK: linalg.transpose + // CHECK: scf.if + // CHECK: linalg.fill + // CHECK: linalgx.batch_reduce_matmul_vnni + // CHECK: else + // CHECK: linalgx.batch_reduce_matmul_vnni + // CHECK: scf.forall.in_parallel + // CHECK: scf.forall.in_parallel + // CHECK: scf.forall.in_parallel + // CHECK: linalg.reduce + // CHECK: linalg.copy + %2 = linalgx.mm2d_vnni {MThreads = 32 : i32, NThreads = 2 : i32, KThreads = 2 : i32, MBlock = 256 : i32, NBlock = 256 : i32, KBlock = 256 : i32,innermostMBlock = 32 : i32, innermostNBlock = 32 : i32, innermostKBlock = 32 : i32 } ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> return %2 : tensor<4096x4096xbf16> } diff --git a/test/mlir/test/gc/Transforms/mergeNestedForall.mlir b/test/mlir/test/gc/Transforms/mergeNestedForall.mlir new file mode 100644 index 000000000..d878739c8 --- /dev/null +++ b/test/mlir/test/gc/Transforms/mergeNestedForall.mlir @@ -0,0 +1,93 @@ +// RUN: gc-opt --split-input-file --merge-nested-forall %s | FileCheck %s + +// ----- + +#map = affine_map<(d0) -> (d0 * 1024)> +#map1 = affine_map<(d0) -> (d0 * 2048)> +#map2 = affine_map<(d0)[s0, s1] -> (d0 * 2048 + s0 + s1)> +#map3 = affine_map<(d0)[s0, s1] -> (d0 * 1024 + s0 + s1)> +module { + func.func @matmul_2Dx2D_f32(%arg0: memref<4096x4096xf32>, %arg1: memref<4096x4096xf32>, %arg2: memref<4096x4096xf32>) { + // CHECK: scf.forall {{.*}} (4, 2) + scf.forall (%arg3) in (4) { + scf.forall (%arg4) in (2) { + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32xf32> + scf.for %arg5 = %c0 to %c1024 step %c256 { + %c2048 = arith.constant 2048 : index + scf.for %arg6 = %c0 to %c2048 step %c256 { + %c4096 = arith.constant 4096 : index + scf.for %arg7 = %c0 to %c4096 step %c256 { + %c32 = arith.constant 32 : index + scf.for %arg8 = %c0 to %c256 step %c32 { + scf.for %arg9 = %c0 to %c256 step %c32 { + %0 = affine.apply #map(%arg3) + %1 = affine.apply #map1(%arg4) + %subview = memref.subview %arg2[%0, 0] [1024, 4096] [1, 1] : memref<4096x4096xf32> to memref<1024x4096xf32, strided<[4096, 1], offset: ?>> + %subview_0 = memref.subview %subview[0, %1] [1024, 2048] [1, 1] : memref<1024x4096xf32, strided<[4096, 1], offset: ?>> to memref<1024x2048xf32, strided<[4096, 1], offset: ?>> + %subview_1 = memref.subview %subview_0[%arg5, 0] [256, 2048] [1, 1] : memref<1024x2048xf32, strided<[4096, 1], offset: ?>> to memref<256x2048xf32, strided<[4096, 1], offset: ?>> + %subview_2 = memref.subview %subview_1[0, %arg6] [256, 256] [1, 1] : memref<256x2048xf32, strided<[4096, 1], offset: ?>> to memref<256x256xf32, strided<[4096, 1], offset: ?>> + %subview_3 = memref.subview %subview_2[%arg8, 0] [32, 256] [1, 1] : memref<256x256xf32, strided<[4096, 1], offset: ?>> to memref<32x256xf32, strided<[4096, 1], offset: ?>> + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %2 = arith.cmpi eq, %arg7, %c0 : index + %3 = affine.apply #map2(%arg4)[%arg9, %arg6] + %subview_4 = memref.subview %arg1[%arg7, %3] [256, 32] [1, 1] : memref<4096x4096xf32> to memref<256x32xf32, strided<[4096, 1], offset: ?>> + %subview_5 = memref.subview %subview_3[0, %arg9] [32, 32] [1, 1] : memref<32x256xf32, strided<[4096, 1], offset: ?>> to memref<32x32xf32, strided<[4096, 1], offset: ?>> + scf.parallel (%arg10, %arg11, %arg12) = (%c0, %c0, %c0) to (%c8, %c32, %c32) step (%c1, %c1, %c1) { + %4 = affine.apply #map3(%arg3)[%arg8, %arg5] + %subview_6 = memref.subview %arg0[%4, %arg7] [32, 256] [1, 1] : memref<4096x4096xf32> to memref<32x256xf32, strided<[4096, 1], offset: ?>> + %expand_shape_7 = memref.expand_shape %subview_6 [[0], [1, 2]] output_shape [32, 8, 32] : memref<32x256xf32, strided<[4096, 1], offset: ?>> into memref<32x8x32xf32, strided<[4096, 32, 1], offset: ?>> + %5 = memref.load %expand_shape_7[%arg11, %arg10, %arg12] : memref<32x8x32xf32, strided<[4096, 32, 1], offset: ?>> + memref.store %5, %alloc[%arg10, %arg11, %arg12] : memref<8x32x32xf32> + scf.reduce + } + %expand_shape = memref.expand_shape %subview_4 [[0, 1], [2]] output_shape [8, 32, 32] : memref<256x32xf32, strided<[4096, 1], offset: ?>> into memref<8x32x32xf32, strided<[131072, 4096, 1], offset: ?>> + scf.if %2 { + scf.parallel (%arg10, %arg11) = (%c0, %c0) to (%c32, %c32) step (%c1, %c1) { + %cst = arith.constant 0.000000e+00 : f32 + memref.store %cst, %subview_5[%arg10, %arg11] : memref<32x32xf32, strided<[4096, 1], offset: ?>> + scf.reduce + } + scf.for %arg10 = %c0 to %c8 step %c1 { + scf.parallel (%arg11, %arg12) = (%c0, %c0) to (%c32, %c32) step (%c1, %c1) { + scf.for %arg13 = %c0 to %c32 step %c1 { + %4 = memref.load %alloc[%arg10, %arg11, %arg13] : memref<8x32x32xf32> + %5 = memref.load %expand_shape[%arg10, %arg13, %arg12] : memref<8x32x32xf32, strided<[131072, 4096, 1], offset: ?>> + %6 = memref.load %subview_5[%arg11, %arg12] : memref<32x32xf32, strided<[4096, 1], offset: ?>> + %7 = arith.mulf %4, %5 : f32 + %8 = arith.addf %6, %7 : f32 + memref.store %8, %subview_5[%arg11, %arg12] : memref<32x32xf32, strided<[4096, 1], offset: ?>> + } + scf.reduce + } + } + } else { + scf.for %arg10 = %c0 to %c8 step %c1 { + scf.parallel (%arg11, %arg12) = (%c0, %c0) to (%c32, %c32) step (%c1, %c1) { + scf.for %arg13 = %c0 to %c32 step %c1 { + %4 = memref.load %alloc[%arg10, %arg11, %arg13] : memref<8x32x32xf32> + %5 = memref.load %expand_shape[%arg10, %arg13, %arg12] : memref<8x32x32xf32, strided<[131072, 4096, 1], offset: ?>> + %6 = memref.load %subview_5[%arg11, %arg12] : memref<32x32xf32, strided<[4096, 1], offset: ?>> + %7 = arith.mulf %4, %5 : f32 + %8 = arith.addf %6, %7 : f32 + memref.store %8, %subview_5[%arg11, %arg12] : memref<32x32xf32, strided<[4096, 1], offset: ?>> + } + scf.reduce + } + } + } + } + } + } + } + } + memref.dealloc %alloc : memref<8x32x32xf32> + } + } + return + } +} + diff --git a/test/mlir/test/gc/Transforms/sinkOpIntoInnerLoop.mlir b/test/mlir/test/gc/Transforms/sinkOpIntoInnerLoop.mlir new file mode 100644 index 000000000..908d08883 --- /dev/null +++ b/test/mlir/test/gc/Transforms/sinkOpIntoInnerLoop.mlir @@ -0,0 +1,46 @@ +// RUN: gc-opt --split-input-file --sink-op-into-inner-loop %s | FileCheck %s + +func.func @matmul_2Dx2D_f32(%arg0: memref<4096x4096xf32>, %arg1: memref<4096x4096xf32>, %arg2: memref<4096x4096xf32>) { + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + %c4096 = arith.constant 4096 : index + %c32 = arith.constant 32 : index + // CHECK: scf.forall + // CHECK-NOT: affine.apply + // CHECK-NOT: memref.subview + // CHECK-NEXT: scf.forall + scf.forall (%arg3) in (4) { + %0 = affine.apply affine_map<(d0) -> (d0 * 1024)>(%arg3) + %subview = memref.subview %arg2[%0, 0] [1024, 4096] [1, 1] : memref<4096x4096xf32> to memref<1024x4096xf32, strided<[4096, 1], offset: ?>> + scf.forall (%arg4) in (2) { + %1 = affine.apply affine_map<(d0) -> (d0 * 2048)>(%arg4) + %subview_0 = memref.subview %subview[0, %1] [1024, 2048] [1, 1] : memref<1024x4096xf32, strided<[4096, 1], offset: ?>> to memref<1024x2048xf32, strided<[4096, 1], offset: ?>> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32xf32> + scf.for %arg5 = %c0 to %c1024 step %c256 { + %subview_1 = memref.subview %subview_0[%arg5, 0] [256, 2048] [1, 1] : memref<1024x2048xf32, strided<[4096, 1], offset: ?>> to memref<256x2048xf32, strided<[4096, 1], offset: ?>> + scf.for %arg6 = %c0 to %c2048 step %c256 { + %subview_2 = memref.subview %subview_1[0, %arg6] [256, 256] [1, 1] : memref<256x2048xf32, strided<[4096, 1], offset: ?>> to memref<256x256xf32, strided<[4096, 1], offset: ?>> + scf.for %arg7 = %c0 to %c4096 step %c256 { + %2 = arith.cmpi eq, %arg7, %c0 : index + scf.for %arg8 = %c0 to %c256 step %c32 { + %3 = affine.apply affine_map<(d0)[s0, s1] -> (d0 * 1024 + s0 + s1)>(%arg3)[%arg8, %arg5] + %subview_3 = memref.subview %arg0[%3, %arg7] [32, 256] [1, 1] : memref<4096x4096xf32> to memref<32x256xf32, strided<[4096, 1], offset: ?>> + %subview_4 = memref.subview %subview_2[%arg8, 0] [32, 256] [1, 1] : memref<256x256xf32, strided<[4096, 1], offset: ?>> to memref<32x256xf32, strided<[4096, 1], offset: ?>> + %expand_shape = memref.expand_shape %subview_3 [[0], [1, 2]] output_shape [32, 8, 32] : memref<32x256xf32, strided<[4096, 1], offset: ?>> into memref<32x8x32xf32, strided<[4096, 32, 1], offset: ?>> + scf.for %arg9 = %c0 to %c256 step %c32 { + + } + } + } + } + } + memref.dealloc %alloc : memref<8x32x32xf32> + } + } + return +} \ No newline at end of file From af8aad6e574979dfd9e73c046531d0aa161eba8e Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Thu, 25 Jul 2024 23:05:46 -0700 Subject: [PATCH 16/21] support dlti --- include/gc/Analysis/MatmulConfigAnalysis.h | 68 ++++++++++++------- lib/gc/Analysis/MatmulConfigAnalysis.cpp | 11 +-- .../Transforms/DeepTileContractionNamedOp.cpp | 2 +- lib/gc/Transforms/TilingUtil.hpp | 2 + .../deepTileContractionNamedOp.mlir | 47 +++++++++++++ 5 files changed, 99 insertions(+), 31 deletions(-) diff --git a/include/gc/Analysis/MatmulConfigAnalysis.h b/include/gc/Analysis/MatmulConfigAnalysis.h index d991bec86..e4604383f 100644 --- a/include/gc/Analysis/MatmulConfigAnalysis.h +++ b/include/gc/Analysis/MatmulConfigAnalysis.h @@ -10,51 +10,69 @@ #define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H #include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include +#include "mlir/Interfaces/DataLayoutInterfaces.h" namespace mlir { namespace gc { using namespace mlir; -// A mock for the taget information -// TODO: replace it with upstream hardware description model struct SystemDesc { - - static int getPositiveIntFromStr(char *str, int defaultValue = 1) { - if (!str || strlen(str) == 0 || str[0] > '9' || str[0] < '0') { - return defaultValue; - } - auto val = std::stoi(str); - return val > 0 ? val : defaultValue; - } - // get runtime OMP_NUM_THREADS uint32_t getNumThreads() { - char *numThreads = getenv("OMP_NUM_THREADS"); - return getPositiveIntFromStr(numThreads, 1); + std::optional numThreads = layout.getDevicePropertyValue( + Builder(ctx).getStringAttr("CPU" /* device ID*/), + Builder(ctx).getStringAttr("num_threads")); + if (numThreads && isa(*numThreads)) { + return dyn_cast(*numThreads).getInt(); + } + return 1; } // get cache size by cacheLevel size_t getCacheSize(uint8_t cacheLevel) { if (cacheLevel == 1) { - char *cacheSize = getenv("L1_CACHE_SIZE"); - return getPositiveIntFromStr(cacheSize, 0); + std::optional cacheSize = layout.getDevicePropertyValue( + Builder(ctx).getStringAttr("CPU" /* device ID*/), + Builder(ctx).getStringAttr("L1_cache_size_in_bytes")); + if (cacheSize && isa(*cacheSize)) { + return dyn_cast(*cacheSize).getInt(); + } } else if (cacheLevel == 2) { - char *cacheSize = getenv("L2_CACHE_SIZE"); - return getPositiveIntFromStr(cacheSize, 0); + std::optional cacheSize = layout.getDevicePropertyValue( + Builder(ctx).getStringAttr("CPU" /* device ID*/), + Builder(ctx).getStringAttr("L2_cache_size_in_bytes")); + if (cacheSize && isa(*cacheSize)) { + return dyn_cast(*cacheSize).getInt(); + } } else if (cacheLevel == 3) { - char *cacheSize = getenv("L3_CACHE_SIZE"); - return getPositiveIntFromStr(cacheSize, 0); + std::optional cacheSize = layout.getDevicePropertyValue( + Builder(ctx).getStringAttr("CPU" /* device ID*/), + Builder(ctx).getStringAttr("L3_cache_size_in_bytes")); + if (cacheSize && isa(*cacheSize)) { + return dyn_cast(*cacheSize).getInt(); + } } return 0; } // get the maximum vector length in bits size_t getMaxVectorLength() { - char *maxVectorLanes = getenv("MAX_VECTOR_LENGTH"); - return getPositiveIntFromStr(maxVectorLanes, 512); + std::optional maxVectorLength = layout.getDevicePropertyValue( + Builder(ctx).getStringAttr("CPU" /* device ID*/), + Builder(ctx).getStringAttr("max_vector_width")); + if (maxVectorLength && isa(*maxVectorLength)) { + return dyn_cast(*maxVectorLength).getInt(); + } + return 512; } + + SystemDesc(ModuleOp m) : layout(m), ctx(m->getContext()) {} + +private: + DataLayout layout; + MLIRContext *ctx; }; // The configuration for matmul tiling @@ -62,12 +80,12 @@ struct SystemDesc { struct MatmulConfig { // The number of threads distributed to M, N, K uint32_t MThreads, NThreads, KThreads; - // The innermost block size for M, N, K which will be directly converted to - // brgemm. - uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock; // The outer block size for M, N, K which will be used to decide the loop tile // size in single thread uint32_t MBlock, NBlock, KBlock; + // The innermost block size for M, N, K which will be directly converted to + // brgemm. + uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock; }; enum DimType { Batch, M, N, K }; diff --git a/lib/gc/Analysis/MatmulConfigAnalysis.cpp b/lib/gc/Analysis/MatmulConfigAnalysis.cpp index 682855cd4..abe15d00b 100644 --- a/lib/gc/Analysis/MatmulConfigAnalysis.cpp +++ b/lib/gc/Analysis/MatmulConfigAnalysis.cpp @@ -88,6 +88,7 @@ double vectorRegEfficiencyCost(linalg::LinalgOp &linalgOp, size_t dtypeSize = DataLayout().getTypeSizeInBits( ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); size_t maxVectorLength = sysDesc.getMaxVectorLength() / dtypeSize; + // TODO: take matrix register like amx into account double cost = (maxVectorLength - config.innerMostMBlock % maxVectorLength) % maxVectorLength * 1.0 / config.innerMostMBlock + (maxVectorLength - config.innerMostKBlock % maxVectorLength) % @@ -270,8 +271,8 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc, continue; } MatmulConfig config{ - MBlock, NBlock, KBlock, MThreads, NThreads, KThreads, + MBlock, NBlock, KBlock, innerMostMBlock, innerMostNBlock, innerMostKBlock}; configs.push_back(config); } @@ -311,13 +312,13 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef attrs) { } else if (attr.getName() == "MThreads") { config.MThreads = cast(attr.getValue()).getInt(); cfgItemCnt++; - } else if (attr.getName() == "innerMostMBlock") { + } else if (attr.getName() == "innermostMBlock") { config.innerMostMBlock = cast(attr.getValue()).getInt(); cfgItemCnt++; - } else if (attr.getName() == "innerMostNBlock") { + } else if (attr.getName() == "innermostNBlock") { config.innerMostNBlock = cast(attr.getValue()).getInt(); cfgItemCnt++; - } else if (attr.getName() == "innerMostKBlock") { + } else if (attr.getName() == "innermostKBlock") { config.innerMostKBlock = cast(attr.getValue()).getInt(); cfgItemCnt++; } @@ -338,7 +339,7 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef attrs) { // previous matmul MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { if (auto linalgOp = dyn_cast(root)) { - SystemDesc sysDesc; + SystemDesc sysDesc(root->getParentOfType()); SmallVector> oprandDimType = *getOprandDimType(linalgOp); // get the origin M,N,K size diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index d82cc554a..e84e8d105 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -243,7 +243,7 @@ static Operation *findParentFillOp(Value val) { llvm::find(skipOpList, currentOp->getName().getStringRef()) != skipOpList.end() && !isa(currentOp)) { - currentOp = currentOp->getResult(0).getDefiningOp(); + currentOp = currentOp->getOperand(0).getDefiningOp(); } if (currentOp && isa(currentOp)) { return currentOp; diff --git a/lib/gc/Transforms/TilingUtil.hpp b/lib/gc/Transforms/TilingUtil.hpp index e05680fa1..42460d374 100644 --- a/lib/gc/Transforms/TilingUtil.hpp +++ b/lib/gc/Transforms/TilingUtil.hpp @@ -16,6 +16,8 @@ namespace mlir { namespace linalgX { +// An enahncement for the upstream pass to support tiling reduction for MKmk +// like cases(with multiple reduction iterators). FailureOr tileReductionUsingForall( RewriterBase &b, PartialReductionOpInterface op, ArrayRef threadNums, ArrayRef tileSizes, diff --git a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir index 56e52f852..9fcaa0722 100644 --- a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir +++ b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir @@ -108,3 +108,50 @@ func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x12 return %2 : tensor<4096x4096xbf16> } +// ----- + +module attributes { + dlti.target_system_spec = #dlti.target_system_spec< + "CPU": #dlti.target_device_spec< + #dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>, + #dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>, + #dlti.dl_entry<"L3_cache_size_in_bytes", 110100480 : i32>, + #dlti.dl_entry<"num_threads", 56 : i32>, + #dlti.dl_entry<"max_vector_width", 512 : i32>> + >} { + /// CHECK-LABEL: @matmul_2Dx4D_bf16_with_dlti +func.func @matmul_2Dx4D_bf16_with_dlti(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<4096x4096xbf16> { + %cst_0 = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<4096x4096xbf16> + %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> + // CHECK: scf.forall + // CHECK: tensor.extract_slice + // CHECK: scf.forall + // CHECK: tensor.extract_slice + // CHECK: scf.forall + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice + // CHECK: linalg.transpose + // CHECK: scf.if + // CHECK: linalg.fill + // CHECK: linalgx.batch_reduce_matmul_vnni + // CHECK: else + // CHECK: linalgx.batch_reduce_matmul_vnni + // CHECK: scf.forall.in_parallel + // CHECK: scf.forall.in_parallel + // CHECK: scf.forall.in_parallel + // CHECK: linalg.reduce + // CHECK: linalg.copy + %2 = linalgx.mm2d_vnni ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> + return %2 : tensor<4096x4096xbf16> +} + +} From 24198fbd070d7f5c94292e2e9d8b2ec76f64a6b1 Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Mon, 29 Jul 2024 00:55:42 -0700 Subject: [PATCH 17/21] fix comments --- lib/gc/Analysis/MatmulConfigAnalysis.cpp | 57 ++++++++----- .../Transforms/DeepTileContractionNamedOp.cpp | 50 ++++++------ .../deepTileContractionNamedOp.mlir | 80 +++++++++---------- 3 files changed, 100 insertions(+), 87 deletions(-) diff --git a/lib/gc/Analysis/MatmulConfigAnalysis.cpp b/lib/gc/Analysis/MatmulConfigAnalysis.cpp index abe15d00b..ece062d4c 100644 --- a/lib/gc/Analysis/MatmulConfigAnalysis.cpp +++ b/lib/gc/Analysis/MatmulConfigAnalysis.cpp @@ -29,14 +29,9 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, template static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, - std::vector arry) { + std::vector array) { ss << "["; - for (auto [idx, a] : llvm::enumerate(arry)) { - if (idx != 0) { - ss << ", "; - } - ss << a; - } + llvm::interleaveComma(array, ss); ss << "]"; return ss; } @@ -174,24 +169,23 @@ std::vector filterConfigByCostModel(ArrayRef configs, linalg::LinalgOp &linalgOp, ArrayRef shape, SystemDesc &sysDesc, const CostModelFn &costModel, - float eliminationRatio = 0.5, float threshold = -1) { + float preserveRatio = 0.5, float threshold = -1) { std::vector result; std::vector costs; std::vector idx; - for (auto [i, config] : llvm::enumerate(configs)) { + for (auto &&[i, config] : llvm::enumerate(configs)) { costs.push_back(costModel(linalgOp, shape, config, sysDesc)); idx.push_back(i); } std::stable_sort(idx.begin(), idx.end(), [&costs](size_t i1, size_t i2) { return costs[i1] < costs[i2]; }); - double thresholdCost = - costs[idx[(size_t)(eliminationRatio * configs.size())]]; + double thresholdCost = costs[idx[(size_t)(preserveRatio * configs.size())]]; thresholdCost = threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost; - for (size_t i = 0; i < configs.size(); i++) { - if (costs[idx[i]] <= thresholdCost) { - result.push_back(configs[idx[i]]); + for (const auto &i : idx) { + if (costs[i] <= thresholdCost) { + result.push_back(configs[i]); } } LLVM_DEBUG(llvm::dbgs() << "thresholdCost is: " << thresholdCost @@ -210,6 +204,11 @@ std::vector prepareConfigCandidates(Operation *root, SystemDesc &sysDesc, ArrayRef shape, ArrayRef givenInnermostBlock) { + if (shape.size() < 3) { + LLVM_DEBUG(llvm::dbgs() + << "The shape is invalid, no candidate is generated\n"); + return {}; + } std::vector configs; uint32_t threads = sysDesc.getNumThreads(); std::vector MThreadsCandidates = @@ -290,10 +289,25 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc, return configs; } +bool validateConfig(const MatmulConfig &cfg) { + if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 || + cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 || + cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 || + cfg.innerMostKBlock <= 0) { + return false; + } + if (cfg.MBlock % cfg.innerMostMBlock != 0 || + cfg.NBlock % cfg.innerMostNBlock != 0 || + cfg.KBlock % cfg.innerMostKBlock != 0) { + return false; + } + return true; +} + // read the config from the attributes for tuning bool readConfigFromAttrs(MatmulConfig &config, ArrayRef attrs) { size_t cfgItemCnt = 0; - for (auto &attr : attrs) { + for (const auto &attr : attrs) { if (attr.getName() == "KBlock") { config.KBlock = cast(attr.getValue()).getInt(); cfgItemCnt++; @@ -323,7 +337,12 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef attrs) { cfgItemCnt++; } } - return cfgItemCnt == 9; + if (validateConfig(config)) { + return cfgItemCnt == 9; + } else { + LLVM_DEBUG(llvm::dbgs() << "The predefined config is invalid\n"); + return false; + } } // Analyze the workload and system description to generate the default config @@ -350,14 +369,14 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { SmallVector NDimTypeIdx = extractDimTypeIdx(oprandDimType[1], DimType::N); uint32_t M = 1U, N = 1U, K = 1U; - for (auto [s, dimType] : + for (auto &&[s, dimType] : llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(0)), oprandDimType[0])) { if (dimType == DimType::M) { M *= s; } } - for (auto [s, dimType] : + for (auto &&[s, dimType] : llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(1)), oprandDimType[1])) { if (dimType == DimType::N) { @@ -425,7 +444,7 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { SmallVector shape = {M, N, K}; std::vector configCandidates = prepareConfigCandidates(root, sysDesc, shape, givenInnermostBlock); - for (auto [fn, name, threshold] : costModelList) { + for (auto &&[fn, name, threshold] : costModelList) { configCandidates = filterConfigByCostModel( configCandidates, linalgOp, shape, sysDesc, fn, 0.5, threshold); } diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index e84e8d105..d07797a94 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -151,10 +151,10 @@ static void buildLinalgRegion(Operation *op, bool createTemporaryOp = false) { // Check if the linalgOp need to be legalized to f32 accumulation type static bool needToLegalizeDtype(linalg::LinalgOp linalgOp) { mlir::Type dataType = - dyn_cast(linalgOp.getDpsInputs()[0].getType()) + dyn_cast(linalgOp.getDpsInputs()[0].getType()) .getElementType(); mlir::Type resultType = - dyn_cast(linalgOp.getDpsInits()[0].getType()) + dyn_cast(linalgOp.getDpsInits()[0].getType()) .getElementType(); return (dataType.isBF16() || dataType.isF16()) && dataType == resultType; } @@ -372,7 +372,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, linalg::LinalgOp currentOp = linalgOp; bool hasFullResult = !option.isPartialResult; - for (auto [i, loopType] : llvm::enumerate(loopType)) { + for (auto &&[i, loopType] : llvm::enumerate(loopType)) { ArrayRef currentDim = loopDim[i]; ArrayRef currentTileSize = nestedTileSizes[i]; if (loopType == OuterLoopGenerationOption::LoopType::ForOp) { @@ -420,7 +420,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, cast(currentOp.getOperation()).getIterationDomain(b); currentOp.getReductionDims(reductionDims); bool tileOnReduction = false; - for (auto [d, tile] : llvm::zip(currentDim, currentTileSize)) { + for (auto &&[d, tile] : llvm::zip(currentDim, currentTileSize)) { if (llvm::find(reductionDims, d) != reductionDims.end() && tile != 0 && (!getConstantIntValue(loopRanges[d].size) || tile != static_cast( @@ -438,22 +438,23 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(currentOp); if (tileOnReduction) { - for (auto [idx, tile] : llvm::enumerate(tileSizes)) { + for (auto &&[idx, tile] : llvm::enumerate(tileSizes)) { if (isConstantIntValue(tile, 0) && llvm::find(reductionDims, idx) != reductionDims.end()) { tileSizes[idx] = loopRanges[idx].size; } } SmallVector newParallelDims; - for (size_t i = 0UL; i < reductionDims.size(); i++) { - newParallelDims.push_back(getAsIndexOpFoldResult(b.getContext(), i)); + for (auto iter : llvm::enumerate(reductionDims)) { + newParallelDims.push_back( + getAsIndexOpFoldResult(b.getContext(), iter.index())); } FailureOr tilingResult = linalgX::tileReductionUsingForall( b, cast(currentOp.getOperation()), {}, tileSizes, newParallelDims, std::nullopt); if (failed(tilingResult) && - tilingResult->parallelTiledOps.size() == 1UL) + llvm::hasSingleElement(tilingResult->parallelTiledOps)) return failure(); currentOp = dyn_cast(tilingResult->parallelTiledOps.back()); @@ -585,7 +586,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { : cfg.NBlock; // Outer loop tile size - for (auto [tile, dim] : + for (auto &&[tile, dim] : llvm::zip(SmallVector{KParallelBlockSize, MParallelBlockSize, NParallelBlockSize}, SmallVector{KDimPos[0], MDimPos[0], NDimPos[0]})) { @@ -596,7 +597,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { } // Middle loop tile size - for (auto [tile, dim] : + for (auto &&[tile, dim] : llvm::zip(SmallVector{MOuterBlockSize, NOuterBlockSize, KOuterBlockSize}, SmallVector{MDimPos[0], NDimPos[0], KDimPos[0]})) { @@ -604,19 +605,19 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); option.loopDim.emplace_back(SmallVector{dim}); } - if (KDimPos.size() == 1) { + if (llvm::hasSingleElement(KDimPos)) { option.nestedTileSizes.emplace_back(SmallVector{cfg.KBlock}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); option.loopDim.emplace_back(SmallVector{KDimPos.back()}); } // Inner loop tile size - if (MDimPos.size() == 1) { + if (llvm::hasSingleElement(MDimPos)) { option.nestedTileSizes.emplace_back( SmallVector{cfg.innerMostMBlock}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); option.loopDim.emplace_back(SmallVector{MDimPos.back()}); } - if (NDimPos.size() == 1) { + if (llvm::hasSingleElement(NDimPos)) { option.nestedTileSizes.emplace_back( SmallVector{cfg.innerMostNBlock}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); @@ -656,7 +657,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { const linalg::ForallReductionTilingResult &result) -> FailureOr { ArrayRef initValue = result.initialValues; - if (initValue.size() == 1 && + if (llvm::hasSingleElement(initValue) && isa(initValue[0].getDefiningOp())) { rewriter.replaceOp(initValue[0].getDefiningOp(), dyn_cast( @@ -706,7 +707,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { SmallVector AInnermostDims, BInnermostDims, CInnermostDims; bool firstM = true, firstK = true, firstN = true; if (MDimNum > 1) { - for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[0])) { + for (auto &&[idx, iter] : llvm::enumerate((*operandDimTypes)[0])) { if (iter == DimType::M && firstM) { AInnermostDims.push_back(1); firstM = false; @@ -721,7 +722,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { } firstM = true; firstN = true; - for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[2])) { + for (auto &&[idx, iter] : llvm::enumerate((*operandDimTypes)[2])) { if (iter == DimType::M && firstM) { CInnermostDims.push_back(1); firstM = false; @@ -745,7 +746,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { if (NDimNum > 1) { firstN = true; firstK = true; - for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[1])) { + for (auto &&[idx, iter] : llvm::enumerate((*operandDimTypes)[1])) { if (iter == DimType::N && firstN) { BInnermostDims.push_back(1); firstN = false; @@ -768,13 +769,13 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(currentOp); mlir::Type dataType = - dyn_cast(currentOp.getDpsInputs()[0].getType()) + dyn_cast(currentOp.getDpsInputs()[0].getType()) .getElementType(); mlir::Type weightType = - dyn_cast(currentOp.getDpsInputs()[1].getType()) + dyn_cast(currentOp.getDpsInputs()[1].getType()) .getElementType(); mlir::Type resultType = - dyn_cast(currentOp.getDpsInits()[0].getType()) + dyn_cast(currentOp.getDpsInits()[0].getType()) .getElementType(); // update the extractSlice to static size, replace it with @@ -821,9 +822,8 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { currentOp.getDpsInits()[0]); // Create the brgemm op and replace the origin linalg op linalg::LinalgOp matmul; - if (dyn_cast(weightOprand.getType()) - .getShape() - .size() == 3) { + if (dyn_cast(weightOprand.getType()).getShape().size() == + 3) { matmul = rewriter.create( loc, resultOprand.getType(), ValueRange{dataOprand, weightOprand}, resultOprand); @@ -843,7 +843,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { // fuse the low precision cast to the innermost body rewriter.setInsertionPointAfter(currentOp); Value cond; - for (LoopLikeOpInterface loop : option.KLoopHandles) { + for (LoopLikeOpInterface &loop : option.KLoopHandles) { Value induceVar = turnOpFoldResultIntoValue( rewriter, loc, *loop.getSingleInductionVar()); Value upBound = turnOpFoldResultIntoValue(rewriter, loc, @@ -903,7 +903,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { Value cond; arith::ConstantIndexOp zeroConst = rewriter.create(loc, 0); - for (LoopLikeOpInterface loop : option.KLoopHandles) { + for (LoopLikeOpInterface &loop : option.KLoopHandles) { Value induceVar = loop.getLoopRegions().front()->front().getArgument(0); Value currentCond = rewriter.create( loc, arith::CmpIPredicate::eq, induceVar, zeroConst); diff --git a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir index 9fcaa0722..d9f4feb67 100644 --- a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir +++ b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir @@ -7,29 +7,29 @@ func.func @matmul_2Dx2D_f32(%arg0: tensor<4096x4096xf32>, %arg1: tensor<4096x409 %cst_0 = arith.constant 0.000000e+00 : f32 %0 = tensor.empty() : tensor<4096x4096xf32> %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> - // CHECK: scf.forall - // CHECK: tensor.extract_slice - // CHECK: scf.forall - // CHECK: tensor.extract_slice + // CHECK: scf.forall {{.*}} (4) {{.*}} (tensor<4096x4096xf32>) { + // CHECK: tensor.extract_slice {{.*}} [1024, 4096] [1, 1] + // CHECK: scf.forall {{.*}} (2) {{.*}} (tensor<1024x4096xf32>) + // CHECK: tensor.extract_slice {{.*}} [1024, 2048] [1, 1] // CHECK: scf.for - // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice {{.*}} [256, 2048] [1, 1] // CHECK: scf.for - // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice {{.*}} [256, 256] [1, 1] // CHECK: scf.for // CHECK: scf.for - // CHECK: tensor.extract_slice - // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice {{.*}} [32, 256] [1, 1] + // CHECK: tensor.extract_slice {{.*}} [32, 256] [1, 1] // CHECK: scf.for - // CHECK: tensor.extract_slice - // CHECK: tensor.extract_slice - // CHECK: linalg.transpose - // CHECK: tensor.expand_shape + // CHECK: tensor.extract_slice {{.*}} [256, 32] [1, 1] + // CHECK: tensor.extract_slice {{.*}} [32, 32] [1, 1] + // CHECK: linalg.transpose {{.*}} permutation = [1, 0, 2] + // CHECK: tensor.expand_shape {{.*}} output_shape [8, 32, 32] : tensor<256x32xf32> into tensor<8x32x32xf32> // CHECK: scf.if // CHECK: linalg.fill // CHECK: linalg.batch_reduce_matmul // CHECK: else // CHECK: linalg.batch_reduce_matmul - // CHECK: tensor.insert_slice + // CHECK: tensor.insert_slice {{.*}} [32, 256] [1, 1] %2 = linalg.matmul {MThreads = 4 : i32, NThreads = 2 : i32, KThreads = 1 : i32, MBlock = 256 : i32, NBlock = 256 : i32, KBlock = 256 : i32,innermostMBlock = 32 : i32, innermostNBlock = 32 : i32, innermostKBlock = 32 : i32 } ins(%arg0, %arg1 : tensor<4096x4096xf32>, tensor<4096x4096xf32>) outs(%1 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> return %2 : tensor<4096x4096xf32> } @@ -43,22 +43,22 @@ func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<12 %0 = tensor.empty() : tensor<128x128x32x32xbf16> // CHECK-NOT: linalg.fill %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> - // CHECK: scf.forall - // CHECK: tensor.extract_slice - // CHECK: scf.forall - // CHECK: tensor.extract_slice + // CHECK: scf.forall {{.*}} (16) {{.*}} (tensor<128x128x32x32xbf16>) + // CHECK: tensor.extract_slice {{.*}} [8, 128, 32, 32] [1, 1, 1, 1] + // CHECK: scf.forall {{.*}} (2) {{.*}} (tensor<8x128x32x32xbf16>) + // CHECK: tensor.extract_slice {{.*}} [8, 64, 32, 32] [1, 1, 1, 1] // CHECK: scf.for - // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice {{.*}} [8, 8, 32, 32] [1, 1, 1, 1] // CHECK: tensor.empty() : tensor<8x8x32x32xf32> // CHECK: scf.for // CHECK: scf.for - // CHECK: tensor.extract_slice - // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice {{.*}} [1, 8, 32, 32] [1, 1, 1, 1] + // CHECK: tensor.extract_slice {{.*}} [1, 8, 32, 32] [1, 1, 1, 1] // CHECK: scf.for - // CHECK: tensor.extract_slice - // CHECK: tensor.extract_slice - // CHECK: tensor.extract_slice - // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice {{.*}} [1, 8, 32, 32] [1, 1, 1, 1] + // CHECK: tensor.extract_slice {{.*}} [1, 8, 16, 32, 2] [1, 1, 1, 1, 1] + // CHECK: tensor.extract_slice {{.*}} [1, 1, 32, 32] [1, 1, 1, 1] + // CHECK: tensor.extract_slice {{.*}} [1, 1, 32, 32] [1, 1, 1, 1] // CHECK: scf.if // CHECK: linalg.fill // CHECK: linalgx.batch_reduce_matmul_vnni @@ -78,22 +78,22 @@ func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x12 %cst_0 = arith.constant 0.000000e+00 : bf16 %0 = tensor.empty() : tensor<4096x4096xbf16> %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> - // CHECK: scf.forall - // CHECK: tensor.extract_slice - // CHECK: scf.forall - // CHECK: tensor.extract_slice - // CHECK: scf.forall - // CHECK: tensor.extract_slice + // CHECK: scf.forall {{.*}} (2) {{.*}} (tensor<2x1x1x4096x4096xf32>) + // CHECK: tensor.extract_slice {{.*}} [1, 1, 1, 4096, 4096] [1, 1, 1, 1, 1] + // CHECK: scf.forall {{.*}} (16) {{.*}} (tensor<4096x4096xf32>) + // CHECK: tensor.extract_slice {{.*}} [256, 4096] [1, 1] + // CHECK: scf.forall {{.*}} (2) {{.*}} (tensor<256x4096xf32>) + // CHECK: tensor.extract_slice {{.*}} [256, 2048] [1, 1] // CHECK: scf.for - // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice {{.*}} [256, 256] [1, 1] // CHECK: scf.for // CHECK: scf.for - // CHECK: tensor.extract_slice - // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice {{.*}} [32, 256] [1, 1] + // CHECK: tensor.extract_slice {{.*}} [32, 256] [1, 1] // CHECK: scf.for - // CHECK: tensor.extract_slice - // CHECK: tensor.extract_slice - // CHECK: linalg.transpose + // CHECK: tensor.extract_slice {{.*}} [1, 8, 16, 32, 2] [1, 1, 1, 1, 1] + // CHECK: tensor.extract_slice {{.*}} [32, 32] [1, 1] + // CHECK: linalg.transpose {{.*}} permutation = [1, 0, 2] // CHECK: scf.if // CHECK: linalg.fill // CHECK: linalgx.batch_reduce_matmul_vnni @@ -102,7 +102,7 @@ func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x12 // CHECK: scf.forall.in_parallel // CHECK: scf.forall.in_parallel // CHECK: scf.forall.in_parallel - // CHECK: linalg.reduce + // CHECK: linalg.reduce {{.*}} dimensions = [0, 1, 2] // CHECK: linalg.copy %2 = linalgx.mm2d_vnni {MThreads = 32 : i32, NThreads = 2 : i32, KThreads = 2 : i32, MBlock = 256 : i32, NBlock = 256 : i32, KBlock = 256 : i32,innermostMBlock = 32 : i32, innermostNBlock = 32 : i32, innermostKBlock = 32 : i32 } ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> return %2 : tensor<4096x4096xbf16> @@ -128,12 +128,9 @@ func.func @matmul_2Dx4D_bf16_with_dlti(%arg0: tensor<4096x4096xbf16>, %arg1: ten // CHECK: tensor.extract_slice // CHECK: scf.forall // CHECK: tensor.extract_slice - // CHECK: scf.forall - // CHECK: tensor.extract_slice // CHECK: scf.for // CHECK: tensor.extract_slice // CHECK: scf.for - // CHECK: scf.for // CHECK: tensor.extract_slice // CHECK: tensor.extract_slice // CHECK: scf.for @@ -147,9 +144,6 @@ func.func @matmul_2Dx4D_bf16_with_dlti(%arg0: tensor<4096x4096xbf16>, %arg1: ten // CHECK: linalgx.batch_reduce_matmul_vnni // CHECK: scf.forall.in_parallel // CHECK: scf.forall.in_parallel - // CHECK: scf.forall.in_parallel - // CHECK: linalg.reduce - // CHECK: linalg.copy %2 = linalgx.mm2d_vnni ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> return %2 : tensor<4096x4096xbf16> } From 9f294ea026897eabc4292fba0c3c70a0173c4292 Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Wed, 31 Jul 2024 22:00:17 -0700 Subject: [PATCH 18/21] format code --- lib/gc/Analysis/CMakeLists.txt | 2 +- lib/gc/Analysis/MatmulConfigAnalysis.cpp | 82 ++++++-------- .../Transforms/DeepTileContractionNamedOp.cpp | 107 +++++++----------- lib/gc/Transforms/MergeNestedForall.cpp | 5 +- lib/gc/Transforms/SinkOpIntoInnerLoop.cpp | 3 +- 5 files changed, 77 insertions(+), 122 deletions(-) diff --git a/lib/gc/Analysis/CMakeLists.txt b/lib/gc/Analysis/CMakeLists.txt index 51163823a..d7160f350 100644 --- a/lib/gc/Analysis/CMakeLists.txt +++ b/lib/gc/Analysis/CMakeLists.txt @@ -13,4 +13,4 @@ gc_add_mlir_library(GcAnalysis ${mlir_dialect_libs} ${MLIR_LINK_COMPONENTS} GcInterface - ) +) diff --git a/lib/gc/Analysis/MatmulConfigAnalysis.cpp b/lib/gc/Analysis/MatmulConfigAnalysis.cpp index ece062d4c..46181bdee 100644 --- a/lib/gc/Analysis/MatmulConfigAnalysis.cpp +++ b/lib/gc/Analysis/MatmulConfigAnalysis.cpp @@ -44,11 +44,10 @@ getCandidate(uint32_t num, uint32_t floor, // factor std::vector candidates; uint32_t upperbound = std::min(num, ceil); - for (uint32_t i = floor; i <= upperbound; i++) { - if (num % i == 0) { + for (uint32_t i = floor; i <= upperbound; i++) + if (num % i == 0) candidates.push_back(i); - } - } + // the pow of 2 uint32_t candidate = 1U; while (candidate < floor) @@ -68,9 +67,8 @@ getCandidate(uint32_t num, uint32_t floor, bool validateThreads(ArrayRef threads, SystemDesc &sysDesc) { uint32_t numThreads = sysDesc.getNumThreads(); uint32_t actualThreads = 1U; - for (uint32_t t : threads) { + for (uint32_t t : threads) actualThreads *= t; - } return actualThreads == numThreads; } @@ -154,9 +152,8 @@ double computationIntensityOnL2Cache(linalg::LinalgOp &linalgOp, config.NBlock * config.KBlock + config.MBlock * config.KBlock; double computationIntensity = FLOPS / memoryConsumption; - if (memoryConsumption * dtypeSize > L2Cache * fullLoadRatio) { + if (memoryConsumption * dtypeSize > L2Cache * fullLoadRatio) computationIntensity /= outOfCachePenalty; - } return 1 / computationIntensity; } @@ -183,19 +180,17 @@ filterConfigByCostModel(ArrayRef configs, double thresholdCost = costs[idx[(size_t)(preserveRatio * configs.size())]]; thresholdCost = threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost; - for (const auto &i : idx) { - if (costs[i] <= thresholdCost) { + for (const auto &i : idx) + if (costs[i] <= thresholdCost) result.push_back(configs[i]); - } - } + LLVM_DEBUG(llvm::dbgs() << "thresholdCost is: " << thresholdCost << "\nbest with cost: " << costs[idx[0]] << "\n" << configs[idx[0]] << "\n worst with cost: " << costs[idx[configs.size() - 1]] << "\n" << configs[idx[configs.size() - 1]] << "\n"); - if (result.empty()) { + if (result.empty()) result = configs; - } return result; } @@ -248,27 +243,23 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc, for (uint32_t MThreads : MThreadsCandidates) { for (uint32_t NThreads : NThreadsCandidates) { for (uint32_t KThreads : KThreadsCandidates) { - if (!validateThreads({MThreads, NThreads, KThreads}, sysDesc)) { + if (!validateThreads({MThreads, NThreads, KThreads}, sysDesc)) continue; - } for (uint32_t MBlock : MBlockCandidates) { for (uint32_t innerMostMBlock : innerMostMBlockCandidates) { if (MBlock % innerMostMBlock != 0 || - shape[0] % innerMostMBlock != 0) { + shape[0] % innerMostMBlock != 0) continue; - } for (uint32_t NBlock : NBlockCandidates) { for (uint32_t innerMostNBlock : innerMostNBlockCandidates) { if (NBlock % innerMostNBlock != 0 || - shape[1] % innerMostNBlock != 0) { + shape[1] % innerMostNBlock != 0) continue; - } for (uint32_t KBlock : KBlockCandidates) { for (uint32_t innerMostKBlock : innerMostKBlockCandidates) { if (KBlock % innerMostKBlock != 0 || - shape[2] % innerMostKBlock != 0) { + shape[2] % innerMostKBlock != 0) continue; - } MatmulConfig config{ MThreads, NThreads, KThreads, MBlock, NBlock, KBlock, @@ -293,14 +284,12 @@ bool validateConfig(const MatmulConfig &cfg) { if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 || cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 || cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 || - cfg.innerMostKBlock <= 0) { + cfg.innerMostKBlock <= 0) return false; - } if (cfg.MBlock % cfg.innerMostMBlock != 0 || cfg.NBlock % cfg.innerMostNBlock != 0 || - cfg.KBlock % cfg.innerMostKBlock != 0) { + cfg.KBlock % cfg.innerMostKBlock != 0) return false; - } return true; } @@ -371,19 +360,16 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { uint32_t M = 1U, N = 1U, K = 1U; for (auto &&[s, dimType] : llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(0)), - oprandDimType[0])) { - if (dimType == DimType::M) { + oprandDimType[0])) + if (dimType == DimType::M) M *= s; - } - } for (auto &&[s, dimType] : llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(1)), oprandDimType[1])) { - if (dimType == DimType::N) { + if (dimType == DimType::N) N *= s; - } else if (dimType == DimType::K) { + else if (dimType == DimType::K) K *= s; - } } // innermost Block, if the layout is blockied layout, the innermost block @@ -395,30 +381,30 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { SmallVector givenInnermostBlock; if (MDimTypeIdx.size() > 1) { config.innerMostMBlock = 1; - for (size_t i = 1UL; i < MDimTypeIdx.size(); i++) { - config.innerMostMBlock *= - linalgOp.getShape(linalgOp.getDpsInputOperand(0))[MDimTypeIdx[i]]; - } + for (auto &&[i, d] : llvm::enumerate(MDimTypeIdx)) + if (i != 0) + config.innerMostMBlock *= + linalgOp.getShape(linalgOp.getDpsInputOperand(0))[d]; givenInnermostBlock.push_back(config.innerMostMBlock); } else { givenInnermostBlock.push_back(0); } if (NDimTypeIdx.size() > 1) { config.innerMostNBlock = 1; - for (size_t i = 1UL; i < NDimTypeIdx.size(); i++) { - config.innerMostNBlock *= - linalgOp.getShape(linalgOp.getDpsInputOperand(1))[NDimTypeIdx[i]]; - } + for (auto &&[i, d] : llvm::enumerate(NDimTypeIdx)) + if (i != 0) + config.innerMostNBlock *= + linalgOp.getShape(linalgOp.getDpsInputOperand(1))[d]; givenInnermostBlock.push_back(config.innerMostNBlock); } else { givenInnermostBlock.push_back(0); } if (KDimTypeIdx.size() > 1) { config.innerMostKBlock = 1; - for (size_t i = 1UL; i < KDimTypeIdx.size(); i++) { - config.innerMostKBlock *= - linalgOp.getShape(linalgOp.getDpsInputOperand(1))[KDimTypeIdx[i]]; - } + for (auto &&[i, d] : llvm::enumerate(KDimTypeIdx)) + if (i != 0) + config.innerMostKBlock *= + linalgOp.getShape(linalgOp.getDpsInputOperand(1))[d]; givenInnermostBlock.push_back(config.innerMostKBlock); } else { givenInnermostBlock.push_back(0); @@ -444,13 +430,11 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { SmallVector shape = {M, N, K}; std::vector configCandidates = prepareConfigCandidates(root, sysDesc, shape, givenInnermostBlock); - for (auto &&[fn, name, threshold] : costModelList) { + for (auto &&[fn, name, threshold] : costModelList) configCandidates = filterConfigByCostModel( configCandidates, linalgOp, shape, sysDesc, fn, 0.5, threshold); - } - if (!configCandidates.empty()) { + if (!configCandidates.empty()) config = configCandidates[0]; - } } LLVM_DEBUG(llvm::dbgs() diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index d07797a94..4c0373646 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -44,9 +44,8 @@ tensorViewRankedTensor(RewriterBase &rewriter, RankedTensorType outTensorType, mlir::Type tensorElementType = inTensorType.getElementType(); // Check if the input and output tensor have the same shape - if (inShape == outShape) { + if (inShape == outShape) return currentValue; - } if (outShape.size() < inShape.size()) { SmallVector reassocIndices; @@ -93,9 +92,8 @@ tensorViewRankedTensor(RewriterBase &rewriter, RankedTensorType outTensorType, // Transpose the tensor if permutation is not empty if (!permutation.empty()) { SmallVector transposeShape; - for (int64_t idx : permutation) { + for (int64_t idx : permutation) transposeShape.push_back(outShape[idx]); - } Operation *initOp = rewriter.create(loc, transposeShape, tensorElementType); Operation *transposeOp = rewriter.create( @@ -110,9 +108,8 @@ bool isDummyLoop(LoopLikeOpInterface loop) { std::optional tripCount = mlir::constantTripCount( *loop.getSingleLowerBound(), *loop.getSingleUpperBound(), *loop.getSingleStep()); - if (tripCount) { + if (tripCount) return *tripCount == 1; - } return false; } @@ -132,9 +129,8 @@ static void buildLinalgRegion(Operation *op, bool createTemporaryOp = false) { if (createTemporaryOp) { unsigned argNum = body->getNumArguments(); SmallVector vals; - for (size_t i = initSize; i > 0; i--) { + for (size_t i = initSize; i > 0; i--) vals.push_back(body->getArgument(argNum - i)); - } OpBuilder::InsertionGuard g(b); b.setInsertionPointToEnd(body); Location loc = b.getUnknownLoc(); @@ -185,23 +181,21 @@ matmulDtypeLegalize(RewriterBase &rewriter, Operation *op, ShapedType initType = cast(initValue.getType()); ArrayRef tensorShape = initType.getShape(); SmallVector mixedShape; - for (size_t i = 0UL; i < tensorShape.size(); i++) { + for (auto &&[i, t] : llvm::enumerate(tensorShape)) { if (initType.isDynamicDim(i)) { Value val = rewriter.create(loc, initValue, i); mixedShape.push_back(val); } else { - mixedShape.push_back( - getAsIndexOpFoldResult(rewriter.getContext(), tensorShape[i])); + mixedShape.push_back(getAsIndexOpFoldResult(rewriter.getContext(), t)); } } Operation *currentOp; currentOp = rewriter.create( loc, mixedShape, Float32Type::get(op->getContext())); - if (needCopyInit) { + if (needCopyInit) currentOp = rewriter.create(loc, initOp->getResult(0), currentOp->getResult(0)); - } SmallVector newOperands = linalgOp->getOperands(); Value oldInit = newOperands.back(); newOperands.back() = currentOp->getResult(0); @@ -245,10 +239,8 @@ static Operation *findParentFillOp(Value val) { !isa(currentOp)) { currentOp = currentOp->getOperand(0).getDefiningOp(); } - if (currentOp && isa(currentOp)) { + if (currentOp && isa(currentOp)) return currentOp; - } - return nullptr; } @@ -262,12 +254,11 @@ static void getMatmulParallelDims(linalg::LinalgOp linalgOp, linalgOp.getIteratorTypesArray(); ArrayRef results = map.getResults(); - for (AffineExpr dim : results) { + for (const AffineExpr &dim : results) { AffineDimExpr dimExpr = dyn_cast(dim); if (dimExpr && iteratorTypes[dimExpr.getPosition()] == - mlir::utils::IteratorType::parallel) { + mlir::utils::IteratorType::parallel) dims.push_back(dimExpr.getPosition()); - } } } @@ -283,21 +274,19 @@ static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter, SmallVector mixedOffsets = extractSlice.getMixedOffsets(); SmallVector mixedSizes = extractSlice.getMixedSizes(); SmallVector mixedStrides = extractSlice.getMixedStrides(); - for (size_t i = 0UL; i < mixedSizes.size(); i++) { - mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), size[i]); - } - if (shrinDimNum > 0) { + for (auto &&[i, s] : llvm::enumerate(size)) + mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), s); + if (shrinDimNum > 0) rewriter.replaceOpWithNewOp( extractSlice, mlir::RankedTensorType::get( SmallVector(size.begin() + shrinDimNum, size.end()), extractSlice.getResult().getType().getElementType()), extractSlice.getSource(), mixedOffsets, mixedSizes, mixedStrides); - } else { + else rewriter.replaceOpWithNewOp( extractSlice, extractSlice.getSource(), mixedOffsets, mixedSizes, mixedStrides); - } } } @@ -312,9 +301,8 @@ static void setStaticSizeForInsertSliceOp(RewriterBase &rewriter, Operation *op, SmallVector mixedOffsets = insertSlice.getMixedOffsets(); SmallVector mixedSizes = insertSlice.getMixedSizes(); SmallVector mixedStrides = insertSlice.getMixedStrides(); - for (size_t i = 0UL; i < mixedSizes.size(); i++) { - mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), size[i]); - } + for (auto &&[i, s] : llvm::enumerate(size)) + mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), s); rewriter.replaceOpWithNewOp( insertSlice, source, insertSlice.getDest(), mixedOffsets, mixedSizes, mixedStrides); @@ -360,11 +348,10 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, linalgOp.getIteratorTypesArray(); if (loopType.size() != loopDim.size() || - loopDim.size() != nestedTileSizes.size()) { + loopDim.size() != nestedTileSizes.size()) return b.notifyMatchFailure( linalgOp, "loopType, loopDim and nestedTileSizes should have the same size"); - } if (linalgOp.hasPureBufferSemantics()) return b.notifyMatchFailure( @@ -376,7 +363,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, ArrayRef currentDim = loopDim[i]; ArrayRef currentTileSize = nestedTileSizes[i]; if (loopType == OuterLoopGenerationOption::LoopType::ForOp) { - for (auto [d, tile] : llvm::zip(currentDim, currentTileSize)) { + for (auto &&[d, tile] : llvm::zip(currentDim, currentTileSize)) { scf::SCFTilingOptions tileOption; SmallVector TileSizes( currentOp.getNumLoops(), getAsIndexOpFoldResult(b.getContext(), 0)); @@ -390,9 +377,8 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, for (const auto &fn : option.innermostFullResultCallBacks) { FailureOr result = fn(b, currentOp->getLoc(), currentOp); - if (succeeded(result)) { + if (succeeded(result)) currentOp = *result; - } } hasFullResult = false; } @@ -404,9 +390,8 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, if (!isDummyLoop(tilingResult->loops.back())) { b.replaceOp(currentOp, tilingResult->replacements); currentOp = dyn_cast(tilingResult->tiledOps.back()); - if (iteratorTypes[d] == mlir::utils::IteratorType::reduction) { + if (iteratorTypes[d] == mlir::utils::IteratorType::reduction) result.reductionLoops.push_back(tilingResult->loops.back()); - } result.loops.push_back(tilingResult->loops.back()); } } @@ -423,32 +408,29 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, for (auto &&[d, tile] : llvm::zip(currentDim, currentTileSize)) { if (llvm::find(reductionDims, d) != reductionDims.end() && tile != 0 && (!getConstantIntValue(loopRanges[d].size) || - tile != static_cast( - *getConstantIntValue(loopRanges[d].size)))) { + tile != + static_cast(*getConstantIntValue(loopRanges[d].size)))) tileOnReduction = true; - } if (llvm::find(reductionDims, d) != reductionDims.end() && !dyn_cast(currentOp.getOperation())) { tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), 0); tileOnReduction = false; - } else + } else { tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), tile); + } } OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(currentOp); if (tileOnReduction) { - for (auto &&[idx, tile] : llvm::enumerate(tileSizes)) { + for (auto &&[idx, tile] : llvm::enumerate(tileSizes)) if (isConstantIntValue(tile, 0) && - llvm::find(reductionDims, idx) != reductionDims.end()) { + llvm::find(reductionDims, idx) != reductionDims.end()) tileSizes[idx] = loopRanges[idx].size; - } - } SmallVector newParallelDims; - for (auto iter : llvm::enumerate(reductionDims)) { + for (auto iter : llvm::enumerate(reductionDims)) newParallelDims.push_back( getAsIndexOpFoldResult(b.getContext(), iter.index())); - } FailureOr tilingResult = linalgX::tileReductionUsingForall( b, cast(currentOp.getOperation()), @@ -462,9 +444,8 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, for (const auto &fn : option.finalReduceCallBacks) { FailureOr result = fn(b, currentOp->getLoc(), *tilingResult); - if (succeeded(result)) { + if (succeeded(result)) currentOp = *result; - } } } } else { @@ -623,7 +604,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); option.loopDim.emplace_back(SmallVector{NDimPos.back()}); } - for (size_t dim = 0UL; dim < linalgOp.getNumLoops(); dim++) { + for (size_t dim = 0UL; dim < linalgOp.getNumLoops(); ++dim) { if (dim != MDimPos.back() && dim != NDimPos.back() && iteratorTypes[dim] != mlir::utils::IteratorType::reduction) { option.nestedTileSizes.emplace_back(SmallVector{1}); @@ -658,12 +639,11 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { -> FailureOr { ArrayRef initValue = result.initialValues; if (llvm::hasSingleElement(initValue) && - isa(initValue[0].getDefiningOp())) { + isa(initValue[0].getDefiningOp())) rewriter.replaceOp(initValue[0].getDefiningOp(), dyn_cast( initValue[0].getDefiningOp()) .getDpsInits()[0]); - } return dyn_cast(result.parallelTiledOps.back()); }; option.finalReduceCallBacks.push_back(removeReduncantFill); @@ -786,7 +766,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { setStaticSizeForExtractSliceOp(rewriter, currentOp.getDpsInputs()[0].getDefiningOp(), true, AInnermostDims, MDimNum > 1); - for (Value init : currentOp.getDpsInits()) { + for (const Value &init : currentOp.getDpsInits()) { setStaticSizeForExtractSliceOp(rewriter, init.getDefiningOp(), true, CInnermostDims, MDimNum > 1 ? 2 : 0); } @@ -823,21 +803,19 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { // Create the brgemm op and replace the origin linalg op linalg::LinalgOp matmul; if (dyn_cast(weightOprand.getType()).getShape().size() == - 3) { + 3) matmul = rewriter.create( loc, resultOprand.getType(), ValueRange{dataOprand, weightOprand}, resultOprand); - } else { + else matmul = rewriter.create( loc, resultOprand.getType(), ValueRange{dataOprand, weightOprand}, resultOprand); - } Value result = matmul.getOperation()->getResult(0); // Insert the result back to the original tensor - for (Operation *user : currentOp->getResult(0).getUsers()) { + for (Operation *user : currentOp->getResult(0).getUsers()) setStaticSizeForInsertSliceOp(rewriter, user, result, CInnermostDims); - } if (option.needLowPrecisionCast) { // fuse the low precision cast to the innermost body @@ -876,10 +854,9 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { rewriter.create(loc, currentOp.getDpsInits().back()); } // set static size for the insertSliceOp of copyOp - for (Operation *user : currentOp->getResult(1).getUsers()) { + for (Operation *user : currentOp->getResult(1).getUsers()) setStaticSizeForInsertSliceOp(rewriter, user, ifOp->getResult(0), CInnermostDims); - } rewriter.replaceOp(currentOp, {matmul->getResult(0), ifOp->getResult(0)}); } else { rewriter.replaceOp(currentOp, matmul->getResult(0)); @@ -889,7 +866,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { // Fuse the fill op to the innermost body if (auto fillOp = llvm::dyn_cast_or_null(option.fillOp)) { Value fillValue = fillOp.getDpsInputs()[0]; - if (cfg.KThreads <= 1) { + if (cfg.KThreads <= 1) // if use k slicing, the fill op is still need to be kept for the reduce // init rewriter.replaceUsesWithIf(fillOp.getResult(0), fillOp.getDpsInits()[0], @@ -897,7 +874,6 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { return isa( operand.getOwner()); }); - } rewriter.setInsertionPointAfter(currentOp); Value cond; @@ -983,18 +959,16 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { // Step 2. Outer loop generation FailureOr outerLoopResult = outerLoopGeneration( rewriter, linalgOp, cfg, fillOp && isa(fillOp)); - if (failed(outerLoopResult)) { + if (failed(outerLoopResult)) return failure(); - } linalgOp = dyn_cast(outerLoopResult->tiledOps.back()); // Step 3 generate inner loop body, convert the linalg.generic to brgemm innerBodyGenerationOption option = innerBodyGenerationOption{ fillOp, needLowPrecisionCast, outerLoopResult->reductionLoops}; - if (failed(innerBodyGeneration(rewriter, originOp, linalgOp, option))) { + if (failed(innerBodyGeneration(rewriter, originOp, linalgOp, option))) return failure(); - } rewriter.eraseOp(originOp); return success(); } @@ -1020,10 +994,9 @@ struct DeepTileContractionNamedOp dialect->getCanonicalizationPatterns(patterns); for (RegisteredOperationName op : ctx.getRegisteredOperations()) op.getCanonicalizationPatterns(patterns, &ctx); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); - } } }; diff --git a/lib/gc/Transforms/MergeNestedForall.cpp b/lib/gc/Transforms/MergeNestedForall.cpp index 516981c9c..07eb5ffbf 100644 --- a/lib/gc/Transforms/MergeNestedForall.cpp +++ b/lib/gc/Transforms/MergeNestedForall.cpp @@ -82,10 +82,9 @@ struct MergeNestedForall patterns.add(patterns.getContext()); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); - } } }; diff --git a/lib/gc/Transforms/SinkOpIntoInnerLoop.cpp b/lib/gc/Transforms/SinkOpIntoInnerLoop.cpp index df04b6590..965a26392 100644 --- a/lib/gc/Transforms/SinkOpIntoInnerLoop.cpp +++ b/lib/gc/Transforms/SinkOpIntoInnerLoop.cpp @@ -29,9 +29,8 @@ struct SinkOpIntoInnerLoop getOperation()->walk([&](LoopLikeOpInterface loop) { SmallVector regionsToSink; // Get the regions are that known to be executed at most once. - for (auto &it : loop->getRegions()) { + for (auto &it : loop->getRegions()) regionsToSink.push_back(&it); - } // Sink side-effect free operations. controlFlowSink( regionsToSink, domInfo, From 3c5567f73c14de0cad96f1c408e232bbda4379a4 Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Mon, 5 Aug 2024 19:59:59 -0700 Subject: [PATCH 19/21] replace sysDesc with target info --- include/gc/Analysis/MatmulConfigAnalysis.h | 56 --------------------- lib/gc/Analysis/MatmulConfigAnalysis.cpp | 42 +++++++++------- test/mlir/unittests/Analysis/CMakeLists.txt | 1 + 3 files changed, 24 insertions(+), 75 deletions(-) diff --git a/include/gc/Analysis/MatmulConfigAnalysis.h b/include/gc/Analysis/MatmulConfigAnalysis.h index e4604383f..7bd1bb4f0 100644 --- a/include/gc/Analysis/MatmulConfigAnalysis.h +++ b/include/gc/Analysis/MatmulConfigAnalysis.h @@ -19,62 +19,6 @@ namespace gc { using namespace mlir; -struct SystemDesc { - // get runtime OMP_NUM_THREADS - uint32_t getNumThreads() { - std::optional numThreads = layout.getDevicePropertyValue( - Builder(ctx).getStringAttr("CPU" /* device ID*/), - Builder(ctx).getStringAttr("num_threads")); - if (numThreads && isa(*numThreads)) { - return dyn_cast(*numThreads).getInt(); - } - return 1; - } - // get cache size by cacheLevel - size_t getCacheSize(uint8_t cacheLevel) { - if (cacheLevel == 1) { - std::optional cacheSize = layout.getDevicePropertyValue( - Builder(ctx).getStringAttr("CPU" /* device ID*/), - Builder(ctx).getStringAttr("L1_cache_size_in_bytes")); - if (cacheSize && isa(*cacheSize)) { - return dyn_cast(*cacheSize).getInt(); - } - } else if (cacheLevel == 2) { - std::optional cacheSize = layout.getDevicePropertyValue( - Builder(ctx).getStringAttr("CPU" /* device ID*/), - Builder(ctx).getStringAttr("L2_cache_size_in_bytes")); - if (cacheSize && isa(*cacheSize)) { - return dyn_cast(*cacheSize).getInt(); - } - } else if (cacheLevel == 3) { - std::optional cacheSize = layout.getDevicePropertyValue( - Builder(ctx).getStringAttr("CPU" /* device ID*/), - Builder(ctx).getStringAttr("L3_cache_size_in_bytes")); - if (cacheSize && isa(*cacheSize)) { - return dyn_cast(*cacheSize).getInt(); - } - } - return 0; - } - - // get the maximum vector length in bits - size_t getMaxVectorLength() { - std::optional maxVectorLength = layout.getDevicePropertyValue( - Builder(ctx).getStringAttr("CPU" /* device ID*/), - Builder(ctx).getStringAttr("max_vector_width")); - if (maxVectorLength && isa(*maxVectorLength)) { - return dyn_cast(*maxVectorLength).getInt(); - } - return 512; - } - - SystemDesc(ModuleOp m) : layout(m), ctx(m->getContext()) {} - -private: - DataLayout layout; - MLIRContext *ctx; -}; - // The configuration for matmul tiling // TODO: support batch matmul struct MatmulConfig { diff --git a/lib/gc/Analysis/MatmulConfigAnalysis.cpp b/lib/gc/Analysis/MatmulConfigAnalysis.cpp index 46181bdee..b31e0933e 100644 --- a/lib/gc/Analysis/MatmulConfigAnalysis.cpp +++ b/lib/gc/Analysis/MatmulConfigAnalysis.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "gc/Analysis/MatmulConfigAnalysis.h" +#include "gc/Analysis/TargetDescriptionAnalysis.h" #include #include @@ -64,7 +65,8 @@ getCandidate(uint32_t num, uint32_t floor, } // check if the threads are valid -bool validateThreads(ArrayRef threads, SystemDesc &sysDesc) { +bool validateThreads(ArrayRef threads, + CPUTargetDescriptionAnalysis &sysDesc) { uint32_t numThreads = sysDesc.getNumThreads(); uint32_t actualThreads = 1U; for (uint32_t t : threads) @@ -77,24 +79,25 @@ bool validateThreads(ArrayRef threads, SystemDesc &sysDesc) { double vectorRegEfficiencyCost(linalg::LinalgOp &linalgOp, ArrayRef shape, const MatmulConfig &config, - SystemDesc &sysDesc) { + CPUTargetDescriptionAnalysis &sysDesc) { size_t dtypeSize = DataLayout().getTypeSizeInBits( ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); - size_t maxVectorLength = sysDesc.getMaxVectorLength() / dtypeSize; + size_t maxVectorWidth = sysDesc.getMaxVectorWidth() / dtypeSize; // TODO: take matrix register like amx into account - double cost = (maxVectorLength - config.innerMostMBlock % maxVectorLength) % - maxVectorLength * 1.0 / config.innerMostMBlock + - (maxVectorLength - config.innerMostKBlock % maxVectorLength) % - maxVectorLength * 1.0 / config.innerMostKBlock + - (maxVectorLength - config.innerMostNBlock % maxVectorLength) % - maxVectorLength * 1.0 / config.innerMostNBlock; + double cost = (maxVectorWidth - config.innerMostMBlock % maxVectorWidth) % + maxVectorWidth * 1.0 / config.innerMostMBlock + + (maxVectorWidth - config.innerMostKBlock % maxVectorWidth) % + maxVectorWidth * 1.0 / config.innerMostKBlock + + (maxVectorWidth - config.innerMostNBlock % maxVectorWidth) % + maxVectorWidth * 1.0 / config.innerMostNBlock; return cost; } // calculate the cost of the workload balance double workloadBalancedCost(linalg::LinalgOp &linalgOp, ArrayRef shape, - const MatmulConfig &config, SystemDesc &sysDesc) { + const MatmulConfig &config, + CPUTargetDescriptionAnalysis &sysDesc) { if (shape.size() < 3) { // Has an invalid shape return 0; @@ -118,7 +121,7 @@ double workloadBalancedCost(linalg::LinalgOp &linalgOp, double memoryConsumptionOnThreadCost(linalg::LinalgOp &linalgOp, ArrayRef shape, const MatmulConfig &config, - SystemDesc &sysDesc) { + CPUTargetDescriptionAnalysis &sysDesc) { if (shape.size() < 3) { // Has an invalid shape return 0; @@ -141,7 +144,7 @@ double memoryConsumptionOnThreadCost(linalg::LinalgOp &linalgOp, double computationIntensityOnL2Cache(linalg::LinalgOp &linalgOp, ArrayRef shape, const MatmulConfig &config, - SystemDesc &sysDesc) { + CPUTargetDescriptionAnalysis &sysDesc) { double fullLoadRatio = 0.7; uint32_t L2Cache = sysDesc.getCacheSize(2); size_t dtypeSize = DataLayout().getTypeSize( @@ -157,16 +160,17 @@ double computationIntensityOnL2Cache(linalg::LinalgOp &linalgOp, return 1 / computationIntensity; } -using CostModelFn = - std::function shape, - MatmulConfig cfg, SystemDesc &sysDesc)>; +using CostModelFn = std::function shape, MatmulConfig cfg, + CPUTargetDescriptionAnalysis &sysDesc)>; // filter the config by the cost model std::vector filterConfigByCostModel(ArrayRef configs, linalg::LinalgOp &linalgOp, ArrayRef shape, - SystemDesc &sysDesc, const CostModelFn &costModel, - float preserveRatio = 0.5, float threshold = -1) { + CPUTargetDescriptionAnalysis &sysDesc, + const CostModelFn &costModel, float preserveRatio = 0.5, + float threshold = -1) { std::vector result; std::vector costs; std::vector idx; @@ -196,7 +200,7 @@ filterConfigByCostModel(ArrayRef configs, // prepare the config candidates std::vector -prepareConfigCandidates(Operation *root, SystemDesc &sysDesc, +prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc, ArrayRef shape, ArrayRef givenInnermostBlock) { if (shape.size() < 3) { @@ -347,7 +351,7 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef attrs) { // previous matmul MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { if (auto linalgOp = dyn_cast(root)) { - SystemDesc sysDesc(root->getParentOfType()); + CPUTargetDescriptionAnalysis sysDesc(root); SmallVector> oprandDimType = *getOprandDimType(linalgOp); // get the origin M,N,K size diff --git a/test/mlir/unittests/Analysis/CMakeLists.txt b/test/mlir/unittests/Analysis/CMakeLists.txt index d78877afe..ed253bfdf 100644 --- a/test/mlir/unittests/Analysis/CMakeLists.txt +++ b/test/mlir/unittests/Analysis/CMakeLists.txt @@ -3,5 +3,6 @@ add_mlir_unittest(GCAnalysisTests ) target_link_libraries(GCAnalysisTests PRIVATE + GcPasses GcAnalysis GcJitWrapper) From a205731ac236625ed53f60d040d443cb5a36b1eb Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Wed, 31 Jul 2024 21:20:19 -0700 Subject: [PATCH 20/21] deprecated tileToForallUsingTileSize --- .../Transforms/DeepTileContractionNamedOp.cpp | 35 ++++++++++--------- lib/gc/Transforms/Pipeline.cpp | 5 +-- .../deepTileContractionNamedOp.mlir | 12 +++---- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index 4c0373646..b356fa11c 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -168,11 +168,12 @@ static FailureOr matmulDtypeLegalize(RewriterBase &rewriter, Operation *op, bool needCopyInit = true, bool needFurtherFuse = false) { linalg::LinalgOp linalgOp = dyn_cast(op); - Location loc = linalgOp->getLoc(); - DtypeLegalizeResult result; if (!linalgOp) return failure(); + Location loc = linalgOp->getLoc(); + DtypeLegalizeResult result; + if (needToLegalizeDtype(linalgOp)) { rewriter.setInsertionPoint(linalgOp); IRMapping mapping; @@ -449,15 +450,15 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, } } } else { - TilingInterface tilingInterface = - cast(currentOp.getOperation()); - FailureOr tilingResult = - linalg::tileToForallOpUsingTileSizes(b, tilingInterface, tileSizes, - std::nullopt); + scf::SCFTilingOptions tileOption; + tileOption.setTileSizes(tileSizes); + tileOption.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); + FailureOr tilingResult = scf::tileUsingSCF( + b, cast(currentOp.getOperation()), tileOption); if (failed(tilingResult)) return failure(); - b.replaceOp(currentOp, tilingResult->tileOp); - currentOp = dyn_cast(tilingResult->tiledOp); + b.replaceOp(currentOp, tilingResult->replacements); + currentOp = dyn_cast(tilingResult->tiledOps.back()); } } } @@ -499,8 +500,8 @@ NOuterBlock: (PN + 1) * NOuterBlock] CSlice2 = CSlice[PK, PM * MOuterBlock: (PM for([om, on, ok]: [MNumBlock, NNumBlock, KNumBlock]) { ASlice2 = ASlice[om * MBlock: (om + 1) * MBlock, ok * KBlock: (ok + 1) * KBlock] - BSlice2 = BSlice[0, om * MBlock: (om + 1) * MBlock, ok * KBlock: (ok + -1) * KBlock] + BSlice2 = BSlice[0, ok * KBlock: (ok + 1) * KBlock, on * NBlock: (on + +1) * NBlock] CSlice3 = CSlice2[0, om * MBlock: (om + 1) * MBlock, on * NBlock: (on + 1) * NBlock] (init with 0 when ok == 0) MNumInnerBlock = MBlock / iim_block_ @@ -539,11 +540,13 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { size_t NFirstDim = *getConstantIntValue(loopRange[NDimPos[0]].size); size_t KParallelBlockSize = - KDimPos.size() > 1 - ? llvm::divideCeil(KFirstDim, cfg.KThreads) - : llvm::divideCeil(llvm::divideCeil(KFirstDim, cfg.KBlock), - cfg.KThreads) * - cfg.KBlock; + cfg.KThreads == 1 + ? 0 + : (KDimPos.size() > 1 + ? llvm::divideCeil(KFirstDim, cfg.KThreads) + : llvm::divideCeil(llvm::divideCeil(KFirstDim, cfg.KBlock), + cfg.KThreads) * + cfg.KBlock); size_t MParallelBlockSize = MDimPos.size() > 1 ? llvm::divideCeil(MFirstDim, cfg.MThreads) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 4e4a0dd25..f198c6c75 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -52,9 +52,10 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // todo: layout propagation pass // todo: tensor constant propagation pass // linalg.matmul lowering to (scf.loop + linalg.brgemm) pass - pm.addNestedPass(createIterativeTilingAndFusion()); - // Fine-grain fusion pass pm.addNestedPass(createDeepTileContractionNamedOp()); + + // Fine-grain fusion pass + pm.addNestedPass(createIterativeTilingAndFusion()); // todo: fine-grain fusion pass // todo: lower linalg to arith/math on virtual vector pass diff --git a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir index d9f4feb67..aa6cacbd7 100644 --- a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir +++ b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir @@ -7,9 +7,9 @@ func.func @matmul_2Dx2D_f32(%arg0: tensor<4096x4096xf32>, %arg1: tensor<4096x409 %cst_0 = arith.constant 0.000000e+00 : f32 %0 = tensor.empty() : tensor<4096x4096xf32> %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> - // CHECK: scf.forall {{.*}} (4) {{.*}} (tensor<4096x4096xf32>) { + // CHECK: scf.forall {{.*}} (0) to (4096) step (1024) {{.*}} (tensor<4096x4096xf32>) { // CHECK: tensor.extract_slice {{.*}} [1024, 4096] [1, 1] - // CHECK: scf.forall {{.*}} (2) {{.*}} (tensor<1024x4096xf32>) + // CHECK: scf.forall {{.*}} (0) to (4096) step (2048) {{.*}} (tensor<1024x4096xf32>) // CHECK: tensor.extract_slice {{.*}} [1024, 2048] [1, 1] // CHECK: scf.for // CHECK: tensor.extract_slice {{.*}} [256, 2048] [1, 1] @@ -43,9 +43,9 @@ func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<12 %0 = tensor.empty() : tensor<128x128x32x32xbf16> // CHECK-NOT: linalg.fill %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> - // CHECK: scf.forall {{.*}} (16) {{.*}} (tensor<128x128x32x32xbf16>) + // CHECK: scf.forall {{.*}} (0) to (128) step (8) {{.*}} (tensor<128x128x32x32xbf16>) // CHECK: tensor.extract_slice {{.*}} [8, 128, 32, 32] [1, 1, 1, 1] - // CHECK: scf.forall {{.*}} (2) {{.*}} (tensor<8x128x32x32xbf16>) + // CHECK: scf.forall {{.*}} (0) to (128) step (64) {{.*}} (tensor<8x128x32x32xbf16>) // CHECK: tensor.extract_slice {{.*}} [8, 64, 32, 32] [1, 1, 1, 1] // CHECK: scf.for // CHECK: tensor.extract_slice {{.*}} [8, 8, 32, 32] [1, 1, 1, 1] @@ -80,9 +80,9 @@ func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x12 %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> // CHECK: scf.forall {{.*}} (2) {{.*}} (tensor<2x1x1x4096x4096xf32>) // CHECK: tensor.extract_slice {{.*}} [1, 1, 1, 4096, 4096] [1, 1, 1, 1, 1] - // CHECK: scf.forall {{.*}} (16) {{.*}} (tensor<4096x4096xf32>) + // CHECK: scf.forall {{.*}} (0) to (4096) step (256) {{.*}} (tensor<4096x4096xf32>) // CHECK: tensor.extract_slice {{.*}} [256, 4096] [1, 1] - // CHECK: scf.forall {{.*}} (2) {{.*}} (tensor<256x4096xf32>) + // CHECK: scf.forall {{.*}} (0) to (128) step (64) {{.*}} (tensor<256x4096xf32>) // CHECK: tensor.extract_slice {{.*}} [256, 2048] [1, 1] // CHECK: scf.for // CHECK: tensor.extract_slice {{.*}} [256, 256] [1, 1] From ccd02f28bb4c71c73c4fca6dd4a9bc683a9e9fef Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Tue, 6 Aug 2024 20:16:22 -0700 Subject: [PATCH 21/21] use expand/collapse_shape to do rank alter --- .../Transforms/DeepTileContractionNamedOp.cpp | 32 +++++++++++-------- .../deepTileContractionNamedOp.mlir | 6 +++- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index b356fa11c..30d0e022f 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -275,19 +275,22 @@ static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter, SmallVector mixedOffsets = extractSlice.getMixedOffsets(); SmallVector mixedSizes = extractSlice.getMixedSizes(); SmallVector mixedStrides = extractSlice.getMixedStrides(); + auto targetTensor = mlir::RankedTensorType::get( + SmallVector(size.begin() + shrinDimNum, size.end()), + extractSlice.getResult().getType().getElementType()); for (auto &&[i, s] : llvm::enumerate(size)) mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), s); - if (shrinDimNum > 0) - rewriter.replaceOpWithNewOp( - extractSlice, - mlir::RankedTensorType::get( - SmallVector(size.begin() + shrinDimNum, size.end()), - extractSlice.getResult().getType().getElementType()), - extractSlice.getSource(), mixedOffsets, mixedSizes, mixedStrides); - else - rewriter.replaceOpWithNewOp( - extractSlice, extractSlice.getSource(), mixedOffsets, mixedSizes, - mixedStrides); + Operation *newExtractSliceOp = rewriter.create( + extractSlice->getLoc(), extractSlice.getSource(), mixedOffsets, + mixedSizes, mixedStrides); + if (shrinDimNum > 0) { + rewriter.setInsertionPointAfter(newExtractSliceOp); + Value viewResult = tensorViewRankedTensor( + rewriter, targetTensor, newExtractSliceOp->getResult(0)); + rewriter.replaceOp(extractSlice, viewResult); + } else { + rewriter.replaceOp(extractSlice, newExtractSliceOp); + } } } @@ -304,9 +307,12 @@ static void setStaticSizeForInsertSliceOp(RewriterBase &rewriter, Operation *op, SmallVector mixedStrides = insertSlice.getMixedStrides(); for (auto &&[i, s] : llvm::enumerate(size)) mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), s); + auto targetTensor = mlir::RankedTensorType::get( + size, insertSlice.getDest().getType().getElementType()); + Value viewResult = tensorViewRankedTensor(rewriter, targetTensor, source); rewriter.replaceOpWithNewOp( - insertSlice, source, insertSlice.getDest(), mixedOffsets, mixedSizes, - mixedStrides); + insertSlice, viewResult, insertSlice.getDest(), mixedOffsets, + mixedSizes, mixedStrides); } } diff --git a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir index aa6cacbd7..7cb39e2b3 100644 --- a/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir +++ b/test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir @@ -55,10 +55,13 @@ func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<12 // CHECK: tensor.extract_slice {{.*}} [1, 8, 32, 32] [1, 1, 1, 1] // CHECK: tensor.extract_slice {{.*}} [1, 8, 32, 32] [1, 1, 1, 1] // CHECK: scf.for - // CHECK: tensor.extract_slice {{.*}} [1, 8, 32, 32] [1, 1, 1, 1] + // CHECK: tensor.collapse_shape {{.*}} tensor<1x8x32x32xbf16> into tensor<8x32x32xbf16> // CHECK: tensor.extract_slice {{.*}} [1, 8, 16, 32, 2] [1, 1, 1, 1, 1] + // CHECK: tensor.collapse_shape {{.*}} tensor<1x8x16x32x2xbf16> into tensor<8x16x32x2xbf16> // CHECK: tensor.extract_slice {{.*}} [1, 1, 32, 32] [1, 1, 1, 1] + // CHECK: tensor.collapse_shape {{.*}} tensor<1x1x32x32xf32> into tensor<32x32xf32> // CHECK: tensor.extract_slice {{.*}} [1, 1, 32, 32] [1, 1, 1, 1] + // CHECK: tensor.collapse_shape {{.*}} tensor<1x1x32x32xbf16> into tensor<32x32xbf16> // CHECK: scf.if // CHECK: linalg.fill // CHECK: linalgx.batch_reduce_matmul_vnni @@ -92,6 +95,7 @@ func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x12 // CHECK: tensor.extract_slice {{.*}} [32, 256] [1, 1] // CHECK: scf.for // CHECK: tensor.extract_slice {{.*}} [1, 8, 16, 32, 2] [1, 1, 1, 1, 1] + // CHECK: tensor.collapse_shape {{.*}} tensor<1x8x16x32x2xbf16> into tensor<8x16x32x2xbf16> // CHECK: tensor.extract_slice {{.*}} [32, 32] [1, 1] // CHECK: linalg.transpose {{.*}} permutation = [1, 0, 2] // CHECK: scf.if