Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2284,6 +2284,213 @@ class CmpCharOpConversion : public mlir::OpRewritePattern<hlfir::CmpCharOp> {
}
};

static std::pair<mlir::Value, hlfir::AssociateOp>
getVariable(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value val) {
// If it is an expression - create a variable from it, or forward
// the value otherwise.
hlfir::AssociateOp associate;
if (!mlir::isa<hlfir::ExprType>(val.getType()))
return {val, associate};
hlfir::Entity entity{val};
mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr(builder);
associate = hlfir::genAssociateExpr(loc, builder, entity, entity.getType(),
"", byRefAttr);
return {associate.getBase(), associate};
}

class IndexOpConversion : public mlir::OpRewritePattern<hlfir::IndexOp> {
public:
using mlir::OpRewritePattern<hlfir::IndexOp>::OpRewritePattern;

llvm::LogicalResult
matchAndRewrite(hlfir::IndexOp op,
mlir::PatternRewriter &rewriter) const override {
// We simplify only limited cases:
// 1) a substring length shall be known at compile time
// 2) if a substring length is 0 then replace with 1 for forward search,
// or otherwise with the string length + 1 (builder shall const-fold if
// lookup direction is known at compile time).
// 3) for known string length at compile time, if it is
// shorter than substring => replace with zero.
// 4) if a substring length is one => inline as simple search loop
// 5) for forward search with input strings of kind=1 runtime is faster.
// Do not simplify in all the other cases relying on a runtime call.

fir::FirOpBuilder builder{rewriter, op.getOperation()};
const mlir::Location &loc = op->getLoc();

auto resultTy = op.getType();
mlir::Value back = op.getBack();
mlir::Value substrLen =
hlfir::genCharLength(loc, builder, hlfir::Entity{op.getSubstr()});

auto substrLenCst = fir::getIntIfConstant(substrLen);
if (!substrLenCst) {
return rewriter.notifyMatchFailure(
op, "substring length unknown at compile time");
}
mlir::Value strLen =
hlfir::genCharLength(loc, builder, hlfir::Entity{op.getStr()});
auto i1Ty = builder.getI1Type();
auto idxTy = builder.getIndexType();
if (*substrLenCst == 0) {
mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
// zero length substring. For back search replace with
// strLen+1, or otherwise with 1.
mlir::Value strEnd = mlir::arith::AddIOp::create(
builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx);
if (back)
back = builder.createConvert(loc, i1Ty, back);
else
back = builder.createIntegerConstant(loc, i1Ty, 0);
mlir::Value result =
mlir::arith::SelectOp::create(builder, loc, back, strEnd, oneIdx);

rewriter.replaceOp(op, builder.createConvert(loc, resultTy, result));
return mlir::success();
}

if (auto strLenCst = fir::getIntIfConstant(strLen)) {
if (*strLenCst < *substrLenCst) {
rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 0));
return mlir::success();
}
if (*strLenCst == 0) {
// both strings have zero length
rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 1));
return mlir::success();
}
}
if (*substrLenCst != 1) {
return rewriter.notifyMatchFailure(
op, "rely on runtime implementation if substring length > 1");
}
// For forward search and character kind=1 the runtime uses memchr
// which well optimized. But it looks like memchr idiom is not recognized
// in LLVM yet. On a micro-kernel test with strings of length 40 runtime
// had ~2x less execution time vs inlined code. For unknown search direction
// at compile time pessimistically assume "forward".
std::optional<bool> isBack;
if (back) {
if (auto backCst = fir::getIntIfConstant(back))
isBack = *backCst != 0;
} else {
isBack = false;
}
auto charTy = mlir::cast<fir::CharacterType>(
hlfir::getFortranElementType(op.getSubstr().getType()));
unsigned kind = charTy.getFKind();
if (kind == 1 && (!isBack || !*isBack)) {
return rewriter.notifyMatchFailure(
op, "rely on runtime implementation for character kind 1");
}

// All checks are passed here. Generate single character search loop.
auto [strV, strAssociate] = getVariable(builder, loc, op.getStr());
auto [substrV, substrAssociate] = getVariable(builder, loc, op.getSubstr());
hlfir::Entity str{strV};
hlfir::Entity substr{substrV};
mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);

auto genExtractAndConvertToInt = [&charTy, &idxTy, &oneIdx,
kind](mlir::Location loc,
fir::FirOpBuilder &builder,
hlfir::Entity &charStr,
mlir::Value index) {
auto bits = builder.getKindMap().getCharacterBitsize(kind);
auto intTy = builder.getIntegerType(bits);
auto charLen1Ty =
fir::CharacterType::getSingleton(builder.getContext(), kind);
mlir::Type designatorTy =
fir::ReferenceType::get(charLen1Ty, fir::isa_volatile_type(charTy));
auto idxAttr = builder.getIntegerAttr(idxTy, 0);

auto singleChr = hlfir::DesignateOp::create(
builder, loc, designatorTy, charStr, /*component=*/{},
/*compShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{},
/*substring=*/mlir::ValueRange{index, index},
/*complexPart=*/std::nullopt,
/*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{oneIdx},
fir::FortranVariableFlagsAttr{});
auto chrVal = fir::LoadOp::create(builder, loc, singleChr);
mlir::Value intVal = fir::ExtractValueOp::create(
builder, loc, intTy, chrVal, builder.getArrayAttr(idxAttr));
return intVal;
};

auto wantChar = genExtractAndConvertToInt(loc, builder, substr, oneIdx);

// Generate search loop body with the following C equivalent:
// idx_t result = 0;
// idx_t end = strlen + 1;
// char want = substr[0];
// for (idx_t idx = 1; idx < end; ++idx) {
// if (result == 0) {
// idx_t at = back ? end - idx: idx;
// result = str[at-1] == want ? at : result;
// }
// }
if (!back)
back = builder.createIntegerConstant(loc, i1Ty, 0);
else
back = builder.createConvert(loc, i1Ty, back);
mlir::Value strEnd = mlir::arith::AddIOp::create(
builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx);
mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
auto genSearchBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange index,
mlir::ValueRange reductionArgs)
-> llvm::SmallVector<mlir::Value, 1> {
assert(index.size() == 1 && "expected single loop");
assert(reductionArgs.size() == 1 && "expected single reduction value");
mlir::Value inRes = reductionArgs[0];
auto resEQzero = mlir::arith::CmpIOp::create(
builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroIdx);

mlir::Value res =
builder
.genIfOp(loc, {idxTy}, resEQzero,
/*withElseRegion=*/true)
.genThen([&]() {
mlir::Value idx = builder.createConvert(loc, idxTy, index[0]);
// offset = back ? end - idx : idx;
mlir::Value offset = mlir::arith::SelectOp::create(
builder, loc, back,
mlir::arith::SubIOp::create(builder, loc, strEnd, idx),
idx);

auto haveChar =
genExtractAndConvertToInt(loc, builder, str, offset);
auto charsEQ = mlir::arith::CmpIOp::create(
builder, loc, mlir::arith::CmpIPredicate::eq, haveChar,
wantChar);
mlir::Value newVal = mlir::arith::SelectOp::create(
builder, loc, charsEQ, offset, inRes);

fir::ResultOp::create(builder, loc, newVal);
})
.genElse([&]() { fir::ResultOp::create(builder, loc, inRes); })
.getResults()[0];
return {res};
};

llvm::SmallVector<mlir::Value, 1> loopOut =
hlfir::genLoopNestWithReductions(loc, builder, {strLen},
/*reductionInits=*/{zeroIdx},
genSearchBody,
/*isUnordered=*/false);
mlir::Value result = builder.createConvert(loc, resultTy, loopOut[0]);

if (strAssociate)
hlfir::EndAssociateOp::create(builder, loc, strAssociate);
if (substrAssociate)
hlfir::EndAssociateOp::create(builder, loc, substrAssociate);

rewriter.replaceOp(op, result);
return mlir::success();
}
};

template <typename Op>
class MatmulConversion : public mlir::OpRewritePattern<Op> {
public:
Expand Down Expand Up @@ -2955,6 +3162,7 @@ class SimplifyHLFIRIntrinsics
patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context);
patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context);
patterns.insert<CmpCharOpConversion>(context);
patterns.insert<IndexOpConversion>(context);
patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context);
patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);
Expand Down
Loading