Skip to content

Commit

Permalink
[flang] Lower count intrinsic
Browse files Browse the repository at this point in the history
This patch adds lowering for the count intrinsic.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D121782

Co-authored-by: Jean Perier <jperier@nvidia.com>
Co-authored-by: mleair <leairmark@gmail.com>
  • Loading branch information
3 people committed Mar 16, 2022
1 parent 8dca38d commit 6ed9d3a
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 36 deletions.
137 changes: 101 additions & 36 deletions flang/lib/Lower/IntrinsicCall.cpp
Expand Up @@ -38,6 +38,19 @@
#define PGMATH_DECLARE
#include "flang/Evaluate/pgmath.h.inc"

/// This file implements lowering of Fortran intrinsic procedures.
/// Intrinsics are lowered to a mix of FIR and MLIR operations as
/// well as call to runtime functions or LLVM intrinsics.

/// Lowering of intrinsic procedure calls is based on a map that associates
/// Fortran intrinsic generic names to FIR generator functions.
/// All generator functions are member functions of the IntrinsicLibrary class
/// and have the same interface.
/// If no generator is given for an intrinsic name, a math runtime library
/// is searched for an implementation and, if a runtime function is found,
/// a call is generated for it. LLVM intrinsics are handled as a math
/// runtime library here.

/// Enums used to templatize and share lowering of MIN and MAX.
enum class Extremum { Min, Max };

Expand Down Expand Up @@ -81,19 +94,6 @@ enum class ExtremumBehavior {
// possible to implement it without some target dependent runtime.
};

/// This file implements lowering of Fortran intrinsic procedures.
/// Intrinsics are lowered to a mix of FIR and MLIR operations as
/// well as call to runtime functions or LLVM intrinsics.

/// Lowering of intrinsic procedure calls is based on a map that associates
/// Fortran intrinsic generic names to FIR generator functions.
/// All generator functions are member functions of the IntrinsicLibrary class
/// and have the same interface.
/// If no generator is given for an intrinsic name, a math runtime library
/// is searched for an implementation and, if a runtime function is found,
/// a call is generated for it. LLVM intrinsics are handled as a math
/// runtime library here.

fir::ExtendedValue Fortran::lower::getAbsentIntrinsicArgument() {
return fir::UnboxedValue{};
}
Expand Down Expand Up @@ -439,6 +439,7 @@ struct IntrinsicLibrary {
fir::ExtendedValue genAssociated(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genChar(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genCount(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genDim(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genDotProduct(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
Expand Down Expand Up @@ -592,6 +593,10 @@ static constexpr IntrinsicHandler handlers[]{
{{{"pointer", asInquired}, {"target", asInquired}}},
/*isElemental=*/false},
{"char", &I::genChar},
{"count",
&I::genCount,
{{{"mask", asAddr}, {"dim", asValue}, {"kind", asValue}}},
/*isElemental=*/false},
{"cpu_time",
&I::genCpuTime,
{{{"time", asAddr}}},
Expand Down Expand Up @@ -1644,31 +1649,64 @@ IntrinsicLibrary::genChar(mlir::Type type,
return fir::CharBoxValue{cast, len};
}

// DIM
mlir::Value IntrinsicLibrary::genDim(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 2);
if (resultType.isa<mlir::IntegerType>()) {
mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
auto diff = builder.create<mlir::arith::SubIOp>(loc, args[0], args[1]);
auto cmp = builder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::sgt, diff, zero);
return builder.create<mlir::arith::SelectOp>(loc, cmp, diff, zero);
// COUNT
fir::ExtendedValue
IntrinsicLibrary::genCount(mlir::Type resultType,
llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 3);

// Handle mask argument
fir::BoxValue mask = builder.createBox(loc, args[0]);
unsigned maskRank = mask.rank();

assert(maskRank > 0);

// Handle optional dim argument
bool absentDim = isAbsent(args[1]);
mlir::Value dim =
absentDim ? builder.createIntegerConstant(loc, builder.getIndexType(), 0)
: fir::getBase(args[1]);

if (absentDim || maskRank == 1) {
// Result is scalar if no dim argument or mask is rank 1.
// So, call specialized Count runtime routine.
return builder.createConvert(
loc, resultType,
fir::runtime::genCount(builder, loc, fir::getBase(mask), dim));
}
assert(fir::isa_real(resultType) && "Only expects real and integer in DIM");
mlir::Value zero = builder.createRealZeroConstant(loc, resultType);
auto diff = builder.create<mlir::arith::SubFOp>(loc, args[0], args[1]);
auto cmp = builder.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::OGT, diff, zero);
return builder.create<mlir::arith::SelectOp>(loc, cmp, diff, zero);
}

// DOT_PRODUCT
fir::ExtendedValue
IntrinsicLibrary::genDotProduct(mlir::Type resultType,
llvm::ArrayRef<fir::ExtendedValue> args) {
return genDotProd(fir::runtime::genDotProduct, resultType, builder, loc,
stmtCtx, args);
// Call general CountDim runtime routine.

// Handle optional kind argument
bool absentKind = isAbsent(args[2]);
mlir::Value kind = absentKind ? builder.createIntegerConstant(
loc, builder.getIndexType(),
builder.getKindMap().defaultIntegerKind())
: fir::getBase(args[2]);

// Create mutable fir.box to be passed to the runtime for the result.
mlir::Type type = builder.getVarLenSeqTy(resultType, maskRank - 1);
fir::MutableBoxValue resultMutableBox =
fir::factory::createTempMutableBox(builder, loc, type);

mlir::Value resultIrBox =
fir::factory::getMutableIRBox(builder, loc, resultMutableBox);

fir::runtime::genCountDim(builder, loc, resultIrBox, fir::getBase(mask), dim,
kind);

// Handle cleanup of allocatable result descriptor and return
fir::ExtendedValue res =
fir::factory::genMutableBoxRead(builder, loc, resultMutableBox);
return res.match(
[&](const fir::ArrayBoxValue &box) -> fir::ExtendedValue {
// Add cleanup code
addCleanUpForTemp(loc, box.getAddr());
return box;
},
[&](const auto &) -> fir::ExtendedValue {
fir::emitFatalError(loc, "unexpected result for COUNT");
});
}

// CPU_TIME
Expand Down Expand Up @@ -1699,6 +1737,33 @@ void IntrinsicLibrary::genDateAndTime(llvm::ArrayRef<fir::ExtendedValue> args) {
charArgs[2], values);
}

// DIM
mlir::Value IntrinsicLibrary::genDim(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 2);
if (resultType.isa<mlir::IntegerType>()) {
mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
auto diff = builder.create<mlir::arith::SubIOp>(loc, args[0], args[1]);
auto cmp = builder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::sgt, diff, zero);
return builder.create<mlir::arith::SelectOp>(loc, cmp, diff, zero);
}
assert(fir::isa_real(resultType) && "Only expects real and integer in DIM");
mlir::Value zero = builder.createRealZeroConstant(loc, resultType);
auto diff = builder.create<mlir::arith::SubFOp>(loc, args[0], args[1]);
auto cmp = builder.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::OGT, diff, zero);
return builder.create<mlir::arith::SelectOp>(loc, cmp, diff, zero);
}

// DOT_PRODUCT
fir::ExtendedValue
IntrinsicLibrary::genDotProduct(mlir::Type resultType,
llvm::ArrayRef<fir::ExtendedValue> args) {
return genDotProd(fir::runtime::genDotProduct, resultType, builder, loc,
stmtCtx, args);
}

// IAND
mlir::Value IntrinsicLibrary::genIand(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
Expand Down
45 changes: 45 additions & 0 deletions flang/test/Lower/Intrinsics/count.f90
@@ -0,0 +1,45 @@
! RUN: bbc -emit-fir %s -o - | FileCheck %s

! CHECK-LABEL: count_test1
! CHECK-SAME: %[[arg0:.*]]: !fir.ref<i32>{{.*}}, %[[arg1:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>{{.*}})
subroutine count_test1(rslt, mask)
integer :: rslt
logical :: mask(:)
! CHECK-DAG: %[[c1:.*]] = arith.constant 0 : index
! CHECK-DAG: %[[a2:.*]] = fir.convert %[[arg1]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
! CHECK: %[[a4:.*]] = fir.convert %[[c1]] : (index) -> i32
rslt = count(mask)
! CHECK: %[[a5:.*]] = fir.call @_FortranACount(%[[a2]], %{{.*}}, %{{.*}}, %[[a4]]) : (!fir.box<none>, !fir.ref<i8>, i32, i32) -> i64
end subroutine

! CHECK-LABEL: test_count2
! CHECK-SAME: %[[arg0:.*]]: !fir.box<!fir.array<?xi32>>{{.*}}, %[[arg1:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>{{.*}})
subroutine test_count2(rslt, mask)
integer :: rslt(:)
logical :: mask(:,:)
! CHECK-DAG: %[[c1_i32:.*]] = arith.constant 1 : i32
! CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
! CHECK-DAG: %[[a0:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xi32>>>
! CHECK: %[[a5:.*]] = fir.convert %[[a0]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
! CHECK: %[[a6:.*]] = fir.convert %[[arg1]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.box<none>
! CHECK: %[[a7:.*]] = fir.convert %[[c4]] : (index) -> i32
rslt = count(mask, dim=1)
! CHECK: %{{.*}} = fir.call @_FortranACountDim(%[[a5]], %[[a6]], %[[c1_i32]], %[[a7]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none
! CHECK: %[[a10:.*]] = fir.load %[[a0]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
! CHECK: %[[a12:.*]] = fir.box_addr %[[a10]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
! CHECK: fir.freemem %[[a12]]
end subroutine

! CHECK-LABEL: test_count3
! CHECK-SAME: %[[arg0:.*]]: !fir.ref<i32>{{.*}}, %[[arg1:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>{{.*}})
subroutine test_count3(rslt, mask)
integer :: rslt
logical :: mask(:)
! CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
! CHECK-DAG: %[[a1:.*]] = fir.convert %[[arg1]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
! CHECK: %[[a3:.*]] = fir.convert %[[c0]] : (index) -> i32
call bar(count(mask, kind=2))
! CHECK: %[[a4:.*]] = fir.call @_FortranACount(%[[a1]], %{{.*}}, %{{.*}}, %[[a3]]) : (!fir.box<none>, !fir.ref<i8>, i32, i32) -> i64
! CHECK: %{{.*}} = fir.convert %[[a4]] : (i64) -> i16
end subroutine

0 comments on commit 6ed9d3a

Please sign in to comment.