diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp index 2ef36a21a1ac0..cb144fad67211 100644 --- a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp +++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp @@ -14,6 +14,38 @@ namespace mlir { #include "mlir/Interfaces/IndexingMapOpInterface.cpp.inc" } // namespace mlir +static LogicalResult verifyIndexingMapOperandType(Operation *op, Type t, + unsigned operandNumber) { + // Scalars are allowed (treated as rank-0). verifyImpl checks the rank. + if (t.isIntOrIndexOrFloat() || isa(t)) + return success(); + + // Vectors are allowed. + if (isa(t)) + return success(); + + // MemRefs: must be ranked. + if (isa(t)) { + return op->emitOpError("operand #") + << operandNumber << " must be a ranked memref, but got " << t; + } + if (isa(t)) + return success(); + + // Tensors: must be ranked. + if (isa(t)) { + return op->emitOpError("operand #") + << operandNumber << " must be a ranked tensor, but got " << t; + } + if (isa(t)) + return success(); + + // Any other shaped type is not supported by this interface. + return op->emitOpError("operand #") + << operandNumber + << " must be ranked tensor/memref, vector, or scalar, but got " << t; +} + LogicalResult mlir::IndexingMapOpInterface::verifyImpl() { // All input/output operands must be indexed. if (static_cast(getIndexingMapsArray().size()) != @@ -26,14 +58,27 @@ LogicalResult mlir::IndexingMapOpInterface::verifyImpl() { SmallVector allShapesSizes; for (OpOperand &opOperand : getOperation()->getOpOperands()) { + Type ty = opOperand.get().getType(); + if (failed(verifyIndexingMapOperandType(getOperation(), ty, + opOperand.getOperandNumber()))) + return failure(); AffineMap indexingMap = getMatchingIndexingMap(&opOperand); - SmallVector shape = getStaticOperandShape(&opOperand); - int64_t rank = shape.size(); - // Symbols disallowed. if (indexingMap.getNumSymbols() != 0) return this->emitOpError("unexpected symbols in indexing_map #") << opOperand.getOperandNumber(); + // Handle scalars. + if (ty.isIntOrIndexOrFloat() || isa(ty)) { + int64_t rank = 0; + if (indexingMap.getNumResults() != rank) + return this->emitOpError("expected operand #") + << opOperand.getOperandNumber() << " rank (" << rank + << ") to match the result rank of indexing_map (" + << indexingMap.getNumResults() << ")"; + continue; + } + SmallVector shape = getStaticOperandShape(&opOperand); + int64_t rank = shape.size(); // Result rank must match operand rank. if (indexingMap.getNumResults() != rank) diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 355d801f8732c..cc33205eb1486 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -321,6 +321,24 @@ func.func @generic_result_tensor_type(%arg0: memref // ----- +// Unranked tensor inputs must be diagnosed. +func.func @generic_unranked_input_tensor(%in: tensor<*xf32>) { + %out = tensor.empty() : tensor<16x16xf32> + // expected-error @+1 {{'linalg.generic' op operand #0 must be a ranked tensor, but got 'tensor<*xf32>'}} + %r = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%in : tensor<*xf32>) + outs(%out : tensor<16x16xf32>) { + ^bb0(%a: f32, %b: f32): + linalg.yield %a : f32 + } -> tensor<16x16xf32> + return +} + +// ----- + func.func @generic(%arg0: memref) { // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) <{fastmath = #arith.fastmath}> : (f32, f32) -> f32}} linalg.generic {