diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp index abb34d9129f85..b9539565df8a3 100644 --- a/flang/lib/Lower/IntrinsicCall.cpp +++ b/flang/lib/Lower/IntrinsicCall.cpp @@ -38,6 +38,19 @@ #define PGMATH_DECLARE #include "flang/Evaluate/pgmath.h.inc" +/// This file implements lowering of Fortran intrinsic procedures. +/// Intrinsics are lowered to a mix of FIR and MLIR operations as +/// well as call to runtime functions or LLVM intrinsics. + +/// Lowering of intrinsic procedure calls is based on a map that associates +/// Fortran intrinsic generic names to FIR generator functions. +/// All generator functions are member functions of the IntrinsicLibrary class +/// and have the same interface. +/// If no generator is given for an intrinsic name, a math runtime library +/// is searched for an implementation and, if a runtime function is found, +/// a call is generated for it. LLVM intrinsics are handled as a math +/// runtime library here. + /// Enums used to templatize and share lowering of MIN and MAX. enum class Extremum { Min, Max }; @@ -81,19 +94,6 @@ enum class ExtremumBehavior { // possible to implement it without some target dependent runtime. }; -/// This file implements lowering of Fortran intrinsic procedures. -/// Intrinsics are lowered to a mix of FIR and MLIR operations as -/// well as call to runtime functions or LLVM intrinsics. - -/// Lowering of intrinsic procedure calls is based on a map that associates -/// Fortran intrinsic generic names to FIR generator functions. -/// All generator functions are member functions of the IntrinsicLibrary class -/// and have the same interface. -/// If no generator is given for an intrinsic name, a math runtime library -/// is searched for an implementation and, if a runtime function is found, -/// a call is generated for it. LLVM intrinsics are handled as a math -/// runtime library here. - fir::ExtendedValue Fortran::lower::getAbsentIntrinsicArgument() { return fir::UnboxedValue{}; } @@ -439,6 +439,7 @@ struct IntrinsicLibrary { fir::ExtendedValue genAssociated(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genChar(mlir::Type, llvm::ArrayRef); + fir::ExtendedValue genCount(mlir::Type, llvm::ArrayRef); mlir::Value genDim(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genDotProduct(mlir::Type, llvm::ArrayRef); @@ -592,6 +593,10 @@ static constexpr IntrinsicHandler handlers[]{ {{{"pointer", asInquired}, {"target", asInquired}}}, /*isElemental=*/false}, {"char", &I::genChar}, + {"count", + &I::genCount, + {{{"mask", asAddr}, {"dim", asValue}, {"kind", asValue}}}, + /*isElemental=*/false}, {"cpu_time", &I::genCpuTime, {{{"time", asAddr}}}, @@ -1644,31 +1649,64 @@ IntrinsicLibrary::genChar(mlir::Type type, return fir::CharBoxValue{cast, len}; } -// DIM -mlir::Value IntrinsicLibrary::genDim(mlir::Type resultType, - llvm::ArrayRef args) { - assert(args.size() == 2); - if (resultType.isa()) { - mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0); - auto diff = builder.create(loc, args[0], args[1]); - auto cmp = builder.create( - loc, mlir::arith::CmpIPredicate::sgt, diff, zero); - return builder.create(loc, cmp, diff, zero); +// COUNT +fir::ExtendedValue +IntrinsicLibrary::genCount(mlir::Type resultType, + llvm::ArrayRef args) { + assert(args.size() == 3); + + // Handle mask argument + fir::BoxValue mask = builder.createBox(loc, args[0]); + unsigned maskRank = mask.rank(); + + assert(maskRank > 0); + + // Handle optional dim argument + bool absentDim = isAbsent(args[1]); + mlir::Value dim = + absentDim ? builder.createIntegerConstant(loc, builder.getIndexType(), 0) + : fir::getBase(args[1]); + + if (absentDim || maskRank == 1) { + // Result is scalar if no dim argument or mask is rank 1. + // So, call specialized Count runtime routine. + return builder.createConvert( + loc, resultType, + fir::runtime::genCount(builder, loc, fir::getBase(mask), dim)); } - assert(fir::isa_real(resultType) && "Only expects real and integer in DIM"); - mlir::Value zero = builder.createRealZeroConstant(loc, resultType); - auto diff = builder.create(loc, args[0], args[1]); - auto cmp = builder.create( - loc, mlir::arith::CmpFPredicate::OGT, diff, zero); - return builder.create(loc, cmp, diff, zero); -} -// DOT_PRODUCT -fir::ExtendedValue -IntrinsicLibrary::genDotProduct(mlir::Type resultType, - llvm::ArrayRef args) { - return genDotProd(fir::runtime::genDotProduct, resultType, builder, loc, - stmtCtx, args); + // Call general CountDim runtime routine. + + // Handle optional kind argument + bool absentKind = isAbsent(args[2]); + mlir::Value kind = absentKind ? builder.createIntegerConstant( + loc, builder.getIndexType(), + builder.getKindMap().defaultIntegerKind()) + : fir::getBase(args[2]); + + // Create mutable fir.box to be passed to the runtime for the result. + mlir::Type type = builder.getVarLenSeqTy(resultType, maskRank - 1); + fir::MutableBoxValue resultMutableBox = + fir::factory::createTempMutableBox(builder, loc, type); + + mlir::Value resultIrBox = + fir::factory::getMutableIRBox(builder, loc, resultMutableBox); + + fir::runtime::genCountDim(builder, loc, resultIrBox, fir::getBase(mask), dim, + kind); + + // Handle cleanup of allocatable result descriptor and return + fir::ExtendedValue res = + fir::factory::genMutableBoxRead(builder, loc, resultMutableBox); + return res.match( + [&](const fir::ArrayBoxValue &box) -> fir::ExtendedValue { + // Add cleanup code + addCleanUpForTemp(loc, box.getAddr()); + return box; + }, + [&](const auto &) -> fir::ExtendedValue { + fir::emitFatalError(loc, "unexpected result for COUNT"); + }); } // CPU_TIME @@ -1699,6 +1737,33 @@ void IntrinsicLibrary::genDateAndTime(llvm::ArrayRef args) { charArgs[2], values); } +// DIM +mlir::Value IntrinsicLibrary::genDim(mlir::Type resultType, + llvm::ArrayRef args) { + assert(args.size() == 2); + if (resultType.isa()) { + mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0); + auto diff = builder.create(loc, args[0], args[1]); + auto cmp = builder.create( + loc, mlir::arith::CmpIPredicate::sgt, diff, zero); + return builder.create(loc, cmp, diff, zero); + } + assert(fir::isa_real(resultType) && "Only expects real and integer in DIM"); + mlir::Value zero = builder.createRealZeroConstant(loc, resultType); + auto diff = builder.create(loc, args[0], args[1]); + auto cmp = builder.create( + loc, mlir::arith::CmpFPredicate::OGT, diff, zero); + return builder.create(loc, cmp, diff, zero); +} + +// DOT_PRODUCT +fir::ExtendedValue +IntrinsicLibrary::genDotProduct(mlir::Type resultType, + llvm::ArrayRef args) { + return genDotProd(fir::runtime::genDotProduct, resultType, builder, loc, + stmtCtx, args); +} + // IAND mlir::Value IntrinsicLibrary::genIand(mlir::Type resultType, llvm::ArrayRef args) { diff --git a/flang/test/Lower/Intrinsics/count.f90 b/flang/test/Lower/Intrinsics/count.f90 new file mode 100644 index 0000000000000..212a5653a74ac --- /dev/null +++ b/flang/test/Lower/Intrinsics/count.f90 @@ -0,0 +1,45 @@ +! RUN: bbc -emit-fir %s -o - | FileCheck %s + +! CHECK-LABEL: count_test1 +! CHECK-SAME: %[[arg0:.*]]: !fir.ref{{.*}}, %[[arg1:.*]]: !fir.box>>{{.*}}) +subroutine count_test1(rslt, mask) + integer :: rslt + logical :: mask(:) + ! CHECK-DAG: %[[c1:.*]] = arith.constant 0 : index + ! CHECK-DAG: %[[a2:.*]] = fir.convert %[[arg1]] : (!fir.box>>) -> !fir.box + ! CHECK: %[[a4:.*]] = fir.convert %[[c1]] : (index) -> i32 + rslt = count(mask) + ! CHECK: %[[a5:.*]] = fir.call @_FortranACount(%[[a2]], %{{.*}}, %{{.*}}, %[[a4]]) : (!fir.box, !fir.ref, i32, i32) -> i64 + end subroutine + + ! CHECK-LABEL: test_count2 + ! CHECK-SAME: %[[arg0:.*]]: !fir.box>{{.*}}, %[[arg1:.*]]: !fir.box>>{{.*}}) + subroutine test_count2(rslt, mask) + integer :: rslt(:) + logical :: mask(:,:) + ! CHECK-DAG: %[[c1_i32:.*]] = arith.constant 1 : i32 + ! CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index + ! CHECK-DAG: %[[a0:.*]] = fir.alloca !fir.box>> + ! CHECK: %[[a5:.*]] = fir.convert %[[a0]] : (!fir.ref>>>) -> !fir.ref> + ! CHECK: %[[a6:.*]] = fir.convert %[[arg1]] : (!fir.box>>) -> !fir.box + ! CHECK: %[[a7:.*]] = fir.convert %[[c4]] : (index) -> i32 + rslt = count(mask, dim=1) + ! CHECK: %{{.*}} = fir.call @_FortranACountDim(%[[a5]], %[[a6]], %[[c1_i32]], %[[a7]], %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.box, i32, i32, !fir.ref, i32) -> none + ! CHECK: %[[a10:.*]] = fir.load %[[a0]] : !fir.ref>>> + ! CHECK: %[[a12:.*]] = fir.box_addr %[[a10]] : (!fir.box>>) -> !fir.heap> + ! CHECK: fir.freemem %[[a12]] + end subroutine + + ! CHECK-LABEL: test_count3 + ! CHECK-SAME: %[[arg0:.*]]: !fir.ref{{.*}}, %[[arg1:.*]]: !fir.box>>{{.*}}) + subroutine test_count3(rslt, mask) + integer :: rslt + logical :: mask(:) + ! CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + ! CHECK-DAG: %[[a1:.*]] = fir.convert %[[arg1]] : (!fir.box>>) -> !fir.box + ! CHECK: %[[a3:.*]] = fir.convert %[[c0]] : (index) -> i32 + call bar(count(mask, kind=2)) + ! CHECK: %[[a4:.*]] = fir.call @_FortranACount(%[[a1]], %{{.*}}, %{{.*}}, %[[a3]]) : (!fir.box, !fir.ref, i32, i32) -> i64 + ! CHECK: %{{.*}} = fir.convert %[[a4]] : (i64) -> i16 + end subroutine +