diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index efbf051645610..6c3d25df12d8d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -552,6 +552,19 @@ static LogicalResult reductionPreconditions(LinalgOp op) { } static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) { + // All types in the body should be a supported element type for VectorType. + for (Operation &innerOp : op->getRegion(0).front()) { + if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) { + return !VectorType::isValidElementType(type); + })) { + return failure(); + } + if (llvm::any_of(innerOp.getResultTypes(), [](Type type) { + return !VectorType::isValidElementType(type); + })) { + return failure(); + } + } if (isElementwise(op)) return success(); // TODO: isaConvolutionOpInterface that can also infer from generic features. diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index 99617d50684ed..dbd09576cb76b 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -207,6 +207,23 @@ func.func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) { // ----- +// CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types +func.func @test_do_not_vectorize_unsupported_element_types(%A : memref<8x16xcomplex>, %arg0 : complex) { + // CHECK-NOT: vector.broadcast + // CHECK-NOT: vector.transfer_write + linalg.generic { + indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : complex) + outs(%A: memref<8x16xcomplex>) { + ^bb(%0: complex, %1: complex) : + linalg.yield %0 : complex + } + return +} + +// ----- + // CHECK-LABEL: func @test_vectorize_fill func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) { // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>