diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index c51c7b7270fae..7500bf7d1d9eb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -965,6 +965,28 @@ struct BreakDownVectorBitCast : public OpRewritePattern { std::function controlFn; }; +static bool haveSameShapeAndScaling(Type t, Type u) { + auto tVec = dyn_cast(t); + auto uVec = dyn_cast(u); + if (!tVec) { + return !uVec; + } + if (!uVec) { + return false; + } + return tVec.getShape() == uVec.getShape() && + tVec.getScalableDims() == uVec.getScalableDims(); +} + +/// If `type` is shaped, clone it with `newElementType`. Otherwise, +/// return `newElementType`. +static Type cloneOrReplace(Type type, Type newElementType) { + if (auto shapedType = dyn_cast(type)) { + return shapedType.clone(newElementType); + } + return newElementType; +} + /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: /// /// Example: @@ -988,16 +1010,14 @@ struct ReorderElementwiseOpsOnBroadcast final PatternRewriter &rewriter) const override { if (op->getNumResults() != 1) return failure(); - if (!llvm::isa(op->getResults()[0].getType())) + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) return failure(); if (!OpTrait::hasElementwiseMappableTraits(op)) return rewriter.notifyMatchFailure( op, "Op doesn't have ElementwiseMappableTraits"); if (op->getNumOperands() == 0) return failure(); - if (op->getResults()[0].getType() != op->getOperand(0).getType()) - return rewriter.notifyMatchFailure(op, - "result and operand type mismatch"); if (isa(op)) { return rewriter.notifyMatchFailure( op, @@ -1005,6 +1025,7 @@ struct ReorderElementwiseOpsOnBroadcast final "might be a scalar"); } + Type resultElemType = resultType.getElementType(); // Get the type of the first non-constant operand Operation *firstBroadcastOrSplat = nullptr; for (Value operand : op->getOperands()) { @@ -1020,24 +1041,23 @@ struct ReorderElementwiseOpsOnBroadcast final } if (!firstBroadcastOrSplat) return failure(); - Type firstBroadcastOrSplatType = - firstBroadcastOrSplat->getOperand(0).getType(); + Type unbroadcastResultType = cloneOrReplace( + firstBroadcastOrSplat->getOperand(0).getType(), resultElemType); - // Make sure that all operands are broadcast from identical types: + // Make sure that all operands are broadcast from identically-shaped types: // * scalar (`vector.broadcast` + `vector.splat`), or // * vector (`vector.broadcast`). // Otherwise the re-ordering wouldn't be safe. - if (!llvm::all_of( - op->getOperands(), [&firstBroadcastOrSplatType](Value val) { - if (auto bcastOp = val.getDefiningOp()) - return (bcastOp.getOperand().getType() == - firstBroadcastOrSplatType); - if (auto splatOp = val.getDefiningOp()) - return (splatOp.getOperand().getType() == - firstBroadcastOrSplatType); - SplatElementsAttr splatConst; - return matchPattern(val, m_Constant(&splatConst)); - })) { + if (!llvm::all_of(op->getOperands(), [&unbroadcastResultType](Value val) { + if (auto bcastOp = val.getDefiningOp()) + return haveSameShapeAndScaling(bcastOp.getOperand().getType(), + unbroadcastResultType); + if (auto splatOp = val.getDefiningOp()) + return haveSameShapeAndScaling(splatOp.getOperand().getType(), + unbroadcastResultType); + SplatElementsAttr splatConst; + return matchPattern(val, m_Constant(&splatConst)); + })) { return failure(); } @@ -1048,15 +1068,16 @@ struct ReorderElementwiseOpsOnBroadcast final SplatElementsAttr splatConst; if (matchPattern(operand, m_Constant(&splatConst))) { Attribute newConst; - if (auto shapedTy = dyn_cast(firstBroadcastOrSplatType)) { - newConst = splatConst.resizeSplat(shapedTy); + Type elementType = getElementTypeOrSelf(operand.getType()); + Type newType = cloneOrReplace(unbroadcastResultType, elementType); + if (auto newTypeShaped = dyn_cast(newType)) { + newConst = splatConst.resizeSplat(newTypeShaped); } else { newConst = splatConst.getSplatValue(); } Operation *newConstOp = operand.getDefiningOp()->getDialect()->materializeConstant( - rewriter, newConst, firstBroadcastOrSplatType, - operand.getLoc()); + rewriter, newConst, newType, operand.getLoc()); srcValues.push_back(newConstOp->getResult(0)); } else { srcValues.push_back(operand.getDefiningOp()->getOperand(0)); @@ -1066,12 +1087,11 @@ struct ReorderElementwiseOpsOnBroadcast final // Create the "elementwise" Op Operation *elementwiseOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, - firstBroadcastOrSplatType, op->getAttrs()); + unbroadcastResultType, op->getAttrs()); // Replace the original Op with the elementwise Op - auto vectorType = op->getResultTypes()[0]; rewriter.replaceOpWithNewOp( - op, vectorType, elementwiseOp->getResults()); + op, resultType, elementwiseOp->getResults()); return success(); } diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir index f8638ab843ecb..ef881ba05a416 100644 --- a/mlir/test/Dialect/Vector/vector-sink.mlir +++ b/mlir/test/Dialect/Vector/vector-sink.mlir @@ -180,13 +180,14 @@ func.func @negative_not_elementwise() -> vector<2x2xf32> { // ----- -// The source and the result for arith.cmp have different types - not supported - -// CHECK-LABEL: func.func @negative_source_and_result_mismatch -// CHECK: %[[BROADCAST:.+]] = vector.broadcast -// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]] -// CHECK: return %[[RETURN]] -func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> { +// The source and the result for arith.cmp have different types + +// CHECK-LABEL: func.func @source_and_result_mismatch( +// CHECK-SAME: %[[ARG0:.+]]: f32) +// CHECK: %[[COMPARE:.+]] = arith.cmpf uno, %[[ARG0]], %[[ARG0]] +// CHECK: %[[BROADCAST:.+]] = vector.broadcast %[[COMPARE]] : i1 to vector<1xi1> +// CHECK: return %[[BROADCAST]] +func.func @source_and_result_mismatch(%arg0 : f32) -> vector<1xi1> { %0 = vector.broadcast %arg0 : f32 to vector<1xf32> %1 = arith.cmpf uno, %0, %0 : vector<1xf32> return %1 : vector<1xi1> @@ -210,53 +211,6 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> { return %1 : vector<1xf32> } -//===----------------------------------------------------------------------===// -// [Pattern: ReorderCastOpsOnBroadcast] -// -// Reorder casting ops and vector ops. The casting ops have almost identical -// pattern, so only arith.extsi op is tested. -//===----------------------------------------------------------------------===// - -// ----- - -func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> { - // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32> - // CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32> - %b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8> - %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> - return %r : vector<2x4xi32> -} - -// ----- - -func.func @broadcast_vector_extsi_scalable(%a : vector<[4]xi8>) -> vector<2x[4]xi32> { - // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]xi8> to vector<[4]xi32> - // CHECK: vector.broadcast %[[EXT:.+]] : vector<[4]xi32> to vector<2x[4]xi32> - %b = vector.broadcast %a : vector<[4]xi8> to vector<2x[4]xi8> - %r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32> - return %r : vector<2x[4]xi32> -} - -// ----- - -func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> { - // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32 - // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32> - %b = vector.broadcast %a : i8 to vector<2x4xi8> - %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> - return %r : vector<2x4xi32> -} - -// ----- - -func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> { - // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32 - // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x[4]xi32> - %b = vector.broadcast %a : i8 to vector<2x[4]xi8> - %r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32> - return %r : vector<2x[4]xi32> -} - // ----- // CHECK-LABEL: func.func @broadcast_scalar_and_splat_const( @@ -321,6 +275,113 @@ func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xi return %2 : vector<1x4xindex> } +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: f16) -> vector<1x4xf32> { +// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : f16 to f32 +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : f32 to vector<1x4xf32> +// CHECK: return %[[BCAST]] : vector<1x4xf32> + +func.func @broadcast_scalar_mixed_type(%arg0: f16) -> vector<1x4xf32> { + %0 = vector.broadcast %arg0 : f16 to vector<1x4xf16> + %1 = arith.extf %0 : vector<1x4xf16> to vector<1x4xf32> + return %1 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf16>) -> vector<3x4xf32> { +// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : vector<4xf16> to vector<4xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> + +func.func @broadcast_vector_mixed_type(%arg0: vector<4xf16>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf16> to vector<3x4xf16> + %1 = arith.extf %0 : vector<3x4xf16> to vector<3x4xf32> + return %1 : vector<3x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: f32) -> vector<1x4xf32> { +// CHECK: %[[NEW_CST:.*]] = arith.constant 3 : i32 +// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : f32, i32 +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : f32 to vector<1x4xf32> +// CHECK: return %[[BCAST]] : vector<1x4xf32> + +func.func @broadcast_scalar_and_splat_const_mixed_type(%arg0: f32) -> vector<1x4xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<1x4xf32> + %cst = arith.constant dense<3> : vector<1x4xi32> + %2 = math.fpowi %0, %cst : vector<1x4xf32>, vector<1x4xi32> + return %2 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_and_splat_const_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> { +// CHECK: %[[NEW_CST:.*]] = arith.constant dense<3> : vector<4xi32> +// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>, vector<4xi32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> + +func.func @broadcast_vector_and_splat_const_mixed_type(%arg0: vector<4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32> + %cst = arith.constant dense<3> : vector<3x4xi32> + %2 = math.fpowi %0, %cst : vector<3x4xf32>, vector<3x4xi32> + return %2 : vector<3x4xf32> +} + +//===----------------------------------------------------------------------===// +// [Pattern: ReorderCastOpsOnBroadcast] +// +// Reorder casting ops and vector ops. The casting ops have almost identical +// pattern, so only arith.extsi op is tested. +//===----------------------------------------------------------------------===// + +// ----- + +func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> { + // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32> + // CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32> + %b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8> + %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> + return %r : vector<2x4xi32> +} + +// ----- + +func.func @broadcast_vector_extsi_scalable(%a : vector<[4]xi8>) -> vector<2x[4]xi32> { + // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]xi8> to vector<[4]xi32> + // CHECK: vector.broadcast %[[EXT:.+]] : vector<[4]xi32> to vector<2x[4]xi32> + %b = vector.broadcast %a : vector<[4]xi8> to vector<2x[4]xi8> + %r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32> + return %r : vector<2x[4]xi32> +} + +// ----- + +func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> { + // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32 + // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32> + %b = vector.broadcast %a : i8 to vector<2x4xi8> + %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> + return %r : vector<2x4xi32> +} + +// ----- + +func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> { + // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32 + // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x[4]xi32> + %b = vector.broadcast %a : i8 to vector<2x[4]xi8> + %r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32> + return %r : vector<2x[4]xi32> +} + //===----------------------------------------------------------------------===// // [Pattern: ReorderElementwiseOpsOnTranspose] //===----------------------------------------------------------------------===//