Skip to content

Commit 355dbbc

Browse files
authored
[flang][FIR] enable fir.box_addr codegen inside fir.global (#157120)
FIR lowering of the fir.box type inside fir.global is special (it is an actual descriptor struct value instead of being a descriptor in memory) and causes builtin.unrealized_conversion_cast to be inserted under the hood by MLIR dialect conversion framework after each operation producing a fir.box is translated. These builtin.unrealized_conversion_cast must be removed before the code generation of operation of using the fir.box in order to get the right "by value" code generation required in global initial value definitions.
1 parent 05e3143 commit 355dbbc

File tree

2 files changed

+52
-19
lines changed

2 files changed

+52
-19
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,31 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
329329
} // namespace
330330

331331
namespace {
332+
333+
static bool isInGlobalOp(mlir::ConversionPatternRewriter &rewriter) {
334+
auto *thisBlock = rewriter.getInsertionBlock();
335+
return thisBlock && mlir::isa<mlir::LLVM::GlobalOp>(thisBlock->getParentOp());
336+
}
337+
338+
// Inside a fir.global, the input box was produced as an llvm.struct<>
339+
// because objects cannot be handled in memory inside a fir.global body that
340+
// must be constant foldable. However, the type translation are not
341+
// contextual, so the fir.box<T> type of the operation that produced the
342+
// fir.box was translated to an llvm.ptr<llvm.struct<>> and the MLIR pass
343+
// manager inserted a builtin.unrealized_conversion_cast that was inserted
344+
// and needs to be removed here.
345+
// This should be called by any pattern operating on operations that are
346+
// accepting fir.box inputs and are used in fir.global.
347+
static mlir::Value
348+
fixBoxInputInsideGlobalOp(mlir::ConversionPatternRewriter &rewriter,
349+
mlir::Value box) {
350+
if (isInGlobalOp(rewriter))
351+
if (auto unrealizedCast =
352+
box.getDefiningOp<mlir::UnrealizedConversionCastOp>())
353+
return unrealizedCast.getInputs()[0];
354+
return box;
355+
}
356+
332357
/// Lower `fir.box_addr` to the sequence of operations to extract the first
333358
/// element of the box.
334359
struct BoxAddrOpConversion : public fir::FIROpConversion<fir::BoxAddrOp> {
@@ -341,6 +366,7 @@ struct BoxAddrOpConversion : public fir::FIROpConversion<fir::BoxAddrOp> {
341366
auto loc = boxaddr.getLoc();
342367
if (auto argty =
343368
mlir::dyn_cast<fir::BaseBoxType>(boxaddr.getVal().getType())) {
369+
a = fixBoxInputInsideGlobalOp(rewriter, a);
344370
TypePair boxTyPair = getBoxTypePair(argty);
345371
rewriter.replaceOp(boxaddr,
346372
getBaseAddrFromBox(loc, boxTyPair, a, rewriter));
@@ -1737,12 +1763,6 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
17371763
xbox.getSubcomponent().size());
17381764
}
17391765

1740-
static bool isInGlobalOp(mlir::ConversionPatternRewriter &rewriter) {
1741-
auto *thisBlock = rewriter.getInsertionBlock();
1742-
return thisBlock &&
1743-
mlir::isa<mlir::LLVM::GlobalOp>(thisBlock->getParentOp());
1744-
}
1745-
17461766
/// If the embox is not in a globalOp body, allocate storage for the box;
17471767
/// store the value inside and return the generated alloca. Return the input
17481768
/// value otherwise.
@@ -2076,21 +2096,10 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
20762096
mlir::ConversionPatternRewriter &rewriter) const override {
20772097
mlir::Location loc = rebox.getLoc();
20782098
mlir::Type idxTy = lowerTy().indexType();
2079-
mlir::Value loweredBox = adaptor.getOperands()[0];
2099+
mlir::Value loweredBox =
2100+
fixBoxInputInsideGlobalOp(rewriter, adaptor.getBox());
20802101
mlir::ValueRange operands = adaptor.getOperands();
20812102

2082-
// Inside a fir.global, the input box was produced as an llvm.struct<>
2083-
// because objects cannot be handled in memory inside a fir.global body that
2084-
// must be constant foldable. However, the type translation are not
2085-
// contextual, so the fir.box<T> type of the operation that produced the
2086-
// fir.box was translated to an llvm.ptr<llvm.struct<>> and the MLIR pass
2087-
// manager inserted a builtin.unrealized_conversion_cast that was inserted
2088-
// and needs to be removed here.
2089-
if (isInGlobalOp(rewriter))
2090-
if (auto unrealizedCast =
2091-
loweredBox.getDefiningOp<mlir::UnrealizedConversionCastOp>())
2092-
loweredBox = unrealizedCast.getInputs()[0];
2093-
20942103
TypePair inputBoxTyPair = getBoxTypePair(rebox.getBox().getType());
20952104

20962105
// Create new descriptor and fill its non-shape related data.
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Test codegen of fir.box_addr inside fir.global
2+
// RUN: tco %s | FileCheck %s
3+
4+
fir.global @x_addr constant : !fir.type<sometype{p:i64}> {
5+
%c-1 = arith.constant -1 : index
6+
%c5 = arith.constant 5 : index
7+
%c3 = arith.constant 3 : index
8+
%c-3 = arith.constant -3 : index
9+
%c2 = arith.constant 2 : index
10+
%c1 = arith.constant 1 : index
11+
%0 = fir.undefined !fir.type<sometype{p:i64}>
12+
%1 = fir.address_of(@_QFEx) : !fir.ref<!fir.array<2x3x5x!fir.type<_QFTt1{c:i32}>>>
13+
%2 = fir.shape_shift %c1, %c2, %c-3, %c3, %c1, %c5 : (index, index, index, index, index, index) -> !fir.shapeshift<3>
14+
%3 = fir.field_index c, !fir.type<_QFTt1{c:i32}>
15+
%4 = fir.slice %c1, %c2, %c1, %c-3, %c-1, %c1, %c1, %c5, %c1 path %3 : (index, index, index, index, index, index, index, index, index, !fir.field) -> !fir.slice<3>
16+
%5 = fir.embox %1(%2) [%4] : (!fir.ref<!fir.array<2x3x5x!fir.type<_QFTt1{c:i32}>>>, !fir.shapeshift<3>, !fir.slice<3>) -> !fir.box<!fir.ref<!fir.array<2x3x5xi32>>>
17+
%6 = fir.box_addr %5 : (!fir.box<!fir.ref<!fir.array<2x3x5xi32>>>) -> !fir.ref<!fir.array<2x3x5xi32>>
18+
%7 = fir.convert %6 : (!fir.ref<!fir.array<2x3x5xi32>>) -> i64
19+
%8 = fir.insert_value %0, %7, ["p", !fir.type<sometype{p:i64}>] : (!fir.type<sometype{p:i64}>, i64) -> !fir.type<sometype{p:i64}>
20+
fir.has_value %8 : !fir.type<sometype{p:i64}>
21+
}
22+
fir.global @_QFEx target : !fir.array<2x3x5x!fir.type<_QFTt1{c:i32}>>
23+
24+
// CHECK: @x_addr = constant %sometype { i64 ptrtoint (ptr @_QFEx to i64) }

0 commit comments

Comments
 (0)