Skip to content

Commit 33465bb

Browse files
authored
[mlir][Vector] Remove vector.extractelement and vector.insertelement ops (#149603)
This PR removes `vector.extractelement` and `vector.insertelement` ops from the code base in favor of the `vector.extract` and `vector.insert` counterparts. See RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops
1 parent 0efcb83 commit 33465bb

File tree

28 files changed

+65
-1004
lines changed

28 files changed

+65
-1004
lines changed

mlir/docs/Dialects/Vector.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ LLVM instructions are prefixed by the `llvm.` dialect prefix (e.g.
294294
`llvm.insertvalue`). Such ops operate exclusively on 1-D vectors and aggregates
295295
following the [LLVM LangRef](https://llvm.org/docs/LangRef.html). MLIR
296296
operations are prefixed by the `vector.` dialect prefix (e.g.
297-
`vector.insertelement`). Such ops operate exclusively on MLIR `n-D` `vector`
297+
`vector.insert`). Such ops operate exclusively on MLIR `n-D` `vector`
298298
types.
299299

300300
### Alternatives For Lowering an n-D Vector Type to LLVM

mlir/docs/Tutorials/transform/Ch0.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ When no support is available, such an operation can be transformed into a loop:
4646
%c8 = arith.constant 8 : index
4747
%init = arith.constant 0.0 : f32
4848
%result = scf.for %i = %c0 to %c8 step %c1 iter_args(%partial = %init) -> (f32) {
49-
%element = vector.extractelement %0[%i : index] : vector<8xf32>
49+
%element = vector.extract %0[%i] : f32 into vector<8xf32>
5050
%updated = arith.addf %partial, %element : f32
5151
scf.yield %updated : f32
5252
}
@@ -145,7 +145,7 @@ linalg.generic {
145145
%c0 = arith.constant 0.0 : f32
146146
%0 = arith.cmpf ogt %in_one, %c0 : f32
147147
%1 = arith.select %0, %in_one, %c0 : f32
148-
linalg.yield %1 : f32
148+
linalg.yield %1 : f32
149149
}
150150
```
151151

@@ -185,7 +185,7 @@ In the case of `linalg.generic` operations, the iteration space is implicit and
185185
For example, tiling the matrix multiplication presented above with tile sizes `(2, 8)`, we obtain a loop nest around a `linalg.generic` expressing the same operation on a `2x8` tensor.
186186

187187
```mlir
188-
// A special "multi-for" loop that supports tensor-insertion semantics
188+
// A special "multi-for" loop that supports tensor-insertion semantics
189189
// as opposed to implicit updates. The resulting 8x16 tensor will be produced
190190
// by this loop.
191191
// The trip count of iterators is computed dividing the original tensor size,
@@ -202,9 +202,9 @@ For example, tiling the matrix multiplication presented above with tile sizes `(
202202
// Take slices of inputs and outputs. Only the "i" and "j" dimensions are sliced.
203203
%lhs_slice = tensor.extract_slice %lhs[%3, 0] [2, 10] [1, 1]
204204
: tensor<8x10xf32> to tensor<2x10xf32>
205-
%rhs_slice = tensor.extract_slice %rhs[0, %4] [10, 8] [1, 1]
205+
%rhs_slice = tensor.extract_slice %rhs[0, %4] [10, 8] [1, 1]
206206
: tensor<10x16xf32> to tensor<10x8xf32>
207-
%result_slice = tensor.extract_slice %shared[%3, %4] [2, 8] [1, 1]
207+
%result_slice = tensor.extract_slice %shared[%3, %4] [2, 8] [1, 1]
208208
: tensor<8x16xf32> to tensor<2x8xf32>
209209
210210
// This is exactly the same operation as before, but now operating on smaller
@@ -214,7 +214,7 @@ For example, tiling the matrix multiplication presented above with tile sizes `(
214214
affine_map<(i, j, k) -> (k, j)>,
215215
affine_map<(i, j, k) -> (i, j)>],
216216
iterator_types = ["parallel", "parallel", "reduction"]
217-
} ins(%lhs_slice, %rhs_slice : tensor<2x10xf32>, tensor<10x8xf32>)
217+
} ins(%lhs_slice, %rhs_slice : tensor<2x10xf32>, tensor<10x8xf32>)
218218
outs(%result_slice : tensor<2x8xf32>) -> tensor<2x8xf32> {
219219
^bb0(%lhs_one: f32, %rhs_one: f32, %init_one: f32):
220220
%0 = arith.mulf %lhs_one, %rhs_one : f32
@@ -238,15 +238,15 @@ After materializing loops with tiling, another key code generation transformatio
238238
1. the subset (slice) of the operand that is used by the tile, and
239239
2. the tensor-level structured operation producing the whole tensor that is being sliced.
240240

241-
By inverting the `indexing_map` and applying it to the set of elements accessed through the slice, we can compute the part of the iteration space of the operation defining the full tensor necessary to compute the tile. Thus fusion boils down to replacing the `tensor.extract_slice` operation with the tile of the `linalg.generic` producing the original operand.
241+
By inverting the `indexing_map` and applying it to the set of elements accessed through the slice, we can compute the part of the iteration space of the operation defining the full tensor necessary to compute the tile. Thus fusion boils down to replacing the `tensor.extract_slice` operation with the tile of the `linalg.generic` producing the original operand.
242242

243243
Let us assume that the matrix multiplication operation is followed by another operation that multiplies each element of the resulting matrix with itself. This trailing elementwise operation has a 2D iteration space, unlike the 3D one in matrix multiplication. Nevertheless, it is possible to tile the trailing operation and then fuse the producer of its operand, the matmul, into the loop generated by tiling. The untiled dimension will be used in its entirety.
244244

245245

246246
```mlir
247247
// Same loop as before.
248-
%0 = scf.forall (%i, %j) in (4, 2)
249-
shared_outs(%shared = %init)
248+
%0 = scf.forall (%i, %j) in (4, 2)
249+
shared_outs(%shared = %init)
250250
-> (tensor<8x16xf32>, tensor<8x16xf32>) {
251251
// Scale the loop induction variables by the tile sizes.
252252
%1 = affine.apply affine_map<(d0) -> (d0 * 2)>(%i)
@@ -286,7 +286,7 @@ Let us assume that the matrix multiplication operation is followed by another op
286286
indexing_maps = [affine_map<(i, j) -> (i, j)>,
287287
affine_map<(i, j) -> (i, j)>],
288288
iterator_types = ["parallel", "parallel"]
289-
} ins(%partial : tensor<2x8xf32>)
289+
} ins(%partial : tensor<2x8xf32>)
290290
outs(%shared_slice : tensor<2x8xf32>) {
291291
^bb0(%in: f32, %out: f32):
292292
%5 = arith.mulf %in, %in : f32

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> {
380380

381381
After:
382382
%3 = memref.load %2[] : memref<f32>
383-
%4 = vector.insertelement %3, %cst[%c0 : index] : vector<32xf32>
383+
%4 = vector.insert %3, %cst [0] : f32 into vector<32xf32>
384384
%5 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4) -> (vector<32xf32>) {
385385
%8 = vector.load %0[%arg3] : memref<?xf32>, vector<32xf32>
386386
%9 = vector.load %1[%arg3] : memref<1024xf32>, vector<32xf32>

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -646,55 +646,6 @@ def Vector_DeinterleaveOp :
646646
}];
647647
}
648648

649-
def Vector_ExtractElementOp :
650-
Vector_Op<"extractelement", [Pure,
651-
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
652-
TypesMatchWith<"result type matches element type of vector operand",
653-
"vector", "result",
654-
"::llvm::cast<VectorType>($_self).getElementType()">]>,
655-
Arguments<(ins AnyVectorOfAnyRank:$vector,
656-
Optional<AnySignlessIntegerOrIndex>:$position)>,
657-
Results<(outs AnyType:$result)> {
658-
let summary = "extractelement operation";
659-
let description = [{
660-
Note: This operation is deprecated. Please use vector.extract insert.
661-
662-
Takes a 0-D or 1-D vector and a optional dynamic index position and
663-
extracts the scalar at that position.
664-
665-
Note that this instruction resembles vector.extract, but is restricted to
666-
0-D and 1-D vectors.
667-
If the vector is 0-D, the position must be std::nullopt.
668-
669-
670-
It is meant to be closer to LLVM's version:
671-
https://llvm.org/docs/LangRef.html#extractelement-instruction
672-
673-
Example:
674-
675-
```mlir
676-
%c = arith.constant 15 : i32
677-
%1 = vector.extractelement %0[%c : i32]: vector<16xf32>
678-
%2 = vector.extractelement %z[]: vector<f32>
679-
```
680-
}];
681-
let assemblyFormat = [{
682-
$vector `[` ($position^ `:` type($position))? `]` attr-dict `:` type($vector)
683-
}];
684-
685-
let builders = [
686-
// 0-D builder.
687-
OpBuilder<(ins "Value":$source)>,
688-
];
689-
let extraClassDeclaration = [{
690-
VectorType getSourceVectorType() {
691-
return ::llvm::cast<VectorType>(getVector().getType());
692-
}
693-
}];
694-
let hasVerifier = 1;
695-
let hasFolder = 1;
696-
}
697-
698649
def Vector_ExtractOp :
699650
Vector_Op<"extract", [Pure,
700651
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
@@ -890,57 +841,6 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
890841
let hasCanonicalizer = 1;
891842
}
892843

893-
def Vector_InsertElementOp :
894-
Vector_Op<"insertelement", [Pure,
895-
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
896-
TypesMatchWith<"source operand type matches element type of result",
897-
"result", "source",
898-
"::llvm::cast<VectorType>($_self).getElementType()">,
899-
AllTypesMatch<["dest", "result"]>]>,
900-
Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest,
901-
Optional<AnySignlessIntegerOrIndex>:$position)>,
902-
Results<(outs AnyVectorOfAnyRank:$result)> {
903-
let summary = "insertelement operation";
904-
let description = [{
905-
Note: This operation is deprecated. Please use vector.insert instead.
906-
907-
Takes a scalar source, a 0-D or 1-D destination vector and a dynamic index
908-
position and inserts the source into the destination at the proper position.
909-
910-
Note that this instruction resembles vector.insert, but is restricted to 0-D
911-
and 1-D vectors.
912-
913-
It is meant to be closer to LLVM's version:
914-
https://llvm.org/docs/LangRef.html#insertelement-instruction
915-
916-
Example:
917-
918-
```mlir
919-
%c = arith.constant 15 : i32
920-
%f = arith.constant 0.0f : f32
921-
%1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32>
922-
%2 = vector.insertelement %f, %z[]: vector<f32>
923-
```
924-
}];
925-
let assemblyFormat = [{
926-
$source `,` $dest `[` ($position^ `:` type($position))? `]` attr-dict `:`
927-
type($result)
928-
}];
929-
930-
let builders = [
931-
// 0-D builder.
932-
OpBuilder<(ins "Value":$source, "Value":$dest)>,
933-
];
934-
let extraClassDeclaration = [{
935-
Type getSourceType() { return getSource().getType(); }
936-
VectorType getDestVectorType() {
937-
return ::llvm::cast<VectorType>(getDest().getType());
938-
}
939-
}];
940-
let hasVerifier = 1;
941-
let hasFolder = 1;
942-
}
943-
944844
def Vector_InsertOp :
945845
Vector_Op<"insert", [Pure,
946846
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,

mlir/include/mlir/Dialect/X86Vector/X86Vector.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,8 +397,8 @@ def DotOp : AVX_LowOp<"dot", [Pure,
397397

398398
```mlir
399399
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
400-
%1 = vector.extractelement %0[%i0 : i32]: vector<8xf32>
401-
%2 = vector.extractelement %0[%i4 : i32]: vector<8xf32>
400+
%1 = vector.extract %0[%i0] : f32 from vector<8xf32>
401+
%2 = vector.extract %0[%i4] : f32 from vector<8xf32>
402402
%d = arith.addf %1, %2 : f32
403403
```
404404
}];

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,39 +1070,6 @@ class VectorShuffleOpConversion
10701070
}
10711071
};
10721072

1073-
class VectorExtractElementOpConversion
1074-
: public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
1075-
public:
1076-
using ConvertOpToLLVMPattern<
1077-
vector::ExtractElementOp>::ConvertOpToLLVMPattern;
1078-
1079-
LogicalResult
1080-
matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1081-
ConversionPatternRewriter &rewriter) const override {
1082-
auto vectorType = extractEltOp.getSourceVectorType();
1083-
auto llvmType = typeConverter->convertType(vectorType.getElementType());
1084-
1085-
// Bail if result type cannot be lowered.
1086-
if (!llvmType)
1087-
return failure();
1088-
1089-
if (vectorType.getRank() == 0) {
1090-
Location loc = extractEltOp.getLoc();
1091-
auto idxType = rewriter.getIndexType();
1092-
auto zero = LLVM::ConstantOp::create(rewriter, loc,
1093-
typeConverter->convertType(idxType),
1094-
rewriter.getIntegerAttr(idxType, 0));
1095-
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1096-
extractEltOp, llvmType, adaptor.getVector(), zero);
1097-
return success();
1098-
}
1099-
1100-
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1101-
extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1102-
return success();
1103-
}
1104-
};
1105-
11061073
class VectorExtractOpConversion
11071074
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
11081075
public:
@@ -1206,39 +1173,6 @@ class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
12061173
}
12071174
};
12081175

1209-
class VectorInsertElementOpConversion
1210-
: public ConvertOpToLLVMPattern<vector::InsertElementOp> {
1211-
public:
1212-
using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
1213-
1214-
LogicalResult
1215-
matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1216-
ConversionPatternRewriter &rewriter) const override {
1217-
auto vectorType = insertEltOp.getDestVectorType();
1218-
auto llvmType = typeConverter->convertType(vectorType);
1219-
1220-
// Bail if result type cannot be lowered.
1221-
if (!llvmType)
1222-
return failure();
1223-
1224-
if (vectorType.getRank() == 0) {
1225-
Location loc = insertEltOp.getLoc();
1226-
auto idxType = rewriter.getIndexType();
1227-
auto zero = LLVM::ConstantOp::create(rewriter, loc,
1228-
typeConverter->convertType(idxType),
1229-
rewriter.getIntegerAttr(idxType, 0));
1230-
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1231-
insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1232-
return success();
1233-
}
1234-
1235-
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1236-
insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1237-
adaptor.getPosition());
1238-
return success();
1239-
}
1240-
};
1241-
12421176
class VectorInsertOpConversion
12431177
: public ConvertOpToLLVMPattern<vector::InsertOp> {
12441178
public:
@@ -2244,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
22442178
VectorGatherOpConversion, VectorScatterOpConversion>(
22452179
converter, useVectorAlignment);
22462180
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2247-
VectorExtractElementOpConversion, VectorExtractOpConversion,
2248-
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
2181+
VectorExtractOpConversion, VectorFMAOp1DConversion,
22492182
VectorInsertOpConversion, VectorPrintOpConversion,
22502183
VectorTypeCastOpConversion, VectorScaleOpConversion,
22512184
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ struct PrepareTransferWriteConversion
690690
/// %lastIndex = arith.subi %length, %c1 : index
691691
/// vector.print punctuation <open>
692692
/// scf.for %i = %c0 to %length step %c1 {
693-
/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
693+
/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32>
694694
/// vector.print %el : i32 punctuation <no_punctuation>
695695
/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
696696
/// scf.if %notLastIndex {
@@ -1643,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> {
16431643
/// Is rewritten to approximately the following pseudo-IR:
16441644
/// ```
16451645
/// for i = 0 to 9 {
1646-
/// %t = vector.extractelement %vec[i] : vector<9xf32>
1646+
/// %t = vector.extract %vec[i] : f32 from vector<9xf32>
16471647
/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
16481648
/// }
16491649
/// ```

0 commit comments

Comments
 (0)