Skip to content

Commit

Permalink
[mlir][Vector] Add support for Value indices to vector.extract/insert
Browse files Browse the repository at this point in the history
`vector.extract/insert` ops only support constant indices. This PR is
extending them so that arbitrary values can be used instead.

This work is part of the RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops

Differential Revision: https://reviews.llvm.org/D155034
  • Loading branch information
dcaballe committed Sep 22, 2023
1 parent 6ebc179 commit 98f6289
Show file tree
Hide file tree
Showing 19 changed files with 535 additions and 197 deletions.
18 changes: 18 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,24 @@ inline bool isReductionIterator(Attribute attr) {
return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::reduction;
}

/// Returns the integer numbers in `values`. `values` are expected to be
/// constant operations.
SmallVector<int64_t> getAsIntegers(ArrayRef<Value> values);

/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
/// be constant operations.
SmallVector<int64_t> getAsIntegers(ArrayRef<OpFoldResult> foldResults);

/// Convert `foldResults` into Values. Integer attributes are converted to
/// constant op.
SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> foldResults);

/// Returns the constant index ops in `values`. `values` are expected to be
/// constant operations.
SmallVector<arith::ConstantIndexOp>
getAsConstantIndexOps(ArrayRef<Value> values);

//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
Expand Down
95 changes: 77 additions & 18 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -523,9 +523,7 @@ def Vector_ExtractOp :
Vector_Op<"extract", [Pure,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
InferTypeOpAdaptorWithIsCompatible]>,
Arguments<(ins AnyVectorOfAnyRank:$vector, DenseI64ArrayAttr:$position)>,
Results<(outs AnyType)> {
InferTypeOpAdaptorWithIsCompatible]> {
let summary = "extract operation";
let description = [{
Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
Expand All @@ -535,21 +533,55 @@ def Vector_ExtractOp :

```mlir
%1 = vector.extract %0[3]: vector<4x8x16xf32>
%2 = vector.extract %0[3, 3, 3]: vector<4x8x16xf32>
%2 = vector.extract %0[2, 1, 3]: vector<4x8x16xf32>
%3 = vector.extract %1[]: vector<f32>
%4 = vector.extract %0[%a, %b, %c]: vector<4x8x16xf32>
%5 = vector.extract %0[2, %b]: vector<4x8x16xf32>
```
}];

let arguments = (ins
AnyVectorOfAnyRank:$vector,
Variadic<Index>:$dynamic_position,
DenseI64ArrayAttr:$static_position
);
let results = (outs AnyType:$result);

let builders = [
// Convenience builder which assumes the values in `position` are defined by
// ConstantIndexOp.
OpBuilder<(ins "Value":$source, "ValueRange":$position)>
OpBuilder<(ins "Value":$source, "int64_t":$position)>,
OpBuilder<(ins "Value":$source, "OpFoldResult":$position)>,
OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$position)>,
OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$position)>,
];

let extraClassDeclaration = [{
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getVector().getType());
}

/// Return a vector with all the static and dynamic position indices.
SmallVector<OpFoldResult> getMixedPosition() {
OpBuilder builder(getContext());
return getMixedValues(getStaticPosition(), getDynamicPosition(), builder);
}

unsigned getNumIndices() {
return getStaticPosition().size();
}

bool hasDynamicPosition() {
auto dynPos = getDynamicPosition();
return std::any_of(dynPos.begin(), dynPos.end(),
[](Value operand) { return operand != nullptr; });
}
}];
let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)";

let assemblyFormat = [{
$vector ``
custom<DynamicIndexList>($dynamic_position, $static_position)
attr-dict `:` type($vector)
}];

let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
Expand Down Expand Up @@ -638,9 +670,7 @@ def Vector_InsertOp :
Vector_Op<"insert", [Pure,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
AllTypesMatch<["dest", "res"]>]>,
Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, DenseI64ArrayAttr:$position)>,
Results<(outs AnyVectorOfAnyRank:$res)> {
AllTypesMatch<["dest", "result"]>]> {
let summary = "insert operation";
let description = [{
Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
Expand All @@ -651,24 +681,53 @@ def Vector_InsertOp :

```mlir
%2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
%5 = vector.insert %3, %4[3, 3, 3] : f32 into vector<4x8x16xf32>
%5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
%8 = vector.insert %6, %7[] : f32 into vector<f32>
%11 = vector.insert %9, %10[3, 3, 3] : vector<f32> into vector<4x8x16xf32>
%11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
```
}];
let assemblyFormat = [{
$source `,` $dest $position attr-dict `:` type($source) `into` type($dest)
}];

let arguments = (ins
AnyType:$source,
AnyVectorOfAnyRank:$dest,
Variadic<Index>:$dynamic_position,
DenseI64ArrayAttr:$static_position
);
let results = (outs AnyVectorOfAnyRank:$result);

let builders = [
// Convenience builder which assumes all values are constant indices.
OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)>
OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
];

let extraClassDeclaration = [{
Type getSourceType() { return getSource().getType(); }
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
}

/// Return a vector with all the static and dynamic position indices.
SmallVector<OpFoldResult> getMixedPosition() {
OpBuilder builder(getContext());
return getMixedValues(getStaticPosition(), getDynamicPosition(), builder);
}

unsigned getNumIndices() {
return getStaticPosition().size();
}

bool hasDynamicPosition() {
return llvm::any_of(getDynamicPosition(),
[](Value operand) { return operand != nullptr; });
}
}];

let assemblyFormat = [{
$source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
attr-dict `:` type($source) `into` type($dest)
}];

let hasCanonicalizer = 1;
Expand Down
93 changes: 66 additions & 27 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,18 @@ static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
}

/// Convert `foldResult` into a Value. Integer attribute is converted to
/// an LLVM constant op.
static Value getAsLLVMValue(OpBuilder &builder, Location loc,
OpFoldResult foldResult) {
if (auto attr = foldResult.dyn_cast<Attribute>()) {
auto intAttr = cast<IntegerAttr>(attr);
return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
}

return foldResult.get<Value>();
}

namespace {

/// Trivial Vector to LLVM conversions
Expand Down Expand Up @@ -1079,41 +1091,53 @@ class VectorExtractOpConversion
auto loc = extractOp->getLoc();
auto resultType = extractOp.getResult().getType();
auto llvmResultType = typeConverter->convertType(resultType);
ArrayRef<int64_t> positionArray = extractOp.getPosition();

// Bail if result type cannot be lowered.
if (!llvmResultType)
return failure();

SmallVector<OpFoldResult> positionVec;
for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
if (pos.is<Value>())
// Make sure we use the value that has been already converted to LLVM.
positionVec.push_back(adaptor.getDynamicPosition()[idx]);
else
positionVec.push_back(pos);
}

// Extract entire vector. Should be handled by folder, but just to be safe.
if (positionArray.empty()) {
ArrayRef<OpFoldResult> position(positionVec);
if (position.empty()) {
rewriter.replaceOp(extractOp, adaptor.getVector());
return success();
}

// One-shot extraction of vector from array (only requires extractvalue).
if (isa<VectorType>(resultType)) {
if (extractOp.hasDynamicPosition())
return failure();

Value extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, adaptor.getVector(), positionArray);
loc, adaptor.getVector(), getAsIntegers(position));
rewriter.replaceOp(extractOp, extracted);
return success();
}

// Potential extraction of 1-D vector from array.
Value extracted = adaptor.getVector();
if (positionArray.size() > 1) {
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, extracted, positionArray.drop_back());
}
if (position.size() > 1) {
if (extractOp.hasDynamicPosition())
return failure();

// Remaining extraction of element from 1-D LLVM vector
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
auto constant =
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(extractOp, extracted);
SmallVector<int64_t> nMinusOnePosition =
getAsIntegers(position.drop_back());
extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
nMinusOnePosition);
}

Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
// Remaining extraction of element from 1-D LLVM vector.
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
lastPosition);
return success();
}
};
Expand Down Expand Up @@ -1194,48 +1218,63 @@ class VectorInsertOpConversion
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = typeConverter->convertType(destVectorType);
ArrayRef<int64_t> positionArray = insertOp.getPosition();

// Bail if result type cannot be lowered.
if (!llvmResultType)
return failure();

SmallVector<OpFoldResult> positionVec;
for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) {
if (pos.is<Value>())
// Make sure we use the value that has been already converted to LLVM.
positionVec.push_back(adaptor.getDynamicPosition()[idx]);
else
positionVec.push_back(pos);
}

// Overwrite entire vector with value. Should be handled by folder, but
// just to be safe.
if (positionArray.empty()) {
ArrayRef<OpFoldResult> position(positionVec);
if (position.empty()) {
rewriter.replaceOp(insertOp, adaptor.getSource());
return success();
}

// One-shot insertion of a vector into an array (only requires insertvalue).
if (isa<VectorType>(sourceType)) {
if (insertOp.hasDynamicPosition())
return failure();

Value inserted = rewriter.create<LLVM::InsertValueOp>(
loc, adaptor.getDest(), adaptor.getSource(), positionArray);
loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
rewriter.replaceOp(insertOp, inserted);
return success();
}

// Potential extraction of 1-D vector from array.
Value extracted = adaptor.getDest();
auto oneDVectorType = destVectorType;
if (positionArray.size() > 1) {
if (position.size() > 1) {
if (insertOp.hasDynamicPosition())
return failure();

oneDVectorType = reducedVectorTypeBack(destVectorType);
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, extracted, positionArray.drop_back());
loc, extracted, getAsIntegers(position.drop_back()));
}

// Insertion of an element into a 1-D LLVM vector.
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
auto constant =
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
Value inserted = rewriter.create<LLVM::InsertElementOp>(
loc, typeConverter->convertType(oneDVectorType), extracted,
adaptor.getSource(), constant);
adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));

// Potential insertion of resulting 1-D vector into array.
if (positionArray.size() > 1) {
if (position.size() > 1) {
if (insertOp.hasDynamicPosition())
return failure();

inserted = rewriter.create<LLVM::InsertValueOp>(
loc, adaptor.getDest(), inserted, positionArray.drop_back());
loc, adaptor.getDest(), inserted,
getAsIntegers(position.drop_back()));
}

rewriter.replaceOp(insertOp, inserted);
Expand Down
26 changes: 14 additions & 12 deletions mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1063,10 +1063,11 @@ struct UnrollTransferReadConversion
/// If the result of the TransferReadOp has exactly one user, which is a
/// vector::InsertOp, return that operation's indices.
void getInsertionIndices(TransferReadOp xferOp,
SmallVector<int64_t, 8> &indices) const {
if (auto insertOp = getInsertOp(xferOp))
indices.assign(insertOp.getPosition().begin(),
insertOp.getPosition().end());
SmallVectorImpl<OpFoldResult> &indices) const {
if (auto insertOp = getInsertOp(xferOp)) {
auto pos = insertOp.getMixedPosition();
indices.append(pos.begin(), pos.end());
}
}

/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
Expand Down Expand Up @@ -1110,9 +1111,9 @@ struct UnrollTransferReadConversion
getXferIndices(b, xferOp, iv, xferIndices);

// Indices for the new vector.insert op.
SmallVector<int64_t, 8> insertionIndices;
SmallVector<OpFoldResult, 8> insertionIndices;
getInsertionIndices(xferOp, insertionIndices);
insertionIndices.push_back(i);
insertionIndices.push_back(rewriter.getIndexAttr(i));

auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
auto newXferOp = b.create<vector::TransferReadOp>(
Expand Down Expand Up @@ -1195,10 +1196,11 @@ struct UnrollTransferWriteConversion
/// If the input of the given TransferWriteOp is an ExtractOp, return its
/// indices.
void getExtractionIndices(TransferWriteOp xferOp,
SmallVector<int64_t, 8> &indices) const {
if (auto extractOp = getExtractOp(xferOp))
indices.assign(extractOp.getPosition().begin(),
extractOp.getPosition().end());
SmallVectorImpl<OpFoldResult> &indices) const {
if (auto extractOp = getExtractOp(xferOp)) {
auto pos = extractOp.getMixedPosition();
indices.append(pos.begin(), pos.end());
}
}

/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
Expand Down Expand Up @@ -1235,9 +1237,9 @@ struct UnrollTransferWriteConversion
getXferIndices(b, xferOp, iv, xferIndices);

// Indices for the new vector.extract op.
SmallVector<int64_t, 8> extractionIndices;
SmallVector<OpFoldResult, 8> extractionIndices;
getExtractionIndices(xferOp, extractionIndices);
extractionIndices.push_back(i);
extractionIndices.push_back(b.getI64IntegerAttr(i));

auto extracted =
b.create<vector::ExtractOp>(loc, vec, extractionIndices);
Expand Down
Loading

0 comments on commit 98f6289

Please sign in to comment.