-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[flang] do not rely on existing fir.convert in TargetRewrite #157413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-codegen Author: None (jeanPerier) ChangesTargetRewrite is doing a shallow rewrite of function signatures. It is only rewriting function definitions (FuncOp), calls (CallOp) and AddressOfOp. It is not trying to visit each operations that may have an operand with a function type. Currently, these casts were not inserted after AddressOfOp rewrites because lowering tends to always insert function cast after generating AddressOfOp to the void type so the pass relied on implicitly updating this cast operand type to get the required cast. This is brittle because there is no guarantee such convert must be here and canonicalization and passes may remove them. Insert a cast after on the result of rewritten operations. If it is redundant, it will be canonicalized away later. Full diff: https://github.com/llvm/llvm-project/pull/157413.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index fa935542d40f7..ac285b5d403df 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -1336,7 +1336,15 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
private:
// Replace `op` and remove it.
void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
- op->replaceAllUsesWith(newValues);
+ llvm::SmallVector<mlir::Value> casts;
+ for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) {
+ if (oldValue.getType() == newValue.getType())
+ casts.push_back(newValue);
+ else
+ casts.push_back(fir::ConvertOp::create(*rewriter, op->getLoc(),
+ oldValue.getType(), newValue));
+ }
+ op->replaceAllUsesWith(casts);
op->dropAllReferences();
op->erase();
}
diff --git a/flang/test/Fir/struct-return-x86-64.fir b/flang/test/Fir/struct-return-x86-64.fir
index 5d1e6129d8f69..b45983daa97ba 100644
--- a/flang/test/Fir/struct-return-x86-64.fir
+++ b/flang/test/Fir/struct-return-x86-64.fir
@@ -17,6 +17,10 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
%1 = fir.convert %0 : (() -> !fits_in_reg) -> (() -> ())
return %1 : () -> ()
}
+ func.func @test_addr_of_inreg_2() -> (() -> !fits_in_reg) {
+ %0 = fir.address_of(@test_inreg) : () -> !fits_in_reg
+ return %0 : () -> !fits_in_reg
+ }
func.func @test_dispatch_inreg(%arg0: !fir.ref<!fits_in_reg>, %arg1: !fir.class<!fir.type<somet>>) {
%0 = fir.dispatch "bar"(%arg1 : !fir.class<!fir.type<somet>>) (%arg1 : !fir.class<!fir.type<somet>>) -> !fits_in_reg {pass_arg_pos = 0 : i32}
fir.store %0 to %arg0 : !fir.ref<!fits_in_reg>
@@ -62,8 +66,15 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
// CHECK-LABEL: func.func @test_addr_of_inreg() -> (() -> ()) {
// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
-// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> ())
-// CHECK: return %[[VAL_1]] : () -> ()
+// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
+// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) -> (() -> ())
+// CHECK: return %[[VAL_2]] : () -> ()
+// CHECK: }
+
+// CHECK-LABEL: func.func @test_addr_of_inreg_2() -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) {
+// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
+// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
+// CHECK: return %[[VAL_1]] : () -> !fir.type<t1{i:f32,j:i32,k:f32}>
// CHECK: }
// CHECK-LABEL: func.func @test_dispatch_inreg(
@@ -95,8 +106,9 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
// CHECK-LABEL: func.func @test_addr_of_sret() -> (() -> ()) {
// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_sret) : (!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()
-// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> ())
-// CHECK: return %[[VAL_1]] : () -> ()
+// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> !fir.type<t2{i:!fir.array<5xf32>}>)
+// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t2{i:!fir.array<5xf32>}>) -> (() -> ())
+// CHECK: return %[[VAL_2]] : () -> ()
// CHECK: }
// CHECK-LABEL: func.func @test_dispatch_sret(
|
@llvm/pr-subscribers-flang-fir-hlfir Author: None (jeanPerier) ChangesTargetRewrite is doing a shallow rewrite of function signatures. It is only rewriting function definitions (FuncOp), calls (CallOp) and AddressOfOp. It is not trying to visit each operations that may have an operand with a function type. Currently, these casts were not inserted after AddressOfOp rewrites because lowering tends to always insert function cast after generating AddressOfOp to the void type so the pass relied on implicitly updating this cast operand type to get the required cast. This is brittle because there is no guarantee such convert must be here and canonicalization and passes may remove them. Insert a cast after on the result of rewritten operations. If it is redundant, it will be canonicalized away later. Full diff: https://github.com/llvm/llvm-project/pull/157413.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index fa935542d40f7..ac285b5d403df 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -1336,7 +1336,15 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
private:
// Replace `op` and remove it.
void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
- op->replaceAllUsesWith(newValues);
+ llvm::SmallVector<mlir::Value> casts;
+ for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) {
+ if (oldValue.getType() == newValue.getType())
+ casts.push_back(newValue);
+ else
+ casts.push_back(fir::ConvertOp::create(*rewriter, op->getLoc(),
+ oldValue.getType(), newValue));
+ }
+ op->replaceAllUsesWith(casts);
op->dropAllReferences();
op->erase();
}
diff --git a/flang/test/Fir/struct-return-x86-64.fir b/flang/test/Fir/struct-return-x86-64.fir
index 5d1e6129d8f69..b45983daa97ba 100644
--- a/flang/test/Fir/struct-return-x86-64.fir
+++ b/flang/test/Fir/struct-return-x86-64.fir
@@ -17,6 +17,10 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
%1 = fir.convert %0 : (() -> !fits_in_reg) -> (() -> ())
return %1 : () -> ()
}
+ func.func @test_addr_of_inreg_2() -> (() -> !fits_in_reg) {
+ %0 = fir.address_of(@test_inreg) : () -> !fits_in_reg
+ return %0 : () -> !fits_in_reg
+ }
func.func @test_dispatch_inreg(%arg0: !fir.ref<!fits_in_reg>, %arg1: !fir.class<!fir.type<somet>>) {
%0 = fir.dispatch "bar"(%arg1 : !fir.class<!fir.type<somet>>) (%arg1 : !fir.class<!fir.type<somet>>) -> !fits_in_reg {pass_arg_pos = 0 : i32}
fir.store %0 to %arg0 : !fir.ref<!fits_in_reg>
@@ -62,8 +66,15 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
// CHECK-LABEL: func.func @test_addr_of_inreg() -> (() -> ()) {
// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
-// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> ())
-// CHECK: return %[[VAL_1]] : () -> ()
+// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
+// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) -> (() -> ())
+// CHECK: return %[[VAL_2]] : () -> ()
+// CHECK: }
+
+// CHECK-LABEL: func.func @test_addr_of_inreg_2() -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) {
+// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
+// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
+// CHECK: return %[[VAL_1]] : () -> !fir.type<t1{i:f32,j:i32,k:f32}>
// CHECK: }
// CHECK-LABEL: func.func @test_dispatch_inreg(
@@ -95,8 +106,9 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
// CHECK-LABEL: func.func @test_addr_of_sret() -> (() -> ()) {
// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_sret) : (!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()
-// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> ())
-// CHECK: return %[[VAL_1]] : () -> ()
+// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> !fir.type<t2{i:!fir.array<5xf32>}>)
+// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t2{i:!fir.array<5xf32>}>) -> (() -> ())
+// CHECK: return %[[VAL_2]] : () -> ()
// CHECK: }
// CHECK-LABEL: func.func @test_dispatch_sret(
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
TargetRewrite is doing a shallow rewrite of function signatures. It is only rewriting function definitions (FuncOp), calls (CallOp) and AddressOfOp. It is not trying to visit each operations that may have an operand with a function type.
It therefore needs function signature casts around the operations it is rewriting.
Currently, these casts were not inserted after AddressOfOp rewrites because lowering tends to always insert function cast after generating AddressOfOp to the void type so the pass relied on implicitly updating this cast operand type to get the required cast. This is brittle because there is no guarantee such convert must be here and canonicalization and passes may remove them.
Insert a cast after on the result of rewritten operations. If it is redundant, it will be canonicalized away later.