Skip to content

Conversation

jeanPerier
Copy link
Contributor

TargetRewrite is doing a shallow rewrite of function signatures. It is only rewriting function definitions (FuncOp), calls (CallOp) and AddressOfOp. It is not trying to visit each operations that may have an operand with a function type.
It therefore needs function signature casts around the operations it is rewriting.

Currently, these casts were not inserted after AddressOfOp rewrites because lowering tends to always insert function cast after generating AddressOfOp to the void type so the pass relied on implicitly updating this cast operand type to get the required cast. This is brittle because there is no guarantee such convert must be here and canonicalization and passes may remove them.

Insert a cast after on the result of rewritten operations. If it is redundant, it will be canonicalized away later.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Sep 8, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 8, 2025

@llvm/pr-subscribers-flang-codegen

Author: None (jeanPerier)

Changes

TargetRewrite is doing a shallow rewrite of function signatures. It is only rewriting function definitions (FuncOp), calls (CallOp) and AddressOfOp. It is not trying to visit each operations that may have an operand with a function type.
It therefore needs function signature casts around the operations it is rewriting.

Currently, these casts were not inserted after AddressOfOp rewrites because lowering tends to always insert function cast after generating AddressOfOp to the void type so the pass relied on implicitly updating this cast operand type to get the required cast. This is brittle because there is no guarantee such convert must be here and canonicalization and passes may remove them.

Insert a cast after on the result of rewritten operations. If it is redundant, it will be canonicalized away later.


Full diff: https://github.com/llvm/llvm-project/pull/157413.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/TargetRewrite.cpp (+9-1)
  • (modified) flang/test/Fir/struct-return-x86-64.fir (+16-4)
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index fa935542d40f7..ac285b5d403df 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -1336,7 +1336,15 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
 private:
   // Replace `op` and remove it.
   void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
-    op->replaceAllUsesWith(newValues);
+    llvm::SmallVector<mlir::Value> casts;
+    for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) {
+      if (oldValue.getType() == newValue.getType())
+        casts.push_back(newValue);
+      else
+        casts.push_back(fir::ConvertOp::create(*rewriter, op->getLoc(),
+                                               oldValue.getType(), newValue));
+    }
+    op->replaceAllUsesWith(casts);
     op->dropAllReferences();
     op->erase();
   }
diff --git a/flang/test/Fir/struct-return-x86-64.fir b/flang/test/Fir/struct-return-x86-64.fir
index 5d1e6129d8f69..b45983daa97ba 100644
--- a/flang/test/Fir/struct-return-x86-64.fir
+++ b/flang/test/Fir/struct-return-x86-64.fir
@@ -17,6 +17,10 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
     %1 = fir.convert %0 : (() -> !fits_in_reg) -> (() -> ())
     return %1 : () -> ()
   }
+  func.func @test_addr_of_inreg_2() -> (() -> !fits_in_reg) {
+    %0 = fir.address_of(@test_inreg) : () -> !fits_in_reg
+    return %0 : () -> !fits_in_reg
+  }
   func.func @test_dispatch_inreg(%arg0: !fir.ref<!fits_in_reg>, %arg1: !fir.class<!fir.type<somet>>) {
     %0 = fir.dispatch "bar"(%arg1 : !fir.class<!fir.type<somet>>) (%arg1 : !fir.class<!fir.type<somet>>) -> !fits_in_reg {pass_arg_pos = 0 : i32}
     fir.store %0 to %arg0 : !fir.ref<!fits_in_reg>
@@ -62,8 +66,15 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
 
 // CHECK-LABEL:   func.func @test_addr_of_inreg() -> (() -> ()) {
 // CHECK:           %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
-// CHECK:           %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> ())
-// CHECK:           return %[[VAL_1]] : () -> ()
+// CHECK:           %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
+// CHECK:           %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) -> (() -> ())
+// CHECK:           return %[[VAL_2]] : () -> ()
+// CHECK:         }
+
+// CHECK-LABEL:   func.func @test_addr_of_inreg_2() -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) {
+// CHECK:           %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
+// CHECK:           %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
+// CHECK:           return %[[VAL_1]] : () -> !fir.type<t1{i:f32,j:i32,k:f32}>
 // CHECK:         }
 
 // CHECK-LABEL:   func.func @test_dispatch_inreg(
@@ -95,8 +106,9 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
 
 // CHECK-LABEL:   func.func @test_addr_of_sret() -> (() -> ()) {
 // CHECK:           %[[VAL_0:.*]] = fir.address_of(@test_sret) : (!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()
-// CHECK:           %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> ())
-// CHECK:           return %[[VAL_1]] : () -> ()
+// CHECK:           %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> !fir.type<t2{i:!fir.array<5xf32>}>)
+// CHECK:           %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t2{i:!fir.array<5xf32>}>) -> (() -> ())
+// CHECK:           return %[[VAL_2]] : () -> ()
 // CHECK:         }
 
 // CHECK-LABEL:   func.func @test_dispatch_sret(

@llvmbot
Copy link
Member

llvmbot commented Sep 8, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (jeanPerier)

Changes

TargetRewrite is doing a shallow rewrite of function signatures. It is only rewriting function definitions (FuncOp), calls (CallOp) and AddressOfOp. It is not trying to visit each operations that may have an operand with a function type.
It therefore needs function signature casts around the operations it is rewriting.

Currently, these casts were not inserted after AddressOfOp rewrites because lowering tends to always insert function cast after generating AddressOfOp to the void type so the pass relied on implicitly updating this cast operand type to get the required cast. This is brittle because there is no guarantee such convert must be here and canonicalization and passes may remove them.

Insert a cast after on the result of rewritten operations. If it is redundant, it will be canonicalized away later.


Full diff: https://github.com/llvm/llvm-project/pull/157413.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/TargetRewrite.cpp (+9-1)
  • (modified) flang/test/Fir/struct-return-x86-64.fir (+16-4)
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index fa935542d40f7..ac285b5d403df 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -1336,7 +1336,15 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
 private:
   // Replace `op` and remove it.
   void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
-    op->replaceAllUsesWith(newValues);
+    llvm::SmallVector<mlir::Value> casts;
+    for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) {
+      if (oldValue.getType() == newValue.getType())
+        casts.push_back(newValue);
+      else
+        casts.push_back(fir::ConvertOp::create(*rewriter, op->getLoc(),
+                                               oldValue.getType(), newValue));
+    }
+    op->replaceAllUsesWith(casts);
     op->dropAllReferences();
     op->erase();
   }
diff --git a/flang/test/Fir/struct-return-x86-64.fir b/flang/test/Fir/struct-return-x86-64.fir
index 5d1e6129d8f69..b45983daa97ba 100644
--- a/flang/test/Fir/struct-return-x86-64.fir
+++ b/flang/test/Fir/struct-return-x86-64.fir
@@ -17,6 +17,10 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
     %1 = fir.convert %0 : (() -> !fits_in_reg) -> (() -> ())
     return %1 : () -> ()
   }
+  func.func @test_addr_of_inreg_2() -> (() -> !fits_in_reg) {
+    %0 = fir.address_of(@test_inreg) : () -> !fits_in_reg
+    return %0 : () -> !fits_in_reg
+  }
   func.func @test_dispatch_inreg(%arg0: !fir.ref<!fits_in_reg>, %arg1: !fir.class<!fir.type<somet>>) {
     %0 = fir.dispatch "bar"(%arg1 : !fir.class<!fir.type<somet>>) (%arg1 : !fir.class<!fir.type<somet>>) -> !fits_in_reg {pass_arg_pos = 0 : i32}
     fir.store %0 to %arg0 : !fir.ref<!fits_in_reg>
@@ -62,8 +66,15 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
 
 // CHECK-LABEL:   func.func @test_addr_of_inreg() -> (() -> ()) {
 // CHECK:           %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
-// CHECK:           %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> ())
-// CHECK:           return %[[VAL_1]] : () -> ()
+// CHECK:           %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
+// CHECK:           %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) -> (() -> ())
+// CHECK:           return %[[VAL_2]] : () -> ()
+// CHECK:         }
+
+// CHECK-LABEL:   func.func @test_addr_of_inreg_2() -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) {
+// CHECK:           %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
+// CHECK:           %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
+// CHECK:           return %[[VAL_1]] : () -> !fir.type<t1{i:f32,j:i32,k:f32}>
 // CHECK:         }
 
 // CHECK-LABEL:   func.func @test_dispatch_inreg(
@@ -95,8 +106,9 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
 
 // CHECK-LABEL:   func.func @test_addr_of_sret() -> (() -> ()) {
 // CHECK:           %[[VAL_0:.*]] = fir.address_of(@test_sret) : (!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()
-// CHECK:           %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> ())
-// CHECK:           return %[[VAL_1]] : () -> ()
+// CHECK:           %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> !fir.type<t2{i:!fir.array<5xf32>}>)
+// CHECK:           %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t2{i:!fir.array<5xf32>}>) -> (() -> ())
+// CHECK:           return %[[VAL_2]] : () -> ()
 // CHECK:         }
 
 // CHECK-LABEL:   func.func @test_dispatch_sret(

Copy link
Contributor

@rscottmanley rscottmanley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@jeanPerier jeanPerier merged commit 3beec2f into llvm:main Sep 8, 2025
13 checks passed
@jeanPerier jeanPerier deleted the convert_target_rewrite branch September 8, 2025 15:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants