Skip to content

Commit

Permalink
[flang][hlfir] Fix hlfir.set_length codegen
Browse files Browse the repository at this point in the history
The bufferization pass was propagating the raw alloca storage
(which may not allow to later retrieve the length) instead of
the hlfir variable value (which is guaranteed to hold the
character length).

Fix this and makes packageBufferizedExpr "storage" argument and
getBufferizedExprStorage return an hlfir::Entity to avoid similar
error in the future (the caller of packageBufferizedExpr will have
to think a bit when adding the explicit hlfir::Entity{} cast).

Differential Revision: https://reviews.llvm.org/D148307
  • Loading branch information
jeanPerier committed Apr 17, 2023
1 parent df8c78c commit 88ac741
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 22 deletions.
40 changes: 22 additions & 18 deletions flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ namespace hlfir {
namespace {

/// Helper to create tuple from a bufferized expr storage and clean up
/// instruction flag.
/// instruction flag. The storage is an HLFIR variable so that it can
/// be manipulated as a variable later (all shape and length information
/// cam be retrieved from it).
static mlir::Value packageBufferizedExpr(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::Value storage,
hlfir::Entity storage,
mlir::Value mustFree) {
auto tupleType = mlir::TupleType::get(
builder.getContext(),
Expand All @@ -61,7 +63,7 @@ static mlir::Value packageBufferizedExpr(mlir::Location loc,
/// boolean clean-up flag.
static mlir::Value packageBufferizedExpr(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::Value storage, bool mustFree) {
hlfir::Entity storage, bool mustFree) {
mlir::Value mustFreeValue = builder.createBool(loc, mustFree);
return packageBufferizedExpr(loc, builder, storage, mustFreeValue);
}
Expand All @@ -70,14 +72,14 @@ static mlir::Value packageBufferizedExpr(mlir::Location loc,
/// It assumes no tuples are used as HLFIR operation operands, which is
/// currently enforced by the verifiers that only accept HLFIR value or
/// variable types which do not include tuples.
static mlir::Value getBufferizedExprStorage(mlir::Value bufferizedExpr) {
static hlfir::Entity getBufferizedExprStorage(mlir::Value bufferizedExpr) {
auto tupleType = bufferizedExpr.getType().dyn_cast<mlir::TupleType>();
if (!tupleType)
return bufferizedExpr;
return hlfir::Entity{bufferizedExpr};
assert(tupleType.size() == 2 && "unexpected tuple type");
if (auto insert = bufferizedExpr.getDefiningOp<fir::InsertValueOp>())
if (insert.getVal().getType() == tupleType.getType(0))
return insert.getVal();
return hlfir::Entity{insert.getVal()};
TODO(bufferizedExpr.getLoc(), "general extract storage case");
}

Expand Down Expand Up @@ -152,7 +154,7 @@ struct AsExprOpConversion : public mlir::OpConversionPattern<hlfir::AsExprOp> {
if (asExpr.isMove()) {
// Move variable storage for the hlfir.expr buffer.
mlir::Value bufferizedExpr = packageBufferizedExpr(
loc, builder, adaptor.getVar(), adaptor.getMustFree());
loc, builder, hlfir::Entity{adaptor.getVar()}, adaptor.getMustFree());
rewriter.replaceOp(asExpr, bufferizedExpr);
return mlir::success();
}
Expand All @@ -175,7 +177,7 @@ struct ApplyOpConversion : public mlir::OpConversionPattern<hlfir::ApplyOp> {
matchAndRewrite(hlfir::ApplyOp apply, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = apply->getLoc();
hlfir::Entity bufferizedExpr{getBufferizedExprStorage(adaptor.getExpr())};
hlfir::Entity bufferizedExpr = getBufferizedExprStorage(adaptor.getExpr());
mlir::Type resultType = hlfir::getVariableElementType(bufferizedExpr);
mlir::Value result = rewriter.create<hlfir::DesignateOp>(
loc, resultType, bufferizedExpr, adaptor.getIndices(),
Expand Down Expand Up @@ -216,8 +218,8 @@ struct ConcatOpConversion : public mlir::OpConversionPattern<hlfir::ConcatOp> {
if (adaptor.getStrings().size() > 2)
TODO(loc, "codegen of optimized chained concatenation of more than two "
"strings");
hlfir::Entity lhs{getBufferizedExprStorage(adaptor.getStrings()[0])};
hlfir::Entity rhs{getBufferizedExprStorage(adaptor.getStrings()[1])};
hlfir::Entity lhs = getBufferizedExprStorage(adaptor.getStrings()[0]);
hlfir::Entity rhs = getBufferizedExprStorage(adaptor.getStrings()[1]);
auto [lhsExv, c1] = hlfir::translateToExtendedValue(loc, builder, lhs);
auto [rhsExv, c2] = hlfir::translateToExtendedValue(loc, builder, rhs);
assert(!c1 && !c2 && "expected variables");
Expand All @@ -229,9 +231,10 @@ struct ConcatOpConversion : public mlir::OpConversionPattern<hlfir::ConcatOp> {
hlfir::getFortranElementType(concat.getResult().getType()));
mlir::Value cast = builder.createConvert(loc, addrType, fir::getBase(res));
res = fir::substBase(res, cast);
auto hlfirTempRes = hlfir::genDeclare(loc, builder, res, "tmp",
fir::FortranVariableFlagsAttr{})
.getBase();
hlfir::Entity hlfirTempRes =
hlfir::Entity{hlfir::genDeclare(loc, builder, res, "tmp",
fir::FortranVariableFlagsAttr{})
.getBase()};
mlir::Value bufferizedExpr =
packageBufferizedExpr(loc, builder, hlfirTempRes, false);
rewriter.replaceOp(concat, bufferizedExpr);
Expand All @@ -251,7 +254,7 @@ struct SetLengthOpConversion
auto module = setLength->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
// Create a temp with the new length.
hlfir::Entity string{getBufferizedExprStorage(adaptor.getString())};
hlfir::Entity string = getBufferizedExprStorage(adaptor.getString());
auto charType = hlfir::getFortranElementType(setLength.getType());
llvm::StringRef tmpName{".tmp"};
llvm::SmallVector<mlir::Value, 1> lenParams{adaptor.getLength()};
Expand All @@ -260,10 +263,11 @@ struct SetLengthOpConversion
auto declareOp = builder.create<hlfir::DeclareOp>(
loc, alloca, tmpName, /*shape=*/mlir::Value{}, lenParams,
fir::FortranVariableFlagsAttr{});
hlfir::Entity temp{declareOp.getBase()};
// Assign string value to the created temp.
builder.create<hlfir::AssignOp>(loc, string, declareOp.getBase());
builder.create<hlfir::AssignOp>(loc, string, temp);
mlir::Value bufferizedExpr =
packageBufferizedExpr(loc, builder, alloca, false);
packageBufferizedExpr(loc, builder, temp, false);
rewriter.replaceOp(setLength, bufferizedExpr);
return mlir::success();
}
Expand Down Expand Up @@ -390,12 +394,12 @@ struct DestroyOpConversion
mlir::ConversionPatternRewriter &rewriter) const override {
// If expr was bufferized on the heap, now is time to deallocate the buffer.
mlir::Location loc = destroy->getLoc();
mlir::Value bufferizedExpr = getBufferizedExprStorage(adaptor.getExpr());
hlfir::Entity bufferizedExpr = getBufferizedExprStorage(adaptor.getExpr());
if (!fir::isa_trivial(bufferizedExpr.getType())) {
auto module = destroy->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
mlir::Value mustFree = getBufferizedExprMustFreeFlag(adaptor.getExpr());
mlir::Value firBase = hlfir::Entity(bufferizedExpr).getFirBase();
mlir::Value firBase = bufferizedExpr.getFirBase();
genFreeIfMustFree(loc, builder, firBase, mustFree);
}
rewriter.eraseOp(destroy);
Expand Down
8 changes: 4 additions & 4 deletions flang/test/HLFIR/set_length-codegen.fir
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func.func @test_cst_len(%str : !fir.boxchar<1>) {
// CHECK: %[[VAL_4:.*]] = arith.constant false
// CHECK: %[[VAL_5:.*]] = fir.undefined tuple<!fir.ref<!fir.char<1,10>>, i1>
// CHECK: %[[VAL_6:.*]] = fir.insert_value %[[VAL_5]], %[[VAL_4]], [1 : index] : (tuple<!fir.ref<!fir.char<1,10>>, i1>, i1) -> tuple<!fir.ref<!fir.char<1,10>>, i1>
// CHECK: %[[VAL_7:.*]] = fir.insert_value %[[VAL_6]], %[[VAL_1]], [0 : index] : (tuple<!fir.ref<!fir.char<1,10>>, i1>, !fir.ref<!fir.char<1,10>>) -> tuple<!fir.ref<!fir.char<1,10>>, i1>
// CHECK: %[[VAL_7:.*]] = fir.insert_value %[[VAL_6]], %[[VAL_3]]#0, [0 : index] : (tuple<!fir.ref<!fir.char<1,10>>, i1>, !fir.ref<!fir.char<1,10>>) -> tuple<!fir.ref<!fir.char<1,10>>, i1>

func.func @test_dyn_len(%str : !fir.ref<!fir.char<1,10>>, %len : index) {
%0 = hlfir.set_length %str len %len : (!fir.ref<!fir.char<1,10>>, index) -> !hlfir.expr<!fir.char<1,?>>
Expand All @@ -28,6 +28,6 @@ func.func @test_dyn_len(%str : !fir.ref<!fir.char<1,10>>, %len : index) {
// CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] typeparams %[[VAL_1]] {uniq_name = ".tmp"} : (!fir.ref<!fir.char<1,?>>, index) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
// CHECK: hlfir.assign %[[VAL_0]] to %[[VAL_3]]#0 : !fir.ref<!fir.char<1,10>>, !fir.boxchar<1>
// CHECK: %[[VAL_4:.*]] = arith.constant false
// CHECK: %[[VAL_5:.*]] = fir.undefined tuple<!fir.ref<!fir.char<1,?>>, i1>
// CHECK: %[[VAL_6:.*]] = fir.insert_value %[[VAL_5]], %[[VAL_4]], [1 : index] : (tuple<!fir.ref<!fir.char<1,?>>, i1>, i1) -> tuple<!fir.ref<!fir.char<1,?>>, i1>
// CHECK: %[[VAL_7:.*]] = fir.insert_value %[[VAL_6]], %[[VAL_2]], [0 : index] : (tuple<!fir.ref<!fir.char<1,?>>, i1>, !fir.ref<!fir.char<1,?>>) -> tuple<!fir.ref<!fir.char<1,?>>, i1>
// CHECK: %[[VAL_5:.*]] = fir.undefined tuple<!fir.boxchar<1>, i1>
// CHECK: %[[VAL_6:.*]] = fir.insert_value %[[VAL_5]], %[[VAL_4]], [1 : index] : (tuple<!fir.boxchar<1>, i1>, i1) -> tuple<!fir.boxchar<1>, i1>
// CHECK: %[[VAL_7:.*]] = fir.insert_value %[[VAL_6]], %[[VAL_3]]#0, [0 : index] : (tuple<!fir.boxchar<1>, i1>, !fir.boxchar<1>) -> tuple<!fir.boxchar<1>, i1>

0 comments on commit 88ac741

Please sign in to comment.