Skip to content

Commit

Permalink
[flang] add hlfir.product operation
Browse files Browse the repository at this point in the history
Adds a HLFIR operation for the PRODUCT intrinsic according to
the design set out in flang/doc/HighLevelFIR.md

Since the PRODUCT intrinsic is essentially identical to SUM
in terms of its arguments and result characteristics in the
Fortran Standard, the operation definition and subsequent
tests also take the same form.

Differential Revision: https://reviews.llvm.org/D147624
  • Loading branch information
jacob-crawley committed May 4, 2023
1 parent b88023c commit 41b5268
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 28 deletions.
25 changes: 25 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,31 @@ def hlfir_ConcatOp : hlfir_Op<"concat", []> {
let hasVerifier = 1;
}

def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
let summary = "PRODUCT transformational intrinsic";
let description = [{
Multiplies the elements of an array, optionally along a particular dimension,
optionally if a mask is true.
}];

let arguments = (ins
AnyFortranNumericalArrayObject:$array,
Optional<AnyIntegerType>:$dim,
Optional<AnyFortranLogicalOrI1ArrayObject>:$mask,
DefaultValuedAttr<Arith_FastMathAttr,
"::mlir::arith::FastMathFlags::none">:$fastmath
);

let results = (outs hlfir_ExprType);

let assemblyFormat = [{
$array (`dim` $dim^)? (`mask` $mask^)? attr-dict `:` functional-type(operands, results)
}];

let hasVerifier = 1;
}

def hlfir_SetLengthOp : hlfir_Op<"set_length", []> {
let summary = "change the length of a character entity";
let description = [{
Expand Down
74 changes: 46 additions & 28 deletions flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,35 +489,18 @@ void hlfir::ConcatOp::build(mlir::OpBuilder &builder,
}

//===----------------------------------------------------------------------===//
// SetLengthOp
// ReductionOp
//===----------------------------------------------------------------------===//

void hlfir::SetLengthOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result, mlir::Value string,
mlir::Value len) {
fir::CharacterType::LenType resultTypeLen = fir::CharacterType::unknownLen();
if (auto cstLen = fir::getIntIfConstant(len))
resultTypeLen = *cstLen;
unsigned kind = getCharacterKind(string.getType());
auto resultType = hlfir::ExprType::get(
builder.getContext(), hlfir::ExprType::Shape{},
fir::CharacterType::get(builder.getContext(), kind, resultTypeLen),
false);
build(builder, result, resultType, string, len);
}

//===----------------------------------------------------------------------===//
// SumOp
//===----------------------------------------------------------------------===//

mlir::LogicalResult hlfir::SumOp::verify() {
mlir::Operation *op = getOperation();
template <typename ReductionOp>
static mlir::LogicalResult verifyReductionOp(ReductionOp reductionOp) {
mlir::Operation *op = reductionOp->getOperation();

auto results = op->getResultTypes();
assert(results.size() == 1);

mlir::Value array = getArray();
mlir::Value mask = getMask();
mlir::Value array = reductionOp->getArray();
mlir::Value mask = reductionOp->getMask();

fir::SequenceType arrayTy =
hlfir::getFortranElementOrSequenceType(array.getType())
Expand All @@ -537,7 +520,7 @@ mlir::LogicalResult hlfir::SumOp::verify() {

if (!maskShape.empty()) {
if (maskShape.size() != arrayShape.size())
return emitWarning("MASK must be conformable to ARRAY");
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
static_assert(fir::SequenceType::getUnknownExtent() ==
hlfir::ExprType::getUnknownExtent());
constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
Expand All @@ -546,32 +529,67 @@ mlir::LogicalResult hlfir::SumOp::verify() {
int64_t maskExtent = maskShape[i];
if ((arrayExtent != maskExtent) && (arrayExtent != unknownExtent) &&
(maskExtent != unknownExtent))
return emitWarning("MASK must be conformable to ARRAY");
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
}
}
}

if (resultTy.isArray()) {
// Result is of the same type as ARRAY
if (resultTy.getEleTy() != numTy)
return emitOpError(
return reductionOp->emitOpError(
"result must have the same element type as ARRAY argument");

llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();

// Result has rank n-1
if (resultShape.size() != (arrayShape.size() - 1))
return emitOpError("result rank must be one less than ARRAY");
return reductionOp->emitOpError(
"result rank must be one less than ARRAY");
} else {
// Result is of the same type as ARRAY
if (resultTy.getElementType() != numTy)
return emitOpError(
return reductionOp->emitOpError(
"result must have the same element type as ARRAY argument");
}

return mlir::success();
}

//===----------------------------------------------------------------------===//
// ProductOp
//===----------------------------------------------------------------------===//

mlir::LogicalResult hlfir::ProductOp::verify() {
return verifyReductionOp<hlfir::ProductOp *>(this);
}

//===----------------------------------------------------------------------===//
// SetLengthOp
//===----------------------------------------------------------------------===//

void hlfir::SetLengthOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result, mlir::Value string,
mlir::Value len) {
fir::CharacterType::LenType resultTypeLen = fir::CharacterType::unknownLen();
if (auto cstLen = fir::getIntIfConstant(len))
resultTypeLen = *cstLen;
unsigned kind = getCharacterKind(string.getType());
auto resultType = hlfir::ExprType::get(
builder.getContext(), hlfir::ExprType::Shape{},
fir::CharacterType::get(builder.getContext(), kind, resultTypeLen),
false);
build(builder, result, resultType, string, len);
}

//===----------------------------------------------------------------------===//
// SumOp
//===----------------------------------------------------------------------===//

mlir::LogicalResult hlfir::SumOp::verify() {
return verifyReductionOp<hlfir::SumOp *>(this);
}

//===----------------------------------------------------------------------===//
// MatmulOp
//===----------------------------------------------------------------------===//
Expand Down
24 changes: 24 additions & 0 deletions flang/test/HLFIR/invalid.fir
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,30 @@ func.func @bad_concat_4(%arg0: !fir.ref<!fir.char<1,30>>) {
return
}

// -----
func.func @bad_product1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
// expected-error@+1 {{'hlfir.product' op result must have the same element type as ARRAY argument}}
%0 = hlfir.product %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<f32>
}

// -----
func.func @bad_product2(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) {
// expected-warning@+1 {{MASK must be conformable to ARRAY}}
%0 = hlfir.product %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) -> !hlfir.expr<i32>
}

// -----
func.func @bad_product3(%arg0: !hlfir.expr<?x5x?xi32>, %arg1: i32, %arg2: !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) {
// expected-warning@+1 {{MASK must be conformable to ARRAY}}
%0 = hlfir.product %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x5x?xi32>, i32, !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) -> !hlfir.expr<i32>
}

// -----
func.func @bad_product4(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
// expected-error@+1 {{'hlfir.product' op result rank must be one less than ARRAY}}
%0 = hlfir.product %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?x?xi32>
}

// -----
func.func @bad_sum1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
// expected-error@+1 {{'hlfir.sum' op result must have the same element type as ARRAY argument}}
Expand Down

0 comments on commit 41b5268

Please sign in to comment.