Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][ArmSME] Add arith-to-arm-sme conversion pass (#78197)
Existing 'arith::ConstantOp' conversion and tests are moved from VectorToArmSME. There's currently only a single op that's converted at the moment, but this will grow in the future as things like in-tile add are implemented. Also, 'createLoopOverTileSlices' is moved to ArmSME utils since it's relevant for both conversions.
- Loading branch information
Showing
18 changed files
with
257 additions
and
114 deletions.
There are no files selected for viewing
27 changes: 27 additions & 0 deletions
27
mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
//===- ArithToArmSME.h - Arith to ArmSME dialect conversion -----*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H | ||
#define MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H | ||
|
||
#include <memory> | ||
|
||
namespace mlir { | ||
|
||
class RewritePatternSet; | ||
class Pass; | ||
|
||
#define GEN_PASS_DECL_ARITHTOARMSMECONVERSIONPASS | ||
#include "mlir/Conversion/Passes.h.inc" | ||
|
||
namespace arith { | ||
void populateArithToArmSMEConversionPatterns(RewritePatternSet &patterns); | ||
} // namespace arith | ||
} // namespace mlir | ||
|
||
#endif // MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
//===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" | ||
|
||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/ArmSME/IR/ArmSME.h" | ||
#include "mlir/Dialect/ArmSME/Utils/Utils.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir { | ||
#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS | ||
#include "mlir/Conversion/Passes.h.inc" | ||
} // namespace mlir | ||
|
||
#define DEBUG_TYPE "arith-to-arm-sme" | ||
|
||
using namespace mlir; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Conversion helpers | ||
//===----------------------------------------------------------------------===// | ||
|
||
/// Returns true if 'val' is a splat of zero, false otherwise. | ||
static bool isSplatZero(Type elemType, DenseElementsAttr val) { | ||
if (llvm::isa<FloatType>(elemType)) | ||
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero(); | ||
if (llvm::isa<IntegerType>(elemType)) | ||
return val && val.isSplat() && val.getSplatValue<APInt>().isZero(); | ||
return false; | ||
} | ||
|
||
namespace { | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ConstantOp | ||
//===----------------------------------------------------------------------===// | ||
|
||
/// Conversion pattern for dense arith.constant. | ||
struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> { | ||
using OpRewritePattern<arith::ConstantOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(arith::ConstantOp constantOp, | ||
PatternRewriter &rewriter) const final { | ||
auto tileType = dyn_cast<VectorType>(constantOp.getType()); | ||
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) | ||
return failure(); | ||
|
||
auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr()); | ||
if (!denseAttr || !denseAttr.isSplat()) | ||
return failure(); | ||
|
||
auto tileElementType = tileType.getElementType(); | ||
|
||
// Lower 'arith.constant dense<0>' to 'arm_sme.zero' op. | ||
if (isSplatZero(tileElementType, denseAttr)) { | ||
rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType); | ||
return success(); | ||
} | ||
|
||
// Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice' | ||
// ops that broadcast the constant to each tile slice. | ||
auto loc = constantOp.getLoc(); | ||
|
||
// To fill a tile with a constant, we create a 1-D splat of the constant, | ||
// then move that into each tile slice (the largest unit we can set at once, | ||
// outside of operations like the outerproduct). | ||
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); | ||
auto denseAttr1D = DenseElementsAttr::get( | ||
tileSliceType, denseAttr.getSplatValue<Attribute>()); | ||
auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D); | ||
|
||
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); | ||
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, | ||
Value currentTile) { | ||
// Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile | ||
// slice. | ||
auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>( | ||
loc, tileType, constantOp1D, currentTile, tileSliceIndex); | ||
return nextTile.getResult(); | ||
}; | ||
auto forOp = mlir::arm_sme::createLoopOverTileSlices( | ||
rewriter, loc, initTile, makeLoopBody); | ||
rewriter.replaceOp(constantOp, forOp.getResult(0)); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Pattern population | ||
//===----------------------------------------------------------------------===// | ||
|
||
void mlir::arith::populateArithToArmSMEConversionPatterns( | ||
RewritePatternSet &patterns) { | ||
patterns.add<ConstantOpToArmSMELowering>(patterns.getContext()); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Pass definition | ||
//===----------------------------------------------------------------------===// | ||
|
||
namespace { | ||
struct ArithToArmSMEConversionPass final | ||
: impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> { | ||
using impl::ArithToArmSMEConversionPassBase< | ||
ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase; | ||
|
||
void runOnOperation() override { | ||
RewritePatternSet patterns(&getContext()); | ||
arith::populateArithToArmSMEConversionPatterns(patterns); | ||
if (failed( | ||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) | ||
return signalPassFailure(); | ||
} | ||
}; | ||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
add_mlir_conversion_library(MLIRArithToArmSME | ||
ArithToArmSME.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToArmSME | ||
|
||
DEPENDS | ||
MLIRConversionPassIncGen | ||
|
||
LINK_COMPONENTS | ||
Core | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRArmSMEDialect | ||
MLIRArithDialect | ||
MLIRPass | ||
MLIRTransforms | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.