Skip to content

Commit

Permalink
[flang] Lower procedure designator
Browse files Browse the repository at this point in the history
This patch adds lowering for procedure designator.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D122153

Co-authored-by: Jean Perier <jperier@nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
  • Loading branch information
3 people committed Mar 21, 2022
1 parent 826bdf5 commit 5754bae
Show file tree
Hide file tree
Showing 9 changed files with 651 additions and 19 deletions.
8 changes: 8 additions & 0 deletions flang/include/flang/Lower/IntrinsicCall.h
Expand Up @@ -82,6 +82,14 @@ ArgLoweringRule lowerIntrinsicArgumentAs(mlir::Location,
/// Return place-holder for absent intrinsic arguments.
fir::ExtendedValue getAbsentIntrinsicArgument();

/// Get SymbolRefAttr of runtime (or wrapper function containing inlined
// implementation) of an unrestricted intrinsic (defined by its signature
// and generic name)
mlir::SymbolRefAttr
getUnrestrictedIntrinsicSymbolRefAttr(fir::FirOpBuilder &, mlir::Location,
llvm::StringRef name,
mlir::FunctionType signature);

//===----------------------------------------------------------------------===//
// Direct access to intrinsics that may be used by lowering outside
// of intrinsic call lowering.
Expand Down
77 changes: 64 additions & 13 deletions flang/include/flang/Optimizer/Dialect/CanonicalizationPatterns.td
Expand Up @@ -23,28 +23,80 @@ def IdenticalTypePred : Constraint<CPred<"$0.getType() == $1.getType()">>;
def IntegerTypePred : Constraint<CPred<"fir::isa_integer($0.getType())">>;
def IndexTypePred : Constraint<CPred<"$0.getType().isa<mlir::IndexType>()">>;

def SmallerWidthPred
: Constraint<CPred<"$0.getType().getIntOrFloatBitWidth() "
"<= $1.getType().getIntOrFloatBitWidth()">>;
// Widths are monotonic.
// $0.bits >= $1.bits >= $2.bits or $0.bits <= $1.bits <= $2.bits
def MonotonicTypePred
: Constraint<CPred<"(($0.getType().isa<mlir::IntegerType>() && "
" $1.getType().isa<mlir::IntegerType>() && "
" $2.getType().isa<mlir::IntegerType>()) || "
" ($0.getType().isa<mlir::FloatType>() && "
" $1.getType().isa<mlir::FloatType>() && "
" $2.getType().isa<mlir::FloatType>())) && "
"(($0.getType().getIntOrFloatBitWidth() <= "
" $1.getType().getIntOrFloatBitWidth() && "
" $1.getType().getIntOrFloatBitWidth() <= "
" $2.getType().getIntOrFloatBitWidth()) || "
" ($0.getType().getIntOrFloatBitWidth() >= "
" $1.getType().getIntOrFloatBitWidth() && "
" $1.getType().getIntOrFloatBitWidth() >= "
" $2.getType().getIntOrFloatBitWidth()))">>;

def IntPred : Constraint<CPred<
"$0.getType().isa<mlir::IntegerType>() && "
"$1.getType().isa<mlir::IntegerType>()">>;

// If both are int type and the first is smaller than the second.
// $0.bits <= $1.bits
def SmallerWidthPred : Constraint<CPred<
"$0.getType().getIntOrFloatBitWidth() <= "
"$1.getType().getIntOrFloatBitWidth()">>;
def StrictSmallerWidthPred : Constraint<CPred<
"$0.getType().getIntOrFloatBitWidth() < "
"$1.getType().getIntOrFloatBitWidth()">>;

// floats or ints that undergo successive extensions or successive truncations.
def ConvertConvertOptPattern
: Pat<(fir_ConvertOp (fir_ConvertOp $arg)),
: Pat<(fir_ConvertOp:$res (fir_ConvertOp:$irm $arg)),
(fir_ConvertOp $arg),
[(MonotonicTypePred $res, $irm, $arg)]>;

// Widths are increasingly monotonic to type index, so there is no
// possibility of a truncation before the conversion to index.
// $res == index && $irm.bits >= $arg.bits
def ConvertAscendingIndexOptPattern
: Pat<(fir_ConvertOp:$res (fir_ConvertOp:$irm $arg)),
(fir_ConvertOp $arg),
[(IndexTypePred $res), (IntPred $irm, $arg),
(SmallerWidthPred $arg, $irm)]>;

// Widths are decreasingly monotonic from type index, so the truncations
// continue to lop off more bits.
// $arg == index && $res.bits < $irm.bits
def ConvertDescendingIndexOptPattern
: Pat<(fir_ConvertOp:$res (fir_ConvertOp:$irm $arg)),
(fir_ConvertOp $arg),
[(IntegerTypePred $arg)]>;
[(IndexTypePred $arg), (IntPred $irm, $res),
(SmallerWidthPred $res, $irm)]>;

// Useless convert to exact same type.
def RedundantConvertOptPattern
: Pat<(fir_ConvertOp:$res $arg),
(replaceWithValue $arg),
[(IdenticalTypePred $res, $arg)
,(IntegerTypePred $arg)]>;
[(IdenticalTypePred $res, $arg)]>;

// Useless extension followed by truncation to get same width integer.
def CombineConvertOptPattern
: Pat<(fir_ConvertOp:$res(fir_ConvertOp:$irm $arg)),
(replaceWithValue $arg),
[(IdenticalTypePred $res, $arg)
,(IntegerTypePred $arg)
,(IntegerTypePred $irm)
,(SmallerWidthPred $arg, $irm)]>;
[(IntPred $res, $arg), (IdenticalTypePred $res, $arg),
(IntPred $arg, $irm), (SmallerWidthPred $arg, $irm)]>;

// Useless extension followed by truncation to get smaller width integer.
def CombineConvertTruncOptPattern
: Pat<(fir_ConvertOp:$res(fir_ConvertOp:$irm $arg)),
(fir_ConvertOp $arg),
[(IntPred $res, $arg), (StrictSmallerWidthPred $res, $arg),
(IntPred $arg, $irm), (SmallerWidthPred $arg, $irm)]>;

def createConstantOp
: NativeCodeCall<"$_builder.create<mlir::arith::ConstantOp>"
Expand All @@ -55,7 +107,6 @@ def createConstantOp
def ForwardConstantConvertPattern
: Pat<(fir_ConvertOp:$res (Arith_ConstantOp:$cnt $attr)),
(createConstantOp $res, $attr),
[(IndexTypePred $res)
,(IntegerTypePred $cnt)]>;
[(IndexTypePred $res), (IntegerTypePred $cnt)]>;

#endif // FORTRAN_FIR_REWRITE_PATTERNS
74 changes: 71 additions & 3 deletions flang/lib/Lower/ConvertExpr.cpp
Expand Up @@ -330,6 +330,16 @@ static bool isParenthesizedVariable(const Fortran::evaluate::Expr<T> &expr) {
}
}

/// Does \p expr only refer to symbols that are mapped to IR values in \p symMap
/// ?
static bool allSymbolsInExprPresentInMap(const Fortran::lower::SomeExpr &expr,
Fortran::lower::SymMap &symMap) {
for (const auto &sym : Fortran::evaluate::CollectSymbols(expr))
if (!symMap.lookupSymbol(sym))
return false;
return true;
}

/// Generate a load of a value from an address. Beware that this will lose
/// any dynamic type information for polymorphic entities (note that unlimited
/// polymorphic cannot be loaded and must not be provided here).
Expand Down Expand Up @@ -743,11 +753,69 @@ class ScalarExprLowering {
/// The type of the function indirection is not guaranteed to match the one
/// of the ProcedureDesignator due to Fortran implicit typing rules.
ExtValue genval(const Fortran::evaluate::ProcedureDesignator &proc) {
TODO(getLoc(), "genval ProcedureDesignator");
mlir::Location loc = getLoc();
if (const Fortran::evaluate::SpecificIntrinsic *intrinsic =
proc.GetSpecificIntrinsic()) {
mlir::FunctionType signature =
Fortran::lower::translateSignature(proc, converter);
// Intrinsic lowering is based on the generic name, so retrieve it here in
// case it is different from the specific name. The type of the specific
// intrinsic is retained in the signature.
std::string genericName =
converter.getFoldingContext().intrinsics().GetGenericIntrinsicName(
intrinsic->name);
mlir::SymbolRefAttr symbolRefAttr =
Fortran::lower::getUnrestrictedIntrinsicSymbolRefAttr(
builder, loc, genericName, signature);
mlir::Value funcPtr =
builder.create<fir::AddrOfOp>(loc, signature, symbolRefAttr);
return funcPtr;
}
const Fortran::semantics::Symbol *symbol = proc.GetSymbol();
assert(symbol && "expected symbol in ProcedureDesignator");
mlir::Value funcPtr;
mlir::Value funcPtrResultLength;
if (Fortran::semantics::IsDummy(*symbol)) {
Fortran::lower::SymbolBox val = symMap.lookupSymbol(*symbol);
assert(val && "Dummy procedure not in symbol map");
funcPtr = val.getAddr();
if (fir::isCharacterProcedureTuple(funcPtr.getType(),
/*acceptRawFunc=*/false))
std::tie(funcPtr, funcPtrResultLength) =
fir::factory::extractCharacterProcedureTuple(builder, loc, funcPtr);
} else {
std::string name = converter.mangleName(*symbol);
mlir::FuncOp func =
Fortran::lower::getOrDeclareFunction(name, proc, converter);
funcPtr = builder.create<fir::AddrOfOp>(loc, func.getFunctionType(),
builder.getSymbolRefAttr(name));
}
if (Fortran::lower::mustPassLengthWithDummyProcedure(proc, converter)) {
// The result length, if available here, must be propagated along the
// procedure address so that call sites where the result length is assumed
// can retrieve the length.
Fortran::evaluate::DynamicType resultType = proc.GetType().value();
if (const auto &lengthExpr = resultType.GetCharLength()) {
// The length expression may refer to dummy argument symbols that are
// meaningless without any actual arguments. Leave the length as
// unknown in that case, it be resolved on the call site
// with the actual arguments.
if (allSymbolsInExprPresentInMap(toEvExpr(*lengthExpr), symMap)) {
mlir::Value rawLen = fir::getBase(genval(*lengthExpr));
// F2018 7.4.4.2 point 5.
funcPtrResultLength =
Fortran::lower::genMaxWithZero(builder, getLoc(), rawLen);
}
}
if (!funcPtrResultLength)
funcPtrResultLength = builder.createIntegerConstant(
loc, builder.getCharacterLengthType(), -1);
return fir::CharBoxValue{funcPtr, funcPtrResultLength};
}
return funcPtr;
}

ExtValue genval(const Fortran::evaluate::NullPointer &) {
TODO(getLoc(), "genval NullPointer");
return builder.createNullConstant(getLoc());
}

static bool
Expand Down
46 changes: 46 additions & 0 deletions flang/lib/Lower/IntrinsicCall.cpp
Expand Up @@ -574,6 +574,12 @@ struct IntrinsicLibrary {
mlir::Value invokeGenerator(SubroutineGenerator generator,
llvm::ArrayRef<mlir::Value> args);

/// Get pointer to unrestricted intrinsic. Generate the related unrestricted
/// intrinsic if it is not defined yet.
mlir::SymbolRefAttr
getUnrestrictedIntrinsicSymbolRefAttr(llvm::StringRef name,
mlir::FunctionType signature);

/// Add clean-up for \p temp to the current statement context;
void addCleanUpForTemp(mlir::Location loc, mlir::Value temp);
/// Helper function for generating code clean-up for result descriptors
Expand Down Expand Up @@ -1608,6 +1614,39 @@ IntrinsicLibrary::getRuntimeCallGenerator(llvm::StringRef name,
};
}

mlir::SymbolRefAttr IntrinsicLibrary::getUnrestrictedIntrinsicSymbolRefAttr(
llvm::StringRef name, mlir::FunctionType signature) {
// Unrestricted intrinsics signature follows implicit rules: argument
// are passed by references. But the runtime versions expect values.
// So instead of duplicating the runtime, just have the wrappers loading
// this before calling the code generators.
bool loadRefArguments = true;
mlir::FuncOp funcOp;
if (const IntrinsicHandler *handler = findIntrinsicHandler(name))
funcOp = std::visit(
[&](auto generator) {
return getWrapper(generator, name, signature, loadRefArguments);
},
handler->generator);

if (!funcOp) {
llvm::SmallVector<mlir::Type> argTypes;
for (mlir::Type type : signature.getInputs()) {
if (auto refType = type.dyn_cast<fir::ReferenceType>())
argTypes.push_back(refType.getEleTy());
else
argTypes.push_back(type);
}
mlir::FunctionType soughtFuncType =
builder.getFunctionType(argTypes, signature.getResults());
IntrinsicLibrary::RuntimeCallGenerator rtCallGenerator =
getRuntimeCallGenerator(name, soughtFuncType);
funcOp = getWrapper(rtCallGenerator, name, signature, loadRefArguments);
}

return mlir::SymbolRefAttr::get(funcOp);
}

void IntrinsicLibrary::addCleanUpForTemp(mlir::Location loc, mlir::Value temp) {
assert(stmtCtx);
fir::FirOpBuilder *bldr = &builder;
Expand Down Expand Up @@ -3611,3 +3650,10 @@ mlir::Value Fortran::lower::genPow(fir::FirOpBuilder &builder,
mlir::Value x, mlir::Value y) {
return IntrinsicLibrary{builder, loc}.genRuntimeCall("pow", type, {x, y});
}

mlir::SymbolRefAttr Fortran::lower::getUnrestrictedIntrinsicSymbolRefAttr(
fir::FirOpBuilder &builder, mlir::Location loc, llvm::StringRef name,
mlir::FunctionType signature) {
return IntrinsicLibrary{builder, loc}.getUnrestrictedIntrinsicSymbolRefAttr(
name, signature);
}
8 changes: 5 additions & 3 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Expand Up @@ -820,9 +820,10 @@ mlir::LogicalResult ConstcOp::verify() {

void fir::ConvertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<ConvertConvertOptPattern, RedundantConvertOptPattern,
CombineConvertOptPattern, ForwardConstantConvertPattern>(
context);
results.insert<ConvertConvertOptPattern, ConvertAscendingIndexOptPattern,
ConvertDescendingIndexOptPattern, RedundantConvertOptPattern,
CombineConvertOptPattern, CombineConvertTruncOptPattern,
ForwardConstantConvertPattern>(context);
}

mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) {
Expand Down Expand Up @@ -875,6 +876,7 @@ mlir::LogicalResult ConvertOp::verify() {
(isIntegerCompatible(inType) && isPointerCompatible(outType)) ||
(isPointerCompatible(inType) && isIntegerCompatible(outType)) ||
(inType.isa<fir::BoxType>() && outType.isa<fir::BoxType>()) ||
(inType.isa<fir::BoxProcType>() && outType.isa<fir::BoxProcType>()) ||
(fir::isa_complex(inType) && fir::isa_complex(outType)))
return mlir::success();
return emitOpError("invalid type conversion");
Expand Down

0 comments on commit 5754bae

Please sign in to comment.