-
Notifications
You must be signed in to change notification settings - Fork 11.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][math] Propagate scalability in convert-math-to-llvm
#82635
Conversation
@llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesThis also generally increases the coverage of scalable vector types in the math-to-llvm tests. Full diff: https://github.com/llvm/llvm-project/pull/82635.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 1b729611a36235..23e957288eb95e 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -148,10 +148,10 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
+ auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get(
- {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
- floatType),
+ mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
+ {numElements.isScalable()}),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
@@ -207,10 +207,10 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
+ auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get(
- {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
- floatType),
+ mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
+ {numElements.isScalable()}),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
@@ -266,10 +266,10 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
+ auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get(
- {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
- floatType),
+ mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
+ {numElements.isScalable()}),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 3de2f11d1d12c7..ca8bba56ccd57c 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -77,6 +77,18 @@ func.func @log1p_2dvector_fmf(%arg0 : vector<4x3xf32>) {
// -----
+// CHECK-LABEL: func @log1p_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @log1p_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+ // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[VEC]] : vector<[4]xf32>
+ // CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]]) : (vector<[4]xf32>) -> vector<[4]xf32>
+ %0 = math.log1p %arg0 : vector<[4]xf32>
+ func.return %0 : vector<[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @expm1(
// CHECK-SAME: f32
func.func @expm1(%arg0 : f32) {
@@ -113,6 +125,18 @@ func.func @expm1_vector(%arg0 : vector<4xf32>) {
// -----
+// CHECK-LABEL: func @expm1_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @expm1_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+ // CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[VEC]]) : (vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : vector<[4]xf32>
+ %0 = math.expm1 %arg0 : vector<[4]xf32>
+ func.return %0 : vector<[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @expm1_vector_fmf(
// CHECK-SAME: vector<4xf32>
func.func @expm1_vector_fmf(%arg0 : vector<4xf32>) {
@@ -177,6 +201,16 @@ func.func @cttz_vec(%arg0 : vector<4xi32>) {
// -----
+// CHECK-LABEL: func @cttz_scalable_vec(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xi32>
+func.func @cttz_scalable_vec(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> {
+ // CHECK: "llvm.intr.cttz"(%[[VEC]]) <{is_zero_poison = false}> : (vector<[4]xi32>) -> vector<[4]xi32>
+ %0 = math.cttz %arg0 : vector<[4]xi32>
+ func.return %0 : vector<[4]xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @ctpop(
// CHECK-SAME: i32
func.func @ctpop(%arg0 : i32) {
@@ -197,6 +231,16 @@ func.func @ctpop_vector(%arg0 : vector<3xi32>) {
// -----
+// CHECK-LABEL: func @ctpop_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xi32>
+func.func @ctpop_scalable_vector(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> {
+ // CHECK: llvm.intr.ctpop(%[[VEC]]) : (vector<[4]xi32>) -> vector<[4]xi32>
+ %0 = math.ctpop %arg0 : vector<[4]xi32>
+ func.return %0 : vector<[4]xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @rsqrt_double(
// CHECK-SAME: f64
func.func @rsqrt_double(%arg0 : f64) {
@@ -233,6 +277,18 @@ func.func @rsqrt_vector(%arg0 : vector<4xf32>) {
// -----
+// CHECK-LABEL: func @rsqrt_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @rsqrt_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32>{
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+ // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[VEC]]) : (vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<[4]xf32>
+ %0 = math.rsqrt %arg0 : vector<[4]xf32>
+ func.return %0 : vector<[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @rsqrt_vector_fmf(
// CHECK-SAME: vector<4xf32>
func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) {
@@ -245,6 +301,18 @@ func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) {
// -----
+// CHECK-LABEL: func @rsqrt_scalable_vector_fmf(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @rsqrt_scalable_vector_fmf(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+ // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[VEC]]) {fastmathFlags = #llvm.fastmath<fast>} : (vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] {fastmathFlags = #llvm.fastmath<fast>} : vector<[4]xf32>
+ %0 = math.rsqrt %arg0 fastmath<fast> : vector<[4]xf32>
+ func.return %0 : vector<[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @rsqrt_multidim_vector(
func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
// CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
@@ -258,6 +326,19 @@ func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
// -----
+// CHECK-LABEL: func @rsqrt_multidim_vector(
+func.func @rsqrt_multidim_vector(%arg0 : vector<4x[4]xf32>) -> vector<4x[4]xf32> {
+ // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<[4]xf32>>
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+ // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[EXTRACT]]) : (vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<[4]xf32>
+ // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[DIV]], %{{.*}}[0] : !llvm.array<4 x vector<[4]xf32>>
+ %0 = math.rsqrt %arg0 : vector<4x[4]xf32>
+ func.return %0 : vector<4x[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fpowi(
// CHECK-SAME: f64
func.func @fpowi(%arg0 : f64, %arg1 : i32) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one minor comment but otherwise LGTM, cheers!
Co-authored-by: Cullen Rhodes <cullen.rhodes@arm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
This also generally increases the coverage of scalable vector types in the math-to-llvm tests.