diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 8353314ed958b..d47206f52def8 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -691,8 +691,9 @@ def Vector_ExtractOp : InferTypeOpAdaptorWithIsCompatible]> { let summary = "extract operation"; let description = [{ - Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at - the proper position. Degenerates to an element type if n-k is zero. + Extracts an (n − k)-D result sub-vector from an n-D source vector at a + specified k-D position. When n = k, the result degenerates to a scalar + element. Static and dynamic indices must be greater or equal to zero and less than the size of the corresponding dimension. The result is undefined if any @@ -704,7 +705,6 @@ def Vector_ExtractOp : ```mlir %1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32> %2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32> - %3 = vector.extract %1[]: vector from vector %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32> %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32> %6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32> @@ -886,9 +886,9 @@ def Vector_InsertOp : AllTypesMatch<["dest", "result"]>]> { let summary = "insert operation"; let description = [{ - Takes an n-D source vector, an (n+k)-D destination vector and a k-D position - and inserts the n-D source into the (n+k)-D destination at the proper - position. Degenerates to a scalar or a 0-d vector source type when n = 0. + Inserts an (n - k)-D sub-vector (value-to-store) into an n-D destination + vector at a specified k-D position. When n = 0, value-to-store degenerates + to a scalar element inserted into the n-D destination vector. Static and dynamic indices must be greater or equal to zero and less than the size of the corresponding dimension. The result is undefined if any @@ -900,8 +900,7 @@ def Vector_InsertOp : ```mlir %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32> %5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32> - %8 = vector.insert %6, %7[] : f32 into vector - %11 = vector.insert %9, %10[%a, %b, %c] : vector into vector<4x8x16xf32> + %11 = vector.insert %9, %10[%a, %b, %c] : f32 into vector<4x8x16xf32> %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32> %13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32> ``` diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index cc5623068ab10..002dfebd2b602 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1324,6 +1324,8 @@ struct UnrollTransferReadConversion for (int64_t i = 0; i < dimSize; ++i) { Value iv = rewriter.create(loc, i); + // FIXME: Rename this lambda - it does much more than just + // in-bounds-check generation. vec = generateInBoundsCheck( rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), /*inBoundsCase=*/ @@ -1338,12 +1340,21 @@ struct UnrollTransferReadConversion insertionIndices.push_back(rewriter.getIndexAttr(i)); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); + auto newXferOp = b.create( loc, newXferVecType, xferOp.getBase(), xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.getPadding(), Value(), inBoundsAttr); maybeAssignMask(b, xferOp, newXferOp, i); - return b.create(loc, newXferOp, vec, + + Value valToInser = newXferOp.getResult(); + if (newXferVecType.getRank() == 0) { + // vector.insert does not accept rank-0 as the non-indexed + // argument. Extract the scalar before inserting. + valToInser = b.create(loc, valToInser, + SmallVector()); + } + return b.create(loc, valToInser, vec, insertionIndices); }, /*outOfBoundsCase=*/ diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 2a2357319bd23..dc4bcd9b6bd84 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1383,6 +1383,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { } LogicalResult vector::ExtractOp::verify() { + if (auto resTy = dyn_cast(getResult().getType())) + if (resTy.getRank() == 0) + return emitError( + "expected a scalar instead of a 0-d vector as the result type"); + // Note: This check must come before getMixedPosition() to prevent a crash. auto dynamicMarkersCount = llvm::count_if(getStaticPosition(), ShapedType::isDynamic); @@ -3122,6 +3127,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result, } LogicalResult InsertOp::verify() { + if (auto srcTy = dyn_cast(getValueToStoreType())) + if (srcTy.getRank() == 0) + return emitError( + "expected a scalar instead of a 0-d vector as the source operand"); + SmallVector position = getMixedPosition(); auto destVectorType = getDestVectorType(); if (position.size() > static_cast(destVectorType.getRank())) diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 04810ed52584f..a2622c06fa71c 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -178,9 +178,9 @@ func.func @extract_precise_position_overflow(%arg0: vector<4x8x16xf32>) { // ----- -func.func @extract_0d(%arg0: vector) { - // expected-error@+1 {{expected position attribute of rank no greater than vector rank}} - %1 = vector.extract %arg0[0] : f32 from vector +func.func @extract_0d_result(%arg0: vector) { + // expected-error@+1 {{expected a scalar instead of a 0-d vector as the result type}} + %1 = vector.extract %arg0[] : vector from vector } // ----- @@ -259,16 +259,9 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) { // ----- -func.func @insert_0d(%a: vector, %b: vector<4x8x16xf32>) { - // expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}} - %1 = vector.insert %a, %b[2, 6] : vector into vector<4x8x16xf32> -} - -// ----- - -func.func @insert_0d(%a: f32, %b: vector) { - // expected-error@+1 {{expected position attribute of rank no greater than dest vector rank}} - %1 = vector.insert %a, %b[0] : f32 into vector +func.func @insert_0d_value_to_store(%a: vector, %b: vector<4x8x16xf32>) { + // expected-error@+1 {{expected a scalar instead of a 0-d vector as the source operand}} + %1 = vector.insert %a, %b[0, 0, 0] : vector into vector<4x8x16xf32> } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index f3220aed4360c..7d43f2a84dc77 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -298,12 +298,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, } // CHECK-LABEL: @insert_0d -func.func @insert_0d(%a: f32, %b: vector, %c: vector<2x3xf32>) -> (vector, vector<2x3xf32>) { +func.func @insert_0d(%a: f32, %b: vector) -> vector { // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector %1 = vector.insert %a, %b[] : f32 into vector - // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector into vector<2x3xf32> - %2 = vector.insert %b, %c[0, 1] : vector into vector<2x3xf32> - return %1, %2 : vector, vector<2x3xf32> + return %1 : vector } // CHECK-LABEL: @insert_poison_idx