Skip to content

Commit

Permalink
[mlir][vector] Add vector.scalable.insert/extract ops
Browse files Browse the repository at this point in the history
These new operations match the semantics of
llvm.experimental.vector.insert and llvm.experimental.vector.extract.

`vector.scalable.insert` and `vector.scalable.extract` allow,
respectively, insert vectors into scalable vectors, and extract vectors
from scalable vectors.

The discussion about the inclusion of these operations is here:
https://discourse.llvm.org/t/rfc-interfacing-between-fixed-length-and-scalable-vectors-for-vls-vector-code-on-scalable-vector-architectures

Differential Revision: https://reviews.llvm.org/D127875
  • Loading branch information
jsetoain committed Nov 8, 2022
1 parent 54fb173 commit aa9647e
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 1 deletion.
108 changes: 108 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Expand Up @@ -727,6 +727,114 @@ def Vector_InsertOp :
let hasVerifier = 1;
}

def Vector_ScalableInsertOp :
Vector_Op<"scalable.insert", [Pure,
AllElementTypesMatch<["source", "dest"]>,
AllTypesMatch<["dest", "res"]>,
PredOpTrait<"position is a multiple of the source length.",
CPred<
"(getPos() % getSourceVectorType().getNumElements()) == 0"
>>]>,
Arguments<(ins VectorOfRank<[1]>:$source,
ScalableVectorOfRank<[1]>:$dest,
I64Attr:$pos)>,
Results<(outs ScalableVectorOfRank<[1]>:$res)> {
let summary = "insert subvector into scalable vector operation";
// NOTE: This operation is designed to map to `llvm.vector.insert`, and its
// documentation should be kept aligned with LLVM IR:
// https://llvm.org/docs/LangRef.html#llvm-vector-insert-intrinsic
let description = [{
This operations takes a rank-1 fixed-length or scalable subvector and
inserts it within the destination scalable vector starting from the
position specificed by `pos`. If the source vector is scalable, the
insertion position will be scaled by the runtime scaling factor of the
source subvector.

The insertion position must be a multiple of the minimum size of the source
vector. For the operation to be well defined, the source vector must fit in
the destination vector from the specified position. Since the destination
vector is scalable and its runtime length is unknown, the validity of the
operation can't be verified nor guaranteed at compile time.

Example:

```mlir
%2 = vector.scalable.insert %0, %1[8] : vector<4xf32> into vector<[16]xf32>
%5 = vector.scalable.insert %3, %4[0] : vector<8xf32> into vector<[4]xf32>
%8 = vector.scalable.insert %6, %7[0] : vector<[4]xf32> into vector<[8]xf32>
```

Invalid example:
```mlir
%2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
```
}];

let assemblyFormat = [{
$source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
}];

let extraClassDeclaration = [{
VectorType getSourceVectorType() {
return getSource().getType().cast<VectorType>();
}
VectorType getDestVectorType() {
return getDest().getType().cast<VectorType>();
}
}];
}

def Vector_ScalableExtractOp :
Vector_Op<"scalable.extract", [Pure,
AllElementTypesMatch<["source", "res"]>,
PredOpTrait<"position is a multiple of the result length.",
CPred<
"(getPos() % getResultVectorType().getNumElements()) == 0"
>>]>,
Arguments<(ins ScalableVectorOfRank<[1]>:$source,
I64Attr:$pos)>,
Results<(outs VectorOfRank<[1]>:$res)> {
let summary = "extract subvector from scalable vector operation";
// NOTE: This operation is designed to map to `llvm.vector.extract`, and its
// documentation should be kept aligned with LLVM IR:
// https://llvm.org/docs/LangRef.html#llvm-vector-extract-intrinsic
let description = [{
Takes rank-1 source vector and a position `pos` within the source
vector, and extracts a subvector starting from that position.

The extraction position must be a multiple of the minimum size of the result
vector. For the operation to be well defined, the destination vector must
fit within the source vector from the specified position. Since the source
vector is scalable and its runtime length is unknown, the validity of the
operation can't be verified nor guaranteed at compile time.

Example:

```mlir
%1 = vector.scalable.extract %0[8] : vector<4xf32> from vector<[8]xf32>
%3 = vector.scalable.extract %2[0] : vector<[4]xf32> from vector<[8]xf32>
```

Invalid example:
```mlir
%1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
```
}];

let assemblyFormat = [{
$source `[` $pos `]` attr-dict `:` type($res) `from` type($source)
}];

let extraClassDeclaration = [{
VectorType getSourceVectorType() {
return getSource().getType().cast<VectorType>();
}
VectorType getResultVectorType() {
return getRes().getType().cast<VectorType>();
}
}];
}

def Vector_InsertStridedSliceOp :
Vector_Op<"insert_strided_slice", [Pure,
PredOpTrait<"operand #0 and result have same element type",
Expand Down
28 changes: 28 additions & 0 deletions mlir/include/mlir/IR/OpBase.td
Expand Up @@ -579,11 +579,39 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
== }]
# allowedlength>)>]>;

// Whether the number of elements of a fixed-length vector is from the given
// `allowedRanks` list
class IsFixedVectorOfRankPred<list<int> allowedRanks> :
And<[IsFixedVectorTypePred,
Or<!foreach(allowedlength, allowedRanks,
CPred<[{$_self.cast<::mlir::VectorType>().getRank()
== }]
# allowedlength>)>]>;

// Whether the number of elements of a scalable vector is from the given
// `allowedRanks` list
class IsScalableVectorOfRankPred<list<int> allowedRanks> :
And<[IsScalableVectorTypePred,
Or<!foreach(allowedlength, allowedRanks,
CPred<[{$_self.cast<::mlir::VectorType>().getRank()
== }]
# allowedlength>)>]>;

// Any vector where the rank is from the given `allowedRanks` list
class VectorOfRank<list<int> allowedRanks> : Type<
IsVectorOfRankPred<allowedRanks>,
" of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;

// Any fixed-length vector where the rank is from the given `allowedRanks` list
class FixedVectorOfRank<list<int> allowedRanks> : Type<
IsFixedVectorOfRankPred<allowedRanks>,
" of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;

// Any scalable vector where the rank is from the given `allowedRanks` list
class ScalableVectorOfRank<list<int> allowedRanks> : Type<
IsScalableVectorOfRankPred<allowedRanks>,
" of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;

// Any vector where the rank is from the given `allowedRanks` list and the type
// is from the given `allowedTypes` list
class VectorOfRankAndType<list<int> allowedRanks,
Expand Down
35 changes: 34 additions & 1 deletion mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Expand Up @@ -857,6 +857,37 @@ class VectorInsertOpConversion
}
};

/// Lower vector.scalable.insert ops to LLVM vector.insert
struct VectorScalableInsertOpLowering
: public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
using ConvertOpToLLVMPattern<
vector::ScalableInsertOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos());
return success();
}
};

/// Lower vector.scalable.extract ops to LLVM vector.extract
struct VectorScalableExtractOpLowering
: public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
using ConvertOpToLLVMPattern<
vector::ScalableExtractOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
extOp, typeConverter->convertType(extOp.getResultVectorType()),
adaptor.getSource(), adaptor.getPos());
return success();
}
};

/// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
///
/// Example:
Expand Down Expand Up @@ -1329,7 +1360,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
vector::MaskedStoreOpAdaptor>,
VectorGatherOpConversion, VectorScatterOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
VectorSplatOpLowering, VectorSplatNdOpLowering>(converter);
VectorSplatOpLowering, VectorSplatNdOpLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering>(
converter);
// Transfer ops with rank > 1 are handled by VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Expand Up @@ -2140,3 +2140,25 @@ func.func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
// CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0, 0, 0, 0]
// CHECK-NEXT: %[[SCALE:[0-9]+]] = arith.mulf %[[A]], %[[SPLAT]] : vector<4xf32>
// CHECK-NEXT: return %[[SCALE]] : vector<4xf32>

// -----

// CHECK-LABEL: @vector_scalable_insert
// CHECK-SAME: %[[SUB:.*]]: vector<4xf32>, %[[SV:.*]]: vector<[4]xf32>
func.func @vector_scalable_insert(%sub: vector<4xf32>, %dsv: vector<[4]xf32>) -> vector<[4]xf32> {
// CHECK-NEXT: %[[TMP:.*]] = llvm.intr.vector.insert %[[SUB]], %[[SV]][0] : vector<4xf32> into vector<[4]xf32>
%0 = vector.scalable.insert %sub, %dsv[0] : vector<4xf32> into vector<[4]xf32>
// CHECK-NEXT: llvm.intr.vector.insert %[[SUB]], %[[TMP]][4] : vector<4xf32> into vector<[4]xf32>
%1 = vector.scalable.insert %sub, %0[4] : vector<4xf32> into vector<[4]xf32>
return %1 : vector<[4]xf32>
}

// -----

// CHECK-LABEL: @vector_scalable_extract
// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
func.func @vector_scalable_extract(%vec: vector<[4]xf32>) -> vector<8xf32> {
// CHECK-NEXT: %{{.*}} = llvm.intr.vector.extract %[[VEC]][0] : vector<8xf32> from vector<[4]xf32>
%0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32>
return %0 : vector<8xf32>
}
13 changes: 13 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Expand Up @@ -1632,3 +1632,16 @@ func.func @vector_mask_passthru_no_return(%val: vector<16xf32>, %t0: tensor<?xf3
return
}

// -----

func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
// expected-error@+1 {{op failed to verify that position is a multiple of the source length.}}
%0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>
}

// -----

func.func @vector_scalable_extract_unaligned(%vec: vector<[16]xf32>) {
// expected-error@+1 {{op failed to verify that position is a multiple of the result length.}}
%0 = vector.scalable.extract %vec[5] : vector<4xf32> from vector<[16]xf32>
}
25 changes: 25 additions & 0 deletions mlir/test/Dialect/Vector/ops.mlir
Expand Up @@ -853,3 +853,28 @@ func.func @vector_mask_tensor_return(%val: vector<16xf32>, %t0: tensor<?xf32>, %
return
}

// CHECK-LABEL: func @vector_scalable_insert(
// CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>,
// CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32>
func.func @vector_scalable_insert(%sub0: vector<4xi32>, %sub1: vector<8xi32>,
%sub2: vector<[4]xi32>, %sv: vector<[8]xi32>) {
// CHECK-NEXT: vector.scalable.insert %[[SUB0]], %[[SV]][12] : vector<4xi32> into vector<[8]xi32>
%0 = vector.scalable.insert %sub0, %sv[12] : vector<4xi32> into vector<[8]xi32>
// CHECK-NEXT: vector.scalable.insert %[[SUB1]], %[[SV]][0] : vector<8xi32> into vector<[8]xi32>
%1 = vector.scalable.insert %sub1, %sv[0] : vector<8xi32> into vector<[8]xi32>
// CHECK-NEXT: vector.scalable.insert %[[SUB2]], %[[SV]][0] : vector<[4]xi32> into vector<[8]xi32>
%2 = vector.scalable.insert %sub2, %sv[0] : vector<[4]xi32> into vector<[8]xi32>
return
}

// CHECK-LABEL: func @vector_scalable_extract(
// CHECK-SAME: %[[SV:.*]]: vector<[8]xi32>
func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
// CHECK-NEXT: vector.scalable.extract %[[SV]][0] : vector<16xi32> from vector<[8]xi32>
%0 = vector.scalable.extract %sv[0] : vector<16xi32> from vector<[8]xi32>
// CHECK-NEXT: vector.scalable.extract %[[SV]][0] : vector<[4]xi32> from vector<[8]xi32>
%1 = vector.scalable.extract %sv[0] : vector<[4]xi32> from vector<[8]xi32>
// CHECK-NEXT: vector.scalable.extract %[[SV]][4] : vector<4xi32> from vector<[8]xi32>
%2 = vector.scalable.extract %sv[4] : vector<4xi32> from vector<[8]xi32>
return
}

0 comments on commit aa9647e

Please sign in to comment.