diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp index 90b4b01e30f62..d29c1d06503d7 100644 --- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp +++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp @@ -1010,6 +1010,21 @@ void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter, Value value = store.getValue(); rewriter.setInsertionPointAfter(store); + // Small local optimization that avoids the round-trip: + // %25 = memref.load ... : memref + // %26 = fir.convert %25 : (i32) -> !fir.logical<4> // from load rewrite + // %27 = fir.convert %26 : (!fir.logical<4>) -> i32 // from store rewrite + // memref.store %27, ... : memref + // which would normalize the loaded value to 1 and break TRANSFER-like flows, + // e.g. transfer(transfer(i, .true.), 0). + if (auto to = value.getDefiningOp()) { + Value raw = to.getValue(); + if (auto memrefTy = dyn_cast(converted.getType())) + if (raw.getType() == memrefTy.getElementType() && + isa_and_nonnull(raw.getDefiningOp())) + value = raw; + } + if (isa(value.getType())) { Type convertedType = typeConverter.convertType(value.getType()); value = diff --git a/flang/test/Transforms/FIRToMemRef/logical.mlir b/flang/test/Transforms/FIRToMemRef/logical.mlir index 75a9fac3e1e45..1c23944a8d75b 100644 --- a/flang/test/Transforms/FIRToMemRef/logical.mlir +++ b/flang/test/Transforms/FIRToMemRef/logical.mlir @@ -28,3 +28,22 @@ func.func @store_scalar(%arg0: !fir.ref>) { fir.store %2 to %1 : !fir.ref> return } + +// CHECK-LABEL: func.func @store_loaded_logical +// CHECK: [[DUMMY:%[0-9]+]] = fir.undefined !fir.dscope +// CHECK: [[SRC_DECLARE:%[0-9]+]] = fir.declare %arg0 dummy_scope [[DUMMY]] +// CHECK: [[DST_DECLARE:%[0-9]+]] = fir.declare %arg1 dummy_scope [[DUMMY]] +// CHECK: [[SRC_MEM:%[0-9]+]] = fir.convert [[SRC_DECLARE]] : (!fir.ref>) -> memref +// CHECK: [[LOAD:%[0-9]+]] = memref.load [[SRC_MEM]][] : memref +// CHECK: [[TOLOGICAL:%[0-9]+]] = fir.convert [[LOAD]] : (i32) -> !fir.logical<4> +// CHECK: [[DST_MEM:%[0-9]+]] = fir.convert [[DST_DECLARE]] : (!fir.ref>) -> memref +// CHECK-NOT: fir.convert [[TOLOGICAL]] : (!fir.logical<4>) -> i32 +// CHECK: memref.store [[LOAD]], [[DST_MEM]][] : memref +func.func @store_loaded_logical(%arg0: !fir.ref>, %arg1: !fir.ref>) { + %0 = fir.undefined !fir.dscope + %1 = fir.declare %arg0 dummy_scope %0 {uniq_name = "src"} : (!fir.ref>, !fir.dscope) -> !fir.ref> + %2 = fir.declare %arg1 dummy_scope %0 {uniq_name = "dst"} : (!fir.ref>, !fir.dscope) -> !fir.ref> + %3 = fir.load %1 : !fir.ref> + fir.store %3 to %2 : !fir.ref> + return +}