From f8c34d1cf408d6e9d3d45634de45dd40b31d4041 Mon Sep 17 00:00:00 2001 From: alessandra simmons <30960626+MixedMatched@users.noreply.github.com> Date: Thu, 2 Oct 2025 18:34:34 -0400 Subject: [PATCH 01/56] [clang][Driver][HIP] Change OffloadingActionBuilder to respect the --no-gpu-bundle-output flag --- clang/lib/Driver/Driver.cpp | 16 +++++++++++----- clang/test/Driver/no-gpu-bundle-respected.hip | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+), 5 deletions(-) create mode 100644 clang/test/Driver/no-gpu-bundle-respected.hip diff --git a/clang/lib/Driver/Driver.cpp b/clang/lib/Driver/Driver.cpp index 85a1335785542..245ecf6053202 100644 --- a/clang/lib/Driver/Driver.cpp +++ b/clang/lib/Driver/Driver.cpp @@ -3844,6 +3844,9 @@ class OffloadingActionBuilder final { /// Flag set to true if all valid builders allow file bundling/unbundling. bool CanUseBundler; + /// Flag set to false if an argument turns off bundling. + bool ShouldUseBundler; + public: OffloadingActionBuilder(Compilation &C, DerivedArgList &Args, const Driver::InputList &Inputs) @@ -3878,6 +3881,9 @@ class OffloadingActionBuilder final { } CanUseBundler = ValidBuilders && ValidBuilders == ValidBuildersSupportingBundling; + + ShouldUseBundler = Args.hasFlag(options::OPT_gpu_bundle_output, + options::OPT_no_gpu_bundle_output, true); } ~OffloadingActionBuilder() { @@ -4029,11 +4035,11 @@ class OffloadingActionBuilder final { SB->appendTopLevelActions(OffloadAL); } - // If we can use the bundler, replace the host action by the bundling one in - // the resulting list. Otherwise, just append the device actions. For - // device only compilation, HostAction is a null pointer, therefore only do - // this when HostAction is not a null pointer. - if (CanUseBundler && HostAction && + // If we can and should use the bundler, replace the host action by the + // bundling one in the resulting list. Otherwise, just append the device + // actions. For device only compilation, HostAction is a null pointer, + // therefore only do this when HostAction is not a null pointer. + if (CanUseBundler && ShouldUseBundler && HostAction && HostAction->getType() != types::TY_Nothing && !OffloadAL.empty()) { // Add the host action to the list in order to create the bundling action. OffloadAL.push_back(HostAction); diff --git a/clang/test/Driver/no-gpu-bundle-respected.hip b/clang/test/Driver/no-gpu-bundle-respected.hip new file mode 100644 index 0000000000000..1587551f0322d --- /dev/null +++ b/clang/test/Driver/no-gpu-bundle-respected.hip @@ -0,0 +1,18 @@ +// RUN: %clang -ccc-print-phases -c -emit-llvm \ +// RUN: --offload-arch=gfx900,gfx1030 -O3 -x hip %s \ +// RUN: 2>&1 | FileCheck %s --check-prefix=OFFLOAD + +// RUN: %clang -ccc-print-phases -c -emit-llvm \ +// RUN: --gpu-bundle-output --offload-arch=gfx900,gfx1030 -O3 -x hip %s \ +// RUN: 2>&1 | FileCheck %s --check-prefix=OFFLOAD + +// RUN: %clang -ccc-print-phases -c -emit-llvm \ +// RUN: --no-gpu-bundle-output --offload-arch=gfx900,gfx1030 -O3 -x hip %s \ +// RUN: 2>&1 | FileCheck %s --check-prefix=OFFLOAD2 + +// OFFLOAD: clang-offload-bundler +// OFFLOAD2-NOT: clang-offload-bundler + +int square(int num) { + return num * num; +} From 1ddc8ed7f6512a3b5c238cac88c402cda9d09312 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 2 Oct 2025 19:07:25 -0400 Subject: [PATCH 02/56] [mlir][vector] Simplify op rewrite pattern inheriting constructors. NFC. (#161670) Use the `Base` type alias from https://github.com/llvm/llvm-project/pull/158433. --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 82 +++++++++---------- .../Transforms/LowerVectorBroadcast.cpp | 2 +- .../Vector/Transforms/LowerVectorContract.cpp | 2 +- .../Vector/Transforms/LowerVectorGather.cpp | 6 +- .../Transforms/LowerVectorInterleave.cpp | 2 +- .../Vector/Transforms/LowerVectorMask.cpp | 6 +- .../Transforms/LowerVectorMultiReduction.cpp | 10 +-- .../Vector/Transforms/LowerVectorScan.cpp | 2 +- .../Transforms/LowerVectorShapeCast.cpp | 4 +- .../Vector/Transforms/LowerVectorShuffle.cpp | 2 +- .../Vector/Transforms/LowerVectorStep.cpp | 2 +- ...LowerVectorToFromElementsToShuffleTree.cpp | 2 +- .../Transforms/LowerVectorTranspose.cpp | 6 +- .../Transforms/VectorDropLeadUnitDim.cpp | 12 +-- .../VectorEmulateMaskedLoadStore.cpp | 4 +- .../Transforms/VectorEmulateNarrowType.cpp | 16 ++-- ...sertExtractStridedSliceRewritePatterns.cpp | 8 +- .../Vector/Transforms/VectorLinearize.cpp | 24 +++--- .../Transforms/VectorTransferOpTransforms.cpp | 2 +- .../Vector/Transforms/VectorTransforms.cpp | 40 ++++----- 20 files changed, 117 insertions(+), 117 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index eb4686997c1b9..b0132e889302f 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -580,7 +580,7 @@ namespace { // ElideSingleElementReduction for ReduceOp. struct ElideUnitDimsInMultiDimReduction : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp, PatternRewriter &rewriter) const override { @@ -730,7 +730,7 @@ std::optional> ReductionOp::getShapeForUnroll() { namespace { struct ElideSingleElementReduction : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ReductionOp reductionOp, PatternRewriter &rewriter) const override { @@ -2197,7 +2197,7 @@ namespace { // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast. class ExtractOpFromBroadcast final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -2220,7 +2220,7 @@ class ExtractOpFromBroadcast final : public OpRewritePattern { // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask. class ExtractOpFromCreateMask final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -2546,7 +2546,7 @@ rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, class FromElementsToShapeCast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(FromElementsOp fromElements, PatternRewriter &rewriter) const override { @@ -2938,7 +2938,7 @@ namespace { // Fold broadcast1(broadcast2(x)) into broadcast1(x). struct BroadcastFolder : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(BroadcastOp broadcastOp, PatternRewriter &rewriter) const override { @@ -3109,7 +3109,7 @@ namespace { // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector // to a broadcast. struct Canonicalize0DShuffleOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShuffleOp shuffleOp, PatternRewriter &rewriter) const override { @@ -3165,7 +3165,7 @@ static Value getScalarSplatSource(Value value) { /// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v). class ShuffleSplat final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShuffleOp op, PatternRewriter &rewriter) const override { @@ -3182,7 +3182,7 @@ class ShuffleSplat final : public OpRewritePattern { /// vector.interleave. class ShuffleInterleave : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShuffleOp op, PatternRewriter &rewriter) const override { @@ -3326,7 +3326,7 @@ namespace { // broadcast. class InsertToBroadcast final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertOp insertOp, PatternRewriter &rewriter) const override { @@ -3344,7 +3344,7 @@ class InsertToBroadcast final : public OpRewritePattern { /// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v). class InsertSplatToSplat final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertOp op, PatternRewriter &rewriter) const override { @@ -3380,7 +3380,7 @@ class InsertSplatToSplat final : public OpRewritePattern { /// %result = vector.from_elements %c1, %c2 : vector<2xi32> class InsertChainFullyInitialized final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertOp op, PatternRewriter &rewriter) const override { @@ -3748,7 +3748,7 @@ namespace { class FoldInsertStridedSliceSplat final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, PatternRewriter &rewriter) const override { @@ -3768,7 +3768,7 @@ class FoldInsertStridedSliceSplat final class FoldInsertStridedSliceOfExtract final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, PatternRewriter &rewriter) const override { @@ -3798,7 +3798,7 @@ class FoldInsertStridedSliceOfExtract final class InsertStridedSliceConstantFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; // Do not create constants with more than `vectorSizeFoldThreashold` elements, // unless the source vector constant has a single use. @@ -4250,7 +4250,7 @@ namespace { // %mask = vector.create_mask %new_ub : vector<8xi1> class StridedSliceCreateMaskFolder final : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; public: LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, @@ -4310,7 +4310,7 @@ class StridedSliceCreateMaskFolder final class StridedSliceConstantMaskFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, PatternRewriter &rewriter) const override { @@ -4365,7 +4365,7 @@ class StridedSliceConstantMaskFolder final class StridedSliceBroadcast final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -4416,7 +4416,7 @@ class StridedSliceBroadcast final /// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v). class StridedSliceSplat final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -4448,7 +4448,7 @@ class StridedSliceSplat final : public OpRewritePattern { class ContiguousExtractStridedSliceToExtract final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -5023,7 +5023,7 @@ namespace { /// ``` struct TransferReadAfterWriteToBroadcast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransferReadOp readOp, PatternRewriter &rewriter) const override { @@ -5458,7 +5458,7 @@ namespace { /// any other uses. class FoldWaw final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransferWriteOp writeOp, PatternRewriter &rewriter) const override { if (!llvm::isa(writeOp.getShapedType())) @@ -5514,7 +5514,7 @@ class FoldWaw final : public OpRewritePattern { struct SwapExtractSliceOfTransferWrite : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, PatternRewriter &rewriter) const override { @@ -5737,7 +5737,7 @@ LogicalResult MaskedLoadOp::verify() { namespace { class MaskedLoadFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MaskedLoadOp load, PatternRewriter &rewriter) const override { switch (getMaskFormat(load.getMask())) { @@ -5794,7 +5794,7 @@ LogicalResult MaskedStoreOp::verify() { namespace { class MaskedStoreFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MaskedStoreOp store, PatternRewriter &rewriter) const override { switch (getMaskFormat(store.getMask())) { @@ -5890,7 +5890,7 @@ static LogicalResult isZeroBasedContiguousSeq(Value indexVec) { namespace { class GatherFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(GatherOp gather, PatternRewriter &rewriter) const override { switch (getMaskFormat(gather.getMask())) { @@ -5910,7 +5910,7 @@ class GatherFolder final : public OpRewritePattern { /// maskedload. Only 1D fixed vectors are supported for now. class FoldContiguousGather final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(GatherOp op, PatternRewriter &rewriter) const override { if (!isa(op.getBase().getType())) @@ -5962,7 +5962,7 @@ LogicalResult ScatterOp::verify() { namespace { class ScatterFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ScatterOp scatter, PatternRewriter &rewriter) const override { switch (getMaskFormat(scatter.getMask())) { @@ -5982,7 +5982,7 @@ class ScatterFolder final : public OpRewritePattern { /// maskedstore. Only 1D fixed vectors are supported for now. class FoldContiguousScatter final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ScatterOp op, PatternRewriter &rewriter) const override { if (failed(isZeroBasedContiguousSeq(op.getIndices()))) @@ -6030,7 +6030,7 @@ LogicalResult ExpandLoadOp::verify() { namespace { class ExpandLoadFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExpandLoadOp expand, PatternRewriter &rewriter) const override { switch (getMaskFormat(expand.getMask())) { @@ -6081,7 +6081,7 @@ LogicalResult CompressStoreOp::verify() { namespace { class CompressStoreFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(CompressStoreOp compress, PatternRewriter &rewriter) const override { switch (getMaskFormat(compress.getMask())) { @@ -6260,7 +6260,7 @@ static VectorType trimTrailingOneDims(VectorType oldType) { class ShapeCastCreateMaskFolderTrailingOneDim final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShapeCastOp shapeOp, PatternRewriter &rewriter) const override { @@ -6330,7 +6330,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final /// If both (i) and (ii) are possible, (i) is chosen. class ShapeCastBroadcastFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { @@ -6614,7 +6614,7 @@ namespace { // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. class TransposeFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -6646,7 +6646,7 @@ class TransposeFolder final : public OpRewritePattern { /// Replace transpose(splat-like(v)) with broadcast(v) class FoldTransposeSplat final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -6663,7 +6663,7 @@ class FoldTransposeSplat final : public OpRewritePattern { /// Folds transpose(create_mask) into a new transposed create_mask. class FoldTransposeCreateMask final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransposeOp transpOp, PatternRewriter &rewriter) const override { @@ -6700,7 +6700,7 @@ class FoldTransposeCreateMask final : public OpRewritePattern { /// Folds transpose(shape_cast) into a new shape_cast. class FoldTransposeShapeCast final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -6750,7 +6750,7 @@ class FoldTransposeShapeCast final : public OpRewritePattern { /// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6). class FoldTransposeBroadcast : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit) {} @@ -6971,7 +6971,7 @@ namespace { /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1> class CreateMaskFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, PatternRewriter &rewriter) const override { @@ -7300,7 +7300,7 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor, /// %0 = arith.select %mask, %a, %passthru : vector<8xf32> /// class CanonializeEmptyMaskOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MaskOp maskOp, PatternRewriter &rewriter) const override { @@ -7410,7 +7410,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { // vector.broadcast. class SplatToBroadcastPattern final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(SplatOp splatOp, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(splatOp, splatOp.getType(), diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index dedc3b3f30201..61d9357e19bb4 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -34,7 +34,7 @@ namespace { /// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly. class BroadcastOpLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BroadcastOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 65702ffa152d9..efe8d14b3532a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -1151,7 +1151,7 @@ FailureOr ContractionOpLowering::lowerReduction( /// class OuterProductOpLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::OuterProductOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 1f96a3a108006..6bc8347bc6f76 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -50,7 +50,7 @@ namespace { /// /// Supports vector types with a fixed leading dimension. struct UnrollGather : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { @@ -98,7 +98,7 @@ struct UnrollGather : OpRewritePattern { /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef, /// but should be fairly straightforward to extend beyond that. struct RemoveStrideFromGatherSource : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { @@ -164,7 +164,7 @@ struct RemoveStrideFromGatherSource : OpRewritePattern { /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these /// loads/extracts are made conditional using `scf.if` ops. struct Gather1DToConditionalLoads : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp index 9d6a865a9301f..479fc0c6a9d8c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp @@ -163,7 +163,7 @@ class UnrollDeinterleaveOp final /// : vector<7xi16>, vector<7xi16> /// ``` struct InterleaveToShuffle final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::InterleaveOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index 5617b067d249e..7730c4e7c950a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -48,7 +48,7 @@ namespace { /// until a one-dimensional vector is reached. class CreateMaskOpLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { @@ -100,7 +100,7 @@ class CreateMaskOpLowering : public OpRewritePattern { /// will be folded at LLVM IR level. class ConstantMaskOpLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ConstantMaskOp op, PatternRewriter &rewriter) const override { @@ -184,7 +184,7 @@ namespace { /// and actually match the traits of its the nested `MaskableOpInterface`. template struct MaskOpRewritePattern : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; private: LogicalResult matchAndRewrite(MaskOp maskOp, diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index 4773732d8d9a6..e86e2a97038db 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -39,7 +39,7 @@ namespace { class InnerOuterDimReductionConversion : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; explicit InnerOuterDimReductionConversion( MLIRContext *context, vector::VectorMultiReductionLowering options, @@ -136,7 +136,7 @@ class InnerOuterDimReductionConversion class ReduceMultiDimReductionRank : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; explicit ReduceMultiDimReductionRank( MLIRContext *context, vector::VectorMultiReductionLowering options, @@ -304,7 +304,7 @@ class ReduceMultiDimReductionRank /// and combines results struct TwoDimMultiReductionToElementWise : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { @@ -359,7 +359,7 @@ struct TwoDimMultiReductionToElementWise /// a sequence of vector.reduction ops. struct TwoDimMultiReductionToReduction : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { @@ -420,7 +420,7 @@ struct TwoDimMultiReductionToReduction /// separately. struct OneDimMultiReductionToTwoDim : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp index af4851eb5f158..258f2cbc77736 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -99,7 +99,7 @@ namespace { /// return %7, %8 : vector<2x3xi32>, vector<2xi32> /// ``` struct ScanToArithOps : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ScanOp scanOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp index 603ea41d43360..c5f22b2eafeb7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -189,7 +189,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern { } public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { @@ -356,7 +356,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern { class ScalableShapeCastOpRewritePattern : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp index 78102f7325b9f..8f46ad6ea892b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp @@ -44,7 +44,7 @@ namespace { /// struct MixedSizeInputShuffleOpRewrite final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp index ee5568aefda27..08e7c895831ce 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp @@ -24,7 +24,7 @@ using namespace mlir::vector; namespace { struct StepToArithConstantOpRewrite final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::StepOp stepOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp index 6407a868abd85..7521e2491335b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp @@ -667,7 +667,7 @@ getToElementsDefiningOps(FromElementsOp fromElemsOp, struct ToFromElementsToShuffleTreeRewrite final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 9e7d0ced3e6d1..c3f7de0ac3c4e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -300,7 +300,7 @@ namespace { /// %x = vector.insert .., .. [.., ..] class TransposeOpLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering, MLIRContext *context, PatternBenefit benefit = 1) @@ -395,7 +395,7 @@ class TransposeOpLowering : public OpRewritePattern { class Transpose2DWithUnitDimToShapeCast : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; Transpose2DWithUnitDimToShapeCast(MLIRContext *context, PatternBenefit benefit = 1) @@ -433,7 +433,7 @@ class Transpose2DWithUnitDimToShapeCast class TransposeOp2DToShuffleLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; TransposeOp2DToShuffleLowering( vector::VectorTransposeLowering vectorTransposeLowering, diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index cab12894487e2..963b2c803bc5a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -54,7 +54,7 @@ namespace { // input by inserting vector.broadcast. struct CastAwayExtractStridedSliceLeadingOneDim : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, PatternRewriter &rewriter) const override { @@ -104,7 +104,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim // inputs by inserting vector.broadcast. struct CastAwayInsertStridedSliceLeadingOneDim : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, PatternRewriter &rewriter) const override { @@ -145,7 +145,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim // Casts away leading one dimensions in vector.insert's vector inputs by // inserting vector.broadcast. struct CastAwayInsertLeadingOneDim : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::InsertOp insertOp, PatternRewriter &rewriter) const override { @@ -221,7 +221,7 @@ static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask, // 1 dimensions. struct CastAwayTransferReadLeadingOneDim : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { @@ -275,7 +275,7 @@ struct CastAwayTransferReadLeadingOneDim // 1 dimensions. struct CastAwayTransferWriteLeadingOneDim : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { @@ -541,7 +541,7 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern { // vector.broadcast back to the original shape. struct CastAwayConstantMaskLeadingOneDim : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ConstantMaskOp mask, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp index bdbb792041e3d..7acc120508a44 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp @@ -48,7 +48,7 @@ namespace { /// struct VectorMaskedLoadOpConverter final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp, PatternRewriter &rewriter) const override { @@ -117,7 +117,7 @@ struct VectorMaskedLoadOpConverter final /// struct VectorMaskedStoreOpConverter final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 264cbc1869b9a..3a6684f4edfb7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -548,7 +548,7 @@ namespace { // NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to // `false` to generate non-atomic RMW sequences. struct ConvertVectorStore final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW) : OpConversionPattern(context), @@ -827,7 +827,7 @@ struct ConvertVectorStore final : OpConversionPattern { /// adjusted mask . struct ConvertVectorMaskedStore final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor, @@ -950,7 +950,7 @@ struct ConvertVectorMaskedStore final /// those cases, loads are converted to byte-aligned, byte-sized loads and the /// target vector is extracted from the loaded vector. struct ConvertVectorLoad final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor, @@ -1059,7 +1059,7 @@ struct ConvertVectorLoad final : OpConversionPattern { /// bitcasting, since each `i8` container element holds two `i4` values. struct ConvertVectorMaskedLoad final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor, @@ -1257,7 +1257,7 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy, // TODO: Document-me struct ConvertVectorTransferRead final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor, @@ -1942,7 +1942,7 @@ namespace { /// advantage of high-level information to avoid leaving LLVM to scramble with /// peephole optimizations. struct RewriteBitCastOfTruncI : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp, PatternRewriter &rewriter) const override { @@ -2147,7 +2147,7 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern { /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4> /// struct RewriteAlignedSubByteIntTrunc : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(arith::TruncIOp truncOp, PatternRewriter &rewriter) const override { @@ -2200,7 +2200,7 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern { /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4> /// struct RewriteVectorTranspose : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit) : OpRewritePattern(context, benefit) {} diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index f6d6555f4c6e2..9e49873a4b4b0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -34,7 +34,7 @@ using namespace mlir::vector; class DecomposeDifferentRankInsertStridedSlice : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -84,7 +84,7 @@ class DecomposeDifferentRankInsertStridedSlice class ConvertSameRankInsertStridedSliceIntoShuffle : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; void initialize() { // This pattern creates recursive InsertStridedSliceOp, but the recursion is @@ -183,7 +183,7 @@ class ConvertSameRankInsertStridedSliceIntoShuffle class Convert1DExtractStridedSliceIntoShuffle : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -271,7 +271,7 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final class DecomposeNDExtractStridedSlice : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; void initialize() { // This pattern creates recursive ExtractStridedSliceOp, but the recursion diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 82bac8c499028..71fba71c9f15f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -214,7 +214,7 @@ SmallVector static getStridedSliceInsertionIndices( /// vector.extract_strided_slice operation. struct LinearizeVectorExtractStridedSlice final : public mlir::OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -285,7 +285,7 @@ struct LinearizeVectorExtractStridedSlice final /// struct LinearizeVectorInsertStridedSlice final : public mlir::OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -348,7 +348,7 @@ struct LinearizeVectorInsertStridedSlice final /// of the original shuffle operation. struct LinearizeVectorShuffle final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorShuffle(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -423,7 +423,7 @@ struct LinearizeVectorShuffle final /// struct LinearizeVectorExtract final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorExtract(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -501,7 +501,7 @@ struct LinearizeVectorExtract final /// struct LinearizeVectorInsert final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorInsert(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -575,7 +575,7 @@ struct LinearizeVectorInsert final /// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16> struct LinearizeVectorBitCast final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorBitCast(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -598,7 +598,7 @@ struct LinearizeVectorBitCast final /// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> struct LinearizeVectorSplat final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -629,7 +629,7 @@ struct LinearizeVectorSplat final /// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> struct LinearizeVectorCreateMask final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorCreateMask(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -684,7 +684,7 @@ struct LinearizeVectorCreateMask final /// For generic cases, the vector unroll pass should be used to unroll the load /// to vector<1x1x...xN> form and then linearized struct LinearizeVectorLoad final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -731,7 +731,7 @@ struct LinearizeVectorLoad final : public OpConversionPattern { /// to vector<1x1x...xN> form and then linearized struct LinearizeVectorStore final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -778,7 +778,7 @@ struct LinearizeVectorStore final /// struct LinearizeVectorFromElements final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorFromElements(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -814,7 +814,7 @@ struct LinearizeVectorFromElements final /// struct LinearizeVectorToElements final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorToElements(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index c364a8b54167c..1121d9550f265 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -1081,7 +1081,7 @@ class RewriteScalarExtractOfTransferRead /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) /// to memref.store. class RewriteScalarWrite : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 866f789ec6a39..d6a6d7cdba673 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -78,7 +78,7 @@ namespace { /// ``` struct MultiReduceToContract : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, PatternRewriter &rewriter) const override { @@ -138,7 +138,7 @@ struct MultiReduceToContract /// ``` struct CombineContractABTranspose final : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { @@ -202,7 +202,7 @@ struct CombineContractABTranspose final /// ``` struct CombineContractResultTranspose final : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp resTOp, PatternRewriter &rewriter) const override { @@ -568,7 +568,7 @@ static SmallVector getIntValueVector(ArrayAttr arrayAttr) { // %2 = vector.extract %1[1] : f16 from vector<2xf16> struct BubbleDownVectorBitCastForExtract : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -643,7 +643,7 @@ struct BubbleDownVectorBitCastForExtract // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> struct BubbleDownBitCastForStridedSliceExtract : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, PatternRewriter &rewriter) const override { @@ -721,7 +721,7 @@ struct BubbleDownBitCastForStridedSliceExtract // %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8> // struct BubbleUpBitCastForInsert : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, PatternRewriter &rewriter) const override { @@ -794,7 +794,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern { // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> struct BubbleUpBitCastForStridedSliceInsert : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, PatternRewriter &rewriter) const override { @@ -892,7 +892,7 @@ struct BubbleUpBitCastForStridedSliceInsert // %7 = vector.insert_strided_slice %6, %cst { // offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> struct BreakDownVectorBitCast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; public: BreakDownVectorBitCast(MLIRContext *context, @@ -1131,7 +1131,7 @@ struct ReorderElementwiseOpsOnBroadcast final class ExtractOpFromElementwise final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractOp op, PatternRewriter &rewriter) const override { @@ -1206,7 +1206,7 @@ static bool isSupportedMemSinkElementType(Type type) { /// ``` class ExtractOpFromLoad final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractOp op, PatternRewriter &rewriter) const override { @@ -1285,7 +1285,7 @@ class ExtractOpFromLoad final : public OpRewritePattern { class StoreOpFromSplatOrBroadcast final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::StoreOp op, PatternRewriter &rewriter) const override { @@ -1476,7 +1476,7 @@ static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) { /// InstCombine seems to handle vectors with multiple elements but not the /// single element ones. struct FoldI1Select : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(arith::SelectOp selectOp, PatternRewriter &rewriter) const override { @@ -1560,7 +1560,7 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { /// Drop inner most contiguous unit dimensions from transfer_read operand. class DropInnerMostUnitDimsTransferRead : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { @@ -1651,7 +1651,7 @@ class DropInnerMostUnitDimsTransferRead /// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`). class DropInnerMostUnitDimsTransferWrite : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { @@ -1728,7 +1728,7 @@ class DropInnerMostUnitDimsTransferWrite /// with the RHS transposed) lowering. struct CanonicalizeContractMatmulToMMT final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; using FilterConstraintType = std::function; @@ -1845,7 +1845,7 @@ struct CanonicalizeContractMatmulToMMT final template struct FoldArithExtIntoContractionOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { @@ -1878,7 +1878,7 @@ struct FoldArithExtIntoContractionOp /// %b = vector.reduction %a, %acc /// ``` struct ChainedReduction final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override { @@ -2033,7 +2033,7 @@ struct DropUnitDimFromElementwiseOps final /// ``` struct DropUnitDimsFromTransposeOp final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { @@ -2110,7 +2110,7 @@ struct DropUnitDimsFromTransposeOp final /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> /// ``` struct DropUnitDimsFromScfForOp final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override { @@ -2155,7 +2155,7 @@ struct DropUnitDimsFromScfForOp final : OpRewritePattern { /// %c = vector.reduction %b, %acc /// ``` struct ReduceRedundantZero final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override { From dadcbaeaf45f54e6c54e703ea6daec969f3c2cd7 Mon Sep 17 00:00:00 2001 From: "S. VenkataKeerthy" <31350914+svkeerthy@users.noreply.github.com> Date: Thu, 2 Oct 2025 16:35:12 -0700 Subject: [PATCH 03/56] [NFC][IR2Vec] Moving `parseVocabSection()` to `VocabStorage` (#161711) --- llvm/include/llvm/Analysis/IR2Vec.h | 9 ++- llvm/lib/Analysis/IR2Vec.cpp | 86 ++++++++++++++--------------- 2 files changed, 50 insertions(+), 45 deletions(-) diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index b7c301580a8a4..ed43f19b4a7d3 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -210,6 +210,13 @@ class VocabStorage { const_iterator end() const { return const_iterator(this, getNumSections(), 0); } + + using VocabMap = std::map; + /// Parse a vocabulary section from JSON and populate the target vocabulary + /// map. + static Error parseVocabSection(StringRef Key, + const json::Value &ParsedVocabValue, + VocabMap &TargetVocab, unsigned &Dim); }; /// Class for storing and accessing the IR2Vec vocabulary. @@ -600,8 +607,6 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin { Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab, VocabMap &ArgVocab); - Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, - VocabMap &TargetVocab, unsigned &Dim); void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab, VocabMap &ArgVocab); void emitError(Error Err, LLVMContext &Ctx); diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index af30422b73759..295b6d33525d9 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -330,6 +330,43 @@ bool VocabStorage::const_iterator::operator!=( return !(*this == Other); } +Error VocabStorage::parseVocabSection(StringRef Key, + const json::Value &ParsedVocabValue, + VocabMap &TargetVocab, unsigned &Dim) { + json::Path::Root Path(""); + const json::Object *RootObj = ParsedVocabValue.getAsObject(); + if (!RootObj) + return createStringError(errc::invalid_argument, + "JSON root is not an object"); + + const json::Value *SectionValue = RootObj->get(Key); + if (!SectionValue) + return createStringError(errc::invalid_argument, + "Missing '" + std::string(Key) + + "' section in vocabulary file"); + if (!json::fromJSON(*SectionValue, TargetVocab, Path)) + return createStringError(errc::illegal_byte_sequence, + "Unable to parse '" + std::string(Key) + + "' section from vocabulary"); + + Dim = TargetVocab.begin()->second.size(); + if (Dim == 0) + return createStringError(errc::illegal_byte_sequence, + "Dimension of '" + std::string(Key) + + "' section of the vocabulary is zero"); + + if (!std::all_of(TargetVocab.begin(), TargetVocab.end(), + [Dim](const std::pair &Entry) { + return Entry.second.size() == Dim; + })) + return createStringError( + errc::illegal_byte_sequence, + "All vectors in the '" + std::string(Key) + + "' section of the vocabulary are not of the same dimension"); + + return Error::success(); +} + // ==----------------------------------------------------------------------===// // Vocabulary //===----------------------------------------------------------------------===// @@ -460,43 +497,6 @@ VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) { // IR2VecVocabAnalysis //===----------------------------------------------------------------------===// -Error IR2VecVocabAnalysis::parseVocabSection( - StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, - unsigned &Dim) { - json::Path::Root Path(""); - const json::Object *RootObj = ParsedVocabValue.getAsObject(); - if (!RootObj) - return createStringError(errc::invalid_argument, - "JSON root is not an object"); - - const json::Value *SectionValue = RootObj->get(Key); - if (!SectionValue) - return createStringError(errc::invalid_argument, - "Missing '" + std::string(Key) + - "' section in vocabulary file"); - if (!json::fromJSON(*SectionValue, TargetVocab, Path)) - return createStringError(errc::illegal_byte_sequence, - "Unable to parse '" + std::string(Key) + - "' section from vocabulary"); - - Dim = TargetVocab.begin()->second.size(); - if (Dim == 0) - return createStringError(errc::illegal_byte_sequence, - "Dimension of '" + std::string(Key) + - "' section of the vocabulary is zero"); - - if (!std::all_of(TargetVocab.begin(), TargetVocab.end(), - [Dim](const std::pair &Entry) { - return Entry.second.size() == Dim; - })) - return createStringError( - errc::illegal_byte_sequence, - "All vectors in the '" + std::string(Key) + - "' section of the vocabulary are not of the same dimension"); - - return Error::success(); -} - // FIXME: Make this optional. We can avoid file reads // by auto-generating a default vocabulary during the build time. Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab, @@ -513,16 +513,16 @@ Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab, return ParsedVocabValue.takeError(); unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0; - if (auto Err = - parseVocabSection("Opcodes", *ParsedVocabValue, OpcVocab, OpcodeDim)) + if (auto Err = VocabStorage::parseVocabSection("Opcodes", *ParsedVocabValue, + OpcVocab, OpcodeDim)) return Err; - if (auto Err = - parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim)) + if (auto Err = VocabStorage::parseVocabSection("Types", *ParsedVocabValue, + TypeVocab, TypeDim)) return Err; - if (auto Err = - parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim)) + if (auto Err = VocabStorage::parseVocabSection("Arguments", *ParsedVocabValue, + ArgVocab, ArgDim)) return Err; if (!(OpcodeDim == TypeDim && TypeDim == ArgDim)) From b5ce27a4fcb58f8d5590c4daafb8bb90dbaf92b4 Mon Sep 17 00:00:00 2001 From: Phoebe Wang Date: Fri, 3 Oct 2025 09:22:55 +0800 Subject: [PATCH 04/56] [X86][AMX] Combine constant zero vector and AMX cast to tilezero (#92384) Found this problem when investigating #91207 --- llvm/lib/Target/X86/X86LowerAMXType.cpp | 30 ++++++++ llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll | 78 +++++---------------- 2 files changed, 48 insertions(+), 60 deletions(-) diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp index 278ae46b8a5f5..0ba71ada8638e 100644 --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -854,6 +854,7 @@ class X86LowerAMXCast { : Func(F), SC(ShapeC), DT(nullptr) {} bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST); bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD); + bool combineTilezero(IntrinsicInst *Cast); bool combineLdSt(SmallVectorImpl &Casts); bool combineAMXcast(TargetLibraryInfo *TLI); bool transformAMXCast(IntrinsicInst *AMXCast); @@ -1175,6 +1176,26 @@ bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { return EraseLoad; } +// %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer) +// --> +// %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col) +bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) { + Value *Row = nullptr, *Col = nullptr; + Use &U = *(Cast->use_begin()); + unsigned OpNo = U.getOperandNo(); + auto *II = cast(U.getUser()); + if (!isAMXIntrinsic(II)) + return false; + + std::tie(Row, Col) = SC->getShape(II, OpNo); + + IRBuilder<> Builder(Cast); + Value *NewInst = + Builder.CreateIntrinsic(Intrinsic::x86_tilezero_internal, {}, {Row, Col}); + Cast->replaceAllUsesWith(NewInst); + return true; +} + bool X86LowerAMXCast::combineLdSt(SmallVectorImpl &Casts) { bool Change = false; for (auto *Cast : Casts) { @@ -1198,6 +1219,14 @@ bool X86LowerAMXCast::combineLdSt(SmallVectorImpl &Casts) { for (auto *Store : DeadStores) Store->eraseFromParent(); } else { // x86_cast_vector_to_tile + // %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer) + // --> + // %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col) + if (isa(Cast->getOperand(0))) { + Change |= combineTilezero(cast(Cast)); + continue; + } + auto *Load = dyn_cast(Cast->getOperand(0)); if (!Load || !Load->hasOneUse()) continue; @@ -1210,6 +1239,7 @@ bool X86LowerAMXCast::combineLdSt(SmallVectorImpl &Casts) { // Set the operand is null so that load instruction can be erased. Cast->setOperand(0, nullptr); Load->eraseFromParent(); + Change = true; } } } diff --git a/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll b/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll index 6ef7219cfebdb..9cf7aab0b3655 100644 --- a/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll @@ -56,14 +56,9 @@ define void @PR90954(ptr %0, ptr %1, i32 %2) nounwind { ; CHECK-LABEL: PR90954: ; CHECK: # %bb.0: ; CHECK-NEXT: pushq %rbp -; CHECK-NEXT: movq %rsp, %rbp -; CHECK-NEXT: pushq %r15 ; CHECK-NEXT: pushq %r14 -; CHECK-NEXT: pushq %r13 -; CHECK-NEXT: pushq %r12 ; CHECK-NEXT: pushq %rbx -; CHECK-NEXT: andq $-1024, %rsp # imm = 0xFC00 -; CHECK-NEXT: subq $5120, %rsp # imm = 0x1400 +; CHECK-NEXT: subq $2912, %rsp # imm = 0xB60 ; CHECK-NEXT: vxorps %xmm0, %xmm0, %xmm0 ; CHECK-NEXT: vmovups %zmm0, {{[0-9]+}}(%rsp) ; CHECK-NEXT: movb $1, {{[0-9]+}}(%rsp) @@ -79,29 +74,26 @@ define void @PR90954(ptr %0, ptr %1, i32 %2) nounwind { ; CHECK-NEXT: movw $64, %cx ; CHECK-NEXT: movw $16, %di ; CHECK-NEXT: movb $1, %r8b -; CHECK-NEXT: movl $64, %r9d -; CHECK-NEXT: leaq {{[0-9]+}}(%rsp), %r10 -; CHECK-NEXT: leaq {{[0-9]+}}(%rsp), %r11 -; CHECK-NEXT: xorl %ebx, %ebx -; CHECK-NEXT: xorl %r14d, %r14d +; CHECK-NEXT: xorl %r9d, %r9d +; CHECK-NEXT: xorl %r10d, %r10d ; CHECK-NEXT: jmp .LBB1_1 ; CHECK-NEXT: .p2align 4 ; CHECK-NEXT: .LBB1_5: # in Loop: Header=BB1_1 Depth=1 -; CHECK-NEXT: incq %r14 -; CHECK-NEXT: addl %edx, %ebx +; CHECK-NEXT: incq %r10 +; CHECK-NEXT: addl %edx, %r9d ; CHECK-NEXT: .LBB1_1: # =>This Loop Header: Depth=1 ; CHECK-NEXT: # Child Loop BB1_2 Depth 2 -; CHECK-NEXT: movslq %ebx, %r15 -; CHECK-NEXT: leaq (%rsi,%r15,4), %r15 -; CHECK-NEXT: xorl %r12d, %r12d -; CHECK-NEXT: xorl %r13d, %r13d +; CHECK-NEXT: movslq %r9d, %r11 +; CHECK-NEXT: leaq (%rsi,%r11,4), %r11 +; CHECK-NEXT: xorl %ebx, %ebx +; CHECK-NEXT: xorl %r14d, %r14d ; CHECK-NEXT: jmp .LBB1_2 ; CHECK-NEXT: .p2align 4 ; CHECK-NEXT: .LBB1_4: # in Loop: Header=BB1_2 Depth=2 -; CHECK-NEXT: tilestored %tmm1, (%r15,%rax) -; CHECK-NEXT: incq %r13 -; CHECK-NEXT: addq $64, %r15 -; CHECK-NEXT: decq %r12 +; CHECK-NEXT: tilestored %tmm1, (%r11,%rax) +; CHECK-NEXT: incq %r14 +; CHECK-NEXT: addq $64, %r11 +; CHECK-NEXT: decq %rbx ; CHECK-NEXT: je .LBB1_5 ; CHECK-NEXT: .LBB1_2: # Parent Loop BB1_1 Depth=1 ; CHECK-NEXT: # => This Inner Loop Header: Depth=2 @@ -110,46 +102,12 @@ define void @PR90954(ptr %0, ptr %1, i32 %2) nounwind { ; CHECK-NEXT: testb %r8b, %r8b ; CHECK-NEXT: jne .LBB1_4 ; CHECK-NEXT: # %bb.3: # in Loop: Header=BB1_2 Depth=2 -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: tileloadd (%r10,%r9), %tmm1 -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: tileloadd (%r11,%r9), %tmm2 +; CHECK-NEXT: tilezero %tmm1 +; CHECK-NEXT: tilezero %tmm2 ; CHECK-NEXT: tdpbf16ps %tmm2, %tmm1, %tmm0 -; CHECK-NEXT: movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill -; CHECK-NEXT: movabsq $64, %rax -; CHECK-NEXT: tilestored %tmm0, 3072(%rsp,%rax) # 1024-byte Folded Spill -; CHECK-NEXT: tileloadd 3072(%rsp,%rax), %tmm1 # 1024-byte Folded Reload -; CHECK-NEXT: movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload +; CHECK-NEXT: movabsq $64, %rbp +; CHECK-NEXT: tilestored %tmm0, 896(%rsp,%rbp) # 1024-byte Folded Spill +; CHECK-NEXT: tileloadd 896(%rsp,%rbp), %tmm1 # 1024-byte Folded Reload ; CHECK-NEXT: jmp .LBB1_4 %4 = shl i32 %2, 4 %5 = icmp eq i64 0, 0 From d0ffd7bc5d7df18aa80e6c0095ea9b5812308f6f Mon Sep 17 00:00:00 2001 From: Lang Hames Date: Fri, 3 Oct 2025 12:01:38 +1000 Subject: [PATCH 05/56] [orc-rt] Add CallableTraitsHelper, refactor WrapperFunction to use it. (#161761) CallableTraitsHelper identifies the return type and argument types of a callable type and passes those to an implementation class template to operate on. The CallableArgInfo utility uses CallableTraitsHelper to provide typedefs for the return type and argument types (as a tuple) of a callable type. In WrapperFunction.h, the detail::WFCallableTraits utility is rewritten in terms of CallableTraitsHandler (and renamed to WFHandlerTraits). --- orc-rt/include/orc-rt/CallableTraitsHelper.h | 74 ++++++++++++++++++ orc-rt/include/orc-rt/WrapperFunction.h | 77 ++++++++----------- orc-rt/unittests/CMakeLists.txt | 1 + orc-rt/unittests/CallableTraitsHelperTest.cpp | 69 +++++++++++++++++ 4 files changed, 176 insertions(+), 45 deletions(-) create mode 100644 orc-rt/include/orc-rt/CallableTraitsHelper.h create mode 100644 orc-rt/unittests/CallableTraitsHelperTest.cpp diff --git a/orc-rt/include/orc-rt/CallableTraitsHelper.h b/orc-rt/include/orc-rt/CallableTraitsHelper.h new file mode 100644 index 0000000000000..12d7d5672c73a --- /dev/null +++ b/orc-rt/include/orc-rt/CallableTraitsHelper.h @@ -0,0 +1,74 @@ +//===- CallableTraitsHelper.h - Callable arg/ret type extractor -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// CallableTraitsHelper API. +// +//===----------------------------------------------------------------------===// + +#ifndef ORC_RT_CALLABLETRAITSHELPER_H +#define ORC_RT_CALLABLETRAITSHELPER_H + +#include +#include + +namespace orc_rt { + +/// CallableTraitsHelper takes an implementation class template Impl and some +/// callable type C and passes the return and argument types of C to the Impl +/// class template. +/// +/// This can be used to simplify the implementation of classes that need to +/// operate on callable types. +template