Skip to content

Commit

Permalink
[mlir][vector] Add vector.from_elements op (#95938)
Browse files Browse the repository at this point in the history
This commit adds a new operation to the vector dialect:
`vector.from_elements`

The op constructs a new vector from a given list of scalar values. It is
similar to `tensor.from_elements`.
```mlir
%0 = vector.from_elements %a, %b, %c, %a, %a, %a : vector<2x3xf32>
```

Constructing a new vector from elements was tedious before this op
existed: a typical way was to define an `arith.constant ... :
vector<...>`, followed by a chain of `vector.insert`.

Folders/canonicalizations are added that can fold `vector.extract` ops
and convert the `vector.from_elements` op into a `vector.splat` op.

The LLVM lowering generates an `llvm.mlir.undef`, followed by a sequence
of scalar insertions in the form of `llvm.insertelement`. Only 0-D and
1-D vectors are currently supported in the LLVM lowering.
  • Loading branch information
matthias-springer committed Jun 19, 2024
1 parent bacbf26 commit c6ff244
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 4 deletions.
40 changes: 37 additions & 3 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -720,10 +720,9 @@ def Vector_ExtractOp :
return getStaticPosition().size();
}

/// Return "true" if the op has at least one dynamic position.
bool hasDynamicPosition() {
auto dynPos = getDynamicPosition();
return std::any_of(dynPos.begin(), dynPos.end(),
[](Value operand) { return operand != nullptr; });
return !getDynamicPosition().empty();
}
}];

Expand Down Expand Up @@ -769,6 +768,41 @@ def Vector_FMAOp :
}];
}

def Vector_FromElementsOp : Vector_Op<"from_elements", [
Pure,
TypesMatchWith<"operand types match result element type",
"result", "elements", "SmallVector<Type>("
"::llvm::cast<VectorType>($_self).getNumElements(), "
"::llvm::cast<VectorType>($_self).getElementType())">]> {
let summary = "operation that defines a vector from scalar elements";
let description = [{
This operation defines a vector from one or multiple scalar elements. The
number of elements must match the number of elements in the result type.
All elements must have the same type, which must match the element type of
the result vector type.

`elements` are a flattened version of the result vector in row-major order.

Example:

```mlir
// %f1
%0 = vector.from_elements %f1 : vector<f32>
// [%f1, %f2]
%1 = vector.from_elements %f1, %f2 : vector<2xf32>
// [[%f1, %f2, %f3], [%f4, %f5, %f6]]
%2 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<2x3xf32>
// [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
%3 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<3x1x2xf32>
```
}];

let arguments = (ins Variadic<AnyType>:$elements);
let results = (outs AnyVectorOfAnyRank:$result);
let assemblyFormat = "$elements attr-dict `:` type($result)";
let hasCanonicalizer = 1;
}

def Vector_InsertElementOp :
Vector_Op<"insertelement", [Pure,
TypesMatchWith<"source operand type matches element type of result",
Expand Down
27 changes: 26 additions & 1 deletion mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1836,6 +1836,30 @@ struct VectorDeinterleaveOpLowering
}
};

/// Conversion pattern for a `vector.from_elements`.
struct VectorFromElementsLowering
: public ConvertOpToLLVMPattern<vector::FromElementsOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = fromElementsOp.getLoc();
VectorType vectorType = fromElementsOp.getType();
// TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
// Such ops should be handled in the same way as vector.insert.
if (vectorType.getRank() > 1)
return rewriter.notifyMatchFailure(fromElementsOp,
"rank > 1 vectors are not supported");
Type llvmType = typeConverter->convertType(vectorType);
Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
result = rewriter.create<vector::InsertOp>(loc, val, result, idx);
rewriter.replaceOp(fromElementsOp, result);
return success();
}
};

} // namespace

/// Populate the given list with patterns that convert from Vector to LLVM.
Expand All @@ -1861,7 +1885,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorSplatOpLowering, VectorSplatNdOpLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
VectorDeinterleaveOpLowering>(converter);
VectorDeinterleaveOpLowering, VectorFromElementsLowering>(
converter);
// Transfer ops with rank > 1 are handled by VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}
Expand Down
111 changes: 111 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1877,6 +1877,45 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
return Value();
}

/// Try to fold the extraction of a scalar from a vector defined by
/// vector.from_elements. E.g.:
///
/// %0 = vector.from_elements %a, %b : vector<2xf32>
/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
/// ==> fold to %a
static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
// Dynamic extractions cannot be folded.
if (extractOp.hasDynamicPosition())
return {};

// Look for extract(from_elements).
auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
if (!fromElementsOp)
return {};

// Scalable vectors are not supported.
auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
if (vecType.isScalable())
return {};

// Only extractions of scalars are supported.
int64_t rank = vecType.getRank();
ArrayRef<int64_t> indices = extractOp.getStaticPosition();
if (extractOp.getType() != vecType.getElementType())
return {};
assert(static_cast<int64_t>(indices.size()) == rank &&
"unexpected number of indices");

// Compute flattened/linearized index and fold to operand.
int flatIndex = 0;
int stride = 1;
for (int i = rank - 1; i >= 0; --i) {
flatIndex += indices[i] * stride;
stride *= vecType.getDimSize(i);
}
return fromElementsOp.getElements()[flatIndex];
}

OpFoldResult ExtractOp::fold(FoldAdaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
Expand All @@ -1895,6 +1934,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
return val;
if (auto val = foldExtractStridedOpFromInsertChain(*this))
return val;
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
return OpFoldResult();
}

Expand Down Expand Up @@ -2099,13 +2140,60 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
return success();
}

/// Try to canonicalize the extraction of a subvector from a vector defined by
/// vector.from_elements. E.g.:
///
/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
PatternRewriter &rewriter) {
// Dynamic positions are not supported.
if (extractOp.hasDynamicPosition())
return failure();

// Scalar extracts are handled by the folder.
auto resultType = dyn_cast<VectorType>(extractOp.getType());
if (!resultType)
return failure();

// Look for extracts from a from_elements op.
auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
if (!fromElementsOp)
return failure();
VectorType inputType = fromElementsOp.getType();

// Scalable vectors are not supported.
if (resultType.isScalable() || inputType.isScalable())
return failure();

// Compute the position of first extracted element and flatten/linearize the
// position.
SmallVector<int64_t> firstElementPos =
llvm::to_vector(extractOp.getStaticPosition());
firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
int flatIndex = 0;
int stride = 1;
for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
flatIndex += firstElementPos[i] * stride;
stride *= inputType.getDimSize(i);
}

// Replace the op with a smaller from_elements op.
rewriter.replaceOpWithNewOp<FromElementsOp>(
extractOp, resultType,
fromElementsOp.getElements().slice(flatIndex,
resultType.getNumElements()));
return success();
}
} // namespace

void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}

static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
Expand All @@ -2122,6 +2210,29 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//

/// Rewrite a vector.from_elements into a vector.splat if all elements are the
/// same SSA value. E.g.:
///
/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
/// ==> rewrite to vector.splat %a : vector<3xf32>
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
PatternRewriter &rewriter) {
if (!llvm::all_equal(fromElementsOp.getElements()))
return failure();
rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
fromElementsOp.getElements().front());
return success();
}

void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(rewriteFromElementsAsSplat);
}

//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
Expand Down
31 changes: 31 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2590,3 +2590,34 @@ func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
%0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
return %0 : vector<2x2xi64>
}

// -----

// CHECK-LABEL: func.func @vector_from_elements_1d(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
// CHECK: %[[undef:.*]] = llvm.mlir.undef : vector<3xf32>
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[insert0:.*]] = llvm.insertelement %[[a]], %[[undef]][%[[c0]] : i64] : vector<3xf32>
// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[insert1:.*]] = llvm.insertelement %[[b]], %[[insert0]][%[[c1]] : i64] : vector<3xf32>
// CHECK: %[[c2:.*]] = llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[insert2:.*]] = llvm.insertelement %[[a]], %[[insert1]][%[[c2]] : i64] : vector<3xf32>
// CHECK: return %[[insert2]]
func.func @vector_from_elements_1d(%a: f32, %b: f32) -> vector<3xf32> {
%0 = vector.from_elements %a, %b, %a : vector<3xf32>
return %0 : vector<3xf32>
}

// -----

// CHECK-LABEL: func.func @vector_from_elements_0d(
// CHECK-SAME: %[[a:.*]]: f32)
// CHECK: %[[undef:.*]] = llvm.mlir.undef : vector<1xf32>
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[insert0:.*]] = llvm.insertelement %[[a]], %[[undef]][%[[c0]] : i64] : vector<1xf32>
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[insert0]] : vector<1xf32> to vector<f32>
// CHECK: return %[[cast]]
func.func @vector_from_elements_0d(%a: f32) -> vector<f32> {
%0 = vector.from_elements %a : vector<f32>
return %0 : vector<f32>
}
69 changes: 69 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2642,3 +2642,72 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
// CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]]
return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
}

// -----

// CHECK-LABEL: func @extract_scalar_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
// Extract from 0D.
%0 = vector.from_elements %a : vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>

// Extract from 1D.
%2 = vector.from_elements %a : vector<1xf32>
%3 = vector.extract %2[0] : f32 from vector<1xf32>
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
%5 = vector.extract %4[4] : f32 from vector<5xf32>

// Extract from 2D.
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>

// CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func @extract_1d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
// CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<3xf32>, vector<3xf32>
}

// -----

// CHECK-LABEL: func @extract_2d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
// CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
}

// -----

// CHECK-LABEL: func @from_elements_to_splat(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
// CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
// CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
%2 = vector.from_elements %a : vector<f32>
// CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
}
17 changes: 17 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1854,3 +1854,20 @@ func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>)
return
}

// -----

func.func @invalid_from_elements(%a: f32) {
// expected-error @+1 {{'vector.from_elements' 1 operands present, but expected 2}}
vector.from_elements %a : vector<2xf32>
return
}

// -----

// expected-note @+1 {{prior use here}}
func.func @invalid_from_elements(%a: f32, %b: i32) {
// expected-error @+1 {{use of value '%b' expects different type than prior uses: 'f32' vs 'i32'}}
vector.from_elements %a, %b : vector<2xf32>
return
}
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Vector/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1158,3 +1158,17 @@ func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4
%0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
}

// CHECK-LABEL: func @from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>) {
// CHECK: vector.from_elements %[[a]] : vector<f32>
%0 = vector.from_elements %a : vector<f32>
// CHECK: vector.from_elements %[[a]] : vector<1xf32>
%1 = vector.from_elements %a : vector<1xf32>
// CHECK: vector.from_elements %[[a]], %[[b]] : vector<1x2xf32>
%2 = vector.from_elements %a, %b : vector<1x2xf32>
// CHECK: vector.from_elements %[[b]], %[[b]], %[[a]], %[[a]] : vector<2x2xf32>
%3 = vector.from_elements %b, %b, %a, %a : vector<2x2xf32>
return %0, %1, %2, %3 : vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>
}

0 comments on commit c6ff244

Please sign in to comment.