From 126e161f04e99f74c5ed14966c5a9ee942517db0 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Mon, 13 Oct 2025 11:41:20 -0700 Subject: [PATCH] [flang][cuda] Make sure dstEleTy is set when used in CUFOpConversion --- .../Optimizer/Transforms/CUFOpConversion.cpp | 4 +- flang/test/Fir/CUDA/cuda-data-transfer.fir | 40 +++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index 609a1fc9fb02c..e5c5ba9082426 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -558,6 +558,7 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter, if (srcTy.isInteger(1)) { // i1 is not a supported type in the descriptor and it is actually coming // from a LOGICAL constant. Use the destination type to avoid mismatch. + assert(dstEleTy && "expect dst element type to be set"); srcTy = dstEleTy; src = createConvertOp(rewriter, loc, srcTy, src); addr = builder.createTemporary(loc, srcTy); @@ -652,7 +653,8 @@ struct CUFDataTransferOpConversion // Initialization of an array from a scalar value should be implemented // via a kernel launch. Use the flang runtime via the Assign function // until we have more infrastructure. - mlir::Value src = emboxSrc(rewriter, op, symtab); + mlir::Type dstEleTy = fir::unwrapInnerType(fir::unwrapRefType(dstTy)); + mlir::Value src = emboxSrc(rewriter, op, symtab, dstEleTy); mlir::Value dst = emboxDst(rewriter, op, symtab); mlir::func::FuncOp func = fir::runtime::getRuntimeFunc( diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir index 669300cf64737..5d3215dd07fce 100644 --- a/flang/test/Fir/CUDA/cuda-data-transfer.fir +++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir @@ -651,5 +651,45 @@ func.func @_QPsub28() { // CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DESC]] : (!fir.ref>>) -> !fir.ref> // CHECK: fir.call @_FortranACUFDataTransferCstDesc(%{{.*}}, %[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.ref>, i32, !fir.ref, i32) -> () +func.func @_QPtesti4(%arg0: !fir.ref {fir.bindc_name = "n1"}, %arg1: !fir.ref {fir.bindc_name = "n2"}, %arg2: !fir.ref {fir.bindc_name = "n3"}, %arg3: !fir.ref {fir.bindc_name = "n4"}) { + %true = arith.constant true + %c0 = arith.constant 0 : index + %c2_i32 = arith.constant 2 : i32 + %0 = fir.dummy_scope : !fir.dscope + %1:2 = hlfir.declare %arg0 dummy_scope %0 {uniq_name = "_QFtesti4En1"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) + %2:2 = hlfir.declare %arg1 dummy_scope %0 {uniq_name = "_QFtesti4En2"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) + %3:2 = hlfir.declare %arg2 dummy_scope %0 {uniq_name = "_QFtesti4En3"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) + %4:2 = hlfir.declare %arg3 dummy_scope %0 {uniq_name = "_QFtesti4En4"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) + %5 = fir.load %1#0 : !fir.ref + %6 = arith.divsi %5, %c2_i32 : i32 + %7 = fir.convert %6 : (i32) -> index + %8 = arith.cmpi sgt, %7, %c0 : index + %9 = arith.select %8, %7, %c0 : index + %10 = fir.load %2#0 : !fir.ref + %11 = arith.divsi %10, %c2_i32 : i32 + %12 = fir.convert %11 : (i32) -> index + %13 = arith.cmpi sgt, %12, %c0 : index + %14 = arith.select %13, %12, %c0 : index + %15 = fir.load %3#0 : !fir.ref + %16 = arith.divsi %15, %c2_i32 : i32 + %17 = fir.convert %16 : (i32) -> index + %18 = arith.cmpi sgt, %17, %c0 : index + %19 = arith.select %18, %17, %c0 : index + %20 = fir.load %4#0 : !fir.ref + %21 = arith.divsi %20, %c2_i32 : i32 + %22 = fir.convert %21 : (i32) -> index + %23 = arith.cmpi sgt, %22, %c0 : index + %24 = arith.select %23, %22, %c0 : index + %25 = cuf.alloc !fir.array>, %9, %14, %19, %24 : index, index, index, index {bindc_name = "lma", data_attr = #cuf.cuda, uniq_name = "_QFtesti4Elma"} -> !fir.ref>> + %26 = fir.shape %9, %14, %19, %24 : (index, index, index, index) -> !fir.shape<4> + %27:2 = hlfir.declare %25(%26) {data_attr = #cuf.cuda, uniq_name = "_QFtesti4Elma"} : (!fir.ref>>, !fir.shape<4>) -> (!fir.box>>, !fir.ref>>) + cuf.data_transfer %true to %27#1, %26 : !fir.shape<4> {transfer_kind = #cuf.cuda_transfer} : i1, !fir.ref>> + cuf.free %27#1 : !fir.ref>> {data_attr = #cuf.cuda} + return +} + +// CHECK-LABEL: func.func @_QPtesti4 +// CHECK: fir.call @_FortranACUFDataTransferCstDesc + } // end of module