diff --git a/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp index 03952da95b11e..265e268ab1b09 100644 --- a/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp @@ -2383,7 +2383,7 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType, auto context{builder.getContext()}; auto argBases{getBasesForArgs(args)}; - mlir::vector::SplatOp splatOp{nullptr}; + mlir::vector::BroadcastOp splatOp{nullptr}; mlir::Type retTy{nullptr}; switch (vop) { case VecOp::Splat: { @@ -2391,9 +2391,9 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType, auto vecTyInfo{getVecTypeFromFir(argBases[0])}; auto extractOp{genVecExtract(resultType, args)}; - splatOp = - mlir::vector::SplatOp::create(builder, loc, *(extractOp.getUnboxed()), - vecTyInfo.toMlirVectorType(context)); + splatOp = mlir::vector::BroadcastOp::create( + builder, loc, vecTyInfo.toMlirVectorType(context), + *(extractOp.getUnboxed())); retTy = vecTyInfo.toFirVectorType(); break; } @@ -2401,8 +2401,8 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType, assert(args.size() == 1); auto vecTyInfo{getVecTypeFromEle(argBases[0])}; - splatOp = mlir::vector::SplatOp::create( - builder, loc, argBases[0], vecTyInfo.toMlirVectorType(context)); + splatOp = mlir::vector::BroadcastOp::create( + builder, loc, vecTyInfo.toMlirVectorType(context), argBases[0]); retTy = vecTyInfo.toFirVectorType(); break; } @@ -2412,8 +2412,8 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType, auto intOp{builder.createConvert(loc, eleTy, argBases[0])}; // the intrinsic always returns vector(integer(4)) - splatOp = mlir::vector::SplatOp::create(builder, loc, intOp, - mlir::VectorType::get(4, eleTy)); + splatOp = mlir::vector::BroadcastOp::create( + builder, loc, mlir::VectorType::get(4, eleTy), intOp); retTy = fir::VectorType::get(4, eleTy); break; } @@ -2444,7 +2444,8 @@ PPCIntrinsicLibrary::genVecXlds(mlir::Type resultType, auto addrConv{fir::ConvertOp::create(builder, loc, i64RefTy, addr)}; auto addrVal{fir::LoadOp::create(builder, loc, addrConv)}; - auto splatRes{mlir::vector::SplatOp::create(builder, loc, addrVal, i64VecTy)}; + auto splatRes{ + mlir::vector::BroadcastOp::create(builder, loc, i64VecTy, addrVal)}; mlir::Value result{nullptr}; if (mlirTy != splatRes.getType()) { diff --git a/mlir/docs/Dialects/Vector.md b/mlir/docs/Dialects/Vector.md index 6c8949d70b4a3..839dc75ff0214 100644 --- a/mlir/docs/Dialects/Vector.md +++ b/mlir/docs/Dialects/Vector.md @@ -125,7 +125,7 @@ Some existing Arith and Vector Dialect on `n-D` `vector` types comprise: // Produces a vector<3x7x8xf32> %b = arith.mulf %0, %1 : vector<3x7x8xf32> // Produces a vector<3x7x8xf32> -%c = vector.splat %1 : vector<3x7x8xf32> +%c = vector.broadcast %1 : f32 to vector<3x7x8xf32> %d = vector.extract %0[1]: vector<7x8xf32> from vector<3x7x8xf32> %e = vector.extract %0[1, 5]: vector<8xf32> from vector<3x7x8xf32> @@ -176,8 +176,6 @@ infrastructure can apply iteratively. ### Virtual Vector to Hardware Vector Lowering For now, `VV -> HWV` are specified in C++ (see for instance the -[SplatOpLowering for n-D vectors](https://github.com/tensorflow/mlir/commit/0a0c4867c6a6fcb0a2f17ef26a791c1d551fe33d) -or the [VectorOuterProductOp lowering](https://github.com/tensorflow/mlir/commit/957b1ca9680b4aacabb3a480fbc4ebd2506334b8)). Simple diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 252c0b72456df..41e075467910f 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2880,53 +2880,6 @@ def Vector_PrintOp : }]; } -//===----------------------------------------------------------------------===// -// SplatOp -//===----------------------------------------------------------------------===// - -def Vector_SplatOp : Vector_Op<"splat", [ - Pure, - DeclareOpInterfaceMethods, - TypesMatchWith<"operand type matches element type of result", - "aggregate", "input", - "::llvm::cast($_self).getElementType()"> - ]> { - let summary = "vector splat or broadcast operation"; - let description = [{ - Note: This operation is deprecated. Please use vector.broadcast. - - Broadcast the operand to all elements of the result vector. The type of the - operand must match the element type of the vector type. - - Example: - - ```mlir - %s = arith.constant 10.1 : f32 - %t = vector.splat %s : vector<8x16xf32> - ``` - - This operation is deprecated, the preferred representation of the above is: - - ```mlir - %s = arith.constant 10.1 : f32 - %t = vector.broadcast %s : f32 to vector<8x16xf32> - ``` - }]; - - let arguments = (ins AnyType:$input); - let results = (outs AnyVectorOfAnyRank:$aggregate); - - let builders = [ - OpBuilder<(ins "Value":$element, "Type":$aggregateType), - [{ build($_builder, $_state, aggregateType, element); }]>]; - let assemblyFormat = "$input attr-dict `:` type($aggregate)"; - - let hasFolder = 1; - - // vector.splat is deprecated, and vector.broadcast should be used instead. - // Canonicalize vector.splat to vector.broadcast. - let hasCanonicalizer = 1; -} //===----------------------------------------------------------------------===// // VectorScaleOp diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index dcbaa5698d767..247dba101cfc1 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -432,10 +432,6 @@ static Value getOriginalVectorValue(Value value) { current = op.getSource(); return false; }) - .Case([¤t](auto op) { - current = op.getInput(); - return false; - }) .Default([](Operation *) { return false; }); if (!skipOp) { diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index bad53c0a4a97a..1002ebe6875b6 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -236,7 +236,7 @@ struct TileLoadOpConversion : public OpRewritePattern { /// AFTER: /// ```mlir /// ... -/// %pad_1d = vector.splat %pad : vector<[4]xi32> +/// %pad_1d = vector.broadcast %pad : i32 to vector<[4]xi32> /// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1 /// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) { /// ... diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index 363685a691180..778c616f1bf44 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -731,28 +731,14 @@ struct ExtractFromCreateMaskToPselLowering } }; -// Convert all `vector.splat` to `vector.broadcast`. There is a path from -// `vector.broadcast` to ArmSME via another pattern. -struct ConvertSplatToBroadcast : public OpRewritePattern { - using Base::Base; - - LogicalResult matchAndRewrite(vector::SplatOp splatOp, - PatternRewriter &rewriter) const final { - - rewriter.replaceOpWithNewOp(splatOp, splatOp.getType(), - splatOp.getInput()); - return success(); - } -}; - } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { - patterns.add(&ctx); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 546164628b795..5355909b62a7f 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -2161,19 +2161,6 @@ class TransposeOpToMatrixTransposeOpLowering } }; -/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from -/// `vector.broadcast` through other patterns. -struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LogicalResult - matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(splat, splat.getType(), - adaptor.getInput()); - return success(); - } -}; - } // namespace void mlir::vector::populateVectorRankReducingFMAPattern( @@ -2212,7 +2199,7 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorInsertOpConversion, VectorPrintOpConversion, VectorTypeCastOpConversion, VectorScaleOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, - VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering, + VectorBroadcastScalarToLowRankLowering, VectorBroadcastScalarToNdLowering, VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, MaskedReductionOpConversion, VectorInterleaveOpLowering, diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 311ff6f5fbeee..56e8fee191432 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -22,7 +22,6 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" -#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" @@ -79,20 +78,6 @@ struct VectorShapeCast final : public OpConversionPattern { } }; -// Convert `vector.splat` to `vector.broadcast`. There is a path from -// `vector.broadcast` to SPIRV via other patterns. -struct VectorSplatToBroadcast final - : public OpConversionPattern { - using Base::Base; - LogicalResult - matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(splat, splat.getType(), - adaptor.getInput()); - return success(); - } -}; - struct VectorBitcastConvert final : public OpConversionPattern { using Base::Base; @@ -1092,10 +1077,10 @@ void mlir::populateVectorToSPIRVPatterns( VectorReductionPattern, VectorReductionFloatMinMax, VectorReductionFloatMinMax, VectorShapeCast, - VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert, - VectorShuffleOpConvert, VectorInterleaveOpConvert, - VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern, - VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>( + VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, + VectorInterleaveOpConvert, VectorDeinterleaveOpConvert, + VectorScalarBroadcastPattern, VectorLoadOpConverter, + VectorStoreOpConverter, VectorStepOpConvert>( typeConverter, patterns.getContext(), PatternBenefit(1)); // Make sure that the more specialized dot product pattern has higher benefit diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index c64e10f534f8e..d018cddeb8dc1 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -123,8 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality( vector::OuterProductOp, vector::ScanOp>( [&](Operation *op) { return converter.isLegal(op); }); target.addLegalOp(); + arith::ConstantOp, arith::SelectOp, vector::BroadcastOp>(); } void EmulateUnsupportedFloatsPass::runOnOperation() { diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index b0132e889302f..dc58ac3cdee6f 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1664,10 +1664,10 @@ static bool hasZeroDimVectors(Operation *op) { llvm::any_of(op->getResultTypes(), hasZeroDimVectorType); } -/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend -/// 1s, are considered to be 'broadcastlike'. +/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are +/// considered to be 'broadcastlike'. static bool isBroadcastLike(Operation *op) { - if (isa(op)) + if (isa(op)) return true; auto shapeCast = dyn_cast(op); @@ -3131,12 +3131,11 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern { }; /// Consider the defining operation `defOp` of `value`. If `defOp` is a -/// vector.splat or a vector.broadcast with a scalar operand, return the scalar -/// value that is splatted. Otherwise return null. +/// vector.broadcast with a scalar operand, return the scalar value that is +/// splatted. Otherwise return null. /// -/// Examples: +/// Example: /// -/// scalar_source --> vector.splat --> value - return scalar_source /// scalar_source --> vector.broadcast --> value - return scalar_source static Value getScalarSplatSource(Value value) { // Block argument: @@ -3144,10 +3143,6 @@ static Value getScalarSplatSource(Value value) { if (!defOp) return {}; - // Splat: - if (auto splat = dyn_cast(defOp)) - return splat.getInput(); - auto broadcast = dyn_cast(defOp); // Not broadcast (and not splat): @@ -7393,41 +7388,6 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns( patterns.getContext(), benefit); } -//===----------------------------------------------------------------------===// -// SplatOp -//===----------------------------------------------------------------------===// - -OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { - auto constOperand = adaptor.getInput(); - if (!isa_and_nonnull(constOperand)) - return {}; - - // SplatElementsAttr::get treats single value for second arg as being a splat. - return SplatElementsAttr::get(getType(), {constOperand}); -} - -// Canonicalizer for vector.splat. It always gets canonicalized to a -// vector.broadcast. -class SplatToBroadcastPattern final : public OpRewritePattern { -public: - using Base::Base; - LogicalResult matchAndRewrite(SplatOp splatOp, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(splatOp, splatOp.getType(), - splatOp.getOperand()); - return success(); - } -}; -void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -void SplatOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), argRanges.front()); -} - Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath, diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 255f2bf5a8161..3a3231d513369 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -90,7 +90,7 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, Operation *maskOp = mask.getDefiningOp(); SmallVector extractOps; - // TODO: add support to `vector.splat`. + // TODO: add support to `vector.broadcast`. // Finding the mask creation operation. while (maskOp && !isa( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 71fba71c9f15f..1b656d82f3201 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -590,32 +590,6 @@ struct LinearizeVectorBitCast final } }; -/// This pattern converts the SplatOp to work on a linearized vector. -/// Following, -/// vector.splat %value : vector<4x4xf32> -/// is converted to: -/// %out_1d = vector.splat %value : vector<16xf32> -/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> -struct LinearizeVectorSplat final - : public OpConversionPattern { - using Base::Base; - - LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context, - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} - - LogicalResult - matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto dstTy = getTypeConverter()->convertType(splatOp.getType()); - if (!dstTy) - return rewriter.notifyMatchFailure(splatOp, "cannot convert type."); - rewriter.replaceOpWithNewOp(splatOp, adaptor.getInput(), - dstTy); - return success(); - } -}; - /// This pattern converts the CreateMaskOp to work on a linearized vector. /// It currently supports only 2D masks with a unit outer dimension. /// Following, @@ -934,9 +908,9 @@ void mlir::vector::populateVectorLinearizeBasePatterns( RewritePatternSet &patterns) { patterns .add(typeConverter, patterns.getContext()); + LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore, + LinearizeVectorFromElements, LinearizeVectorToElements>( + typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index d6a6d7cdba673..726da1e9a3d14 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -878,7 +878,7 @@ struct BubbleUpBitCastForStridedSliceInsert // This transforms IR like: // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> // Into: -// %cst = vector.splat %c0_f32 : vector<4xf32> +// %cst = vector.broadcast %c0_f32 : f32 to vector<4xf32> // %1 = vector.extract_strided_slice %0 { // offsets = [0], sizes = [4], strides = [1] // } : vector<8xf16> to vector<4xf16> @@ -987,8 +987,8 @@ static Type cloneOrReplace(Type type, Type newElementType) { return newElementType; } -/// If `value` is the result of a splat or broadcast operation, return the input -/// of the splat/broadcast operation. +/// If `value` is the result of a broadcast operation, return the input +/// of the broadcast operation. static Value getBroadcastLikeSource(Value value) { Operation *op = value.getDefiningOp(); @@ -998,13 +998,10 @@ static Value getBroadcastLikeSource(Value value) { if (auto broadcast = dyn_cast(op)) return broadcast.getSource(); - if (auto splat = dyn_cast(op)) - return splat.getInput(); - return {}; } -/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: +/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex: /// /// Example: /// ``` @@ -1017,9 +1014,6 @@ static Value getBroadcastLikeSource(Value value) { /// %r = arith.addi %arg0, %arg1 : index /// %b = vector.broadcast %r : index to vector<1x4xindex> /// ``` -/// -/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting -/// ops. struct ReorderElementwiseOpsOnBroadcast final : public OpTraitRewritePattern { using OpTraitRewritePattern::OpTraitRewritePattern; @@ -1045,29 +1039,29 @@ struct ReorderElementwiseOpsOnBroadcast final Type resultElemType = resultType.getElementType(); // Get the type of the first non-constant operand - Value splatSource; + Value broadcastSource; for (Value operand : op->getOperands()) { Operation *definingOp = operand.getDefiningOp(); if (!definingOp) return failure(); if (definingOp->hasTrait()) continue; - splatSource = getBroadcastLikeSource(operand); + broadcastSource = getBroadcastLikeSource(operand); break; } - if (!splatSource) + if (!broadcastSource) return failure(); Type unbroadcastResultType = - cloneOrReplace(splatSource.getType(), resultElemType); + cloneOrReplace(broadcastSource.getType(), resultElemType); // Make sure that all operands are broadcast from identically-shaped types: - // * scalar (`vector.broadcast` + `vector.splat`), or + // * scalar (`vector.broadcast`), or // * vector (`vector.broadcast`). // Otherwise the re-ordering wouldn't be safe. - if (!llvm::all_of(op->getOperands(), [splatSource](Value val) { + if (!llvm::all_of(op->getOperands(), [broadcastSource](Value val) { if (auto source = getBroadcastLikeSource(val)) return haveSameShapeAndScaling(source.getType(), - splatSource.getType()); + broadcastSource.getType()); SplatElementsAttr splatConst; return matchPattern(val, m_Constant(&splatConst)); })) { @@ -1271,19 +1265,18 @@ class ExtractOpFromLoad final : public OpRewritePattern { } }; -/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store. +/// Pattern to rewrite vector.store(vector.broadcast) -> vector/memref.store. /// /// Example: /// ``` -/// %0 = vector.splat %arg2 : vector<1xf32> +/// %0 = vector.broadcast %arg2 : f32 to vector<1xf32> /// vector.store %0, %arg0[%arg1] : memref, vector<1xf32> /// ``` /// Gets converted to: /// ``` /// memref.store %arg2, %arg0[%arg1] : memref /// ``` -class StoreOpFromSplatOrBroadcast final - : public OpRewritePattern { +class StoreOpFromBroadcast final : public OpRewritePattern { public: using Base::Base; @@ -1308,9 +1301,9 @@ class StoreOpFromSplatOrBroadcast final return rewriter.notifyMatchFailure( op, "value to store is not from a broadcast"); - // Checking for single use so we can remove splat. - Operation *splat = toStore.getDefiningOp(); - if (!splat->hasOneUse()) + // Checking for single use so we can remove broadcast. + Operation *broadcast = toStore.getDefiningOp(); + if (!broadcast->hasOneUse()) return rewriter.notifyMatchFailure(op, "expected single op use"); Value base = op.getBase(); @@ -1321,7 +1314,7 @@ class StoreOpFromSplatOrBroadcast final } else { rewriter.replaceOpWithNewOp(op, source, base, indices); } - rewriter.eraseOp(splat); + rewriter.eraseOp(broadcast); return success(); } }; @@ -2391,8 +2384,8 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns, void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { // TODO: Consider converting these patterns to canonicalizations. - patterns.add( - patterns.getContext(), benefit); + patterns.add(patterns.getContext(), + benefit); } void mlir::vector::populateChainedVectorReductionFoldingPatterns( diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir index eb9feaad15c5b..a75f30d57fa74 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir @@ -86,7 +86,7 @@ func.func @fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf // CHECK: %[[VAL:.+]] = spirv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] // CHECK: spirv.ReturnValue %[[VAL]] : vector<4xf32> func.func @splat(%f : f32) -> vector<4xf32> { - %splat = vector.splat %f : vector<4xf32> + %splat = vector.broadcast %f : f32 to vector<4xf32> return %splat : vector<4xf32> } diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir index c8a434bb8f5de..1735e08782528 100644 --- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir +++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir @@ -429,38 +429,6 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) { return } -//===----------------------------------------------------------------------===// -// vector.splat -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func.func @splat_vec2d_from_i32( -// CHECK-SAME: %[[SRC:.*]]: i32) { -// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32> -// CHECK: arm_sme.get_tile : vector<[4]x[4]xi32> -// CHECK: %[[VSCALE:.*]] = vector.vscale -// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index -// CHECK: scf.for {{.*}} to %[[NUM_TILE_SLICES]] {{.*}} { -// CHECK: arm_sme.insert_tile_slice %[[BCST]], {{.*}} : vector<[4]xi32> into vector<[4]x[4]xi32> -func.func @splat_vec2d_from_i32(%arg0: i32) { - %0 = vector.splat %arg0 : vector<[4]x[4]xi32> - "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> () - return -} - -// ----- - -// CHECK-LABEL: func.func @splat_vec2d_from_f16( -// CHECK-SAME: %[[SRC:.*]]: f16) { -// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<[8]xf16> -// CHECK: scf.for -// CHECK: arm_sme.insert_tile_slice %[[BCST]], {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16> -func.func @splat_vec2d_from_f16(%arg0: f16) { - %0 = vector.splat %arg0 : vector<[8]x[8]xf16> - "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> () - return -} //===----------------------------------------------------------------------===// // vector.transpose diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index 5973c2ba2cbd0..cb48ca3374e8d 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -2216,23 +2216,6 @@ func.func @compress_store_op_with_alignment(%arg0: memref, %arg1: vecto // ----- -//===----------------------------------------------------------------------===// -// vector.splat -//===----------------------------------------------------------------------===// - -// vector.splat is converted to vector.broadcast. Then, vector.broadcast is converted to LLVM. -// CHECK-LABEL: @splat_0d -// CHECK-NOT: splat -// CHECK: return -func.func @splat_0d(%elt: f32) -> (vector, vector<4xf32>, vector<[4]xf32>) { - %a = vector.splat %elt : vector - %b = vector.splat %elt : vector<4xf32> - %c = vector.splat %elt : vector<[4]xf32> - return %a, %b, %c : vector, vector<4xf32>, vector<[4]xf32> -} - -// ----- - //===----------------------------------------------------------------------===// // vector.scalable_insert //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Math/canonicalize_ipowi.mlir b/mlir/test/Dialect/Math/canonicalize_ipowi.mlir index 9e65a96869460..681209276ad6b 100644 --- a/mlir/test/Dialect/Math/canonicalize_ipowi.mlir +++ b/mlir/test/Dialect/Math/canonicalize_ipowi.mlir @@ -105,9 +105,9 @@ func.func @ipowi32_fold(%result : memref) { // --- Test vector folding --- %arg11_base = arith.constant 2 : i32 - %arg11_base_vec = vector.splat %arg11_base : vector<2x2xi32> + %arg11_base_vec = vector.broadcast %arg11_base : i32 to vector<2x2xi32> %arg11_power = arith.constant 30 : i32 - %arg11_power_vec = vector.splat %arg11_power : vector<2x2xi32> + %arg11_power_vec = vector.broadcast %arg11_power : i32 to vector<2x2xi32> %res11_vec = math.ipowi %arg11_base_vec, %arg11_power_vec : vector<2x2xi32> %i11 = arith.constant 11 : index %res11 = vector.extract %res11_vec[1, 1] : i32 from vector<2x2xi32> diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index bccf5d5b77b0e..d093bc92cd8c4 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -837,7 +837,7 @@ func.func @fold_extract_splatlike(%a : f32, %idx0 : index, %idx1 : index, %idx2 // CHECK-LABEL: fold_extract_vector_from_splat // CHECK: vector.broadcast {{.*}} f32 to vector<4xf32> func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> { - %b = vector.splat %a : vector<1x2x4xf32> + %b = vector.broadcast %a : f32 to vector<1x2x4xf32> %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> return %r : vector<4xf32> } diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir deleted file mode 100644 index e4a9391770b6c..0000000000000 --- a/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir +++ /dev/null @@ -1,126 +0,0 @@ -// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s - -// This file should be removed when vector.splat is removed. -// This file tests canonicalization/folding with vector.splat. -// These tests all have equivalent tests using vector.broadcast in canonicalize.mlir - - -// CHECK-LABEL: fold_extract_splat -// CHECK-SAME: %[[A:.*]]: f32 -// CHECK: return %[[A]] : f32 -func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { - %b = vector.splat %a : vector<1x2x4xf32> - %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32> - return %r : f32 -} - -// ----- - -// CHECK-LABEL: extract_strided_splat -// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16> -// CHECK-NEXT: return %[[B]] : vector<2x4xf16> -func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> { - %0 = vector.splat %arg0 : vector<16x4xf16> - %1 = vector.extract_strided_slice %0 - {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} : - vector<16x4xf16> to vector<2x4xf16> - return %1 : vector<2x4xf16> -} - -// ----- - -// CHECK-LABEL: func @splat_fold -// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> -// CHECK-NEXT: return [[V]] : vector<4xf32> -func.func @splat_fold() -> vector<4xf32> { - %c = arith.constant 1.0 : f32 - %v = vector.splat %c : vector<4xf32> - return %v : vector<4xf32> - -} - -// ----- - -// CHECK-LABEL: func @transpose_splat2( -// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { -// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> -// CHECK: return %[[VAL_1]] : vector<3x4xf32> -func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> { - %splat = vector.splat %arg : vector<4x3xf32> - %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32> - return %0 : vector<3x4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_strided_slice_splat -// CHECK-SAME: (%[[ARG:.*]]: f32) -// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32> -// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32> -func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) { - %splat0 = vector.splat %x : vector<4x4xf32> - %splat1 = vector.splat %x : vector<8x16xf32> - %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]} - : vector<4x4xf32> into vector<8x16xf32> - return %0 : vector<8x16xf32> -} - -// ----- - -// CHECK-LABEL: func @shuffle_splat -// CHECK-SAME: (%[[ARG:.*]]: i32) -// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32> -// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32> -func.func @shuffle_splat(%x : i32) -> vector<4xi32> { - %v0 = vector.splat %x : vector<4xi32> - %v1 = vector.splat %x : vector<2xi32> - %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32> - return %shuffle : vector<4xi32> -} - - -// ----- - -// CHECK-LABEL: func @insert_splat -// CHECK-SAME: (%[[ARG:.*]]: i32) -// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32> -// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32> -func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> { - %v0 = vector.splat %x : vector<4x3xi32> - %v1 = vector.splat %x : vector<2x4x3xi32> - %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32> - return %insert : vector<2x4x3xi32> -} - -// ----- - -// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression -// CHECK-SAME: (%[[A:.*]]: f32, %[[C:.*]]: vector<2xf32>) -func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %c: vector<2xf32>) -> (f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) { - // Splat scalar to 0D and extract scalar. - %0 = vector.splat %a : vector - %1 = vector.extract %0[] : f32 from vector - - // Broadcast scalar to 0D and extract scalar. - %2 = vector.splat %a : vector - %3 = vector.extract %2[] : f32 from vector - - // Splat scalar to 2D and extract scalar. - %6 = vector.splat %a : vector<2x3xf32> - %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32> - - // Broadcast scalar to 3D and extract scalar. - %8 = vector.splat %a : vector<5x6x7xf32> - %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32> - - // Extract 2D from 3D that was broadcasted from a scalar. - // CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32> - %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32> - - // Extract 1D from 2D that was splat'ed from a scalar. - // CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32> - %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32> - - // CHECK: return %[[A]], %[[A]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]] - return %1, %3, %7, %9, %10, %11 : f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32> -} diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir index b2f16bb3dac9c..4da8d8a967c73 100644 --- a/mlir/test/Dialect/Vector/int-range-interface.mlir +++ b/mlir/test/Dialect/Vector/int-range-interface.mlir @@ -28,7 +28,7 @@ func.func @float_constant_splat() -> vector<8xf32> { // CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} func.func @vector_splat() -> vector<4xindex> { %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index - %1 = vector.splat %0 : vector<4xindex> + %1 = vector.broadcast %0 : index to vector<4xindex> %2 = test.reflect_bounds %1 : vector<4xindex> func.return %2 : vector<4xindex> } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 6ee70fdd89a85..5f035e35a1b86 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -320,7 +320,7 @@ func.func @test_vector.transfer_write(%m: memref<1xi32>, %2: vector<1x32xi32>) func.func @test_vector.transfer_read(%arg0: vector<4x3xf32>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = vector.splat %f0 : vector<4x3xf32> + %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32> // expected-error@+1 {{ requires memref or ranked tensor type}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : vector<4x3xf32>, vector<1x1x2x3xf32> } @@ -330,7 +330,7 @@ func.func @test_vector.transfer_read(%arg0: vector<4x3xf32>) { func.func @test_vector.transfer_read(%arg0: memref<4x3xf32>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = vector.splat %f0 : vector<4x3xf32> + %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32> // expected-error@+1 {{ requires vector type}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<4x3xf32>, f32 } @@ -414,7 +414,7 @@ func.func @test_vector.transfer_read(%arg0: memref) { %c3 = arith.constant 3 : index %cst = arith.constant 3.0 : f32 // expected-note@+1 {{prior use here}} - %mask = vector.splat %c1 : vector<3x8x7xi1> + %mask = vector.broadcast %c1 : i1 to vector<3x8x7xi1> // expected-error@+1 {{expects different type than prior uses: 'vector<3x7xi1>' vs 'vector<3x8x7xi1>'}} %0 = vector.transfer_read %arg0[%c3, %c3, %c3], %cst, %mask {permutation_map = affine_map<(d0, d1, d2)->(d0, 0, d2)>} : memref, vector<3x8x7xf32> } @@ -424,7 +424,7 @@ func.func @test_vector.transfer_read(%arg0: memref) { func.func @test_vector.transfer_read(%arg0: memref>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = vector.splat %f0 : vector<4x3xf32> + %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32> // expected-error@+1 {{requires source vector element and vector result ranks to match}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<3xf32> } @@ -434,7 +434,7 @@ func.func @test_vector.transfer_read(%arg0: memref>) { func.func @test_vector.transfer_read(%arg0: memref>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = vector.splat %f0 : vector<6xf32> + %vf0 = vector.broadcast %f0 : f32 to vector<6xf32> // expected-error@+1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the source}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref>, vector<3xf32> } @@ -444,7 +444,7 @@ func.func @test_vector.transfer_read(%arg0: memref>) { func.func @test_vector.transfer_read(%arg0: memref>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = vector.splat %f0 : vector<2x3xf32> + %vf0 = vector.broadcast %f0 : f32 to vector<2x3xf32> // expected-error@+1 {{ expects the in_bounds attr of same rank as permutation_map results: affine_map<(d0, d1) -> (d0, d1)>}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {in_bounds = [true], permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x2x3xf32> } @@ -454,8 +454,8 @@ func.func @test_vector.transfer_read(%arg0: memref>) { func.func @test_vector.transfer_read(%arg0: memref>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = vector.splat %f0 : vector<2x3xf32> - %mask = vector.splat %c1 : vector<2x3xi1> + %vf0 = vector.broadcast %f0 : f32 to vector<2x3xf32> + %mask = vector.broadcast %c1 : f32 to vector<2x3xi1> // expected-error@+1 {{does not support masks with vector element type}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0, %mask {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x2x3xf32> } @@ -492,7 +492,7 @@ func.func @test_vector.transfer_write(%arg0: memref) { func.func @test_vector.transfer_write(%arg0: memref>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = vector.splat %f0 : vector<4x3xf32> + %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32> // expected-error@+1 {{ requires vector type}} vector.transfer_write %arg0, %arg0[%c3, %c3] : memref>, vector<4x3xf32> } @@ -502,7 +502,7 @@ func.func @test_vector.transfer_write(%arg0: memref>) { func.func @test_vector.transfer_write(%arg0: vector<4x3xf32>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = vector.splat %f0 : vector<4x3xf32> + %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32> // expected-error@+1 {{ requires memref or ranked tensor type}} vector.transfer_write %arg0, %arg0[%c3, %c3] : vector<4x3xf32>, f32 } @@ -1980,29 +1980,6 @@ func.func @invalid_step_2d() { // ----- -//===----------------------------------------------------------------------===// -// vector.splat -//===----------------------------------------------------------------------===// - -// ----- - -func.func @vector_splat_invalid_result(%v : f32) { - // expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'memref<8xf32>'}} - vector.splat %v : memref<8xf32> - return -} - -// ----- - -// expected-note @+1 {{prior use here}} -func.func @vector_splat_type_mismatch(%a: f32) { - // expected-error @+1 {{expects different type than prior uses: 'i32' vs 'f32'}} - %0 = vector.splat %a : vector<1xi32> - return -} - -// ----- - //===----------------------------------------------------------------------===// // vector.load //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index fe697c8b9c057..ee5cfbcda5c19 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -428,33 +428,6 @@ func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> { // ----- -// CHECK-LABEL: linearize_vector_splat -// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32> -func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> { - - // CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32> - // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32> - // CHECK: return %[[CAST]] : vector<4x2xi32> - %0 = vector.splat %arg0 : vector<4x2xi32> - return %0 : vector<4x2xi32> -} - -// ----- - -// CHECK-LABEL: linearize_scalable_vector_splat -// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32> -func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> { - - // CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32> - // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32> - // CHECK: return %[[CAST]] : vector<4x[2]xi32> - %0 = vector.splat %arg0 : vector<4x[2]xi32> - return %0 : vector<4x[2]xi32> - -} - -// ----- - // CHECK-LABEL: linearize_create_mask // CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1> func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> { diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 550e52af7874b..da9a1a8180a05 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -45,11 +45,11 @@ func.func @vector_transfer_ops(%arg0: memref, %i0 = arith.constant 0 : index %i1 = arith.constant 1 : i1 - %vf0 = vector.splat %f0 : vector<4x3xf32> - %v0 = vector.splat %c0 : vector<4x3xi32> - %vi0 = vector.splat %i0 : vector<4x3xindex> + %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32> + %v0 = vector.broadcast %c0 : i32 to vector<4x3xi32> + %vi0 = vector.broadcast %i0 : index to vector<4x3xindex> %m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1> - %m2 = vector.splat %i1 : vector<4x5xi1> + %m2 = vector.broadcast %i1 : i1 to vector<4x5xi1> // // CHECK: vector.transfer_read %0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : memref, vector<128xf32> @@ -106,9 +106,9 @@ func.func @vector_transfer_ops_tensor(%arg0: tensor, %c0 = arith.constant 0 : i32 %i0 = arith.constant 0 : index - %vf0 = vector.splat %f0 : vector<4x3xf32> - %v0 = vector.splat %c0 : vector<4x3xi32> - %vi0 = vector.splat %i0 : vector<4x3xindex> + %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32> + %v0 = vector.broadcast %c0 : i32 to vector<4x3xi32> + %vi0 = vector.broadcast %i0 : index to vector<4x3xindex> // // CHECK: vector.transfer_read @@ -922,28 +922,6 @@ func.func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> { return %2#0 : vector<4x8x16x32xf32> } -// CHECK-LABEL: func @test_splat_op -// CHECK-SAME: %[[s:.*]]: f32, %[[s2:.*]]: !llvm.ptr<1> -func.func @test_splat_op(%s : f32, %s2 : !llvm.ptr<1>) { - // CHECK: vector.splat %[[s]] : vector<8xf32> - %v = vector.splat %s : vector<8xf32> - - // CHECK: vector.splat %[[s]] : vector<4xf32> - %u = "vector.splat"(%s) : (f32) -> vector<4xf32> - - // CHECK: vector.splat %[[s2]] : vector<16x!llvm.ptr<1>> - %w = vector.splat %s2 : vector<16x!llvm.ptr<1>> - return -} - -// CHECK-LABEL: func @vector_splat_0d( -func.func @vector_splat_0d(%a: f32) -> vector { - // CHECK: vector.splat %{{.*}} : vector - %0 = vector.splat %a : vector - return %0 : vector -} - - // CHECK-LABEL: func @vector_mask func.func @vector_mask(%a: vector<8xi32>, %m0: vector<8xi1>) -> i32 { // CHECK-NEXT: %{{.*}} = vector.mask %{{.*}} { vector.reduction , %{{.*}} : vector<8xi32> into i32 } : vector<8xi1> -> i32 diff --git a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir index e74eb08339684..6e5d68c859e2c 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir @@ -49,7 +49,7 @@ func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> { %idx_4 = arith.constant 4 : index %mask = vector.create_mask %idx_1 : vector<4xi1> %s = arith.constant 0.0 : f32 - %pass_thru = vector.splat %s : vector<4xf32> + %pass_thru = vector.broadcast %s : f32 to vector<4xf32> %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> return %0: vector<4xf32> } @@ -65,7 +65,7 @@ func.func @vector_maskedload_with_alignment(%arg0 : memref<4x5xf32>) -> vector<4 %idx_4 = arith.constant 4 : index %mask = vector.create_mask %idx_1 : vector<4xi1> %s = arith.constant 0.0 : f32 - %pass_thru = vector.splat %s : vector<4xf32> + %pass_thru = vector.broadcast %s : f32 to vector<4xf32> %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru {alignment = 8}: memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> return %0: vector<4xf32> } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 12a911ca8c826..0c5fec8c4055a 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -107,7 +107,7 @@ func.func @return_not_in_function() { // ----- func.func @invalid_splat(%v : f32) { // expected-note {{prior use here}} - vector.splat %v : vector<8xf64> + vector.broadcast %v : f64 to vector<8xf64> // expected-error@-1 {{expects different type than prior uses}} return } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir index 6ec103193ac6b..1938a3c8ab484 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir @@ -21,13 +21,6 @@ func.func @print_vector_0d(%a: vector) { return } -func.func @splat_0d(%a: f32) { - %1 = vector.splat %a : vector - // CHECK: ( 42 ) - vector.print %1: vector - return -} - func.func @broadcast_0d(%a: f32) { %1 = vector.broadcast %a : f32 to vector // CHECK: ( 42 ) diff --git a/mlir/test/mlir-runner/utils.mlir b/mlir/test/mlir-runner/utils.mlir index 0c25078449987..d3fc23b423a56 100644 --- a/mlir/test/mlir-runner/utils.mlir +++ b/mlir/test/mlir-runner/utils.mlir @@ -56,7 +56,7 @@ func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interf func.func @vector_splat_2d() { %c0 = arith.constant 0 : index %f10 = arith.constant 10.0 : f32 - %vf10 = vector.splat %f10: !vector_type_C + %vf10 = vector.broadcast %f10: f32 to !vector_type_C %C = memref.alloc() : !matrix_type_CC memref.store %vf10, %C[%c0, %c0]: !matrix_type_CC diff --git a/mlir/utils/tree-sitter-mlir/queries/highlights.scm b/mlir/utils/tree-sitter-mlir/queries/highlights.scm index 59e280bab414a..ca52bcce042f7 100644 --- a/mlir/utils/tree-sitter-mlir/queries/highlights.scm +++ b/mlir/utils/tree-sitter-mlir/queries/highlights.scm @@ -181,7 +181,6 @@ "vector.insert_strided_slice" "vector.matrix_multiply" "vector.print" - "vector.splat" "vector.transfer_read" "vector.transfer_write" "vector.yield"