Skip to content

Commit

Permalink
[flang][hlfir] Fixed KindMapping for HLFIR intrinsics lowering.
Browse files Browse the repository at this point in the history
hlfir.count lowering was using incorrect default integer kind
by ignoring the kind specified in the ModuleOp.

Reviewed By: tblah

Differential Revision: https://reviews.llvm.org/D156017
  • Loading branch information
vzakhari committed Jul 24, 2023
1 parent 3eedff3 commit 60f02aa
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 50 deletions.
2 changes: 2 additions & 0 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
setFastMathFlags(fmi.getFastMathFlagsAttr().getValue());
}
}
FirOpBuilder(mlir::OpBuilder &builder, mlir::Operation *op)
: FirOpBuilder(builder, fir::getKindMapping(op), op) {}

// The listener self-reference has to be updated in case of copy-construction.
FirOpBuilder(const FirOpBuilder &other)
Expand Down
7 changes: 7 additions & 0 deletions flang/include/flang/Optimizer/Dialect/Support/FIRContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

namespace mlir {
class ModuleOp;
class Operation;
} // namespace mlir

namespace fir {
Expand All @@ -43,6 +44,12 @@ void setKindMapping(mlir::ModuleOp mod, KindMapping &kindMap);
/// default.
KindMapping getKindMapping(mlir::ModuleOp mod);

/// Get the KindMapping instance that is in effect for the specified
/// operation. The KindMapping is taken from the operation itself,
/// if the operation is a ModuleOp, or from its parent ModuleOp.
/// If a ModuleOp cannot be reached, the function returns default KindMapping.
KindMapping getKindMapping(mlir::Operation *op);

/// Helper for determining the target from the host, etc. Tools may use this
/// function to provide a consistent interpretation of the `--target=<string>`
/// command-line option.
Expand Down
9 changes: 9 additions & 0 deletions flang/lib/Optimizer/Dialect/Support/FIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ fir::KindMapping fir::getKindMapping(mlir::ModuleOp mod) {
return fir::KindMapping(ctx);
}

fir::KindMapping fir::getKindMapping(mlir::Operation *op) {
auto moduleOp = mlir::dyn_cast<mlir::ModuleOp>(op);
if (moduleOp)
return getKindMapping(moduleOp);

moduleOp = op->getParentOfType<mlir::ModuleOp>();
return getKindMapping(moduleOp);
}

std::string fir::determineTargetTriple(llvm::StringRef triple) {
// Treat "" or "default" as stand-ins for the default machine.
if (triple.empty() || triple == "default")
Expand Down
30 changes: 10 additions & 20 deletions flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,7 @@ struct ApplyOpConversion : public mlir::OpConversionPattern<hlfir::ApplyOp> {
if (fir::isa_trivial(apply.getType())) {
result = rewriter.create<fir::LoadOp>(loc, result);
} else {
auto module = apply->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
fir::FirOpBuilder builder(rewriter, apply.getOperation());
result =
packageBufferizedExpr(loc, builder, hlfir::Entity{result}, false);
}
Expand Down Expand Up @@ -288,8 +287,7 @@ struct ConcatOpConversion : public mlir::OpConversionPattern<hlfir::ConcatOp> {
matchAndRewrite(hlfir::ConcatOp concat, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = concat->getLoc();
auto module = concat->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
fir::FirOpBuilder builder(rewriter, concat.getOperation());
assert(adaptor.getStrings().size() >= 2 &&
"must have at least two strings operands");
if (adaptor.getStrings().size() > 2)
Expand Down Expand Up @@ -328,8 +326,7 @@ struct SetLengthOpConversion
matchAndRewrite(hlfir::SetLengthOp setLength, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = setLength->getLoc();
auto module = setLength->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
fir::FirOpBuilder builder(rewriter, setLength.getOperation());
// Create a temp with the new length.
hlfir::Entity string = getBufferizedExprStorage(adaptor.getString());
auto charType = hlfir::getFortranElementType(setLength.getType());
Expand Down Expand Up @@ -362,8 +359,7 @@ struct GetLengthOpConversion
matchAndRewrite(hlfir::GetLengthOp getLength, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = getLength->getLoc();
auto module = getLength->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
fir::FirOpBuilder builder(rewriter, getLength.getOperation());
hlfir::Entity bufferizedExpr = getBufferizedExprStorage(adaptor.getExpr());
mlir::Value length = hlfir::genCharLength(loc, builder, bufferizedExpr);
if (!length)
Expand Down Expand Up @@ -436,8 +432,7 @@ struct AssociateOpConversion
matchAndRewrite(hlfir::AssociateOp associate, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = associate->getLoc();
auto module = associate->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
fir::FirOpBuilder builder(rewriter, associate.getOperation());
mlir::Value bufferizedExpr = getBufferizedExprStorage(adaptor.getSource());
const bool isTrivialValue = fir::isa_trivial(bufferizedExpr.getType());

Expand Down Expand Up @@ -577,8 +572,7 @@ struct EndAssociateOpConversion
matchAndRewrite(hlfir::EndAssociateOp endAssociate, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = endAssociate->getLoc();
auto module = endAssociate->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
fir::FirOpBuilder builder(rewriter, endAssociate.getOperation());
genFreeIfMustFree(loc, builder, adaptor.getVar(), adaptor.getMustFree());
rewriter.eraseOp(endAssociate);
return mlir::success();
Expand All @@ -597,8 +591,7 @@ struct DestroyOpConversion
mlir::Location loc = destroy->getLoc();
hlfir::Entity bufferizedExpr = getBufferizedExprStorage(adaptor.getExpr());
if (!fir::isa_trivial(bufferizedExpr.getType())) {
auto module = destroy->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
fir::FirOpBuilder builder(rewriter, destroy.getOperation());
mlir::Value mustFree = getBufferizedExprMustFreeFlag(adaptor.getExpr());
mlir::Value firBase = bufferizedExpr.getFirBase();
genFreeIfMustFree(loc, builder, firBase, mustFree);
Expand All @@ -617,8 +610,7 @@ struct NoReassocOpConversion
matchAndRewrite(hlfir::NoReassocOp noreassoc, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = noreassoc->getLoc();
auto module = noreassoc->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
fir::FirOpBuilder builder(rewriter, noreassoc.getOperation());
mlir::Value bufferizedExpr = getBufferizedExprStorage(adaptor.getVal());
mlir::Value result =
builder.create<hlfir::NoReassocOp>(loc, bufferizedExpr);
Expand Down Expand Up @@ -677,8 +669,7 @@ struct ElementalOpConversion
matchAndRewrite(hlfir::ElementalOp elemental, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = elemental->getLoc();
auto module = elemental->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
fir::FirOpBuilder builder(rewriter, elemental.getOperation());
// The body of the elemental op may contain operation that will require
// to be translated. Notify the rewriter about the cloned operations.
HLFIRListener listener{builder, rewriter};
Expand Down Expand Up @@ -743,11 +734,10 @@ struct CharExtremumOpConversion
matchAndRewrite(hlfir::CharExtremumOp char_extremum, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = char_extremum->getLoc();
auto module = char_extremum->getParentOfType<mlir::ModuleOp>();
auto predicate = char_extremum.getPredicate();
bool predIsMin =
predicate == hlfir::CharExtremumPredicate::min ? true : false;
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
fir::FirOpBuilder builder(rewriter, char_extremum.getOperation());
assert(adaptor.getStrings().size() >= 2 &&
"must have at least two strings operands");
auto numOperands = adaptor.getStrings().size();
Expand Down
12 changes: 4 additions & 8 deletions flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ class CopyInOpConversion : public mlir::OpRewritePattern<hlfir::CopyInOp> {
matchAndRewrite(hlfir::CopyInOp copyInOp,
mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = copyInOp.getLoc();
auto module = copyInOp->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
fir::FirOpBuilder builder(rewriter, copyInOp.getOperation());
CopyInResult result = copyInOp.getVarIsPresent()
? genOptionalCopyIn(loc, builder, copyInOp)
: genNonOptionalCopyIn(loc, builder, copyInOp);
Expand All @@ -259,8 +258,7 @@ class CopyOutOpConversion : public mlir::OpRewritePattern<hlfir::CopyOutOp> {
matchAndRewrite(hlfir::CopyOutOp copyOutOp,
mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = copyOutOp.getLoc();
auto module = copyOutOp->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
fir::FirOpBuilder builder(rewriter, copyOutOp.getOperation());

builder.genIfThen(loc, copyOutOp.getWasCopied())
.genThen([&]() {
Expand Down Expand Up @@ -323,8 +321,7 @@ class DeclareOpConversion : public mlir::OpRewritePattern<hlfir::DeclareOp> {
mlir::Value hlfirBase;
mlir::Type hlfirBaseType = declareOp.getBase().getType();
if (hlfirBaseType.isa<fir::BaseBoxType>()) {
auto module = declareOp->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
fir::FirOpBuilder builder(rewriter, declareOp.getOperation());
// Helper to generate the hlfir fir.box with the local lower bounds and
// type parameters.
auto genHlfirBox = [&]() -> mlir::Value {
Expand Down Expand Up @@ -423,8 +420,7 @@ class DesignateOpConversion
matchAndRewrite(hlfir::DesignateOp designate,
mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = designate.getLoc();
auto module = designate->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
fir::FirOpBuilder builder(rewriter, designate.getOperation());

hlfir::Entity baseEntity(designate.getMemref());

Expand Down
4 changes: 1 addition & 3 deletions flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -91,8 +90,7 @@ class InlineElementalConversion
assert(elemental.getRegion().hasOneBlock() &&
"expect elemental region to have one block");

fir::FirOpBuilder builder{rewriter,
fir::KindMapping{rewriter.getContext()}};
fir::FirOpBuilder builder{rewriter, elemental.getOperation()};
builder.setInsertionPointAfter(apply);
hlfir::YieldElementOp yield = hlfir::inlineElementalOp(
elemental.getLoc(), builder, elemental, apply.getIndices());
Expand Down
21 changes: 7 additions & 14 deletions flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
mlir::PatternRewriter &rewriter,
const fir::IntrinsicArgumentLoweringRules *argLowering) const {
mlir::Location loc = op->getLoc();
fir::KindMapping kindMapping{rewriter.getContext()};
fir::FirOpBuilder builder{rewriter, kindMapping, op};
fir::FirOpBuilder builder{rewriter, op};

llvm::SmallVector<fir::ExtendedValue, 3> ret;
llvm::SmallVector<std::function<void()>, 2> cleanupFns;
Expand Down Expand Up @@ -229,8 +228,7 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
return mlir::failure();
}

fir::KindMapping kindMapping{rewriter.getContext()};
fir::FirOpBuilder builder{rewriter, kindMapping, operation};
fir::FirOpBuilder builder{rewriter, operation.getOperation()};
const mlir::Location &loc = operation->getLoc();

mlir::Type i32 = builder.getI32Type();
Expand Down Expand Up @@ -271,8 +269,7 @@ struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
mlir::LogicalResult
matchAndRewrite(hlfir::CountOp count,
mlir::PatternRewriter &rewriter) const override {
fir::KindMapping kindMapping{rewriter.getContext()};
fir::FirOpBuilder builder{rewriter, kindMapping, count};
fir::FirOpBuilder builder{rewriter, count.getOperation()};
const mlir::Location &loc = count->getLoc();

mlir::Type i32 = builder.getI32Type();
Expand Down Expand Up @@ -304,8 +301,7 @@ struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
mlir::LogicalResult
matchAndRewrite(hlfir::MatmulOp matmul,
mlir::PatternRewriter &rewriter) const override {
fir::KindMapping kindMapping{rewriter.getContext()};
fir::FirOpBuilder builder{rewriter, kindMapping, matmul};
fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
const mlir::Location &loc = matmul->getLoc();

mlir::Value lhs = matmul.getLhs();
Expand Down Expand Up @@ -336,8 +332,7 @@ struct DotProductOpConversion
mlir::LogicalResult
matchAndRewrite(hlfir::DotProductOp dotProduct,
mlir::PatternRewriter &rewriter) const override {
fir::KindMapping kindMapping{rewriter.getContext()};
fir::FirOpBuilder builder{rewriter, kindMapping, dotProduct};
fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()};
const mlir::Location &loc = dotProduct->getLoc();

mlir::Value lhs = dotProduct.getLhs();
Expand Down Expand Up @@ -368,8 +363,7 @@ class TransposeOpConversion
mlir::LogicalResult
matchAndRewrite(hlfir::TransposeOp transpose,
mlir::PatternRewriter &rewriter) const override {
fir::KindMapping kindMapping{rewriter.getContext()};
fir::FirOpBuilder builder{rewriter, kindMapping, transpose};
fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
const mlir::Location &loc = transpose->getLoc();

mlir::Value arg = transpose.getArray();
Expand Down Expand Up @@ -399,8 +393,7 @@ struct MatmulTransposeOpConversion
mlir::LogicalResult
matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
mlir::PatternRewriter &rewriter) const override {
fir::KindMapping kindMapping{rewriter.getContext()};
fir::FirOpBuilder builder{rewriter, kindMapping, multranspose};
fir::FirOpBuilder builder{rewriter, multranspose.getOperation()};
const mlir::Location &loc = multranspose->getLoc();

mlir::Value lhs = multranspose.getLhs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
Expand All @@ -40,8 +39,7 @@ class TransposeAsElementalConversion
matchAndRewrite(hlfir::TransposeOp transpose,
mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = transpose.getLoc();
fir::KindMapping kindMapping{rewriter.getContext()};
fir::FirOpBuilder builder{rewriter, kindMapping};
fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
hlfir::ExprType expr = transpose.getType();
mlir::Type elementType = expr.getElementType();
hlfir::Entity array = hlfir::Entity{transpose.getArray()};
Expand Down
47 changes: 47 additions & 0 deletions flang/test/HLFIR/count-lowering-default-int-kinds.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Test hlfir.count operation lowering with different default integer kinds.
// RUN: fir-opt %s -lower-hlfir-intrinsics | FileCheck %s

module attributes {fir.defaultkind = "a1c4d8i8l4r4", fir.kindmap = ""} {
func.func @test_i8(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) {
%4 = hlfir.count %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i64) -> !hlfir.expr<?xi64>
return
}
}
// CHECK-LABEL: func.func @test_i8
// CHECK: %[[KIND:.*]] = arith.constant 8 : index
// CHECK: %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32
// CHECK: fir.call @_FortranACountDim(%{{.*}}, %{{.*}}, %{{.*}}, %[[KIND_ARG]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none

module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = ""} {
func.func @test_i4(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) {
%4 = hlfir.count %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i64) -> !hlfir.expr<?xi32>
return
}
}
// CHECK-LABEL: func.func @test_i4
// CHECK: %[[KIND:.*]] = arith.constant 4 : index
// CHECK: %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32
// CHECK: fir.call @_FortranACountDim(%{{.*}}, %{{.*}}, %{{.*}}, %[[KIND_ARG]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none

module attributes {fir.defaultkind = "a1c4d8i2l4r4", fir.kindmap = ""} {
func.func @test_i2(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) {
%4 = hlfir.count %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i64) -> !hlfir.expr<?xi16>
return
}
}
// CHECK-LABEL: func.func @test_i2
// CHECK: %[[KIND:.*]] = arith.constant 2 : index
// CHECK: %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32
// CHECK: fir.call @_FortranACountDim(%{{.*}}, %{{.*}}, %{{.*}}, %[[KIND_ARG]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none

module attributes {fir.defaultkind = "a1c4d8i1l4r4", fir.kindmap = ""} {
func.func @test_i1(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) {
%4 = hlfir.count %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i64) -> !hlfir.expr<?xi8>
return
}
}
// CHECK-LABEL: func.func @test_i1
// CHECK: arith.constant 1 : index
// CHECK: %[[KIND:.*]] = arith.constant 1 : index
// CHECK: %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32
// CHECK: fir.call @_FortranACountDim(%{{.*}}, %{{.*}}, %{{.*}}, %[[KIND_ARG]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none
Loading

0 comments on commit 60f02aa

Please sign in to comment.