Skip to content

Commit

Permalink
[flang][LoopVersioning] support reboxed operands
Browse files Browse the repository at this point in the history
Since https://reviews.llvm.org/D158119, many boxes lowered via HLFIR are
reboxed with better lower bounds information after they are declared.

For the loop versioning pass to support FIR lowered via HLFIR, it needs
to dereference fir.rebox operations to figure out that the variable was
a function argument.

I decided to modify the existing dereferencing of fir.declare so that
the declared/reboxed value is used in the versioned loop instead of the
function argument. This makes it easier for the improved lower bounds
information to be accessed. In doing this, I changed ArgInfo to store
ArgInfo::arg by value instead of by pointer because mlir::Value has
value-type semantics.

Differential Revision: https://reviews.llvm.org/D158408
  • Loading branch information
tblah committed Aug 23, 2023
1 parent 762629b commit 8d24b73
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 12 deletions.
32 changes: 24 additions & 8 deletions flang/lib/Optimizer/Transforms/LoopVersioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,21 @@ static mlir::Value unwrapFirDeclare(mlir::Value val) {
return val;
}

/// if a value comes from a fir.rebox, follow the rebox to the original source,
/// of the value, otherwise return the value
static mlir::Value unwrapReboxOp(mlir::Value val) {
// don't support reboxes of reboxes
if (fir::ReboxOp rebox = val.getDefiningOp<fir::ReboxOp>())
val = rebox.getBox();
return val;
}

/// normalize a value (removing fir.declare and fir.rebox) so that we can
/// more conveniently spot values which came from function arguments
static mlir::Value normaliseVal(mlir::Value val) {
return unwrapFirDeclare(unwrapReboxOp(val));
}

void LoopVersioningPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
mlir::func::FuncOp func = getOperation();
Expand All @@ -112,7 +127,7 @@ void LoopVersioningPass::runOnOperation() {
/// A structure to hold an argument, the size of the argument and dimension
/// information.
struct ArgInfo {
mlir::Value *arg;
mlir::Value arg;
size_t size;
unsigned rank;
fir::BoxDimsOp dims[CFI_MAX_RANK];
Expand All @@ -138,7 +153,7 @@ void LoopVersioningPass::runOnOperation() {
else if (auto cty = elementType.dyn_cast<fir::ComplexType>())
typeSize = 2 * cty.getEleType(kindMap).getIntOrFloatBitWidth() / 8;
if (typeSize)
argsOfInterest.push_back({&arg, typeSize, rank, {}});
argsOfInterest.push_back({arg, typeSize, rank, {}});
else
LLVM_DEBUG(llvm::dbgs() << "Type not supported\n");
}
Expand Down Expand Up @@ -166,7 +181,9 @@ void LoopVersioningPass::runOnOperation() {
return;
mlir::Value operand = op->getOperand(0);
for (auto a : argsOfInterest) {
if (*a.arg == unwrapFirDeclare(operand)) {
if (a.arg == normaliseVal(operand)) {
// use the reboxed value, not the block arg when re-creating the loop:
a.arg = operand;
// Only add if it's not already in the list.
if (std::find_if(argsInLoop.begin(), argsInLoop.end(), [&](auto it) {
return it.arg == a.arg;
Expand Down Expand Up @@ -211,7 +228,7 @@ void LoopVersioningPass::runOnOperation() {
for (unsigned i = 0; i < ndims; i++) {
mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
arg.dims[i] = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy,
*arg.arg, dimIdx);
arg.arg, dimIdx);
}
// We only care about lowest order dimension, here.
mlir::Value elemSize =
Expand All @@ -238,11 +255,11 @@ void LoopVersioningPass::runOnOperation() {
for (auto &arg : op.argsAndDims) {
fir::SequenceType::Shape newShape;
newShape.push_back(fir::SequenceType::getUnknownExtent());
auto elementType = fir::unwrapSeqOrBoxedSeqType(arg.arg->getType());
auto elementType = fir::unwrapSeqOrBoxedSeqType(arg.arg.getType());
mlir::Type arrTy = fir::SequenceType::get(newShape, elementType);
mlir::Type boxArrTy = fir::BoxType::get(arrTy);
mlir::Type refArrTy = builder.getRefType(arrTy);
auto carg = builder.create<fir::ConvertOp>(loc, boxArrTy, *arg.arg);
auto carg = builder.create<fir::ConvertOp>(loc, boxArrTy, arg.arg);
auto caddr = builder.create<fir::BoxAddrOp>(loc, refArrTy, carg);
auto insPt = builder.saveInsertionPoint();
// Use caddr instead of arg.
Expand All @@ -254,8 +271,7 @@ void LoopVersioningPass::runOnOperation() {
// arr(x, y, z) bedcomes arr(z * stride(2) + y * stride(1) + x)
// where stride is the distance between elements in the dimensions
// 0, 1 and 2 or x, y and z.
if (unwrapFirDeclare(coop->getOperand(0)) == *arg.arg &&
coop->getOperands().size() >= 2) {
if (coop->getOperand(0) == arg.arg && coop->getOperands().size() >= 2) {
builder.setInsertionPoint(coop);
mlir::Value totalIndex;
for (unsigned i = arg.rank - 1; i > 0; i--) {
Expand Down
10 changes: 6 additions & 4 deletions flang/test/Transforms/loop-versioning.fir
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
module {
func.func @sum1d(%arg0: !fir.box<!fir.array<?xf64>> {fir.bindc_name = "a"}, %arg1: !fir.ref<i32> {fir.bindc_name = "n"}) {
%decl = fir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
%rebox = fir.rebox %decl : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
%0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMmoduleFsum1dEi"}
%1 = fir.alloca f64 {bindc_name = "sum", uniq_name = "_QMmoduleFsum1dEsum"}
%cst = arith.constant 0.000000e+00 : f64
Expand All @@ -31,7 +32,7 @@ module {
%9 = fir.convert %8 : (i32) -> i64
%c1_i64 = arith.constant 1 : i64
%10 = arith.subi %9, %c1_i64 : i64
%11 = fir.coordinate_of %decl, %10 : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
%11 = fir.coordinate_of %rebox, %10 : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
%12 = fir.load %11 : !fir.ref<f64>
%13 = arith.addf %7, %12 fastmath<contract> : f64
fir.store %13 to %1 : !fir.ref<f64>
Expand All @@ -49,12 +50,13 @@ module {
// CHECK-LABEL: func.func @sum1d(
// CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xf64>> {{.*}})
// CHECK: %[[DECL:.*]] = fir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
// CHECK: %[[REBOX:.*]] = fir.rebox %[[DECL]]
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %[[ZERO]] : {{.*}}
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[REBOX]], %[[ZERO]] : {{.*}}
// CHECK: %[[SIZE:.*]] = arith.constant 8 : index
// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[DIMS]]#2, %[[SIZE]]
// CHECK: %[[IF_RES:.*]]:2 = fir.if %[[CMP]] -> {{.*}}
// CHECK: %[[NEWARR:.*]] = fir.convert %[[ARG0]]
// CHECK: %[[NEWARR:.*]] = fir.convert %[[REBOX]]
// CHECK: %[[BOXADDR:.*]] = fir.box_addr %[[NEWARR]] : {{.*}} -> !fir.ref<!fir.array<?xf64>>
// CHECK: %[[LOOP_RES:.*]]:2 = fir.do_loop {{.*}}
// CHECK: %[[COORD:.*]] = fir.coordinate_of %[[BOXADDR]], %{{.*}} : (!fir.ref<!fir.array<?xf64>>, index) -> !fir.ref<f64>
Expand All @@ -64,7 +66,7 @@ module {
// CHECK fir.result %[[LOOP_RES]]#0, %[[LOOP_RES]]#1
// CHECK: } else {
// CHECK: %[[LOOP_RES2:.*]]:2 = fir.do_loop {{.*}}
// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[DECL]], %{{.*}} : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[REBOX]], %{{.*}} : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
// CHECK: %{{.*}}= fir.load %[[COORD2]] : !fir.ref<f64>
// CHECK: fir.result %{{.*}}, %{{.*}}
// CHECK: }
Expand Down

0 comments on commit 8d24b73

Please sign in to comment.