diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 29bd100164b79..08239230f793f 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2436,6 +2436,65 @@ def fir_DispatchOp : fir_Op<"dispatch", []> { }]; } +def fir_CUDAKernelLaunch : fir_Op<"cuda_kernel_launch", [CallOpInterface, + AttrSizedOperandSegments]> { + let summary = "call CUDA kernel"; + + let description = [{ + Launch a CUDA kernel from the host. + + ``` + // launch simple kernel with no arguments. bytes and stream value are + // optional in the chevron notation. + fir.cuda_kernel_launch @kernel<<<%gx, %gy, %bx, %by, %bz>>>() + ``` + }]; + + let arguments = (ins + SymbolRefAttr:$callee, + I32:$grid_x, + I32:$grid_y, + I32:$block_x, + I32:$block_y, + I32:$block_z, + Optional:$bytes, + Optional:$stream, + Variadic:$args + ); + + let assemblyFormat = [{ + $callee `<` `<` `<` $grid_x `,` $grid_y `,` $block_x `,` $block_y `,` + $block_z ( `,` $bytes^ ( `,` $stream^ )? )? `>` `>` `>` + `` `(` ( $args^ `:` type($args) )? `)` attr-dict + }]; + + let extraClassDeclaration = [{ + mlir::CallInterfaceCallable getCallableForCallee() { + return getCalleeAttr(); + } + + void setCalleeFromCallable(mlir::CallInterfaceCallable callee) { + (*this)->setAttr(getCalleeAttrName(), callee.get()); + } + mlir::FunctionType getFunctionType(); + + unsigned getNbNoArgOperand() { + unsigned nbNoArgOperand = 5; // grids and blocks values are always present. + if (getBytes()) ++nbNoArgOperand; + if (getStream()) ++nbNoArgOperand; + return nbNoArgOperand; + } + + operand_range getArgOperands() { + return {operand_begin() + getNbNoArgOperand(), operand_end()}; + } + mlir::MutableOperandRange getArgOperandsMutable() { + return mlir::MutableOperandRange( + *this, getNbNoArgOperand(), getArgs().size() - 1); + } + }]; +} + // Constant operations that support Fortran def fir_StringLitOp : fir_Op<"string_lit", [NoMemoryEffect]> { diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp index d8271b1f14635..baf08b58a91b3 100644 --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -149,6 +149,21 @@ static bool mustCastFuncOpToCopeWithImplicitInterfaceMismatch( return false; } +static mlir::Value readDim3Value(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value dim3Addr, llvm::StringRef comp) { + mlir::Type i32Ty = builder.getI32Type(); + mlir::Type refI32Ty = fir::ReferenceType::get(i32Ty); + llvm::SmallVector lenParams; + + mlir::Value designate = builder.create( + loc, refI32Ty, dim3Addr, /*component=*/comp, + /*componentShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{}, + /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, + mlir::Value{}, lenParams); + + return hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{designate}); +} + std::pair Fortran::lower::genCallOpAndResult( mlir::Location loc, Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx, @@ -394,7 +409,67 @@ std::pair Fortran::lower::genCallOpAndResult( mlir::Value callResult; unsigned callNumResults; - if (caller.requireDispatchCall()) { + + if (!caller.getCallDescription().chevrons().empty()) { + // A call to a CUDA kernel with the chevron syntax. + + mlir::Type i32Ty = builder.getI32Type(); + mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1); + + mlir::Value grid_x, grid_y; + if (caller.getCallDescription().chevrons()[0].GetType()->category() == + Fortran::common::TypeCategory::Integer) { + // If grid is an integer, it is converted to dim3(grid,1,1). Since z is + // not used for the number of thread blocks, it is omitted in the op. + grid_x = builder.createConvert( + loc, i32Ty, + fir::getBase(converter.genExprValue( + caller.getCallDescription().chevrons()[0], stmtCtx))); + grid_y = one; + } else { + auto dim3Addr = converter.genExprAddr( + caller.getCallDescription().chevrons()[0], stmtCtx); + grid_x = readDim3Value(builder, loc, fir::getBase(dim3Addr), "x"); + grid_y = readDim3Value(builder, loc, fir::getBase(dim3Addr), "y"); + } + + mlir::Value block_x, block_y, block_z; + if (caller.getCallDescription().chevrons()[1].GetType()->category() == + Fortran::common::TypeCategory::Integer) { + // If block is an integer, it is converted to dim3(block,1,1). + block_x = builder.createConvert( + loc, i32Ty, + fir::getBase(converter.genExprValue( + caller.getCallDescription().chevrons()[1], stmtCtx))); + block_y = one; + block_z = one; + } else { + auto dim3Addr = converter.genExprAddr( + caller.getCallDescription().chevrons()[1], stmtCtx); + block_x = readDim3Value(builder, loc, fir::getBase(dim3Addr), "x"); + block_y = readDim3Value(builder, loc, fir::getBase(dim3Addr), "y"); + block_z = readDim3Value(builder, loc, fir::getBase(dim3Addr), "z"); + } + + mlir::Value bytes; // bytes is optional. + if (caller.getCallDescription().chevrons().size() > 2) + bytes = builder.createConvert( + loc, i32Ty, + fir::getBase(converter.genExprValue( + caller.getCallDescription().chevrons()[2], stmtCtx))); + + mlir::Value stream; // stream is optional. + if (caller.getCallDescription().chevrons().size() > 3) + stream = builder.createConvert( + loc, i32Ty, + fir::getBase(converter.genExprValue( + caller.getCallDescription().chevrons()[3], stmtCtx))); + + builder.create( + loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, block_x, + block_y, block_z, bytes, stream, operands); + callNumResults = 0; + } else if (caller.requireDispatchCall()) { // Procedure call requiring a dynamic dispatch. Call is created with // fir.dispatch. diff --git a/flang/test/Lower/CUDA/cuda-kernel-calls.cuf b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf new file mode 100644 index 0000000000000..c1e89d1978e4c --- /dev/null +++ b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf @@ -0,0 +1,50 @@ +! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s + +! Test lowering of CUDA procedure calls. + +module test_call + use, intrinsic :: __fortran_builtins, only: __builtin_dim3 +contains + attributes(global) subroutine dev_kernel0() + end + + attributes(global) subroutine dev_kernel1(a) + real :: a + end + + subroutine host() + real, device :: a +! CHECK-LABEL: func.func @_QMtest_callPhost() +! CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda, uniq_name = "_QMtest_callFhostEa"} : (!fir.ref) -> (!fir.ref, !fir.ref) + + call dev_kernel0<<<10, 20>>>() +! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}>>>() + + call dev_kernel0<<< __builtin_dim3(1,1), __builtin_dim3(32,1,1) >>> +! CHECK: %[[ADDR_DIM3_GRID:.*]] = fir.address_of(@_QQro._QM__fortran_builtinsT__builtin_dim3.{{.*}}) : !fir.ref> +! CHECK: %[[DIM3_GRID:.*]]:2 = hlfir.declare %[[ADDR_DIM3_GRID]] {fortran_attrs = #fir.var_attrs, uniq_name = "_QQro._QM__fortran_builtinsT__builtin_dim3.0"} : (!fir.ref>) -> (!fir.ref>, !fir.ref>) +! CHECK: %[[GRID_X:.*]] = hlfir.designate %[[DIM3_GRID]]#1{"x"} : (!fir.ref>) -> !fir.ref +! CHECK: %[[GRID_X_LOAD:.*]] = fir.load %[[GRID_X]] : !fir.ref +! CHECK: %[[GRID_Y:.*]] = hlfir.designate %[[DIM3_GRID]]#1{"y"} : (!fir.ref>) -> !fir.ref +! CHECK: %[[GRID_Y_LOAD:.*]] = fir.load %[[GRID_Y]] : !fir.ref +! CHECK: %[[ADDR_DIM3_BLOCK:.*]] = fir.address_of(@_QQro._QM__fortran_builtinsT__builtin_dim3.{{.*}}) : !fir.ref> +! CHECK: %[[DIM3_BLOCK:.*]]:2 = hlfir.declare %[[ADDR_DIM3_BLOCK]] {fortran_attrs = #fir.var_attrs, uniq_name = "_QQro._QM__fortran_builtinsT__builtin_dim3.1"} : (!fir.ref>) -> (!fir.ref>, !fir.ref>) +! CHECK: %[[BLOCK_X:.*]] = hlfir.designate %[[DIM3_BLOCK]]#1{"x"} : (!fir.ref>) -> !fir.ref +! CHECK: %[[BLOCK_X_LOAD:.*]] = fir.load %[[BLOCK_X]] : !fir.ref +! CHECK: %[[BLOCK_Y:.*]] = hlfir.designate %[[DIM3_BLOCK]]#1{"y"} : (!fir.ref>) -> !fir.ref +! CHECK: %[[BLOCK_Y_LOAD:.*]] = fir.load %[[BLOCK_Y]] : !fir.ref +! CHECK: %[[BLOCK_Z:.*]] = hlfir.designate %[[DIM3_BLOCK]]#1{"z"} : (!fir.ref>) -> !fir.ref +! CHECK: %[[BLOCK_Z_LOAD:.*]] = fir.load %[[BLOCK_Z]] : !fir.ref +! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%[[GRID_X_LOAD]], %[[GRID_Y_LOAD]], %[[BLOCK_X_LOAD]], %[[BLOCK_Y_LOAD]], %[[BLOCK_Z_LOAD]]>>>() + + call dev_kernel0<<<10, 20, 2>>>() +! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}>>>() + + call dev_kernel0<<<10, 20, 2, 0>>>() +! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}, %c0{{.*}}>>>() + + call dev_kernel1<<<1, 32>>>(a) +! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel1<<<%c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%1#1 : !fir.ref) + end + +end