Skip to content

Commit

Permalink
[flang] Finish substring lowering
Browse files Browse the repository at this point in the history
Hlfir.designate was made to support substrings but so far substrings
were not yet lowered to it. Implement support for them.

Differential Revision: https://reviews.llvm.org/D140310
  • Loading branch information
jeanPerier committed Dec 20, 2022
1 parent 5bc703f commit d0018c9
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 37 deletions.
7 changes: 7 additions & 0 deletions flang/include/flang/Optimizer/Builder/Character.h
Expand Up @@ -47,6 +47,13 @@ class CharacterExprHelper {
fir::CharBoxValue createSubstring(const fir::CharBoxValue &str,
llvm::ArrayRef<mlir::Value> bounds);

/// Compute substring base address given the raw address (not fir.boxchar) of
/// a scalar string, a substring / lower bound, and the substring type.
mlir::Value genSubstringBase(mlir::Value stringRawAddr,
mlir::Value lowerBound,
mlir::Type substringAddrType,
mlir::Value one = {});

/// Return blank character of given \p type !fir.char<kind>
mlir::Value createBlankConstant(fir::CharacterType type);

Expand Down
14 changes: 10 additions & 4 deletions flang/include/flang/Optimizer/Builder/HLFIRTools.h
Expand Up @@ -174,10 +174,11 @@ translateToExtendedValue(mlir::Location loc, fir::FirOpBuilder &builder,
fir::FortranVariableOpInterface fortranVariable);

/// Generate declaration for a fir::ExtendedValue in memory.
EntityWithAttributes genDeclare(mlir::Location loc, fir::FirOpBuilder &builder,
const fir::ExtendedValue &exv,
llvm::StringRef name,
fir::FortranVariableFlagsAttr flags);
fir::FortranVariableOpInterface genDeclare(mlir::Location loc,
fir::FirOpBuilder &builder,
const fir::ExtendedValue &exv,
llvm::StringRef name,
fir::FortranVariableFlagsAttr flags);

/// Generate an hlfir.associate to build a variable from an expression value.
/// The type of the variable must be provided so that scalar logicals are
Expand Down Expand Up @@ -238,6 +239,11 @@ void genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
Entity entity,
llvm::SmallVectorImpl<mlir::Value> &result);

/// Get the length of a character entity. Crashes if the entity is not
/// a character entity.
mlir::Value genCharLength(mlir::Location loc, fir::FirOpBuilder &builder,
Entity entity);

/// Return the fir base, shape, and type parameters for a variable. Note that
/// type parameters are only added if the entity is not a box and the type
/// parameters is not a constant in the base type. This matches the arguments
Expand Down
136 changes: 131 additions & 5 deletions flang/lib/Lower/ConvertExprToHLFIR.cpp
Expand Up @@ -24,6 +24,7 @@
#include "flang/Optimizer/Builder/Runtime/Character.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "llvm/ADT/TypeSwitch.h"

namespace {

Expand Down Expand Up @@ -65,6 +66,13 @@ class HlfirDesignatorBuilder {
designatorVariant);
}

hlfir::EntityWithAttributes
gen(const Fortran::evaluate::NamedEntity &namedEntity) {
if (namedEntity.IsSymbol())
return gen(Fortran::evaluate::SymbolRef{namedEntity.GetLastSymbol()});
return gen(namedEntity.GetComponent());
}

private:
/// Struct that is filled while visiting a part-ref (in the "visit" member
/// function) before the top level "gen" generates an hlfir.declare for the
Expand All @@ -75,6 +83,7 @@ class HlfirDesignatorBuilder {
hlfir::DesignateOp::Subscripts subscripts;
mlir::Value resultShape;
llvm::SmallVector<mlir::Value> typeParams;
llvm::SmallVector<mlir::Value, 2> substring;
};

/// Generate an hlfir.declare for a part-ref given a filled PartInfo and the
Expand All @@ -100,11 +109,11 @@ class HlfirDesignatorBuilder {
resultType = fir::ReferenceType::get(resultValueType);

std::optional<bool> complexPart;
llvm::SmallVector<mlir::Value> substring;
auto designate = getBuilder().create<hlfir::DesignateOp>(
getLoc(), resultType, partInfo.base.getBase(), "",
/*componentShape=*/mlir::Value{}, partInfo.subscripts, substring,
complexPart, partInfo.resultShape, partInfo.typeParams);
/*componentShape=*/mlir::Value{}, partInfo.subscripts,
partInfo.substring, complexPart, partInfo.resultShape,
partInfo.typeParams);
return mlir::cast<fir::FortranVariableOpInterface>(
designate.getOperation());
}
Expand Down Expand Up @@ -132,6 +141,9 @@ class HlfirDesignatorBuilder {
gen(const Fortran::evaluate::CoarrayRef &coarrayRef) {
TODO(getLoc(), "lowering CoarrayRef to HLFIR");
}
mlir::Type visit(const Fortran::evaluate::CoarrayRef &, PartInfo &) {
TODO(getLoc(), "lowering CoarrayRef to HLFIR");
}

hlfir::EntityWithAttributes
gen(const Fortran::evaluate::ComplexPart &complexPart) {
Expand All @@ -140,7 +152,95 @@ class HlfirDesignatorBuilder {

hlfir::EntityWithAttributes
gen(const Fortran::evaluate::Substring &substring) {
TODO(getLoc(), "lowering substrings to HLFIR");
PartInfo partInfo;
mlir::Type baseStringType = std::visit(
[&](const auto &x) { return visit(x, partInfo); }, substring.parent());
assert(partInfo.typeParams.size() == 1 && "expect base string length");
// Compute the substring lower and upper bound.
partInfo.substring.push_back(genSubscript(substring.lower()));
if (Fortran::evaluate::MaybeExtentExpr upperBound = substring.upper())
partInfo.substring.push_back(genSubscript(*upperBound));
else
partInfo.substring.push_back(partInfo.typeParams[0]);
fir::FirOpBuilder &builder = getBuilder();
mlir::Location loc = getLoc();
mlir::Type idxTy = builder.getIndexType();
partInfo.substring[0] =
builder.createConvert(loc, idxTy, partInfo.substring[0]);
partInfo.substring[1] =
builder.createConvert(loc, idxTy, partInfo.substring[1]);
// Try using constant length if available. mlir::arith folding would
// most likely be able to fold "max(ub-lb+1,0)" too, but getting
// the constant length in the FIR types would be harder.
std::optional<int64_t> cstLen =
Fortran::evaluate::ToInt64(Fortran::evaluate::Fold(
getConverter().getFoldingContext(), substring.LEN()));
if (cstLen) {
partInfo.typeParams[0] =
builder.createIntegerConstant(loc, idxTy, *cstLen);
} else {
// Compute "len = max(ub-lb+1,0)" (Fortran 2018 9.4.1).
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
auto boundsDiff = builder.create<mlir::arith::SubIOp>(
loc, partInfo.substring[1], partInfo.substring[0]);
auto rawLen = builder.create<mlir::arith::AddIOp>(loc, boundsDiff, one);
partInfo.typeParams[0] =
fir::factory::genMaxWithZero(builder, loc, rawLen);
}
mlir::Type resultType = changeLengthInCharacterType(
loc, baseStringType,
cstLen ? *cstLen : fir::CharacterType::unknownLen());
return genDeclare(resultType, partInfo);
}

static mlir::Type changeLengthInCharacterType(mlir::Location loc,
mlir::Type type,
int64_t newLen) {
return llvm::TypeSwitch<mlir::Type, mlir::Type>(type)
.Case<fir::CharacterType>([&](fir::CharacterType charTy) -> mlir::Type {
return fir::CharacterType::get(charTy.getContext(), charTy.getFKind(),
newLen);
})
.Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
return fir::SequenceType::get(
seqTy.getShape(),
changeLengthInCharacterType(loc, seqTy.getEleTy(), newLen));
})
.Case<fir::PointerType, fir::HeapType, fir::ReferenceType,
fir::BoxType>([&](auto t) -> mlir::Type {
using FIRT = decltype(t);
return FIRT::get(
changeLengthInCharacterType(loc, t.getEleTy(), newLen));
})
.Default([loc](mlir::Type t) -> mlir::Type {
fir::emitFatalError(loc, "expected character type");
});
}

mlir::Type visit(const Fortran::evaluate::DataRef &dataRef,
PartInfo &partInfo) {
return std::visit([&](const auto &x) { return visit(x, partInfo); },
dataRef.u);
}

mlir::Type
visit(const Fortran::evaluate::StaticDataObject::Pointer &staticObject,
PartInfo &partInfo) {
fir::FirOpBuilder &builder = getBuilder();
mlir::Location loc = getLoc();
std::optional<std::string> string = staticObject->AsString();
// TODO: see if StaticDataObject can be replaced by something based on
// Constant<T> to avoid dealing with endianness here for KIND>1.
// This will also avoid making string copies here.
if (!string)
TODO(loc, "StaticDataObject::Pointer substring with kind > 1");
fir::ExtendedValue exv =
fir::factory::createStringLiteral(builder, getLoc(), *string);
auto flags = fir::FortranVariableFlagsAttr::get(
builder.getContext(), fir::FortranVariableFlagsEnum::parameter);
partInfo.base = hlfir::genDeclare(loc, builder, exv, ".stringlit", flags);
partInfo.typeParams.push_back(fir::getLen(exv));
return partInfo.base.getElementOrSequenceType();
}

mlir::Type visit(const Fortran::evaluate::SymbolRef &symbolRef,
Expand Down Expand Up @@ -845,7 +945,33 @@ class HlfirBuilder {

hlfir::EntityWithAttributes
gen(const Fortran::evaluate::DescriptorInquiry &desc) {
TODO(getLoc(), "lowering descriptor inquiry to HLFIR");
mlir::Location loc = getLoc();
auto &builder = getBuilder();
hlfir::EntityWithAttributes entity =
HlfirDesignatorBuilder(getLoc(), getConverter(), getSymMap(),
getStmtCtx())
.gen(desc.base());
using ResTy = Fortran::evaluate::DescriptorInquiry::Result;
mlir::Type resultType =
getConverter().genType(ResTy::category, ResTy::kind);
auto castResult = [&](mlir::Value v) {
return hlfir::EntityWithAttributes{
builder.createConvert(loc, resultType, v)};
};
switch (desc.field()) {
case Fortran::evaluate::DescriptorInquiry::Field::Len:
return castResult(hlfir::genCharLength(loc, builder, entity));
case Fortran::evaluate::DescriptorInquiry::Field::LowerBound:
TODO(loc, "lower bound inquiry in HLFIR");
case Fortran::evaluate::DescriptorInquiry::Field::Extent:
TODO(loc, "extent inquiry in HLFIR");
case Fortran::evaluate::DescriptorInquiry::Field::Rank:
TODO(loc, "rank inquiry on assumed rank");
case Fortran::evaluate::DescriptorInquiry::Field::Stride:
// So far the front end does not generate this inquiry.
TODO(loc, "stride inquiry");
}
llvm_unreachable("unknown descriptor inquiry");
}

hlfir::EntityWithAttributes
Expand Down
19 changes: 14 additions & 5 deletions flang/lib/Optimizer/Builder/Character.cpp
Expand Up @@ -473,6 +473,17 @@ fir::CharBoxValue fir::factory::CharacterExprHelper::createConcatenate(
return temp;
}

mlir::Value fir::factory::CharacterExprHelper::genSubstringBase(
mlir::Value stringRawAddr, mlir::Value lowerBound,
mlir::Type substringAddrType, mlir::Value one) {
if (!one)
one = builder.createIntegerConstant(loc, lowerBound.getType(), 1);
auto offset =
builder.create<mlir::arith::SubIOp>(loc, lowerBound, one).getResult();
auto addr = createElementAddr(stringRawAddr, offset);
return builder.createConvert(loc, substringAddrType, addr);
}

fir::CharBoxValue fir::factory::CharacterExprHelper::createSubstring(
const fir::CharBoxValue &box, llvm::ArrayRef<mlir::Value> bounds) {
// Constant need to be materialize in memory to use fir.coordinate_of.
Expand All @@ -488,14 +499,12 @@ fir::CharBoxValue fir::factory::CharacterExprHelper::createSubstring(
builder.createConvert(loc, builder.getCharacterLengthType(), bound));
auto lowerBound = castBounds[0];
// FIR CoordinateOp is zero based but Fortran substring are one based.
auto one = builder.createIntegerConstant(loc, lowerBound.getType(), 1);
auto offset =
builder.create<mlir::arith::SubIOp>(loc, lowerBound, one).getResult();
auto addr = createElementAddr(box.getBuffer(), offset);
auto kind = getCharacterKind(box.getBuffer().getType());
auto charTy = fir::CharacterType::getUnknownLen(builder.getContext(), kind);
auto resultType = builder.getRefType(charTy);
auto substringRef = builder.createConvert(loc, resultType, addr);
auto one = builder.createIntegerConstant(loc, lowerBound.getType(), 1);
auto substringRef =
genSubstringBase(box.getBuffer(), lowerBound, resultType, one);

// Compute the length.
mlir::Value substringLen;
Expand Down
10 changes: 9 additions & 1 deletion flang/lib/Optimizer/Builder/HLFIRTools.cpp
Expand Up @@ -139,7 +139,7 @@ hlfir::translateToExtendedValue(mlir::Location loc, fir::FirOpBuilder &builder,
return firBase;
}

hlfir::EntityWithAttributes
fir::FortranVariableOpInterface
hlfir::genDeclare(mlir::Location loc, fir::FirOpBuilder &builder,
const fir::ExtendedValue &exv, llvm::StringRef name,
fir::FortranVariableFlagsAttr flags) {
Expand Down Expand Up @@ -457,6 +457,14 @@ void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
TODO(loc, "inquire PDTs length parameters in HLFIR");
}

mlir::Value hlfir::genCharLength(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity entity) {
llvm::SmallVector<mlir::Value, 1> lenParams;
genLengthParameters(loc, builder, entity, lenParams);
assert(lenParams.size() == 1 && "characters must have one length parameters");
return lenParams[0];
}

std::pair<mlir::Value, mlir::Value> hlfir::genVariableFirBaseShapeAndParams(
mlir::Location loc, fir::FirOpBuilder &builder, Entity entity,
llvm::SmallVectorImpl<mlir::Value> &typeParams) {
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
Expand Up @@ -220,7 +220,8 @@ struct ConcatOpConversion : public mlir::OpConversionPattern<hlfir::ConcatOp> {
mlir::Value cast = builder.createConvert(loc, addrType, fir::getBase(res));
res = fir::substBase(res, cast);
auto hlfirTempRes = hlfir::genDeclare(loc, builder, res, "tmp",
fir::FortranVariableFlagsAttr{});
fir::FortranVariableFlagsAttr{})
.getBase();
mlir::Value bufferizedExpr =
packageBufferizedExpr(loc, builder, hlfirTempRes, false);
rewriter.replaceOp(concat, bufferizedExpr);
Expand Down
58 changes: 38 additions & 20 deletions flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
Expand Up @@ -8,6 +8,7 @@
// This file defines a pass to lower HLFIR to FIR
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Builder/Character.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/MutableBox.h"
Expand Down Expand Up @@ -183,11 +184,8 @@ class DesignateOpConversion
auto module = designate->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));

if (designate.getComponent() || designate.getComplexPart() ||
!designate.getSubstring().empty()) {
// build path.
TODO(loc, "hlfir::designate with complex part or substring or component");
}
if (designate.getComponent() || designate.getComplexPart())
TODO(loc, "hlfir::designate with complex part or component");

hlfir::Entity baseEntity(designate.getMemref());
if (baseEntity.isMutableBox())
Expand Down Expand Up @@ -216,8 +214,20 @@ class DesignateOpConversion
triples.push_back(undef);
}
}
llvm::SmallVector<mlir::Value, 2> substring;
if (!designate.getSubstring().empty()) {
substring.push_back(designate.getSubstring()[0]);
mlir::Type idxTy = builder.getIndexType();
// fir.slice op substring expects the zero based lower bound.
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
substring[0] = builder.createConvert(loc, idxTy, substring[0]);
substring[0] =
builder.create<mlir::arith::SubIOp>(loc, substring[0], one);
substring.push_back(designate.getTypeparams()[0]);
}

mlir::Value slice = builder.create<fir::SliceOp>(
loc, triples, /*path=*/mlir::ValueRange{});
loc, triples, /*fields=*/mlir::ValueRange{}, substring);
llvm::SmallVector<mlir::Type> resultType{designateResultType};
mlir::Value resultBox;
if (base.getType().isa<fir::BoxType>())
Expand All @@ -230,29 +240,37 @@ class DesignateOpConversion
return mlir::success();
}

// Indexing a single element (use fir.array_coor of fir.coordinate_of).
// Otherwise, the result is the address of a scalar. The base may be an
// array, or a scalar.
mlir::Type resultAddressType = designateResultType;
if (auto boxCharType = designateResultType.dyn_cast<fir::BoxCharType>())
resultAddressType = fir::ReferenceType::get(boxCharType.getEleTy());

if (designate.getIndices().empty()) {
// Scalar substring or complex part.
// generate fir.coordinate_of.
TODO(loc, "hlfir::designate to fir.coordinate_of");
// Array element indexing.
if (!designate.getIndices().empty()) {
auto eleTy = hlfir::getFortranElementType(base.getType());
auto arrayCoorType = fir::ReferenceType::get(eleTy);
base = builder.create<fir::ArrayCoorOp>(loc, arrayCoorType, base, shape,
/*slice=*/mlir::Value{},
designate.getIndices(),
firBaseTypeParameters);
}

// Generate fir.array_coor
mlir::Type resultType = designateResultType;
if (auto boxCharType = designateResultType.dyn_cast<fir::BoxCharType>())
resultType = fir::ReferenceType::get(boxCharType.getEleTy());
auto arrayCoor = builder.create<fir::ArrayCoorOp>(
loc, resultType, base, shape,
/*slice=*/mlir::Value{}, designate.getIndices(), firBaseTypeParameters);
// Scalar substring (potentially on the previously built array element).
if (!designate.getSubstring().empty())
base = fir::factory::CharacterExprHelper{builder, loc}.genSubstringBase(
base, designate.getSubstring()[0], resultAddressType);

// Cast/embox the computed scalar address if needed.
if (designateResultType.isa<fir::BoxCharType>()) {
assert(designate.getTypeparams().size() == 1 &&
"must have character length");
auto emboxChar = builder.create<fir::EmboxCharOp>(
loc, designateResultType, arrayCoor, designate.getTypeparams()[0]);
loc, designateResultType, base, designate.getTypeparams()[0]);
rewriter.replaceOp(designate, emboxChar.getResult());
} else {
rewriter.replaceOp(designate, arrayCoor.getResult());
base = builder.createConvert(loc, designateResultType, base);
rewriter.replaceOp(designate, base);
}
return mlir::success();
}
Expand Down

0 comments on commit d0018c9

Please sign in to comment.