Skip to content

Commit

Permalink
[flang][hlfir] move intrinsic lowering out of BufferizeHLFIR
Browse files Browse the repository at this point in the history
This move is useful for a few reasons:
  - It is easier to see what the intrinsic lowering is doing when the
    operations it creates are not immediately lowered
  - When lowering a HLFIR intrinsic generates an operation which needs
    to be lowered by another pattern matcher in the same pass, MLIR will
    run that other substitution before validating and finalizing the
    original changes. This means that the erasure of operations is not
    yet visible to subsequent matchers, which hugely complicates
    transformations (in this case, hlfir.exprs cannot be rewritten
    because they are still used by the now-erased HLFIR intrinsic op.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D145502
  • Loading branch information
tblah committed Mar 17, 2023
1 parent acdb199 commit 9cbeb97
Show file tree
Hide file tree
Showing 8 changed files with 310 additions and 237 deletions.
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/HLFIR/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace hlfir {

std::unique_ptr<mlir::Pass> createConvertHLFIRtoFIRPass();
std::unique_ptr<mlir::Pass> createBufferizeHLFIRPass();
std::unique_ptr<mlir::Pass> createLowerHLFIRIntrinsicsPass();

#define GEN_PASS_REGISTRATION
#include "flang/Optimizer/HLFIR/Passes.h.inc"
Expand Down
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,9 @@ def BufferizeHLFIR : Pass<"bufferize-hlfir", "::mlir::ModuleOp"> {
let constructor = "hlfir::createBufferizeHLFIRPass()";
}

def LowerHLFIRIntrinsics : Pass<"lower-hlfir-intrinsics", "::mlir::ModuleOp"> {
let summary = "Lower HLFIR transformational intrinsic operations";
let constructor = "hlfir::createLowerHLFIRIntrinsicsPass()";
}

#endif //FORTRAN_DIALECT_HLFIR_PASSES
220 changes: 1 addition & 219 deletions flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
#include "flang/Optimizer/Builder/Character.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/IntrinsicCall.h"
#include "flang/Optimizer/Builder/MutableBox.h"
#include "flang/Optimizer/Builder/Runtime/Assign.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
Expand All @@ -31,7 +29,6 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include <mlir/Support/LogicalResult.h>
#include <optional>

namespace hlfir {
#define GEN_PASS_DEF_BUFFERIZEHLFIR
Expand Down Expand Up @@ -510,219 +507,6 @@ struct ElementalOpConversion
}
};

/// Base class for passes converting transformational intrinsic operations into
/// runtime calls
template <class OP>
class HlfirIntrinsicConversion : public mlir::OpConversionPattern<OP> {
using mlir::OpConversionPattern<OP>::OpConversionPattern;

protected:
struct IntrinsicArgument {
mlir::Value val; // allowed to be null if the argument is absent
mlir::Type desiredType;
};

/// Lower the arguments to the intrinsic: adding nesecarry boxing and
/// conversion to match the signature of the intrinsic in the runtime library.
llvm::SmallVector<fir::ExtendedValue, 3>
lowerArguments(mlir::Operation *op,
const llvm::ArrayRef<IntrinsicArgument> &args,
mlir::ConversionPatternRewriter &rewriter,
const fir::IntrinsicArgumentLoweringRules *argLowering) const {
mlir::Location loc = op->getLoc();
fir::KindMapping kindMapping{rewriter.getContext()};
fir::FirOpBuilder builder{rewriter, kindMapping};

llvm::SmallVector<fir::ExtendedValue, 3> ret;

for (size_t i = 0; i < args.size(); ++i) {
mlir::Value arg = args[i].val;
mlir::Type desiredType = args[i].desiredType;
if (!arg) {
ret.emplace_back(fir::getAbsentIntrinsicArgument());
continue;
}
hlfir::Entity entity{arg};

fir::ArgLoweringRule argRules =
fir::lowerIntrinsicArgumentAs(*argLowering, i);
switch (argRules.lowerAs) {
case fir::LowerIntrinsicArgAs::Value: {
if (args[i].desiredType != arg.getType()) {
arg = builder.createConvert(loc, desiredType, arg);
entity = hlfir::Entity{arg};
}
auto [exv, cleanup] = hlfir::convertToValue(loc, builder, entity);
if (cleanup)
TODO(loc, "extended value cleanup");
ret.emplace_back(exv);
} break;
case fir::LowerIntrinsicArgAs::Addr: {
auto [exv, cleanup] =
hlfir::convertToAddress(loc, builder, entity, desiredType);
if (cleanup)
TODO(loc, "extended value cleanup");
ret.emplace_back(exv);
} break;
case fir::LowerIntrinsicArgAs::Box: {
auto [box, cleanup] =
hlfir::convertToBox(loc, builder, entity, desiredType);
if (cleanup)
TODO(loc, "extended value cleanup");
ret.emplace_back(box);
} break;
case fir::LowerIntrinsicArgAs::Inquired: {
if (args[i].desiredType != arg.getType()) {
arg = builder.createConvert(loc, desiredType, arg);
entity = hlfir::Entity{arg};
}
// Place hlfir.expr in memory, and unbox fir.boxchar. Other entities
// are translated to fir::ExtendedValue without transofrmation (notably,
// pointers/allocatable are not dereferenced).
// TODO: once lowering to FIR retires, UBOUND and LBOUND can be
// simplified since the fir.box lowered here are now guarenteed to
// contain the local lower bounds thanks to the hlfir.declare (the extra
// rebox can be removed).
auto [exv, cleanup] =
hlfir::translateToExtendedValue(loc, builder, entity);
if (cleanup)
TODO(loc, "extended value cleanup");
ret.emplace_back(exv);
} break;
}
}

return ret;
}

void processReturnValue(mlir::Operation *op,
const fir::ExtendedValue &resultExv, bool mustBeFreed,
fir::FirOpBuilder &builder,
mlir::PatternRewriter &rewriter) const {
mlir::Location loc = op->getLoc();

mlir::Value firBase = fir::getBase(resultExv);
mlir::Type firBaseTy = firBase.getType();

std::optional<hlfir::EntityWithAttributes> resultEntity;
if (fir::isa_trivial(firBaseTy)) {
resultEntity = hlfir::EntityWithAttributes{firBase};
} else {
resultEntity =
hlfir::genDeclare(loc, builder, resultExv, ".tmp.intrinsic_result",
fir::FortranVariableFlagsAttr{});
}

if (resultEntity->isVariable()) {
hlfir::AsExprOp asExpr = builder.create<hlfir::AsExprOp>(
loc, *resultEntity, builder.createBool(loc, mustBeFreed));
resultEntity = hlfir::EntityWithAttributes{asExpr.getResult()};
}

rewriter.replaceOp(op, resultEntity->getBase());
}
};

struct SumOpConversion : public HlfirIntrinsicConversion<hlfir::SumOp> {
using HlfirIntrinsicConversion<hlfir::SumOp>::HlfirIntrinsicConversion;

mlir::LogicalResult
matchAndRewrite(hlfir::SumOp sum, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
fir::KindMapping kindMapping{rewriter.getContext()};
fir::FirOpBuilder builder{rewriter, kindMapping};
const mlir::Location &loc = sum->getLoc();
HLFIRListener listener{builder, rewriter};
builder.setListener(&listener);

mlir::Type i32 = builder.getI32Type();
mlir::Type logicalType = fir::LogicalType::get(
builder.getContext(), builder.getKindMap().defaultLogicalKind());

llvm::SmallVector<IntrinsicArgument, 3> inArgs;
inArgs.push_back({sum.getArray(), sum.getArray().getType()});
inArgs.push_back({sum.getDim(), i32});
inArgs.push_back({sum.getMask(), logicalType});

auto *argLowering = fir::getIntrinsicArgumentLowering("sum");
llvm::SmallVector<fir::ExtendedValue, 3> args =
lowerArguments(sum, inArgs, rewriter, argLowering);

mlir::Type scalarResultType = hlfir::getFortranElementType(sum.getType());

auto [resultExv, mustBeFreed] =
fir::genIntrinsicCall(builder, loc, "sum", scalarResultType, args);

processReturnValue(sum, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};

struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion;

mlir::LogicalResult
matchAndRewrite(hlfir::MatmulOp matmul, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
fir::KindMapping kindMapping{rewriter.getContext()};
fir::FirOpBuilder builder{rewriter, kindMapping};
const mlir::Location &loc = matmul->getLoc();
HLFIRListener listener{builder, rewriter};
builder.setListener(&listener);

mlir::Value lhs = matmul.getLhs();
mlir::Value rhs = matmul.getRhs();
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({lhs, lhs.getType()});
inArgs.push_back({rhs, rhs.getType()});

auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
llvm::SmallVector<fir::ExtendedValue, 2> args =
lowerArguments(matmul, inArgs, rewriter, argLowering);

mlir::Type scalarResultType =
hlfir::getFortranElementType(matmul.getType());

auto [resultExv, mustBeFreed] =
fir::genIntrinsicCall(builder, loc, "matmul", scalarResultType, args);

processReturnValue(matmul, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};

class TransposeOpConversion
: public HlfirIntrinsicConversion<hlfir::TransposeOp> {
using HlfirIntrinsicConversion<hlfir::TransposeOp>::HlfirIntrinsicConversion;

mlir::LogicalResult
matchAndRewrite(hlfir::TransposeOp transpose, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
fir::KindMapping kindMapping{rewriter.getContext()};
fir::FirOpBuilder builder{rewriter, kindMapping};
const mlir::Location &loc = transpose->getLoc();
HLFIRListener listener{builder, rewriter};
builder.setListener(&listener);

mlir::Value arg = transpose.getArray();
llvm::SmallVector<IntrinsicArgument, 1> inArgs;
inArgs.push_back({arg, arg.getType()});

auto *argLowering = fir::getIntrinsicArgumentLowering("transpose");
llvm::SmallVector<fir::ExtendedValue, 1> args =
lowerArguments(transpose, inArgs, rewriter, argLowering);

mlir::Type scalarResultType =
hlfir::getFortranElementType(transpose.getType());

auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
builder, loc, "transpose", scalarResultType, args);

processReturnValue(transpose, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};

class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
public:
void runOnOperation() override {
Expand All @@ -740,9 +524,7 @@ class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
AssociateOpConversion, ConcatOpConversion, DestroyOpConversion,
ElementalOpConversion, EndAssociateOpConversion,
MatmulOpConversion, NoReassocOpConversion,
SetLengthOpConversion, SumOpConversion, TransposeOpConversion>(
context);
NoReassocOpConversion, SetLengthOpConversion>(context);
mlir::ConversionTarget target(*context);
target.addIllegalOp<hlfir::ApplyOp, hlfir::AssociateOp, hlfir::ElementalOp,
hlfir::EndAssociateOp, hlfir::SetLengthOp,
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
add_flang_library(HLFIRTransforms
BufferizeHLFIR.cpp
ConvertToFIR.cpp
LowerHLFIRIntrinsics.cpp

DEPENDS
FIRDialect
Expand Down

0 comments on commit 9cbeb97

Please sign in to comment.