diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index fa935542d40f7..ac285b5d403df 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -1336,7 +1336,15 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { private: // Replace `op` and remove it. void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { - op->replaceAllUsesWith(newValues); + llvm::SmallVector casts; + for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) { + if (oldValue.getType() == newValue.getType()) + casts.push_back(newValue); + else + casts.push_back(fir::ConvertOp::create(*rewriter, op->getLoc(), + oldValue.getType(), newValue)); + } + op->replaceAllUsesWith(casts); op->dropAllReferences(); op->erase(); } diff --git a/flang/test/Fir/struct-return-x86-64.fir b/flang/test/Fir/struct-return-x86-64.fir index 5d1e6129d8f69..b45983daa97ba 100644 --- a/flang/test/Fir/struct-return-x86-64.fir +++ b/flang/test/Fir/struct-return-x86-64.fir @@ -17,6 +17,10 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data %1 = fir.convert %0 : (() -> !fits_in_reg) -> (() -> ()) return %1 : () -> () } + func.func @test_addr_of_inreg_2() -> (() -> !fits_in_reg) { + %0 = fir.address_of(@test_inreg) : () -> !fits_in_reg + return %0 : () -> !fits_in_reg + } func.func @test_dispatch_inreg(%arg0: !fir.ref, %arg1: !fir.class>) { %0 = fir.dispatch "bar"(%arg1 : !fir.class>) (%arg1 : !fir.class>) -> !fits_in_reg {pass_arg_pos = 0 : i32} fir.store %0 to %arg0 : !fir.ref @@ -62,8 +66,15 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data // CHECK-LABEL: func.func @test_addr_of_inreg() -> (() -> ()) { // CHECK: %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple -// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple) -> (() -> ()) -// CHECK: return %[[VAL_1]] : () -> () +// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple) -> (() -> !fir.type) +// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type) -> (() -> ()) +// CHECK: return %[[VAL_2]] : () -> () +// CHECK: } + +// CHECK-LABEL: func.func @test_addr_of_inreg_2() -> (() -> !fir.type) { +// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple +// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple) -> (() -> !fir.type) +// CHECK: return %[[VAL_1]] : () -> !fir.type // CHECK: } // CHECK-LABEL: func.func @test_dispatch_inreg( @@ -95,8 +106,9 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data // CHECK-LABEL: func.func @test_addr_of_sret() -> (() -> ()) { // CHECK: %[[VAL_0:.*]] = fir.address_of(@test_sret) : (!fir.ref}>>) -> () -// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref}>>) -> ()) -> (() -> ()) -// CHECK: return %[[VAL_1]] : () -> () +// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref}>>) -> ()) -> (() -> !fir.type}>) +// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type}>) -> (() -> ()) +// CHECK: return %[[VAL_2]] : () -> () // CHECK: } // CHECK-LABEL: func.func @test_dispatch_sret(