From 665995b9182b3712c47e83a80125b88599dea726 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 15 Sep 2023 09:10:31 +0000 Subject: [PATCH] [mlir][Conversion] Allow lowering to fixed arrays of scalable vectors This allows lowering vector types like: vector<3x[4]> or vector<3x2x[4]> to LLVM IR, i.e. vectors where the trailing dim is scalable. This is contingent on: https://discourse.llvm.org/t/rfc-enable-arrays-of-scalable-vector-types/72935 More tests will be added in later patches, however, some MLIR fixes are needed first. Depends on: D158517 Reviewed By: awarzynski Differential Revision: https://reviews.llvm.org/D158752 --- mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp | 3 ++- mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 49e0513e629d9..fe3a8c6d41090 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -507,7 +507,8 @@ FailureOr LLVMTypeConverter::convertVectorType(VectorType type) const { type.getScalableDims().back()); assert(LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"); - if (type.isScalable() && (type.getRank() > 1)) + // Only the trailing dimension can be scalable. + if (llvm::is_contained(type.getScalableDims().drop_back(), true)) return failure(); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 4c06324087a01..7b29ef44c1f2f 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -2260,3 +2260,13 @@ func.func @vector_scalable_extract(%vec: vector<[4]xf32>) -> vector<8xf32> { %0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32> return %0 : vector<8xf32> } + +// ----- + +// CHECK-LABEL: @make_fixed_vector_of_scalable_vector +func.func @make_fixed_vector_of_scalable_vector(%f : f64) -> vector<3x[2]xf64> +{ + // CHECK: %{{.*}} = llvm.mlir.undef : !llvm.array<3 x vector<[2]xf64>> + %res = vector.broadcast %f : f64 to vector<3x[2]xf64> + return %res : vector<3x[2]xf64> +}