diff --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h index 51d6b8d4545f0..4ac2528c1aedb 100644 --- a/flang/include/flang/Runtime/CUDA/memory.h +++ b/flang/include/flang/Runtime/CUDA/memory.h @@ -35,11 +35,6 @@ void RTDECL(CUFMemsetDescriptor)(Descriptor *desc, void *value, void RTDECL(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes, unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0); -/// Data transfer from a pointer to a descriptor. -void RTDECL(CUFDataTransferDescPtr)(Descriptor *dst, void *src, - std::size_t bytes, unsigned mode, const char *sourceFile = nullptr, - int sourceLine = 0); - /// Data transfer from a descriptor to a pointer. void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src, std::size_t bytes, unsigned mode, const char *sourceFile = nullptr, diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index a28d0a562f2f0..89d0af1fcd136 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -23,6 +23,7 @@ #include "flang/Runtime/allocatable.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -439,6 +440,14 @@ static bool isDstGlobal(cuf::DataTransferOp op) { return false; } +static mlir::Value getShapeFromDecl(mlir::Value src) { + if (auto declareOp = src.getDefiningOp()) + return declareOp.getShape(); + if (auto declareOp = src.getDefiningOp()) + return declareOp.getShape(); + return mlir::Value{}; +} + struct CUFDataTransferOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -528,54 +537,54 @@ struct CUFDataTransferOpConversion } // Conversion of data transfer involving at least one descriptor. - if (mlir::isa(srcTy) && - mlir::isa(dstTy)) { - // Transfer between two descriptor. + if (mlir::isa(dstTy)) { + // Transfer to a descriptor. mlir::func::FuncOp func = isDstGlobal(op) ? fir::runtime::getRuntimeFunc(loc, builder) : fir::runtime::getRuntimeFunc( loc, builder); - - auto fTy = func.getFunctionType(); - mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); - mlir::Value sourceLine = - fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); mlir::Value dst = op.getDst(); mlir::Value src = op.getSrc(); - llvm::SmallVector args{fir::runtime::createArguments( - builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)}; - builder.create(loc, func, args); - rewriter.eraseOp(op); - } else if (mlir::isa(dstTy) && fir::isa_trivial(srcTy)) { - // Scalar to descriptor transfer. - mlir::Value val = op.getSrc(); - if (op.getSrc().getDefiningOp() && - mlir::isa(op.getSrc().getDefiningOp())) { - mlir::Value alloc = builder.createTemporary(loc, srcTy); - builder.create(loc, op.getSrc(), alloc); - val = alloc; + + if (!mlir::isa(srcTy)) { + // If src is not a descriptor, create one. + mlir::Value addr; + if (fir::isa_trivial(srcTy) && + mlir::matchPattern(op.getSrc().getDefiningOp(), + mlir::m_Constant())) { + // Put constant in memory if it is not. + mlir::Value alloc = builder.createTemporary(loc, srcTy); + builder.create(loc, op.getSrc(), alloc); + addr = alloc; + } else { + addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab); + } + mlir::Type boxTy = fir::BoxType::get(srcTy); + llvm::SmallVector lenParams; + mlir::Value box = + builder.createBox(loc, boxTy, addr, getShapeFromDecl(src), + /*slice=*/nullptr, lenParams, + /*tdesc=*/nullptr); + mlir::Value memBox = builder.createTemporary(loc, box.getType()); + builder.create(loc, box, memBox); + src = memBox; } - mlir::func::FuncOp func = - fir::runtime::getRuntimeFunc(loc, - builder); auto fTy = func.getFunctionType(); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = - fir::factory::locationToLineNo(builder, loc, fTy.getInput(3)); + fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); llvm::SmallVector args{fir::runtime::createArguments( - builder, loc, fTy, op.getDst(), val, sourceFile, sourceLine)}; + builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)}; builder.create(loc, func, args); rewriter.eraseOp(op); } else { // Type used to compute the width. mlir::Type computeType = dstTy; auto seqTy = mlir::dyn_cast(dstTy); - bool dstIsDesc = false; if (mlir::isa(dstTy)) { - dstIsDesc = true; computeType = srcTy; seqTy = mlir::dyn_cast(srcTy); } @@ -606,11 +615,8 @@ struct CUFDataTransferOpConversion rewriter.create(loc, nbElement, widthValue); mlir::func::FuncOp func = - dstIsDesc - ? fir::runtime::getRuntimeFunc( - loc, builder) - : fir::runtime::getRuntimeFunc( - loc, builder); + fir::runtime::getRuntimeFunc( + loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = diff --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp index 0e03c618663eb..2d499f93fbaec 100644 --- a/flang/runtime/CUDA/memory.cpp +++ b/flang/runtime/CUDA/memory.cpp @@ -96,13 +96,6 @@ void RTDEF(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes, CUDA_REPORT_IF_ERROR(cudaMemcpy(dst, src, bytes, kind)); } -void RTDEF(CUFDataTransferDescPtr)(Descriptor *desc, void *addr, - std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) { - Terminator terminator{sourceFile, sourceLine}; - terminator.Crash( - "not yet implemented: CUDA data transfer from a pointer to a descriptor"); -} - void RTDEF(CUFDataTransferPtrDesc)(void *addr, Descriptor *desc, std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) { Terminator terminator{sourceFile, sourceLine}; diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir index a760650d14358..6a33190168024 100644 --- a/flang/test/Fir/CUDA/cuda-data-transfer.fir +++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir @@ -29,13 +29,16 @@ func.func @_QPsub2() { } // CHECK-LABEL: func.func @_QPsub2() +// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box // CHECK: %[[TEMP:.*]] = fir.alloca i32 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QFsub2Eadev"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) // CHECK: %[[C2:.*]] = arith.constant 2 : i32 // CHECK: fir.store %[[C2]] to %[[TEMP]] : !fir.ref +// CHECK: %[[EMBOX:.*]] = fir.embox %[[TEMP]] : (!fir.ref) -> !fir.box +// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref> // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref>>>) -> !fir.ref> -// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP]] : (!fir.ref) -> !fir.llvm_ptr -// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[TEMP_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.llvm_ptr, !fir.ref, i32) -> none +// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref>) -> !fir.ref> +// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[TEMP_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.ref>, i32, !fir.ref, i32) -> none func.func @_QPsub3() { %0 = cuf.alloc !fir.box>> {bindc_name = "adev", data_attr = #cuf.cuda, uniq_name = "_QFsub3Eadev"} -> !fir.ref>>> @@ -48,12 +51,15 @@ func.func @_QPsub3() { } // CHECK-LABEL: func.func @_QPsub3() +// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QFsub3Eadev"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) // CHECK: %[[V:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub3Ev"} : (!fir.ref) -> (!fir.ref, !fir.ref) +// CHECK: %[[EMBOX:.*]] = fir.embox %[[V]]#0 : (!fir.ref) -> !fir.box +// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref> // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref>>>) -> !fir.ref> -// CHECK: %[[V_CONV:.*]] = fir.convert %[[V]]#0 : (!fir.ref) -> !fir.llvm_ptr -// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[V_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.llvm_ptr, !fir.ref, i32) -> none - +// CHECK: %[[V_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref>) -> !fir.ref> +// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[V_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.ref>, i32, !fir.ref, i32) -> none + func.func @_QPsub4() { %0 = cuf.alloc !fir.box>> {bindc_name = "adev", data_attr = #cuf.cuda, uniq_name = "_QFsub4Eadev"} -> !fir.ref>>> %4:2 = hlfir.declare %0 {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QFsub4Eadev"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) @@ -67,15 +73,14 @@ func.func @_QPsub4() { return } // CHECK-LABEL: func.func @_QPsub4() +// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box> // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QFsub4Eadev"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) -// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFsub4Eahost"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) -// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index -// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index -// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index +// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[AHOST_SHAPE:.*]]) {uniq_name = "_QFsub4Eahost"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) +// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#0(%[[AHOST_SHAPE]]) : (!fir.ref>, !fir.shape<1>) -> !fir.box> +// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref>> // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref>>>) -> !fir.ref> -// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref>) -> !fir.llvm_ptr -// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64 -// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none +// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref>>) -> !fir.ref> +// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.ref>, i32, !fir.ref, i32) -> none // CHECK: %[[NBELEM:.*]] = arith.constant 10 : index // CHECK: %[[WIDTH:.*]] = arith.constant 4 : index // CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index @@ -110,16 +115,15 @@ func.func @_QPsub5(%arg0: !fir.ref {fir.bindc_name = "n"}) { } // CHECK-LABEL: func.func @_QPsub5 +// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box> // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QFsub5Eadev"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) // CHECK: %[[SHAPE:.*]] = fir.shape %[[I1:.*]], %[[I2:.*]] : (index, index) -> !fir.shape<2> // CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[SHAPE]]) {uniq_name = "_QFsub5Eahost"} : (!fir.ref>, !fir.shape<2>) -> (!fir.box>, !fir.ref>) -// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index -// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index -// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index +// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#1(%[[SHAPE]]) : (!fir.ref>, !fir.shape<2>) -> !fir.box> +// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref>> // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref>>>) -> !fir.ref> -// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref>) -> !fir.llvm_ptr -// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64 -// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none +// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref>>) -> !fir.ref> +// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.ref>, i32, !fir.ref, i32) -> none // CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index // CHECK: %[[WIDTH:.*]] = arith.constant 4 : index // CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index @@ -248,5 +252,35 @@ func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"} { // CHECK: %[[BOX_NONE:.*]] = fir.convert %[[GLOBAL_DECL:.*]]#0 : (!fir.ref>>>) -> !fir.ref> // CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[BOX_NONE]],{{.*}}) : (!fir.ref>, !fir.ref>, i32, !fir.ref, i32) -> none +fir.global @_QMmod2Eadev {data_attr = #cuf.cuda} : !fir.box>> { + %c0 = arith.constant 0 : index + %0 = fir.zero_bits !fir.heap> + %1 = fir.shape %c0 : (index) -> !fir.shape<1> + %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.heap>, !fir.shape<1>) -> !fir.box>> + fir.has_value %2 : !fir.box>> +} +func.func @_QPdesc_global_ptr() { + %c10 = arith.constant 10 : index + %0 = fir.address_of(@_QMmod2Eadev) : !fir.ref>>> + %1 = fir.declare %0 {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QMmod2Eadev"} : (!fir.ref>>>) -> !fir.ref>>> + %2 = fir.alloca !fir.array<10xi32> {bindc_name = "ahost", uniq_name = "_QFdesc_global_ptrEahost"} + %3 = fir.shape %c10 : (index) -> !fir.shape<1> + %4 = fir.declare %2(%3) {uniq_name = "_QFdesc_global_ptrEahost"} : (!fir.ref>, !fir.shape<1>) -> !fir.ref> + cuf.data_transfer %4 to %1 {transfer_kind = #cuf.cuda_transfer} : !fir.ref>, !fir.ref>>> + return +} + +// CHECK-LABEL: func.func @_QPdesc_global_ptr() +// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box> +// CHECK: %[[ADDR_ADEV:.*]] = fir.address_of(@_QMmod2Eadev) : !fir.ref>>> +// CHECK: %[[DECL_ADEV:.*]] = fir.declare %[[ADDR_ADEV]] {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QMmod2Eadev"} : (!fir.ref>>>) -> !fir.ref>>> +// CHECK: %[[AHOST:.*]] = fir.alloca !fir.array<10xi32> {bindc_name = "ahost", uniq_name = "_QFdesc_global_ptrEahost"} +// CHECK: %[[SHAPE:.*]] = fir.shape %c10 : (index) -> !fir.shape<1> +// CHECK: %[[DECL_AHOST:.*]] = fir.declare %[[AHOST]](%[[SHAPE]]) {uniq_name = "_QFdesc_global_ptrEahost"} : (!fir.ref>, !fir.shape<1>) -> !fir.ref> +// CHECK: %[[EMBOX:.*]] = fir.embox %[[DECL_AHOST]](%[[SHAPE]]) : (!fir.ref>, !fir.shape<1>) -> !fir.box> +// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref>> +// CHECK: %[[ADEV_BOXNONE:.*]] = fir.convert %[[DECL_ADEV]] : (!fir.ref>>>) -> !fir.ref> +// CHECK: %[[AHOST_BOXNONE:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref>>) -> !fir.ref> +// CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[ADEV_BOXNONE]], %[[AHOST_BOXNONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.ref>, i32, !fir.ref, i32) -> none } // end of module