Skip to content

Commit

Permalink
[flang][openacc] Always lower bounds with lb, ub and stride information
Browse files Browse the repository at this point in the history
Since we have all the information while lowering, always
add the lowerbound, upperbound and stride information for acc.bounds
create from the Flang frontend.

Reviewed By: razvanlupusoru, jeanPerier

Differential Revision: https://reviews.llvm.org/D149704
  • Loading branch information
clementval committed May 3, 2023
1 parent fe710ff commit be09327
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 152 deletions.
55 changes: 35 additions & 20 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,23 +105,24 @@ genObjectList(const Fortran::parser::AccObjectList &objectList,
static llvm::SmallVector<mlir::Value>
genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymbolRef sym, mlir::Value box, int rank) {
fir::ExtendedValue dataExv, mlir::Value box) {
llvm::SmallVector<mlir::Value> bounds;
mlir::Type idxTy = builder.getIndexType();
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
mlir::Type boundTy = builder.getType<mlir::acc::DataBoundsType>();
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
assert(box.getType().isa<fir::BaseBoxType>() && "expect firbox or fir.class");
for (int dim = 0; dim < rank; ++dim) {
for (unsigned dim = 0; dim < dataExv.rank(); ++dim) {
mlir::Value d = builder.createIntegerConstant(loc, idxTy, dim);
mlir::Value baseLb =
fir::factory::readLowerBound(builder, loc, dataExv, dim, one);
auto dimInfo =
builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, d);
mlir::Value empty;
mlir::Value lb = builder.createIntegerConstant(loc, idxTy, 0);
mlir::Value ub =
builder.create<mlir::arith::SubIOp>(loc, dimInfo.getExtent(), one);
mlir::Value bound = builder.create<mlir::acc::DataBoundsOp>(
loc, boundTy, empty, empty, dimInfo.getExtent(),
dimInfo.getByteStride(), true, baseLb);
loc, boundTy, lb, ub, mlir::Value(), dimInfo.getByteStride(), true,
baseLb);
bounds.push_back(bound);
}
return bounds;
Expand All @@ -140,14 +141,19 @@ genBaseBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
if (dataExv.rank() == 0)
return bounds;

mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
for (std::size_t dim = 0; dim < dataExv.rank(); ++dim) {
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
mlir::Value startIdx =
mlir::Value baseLb =
fir::factory::readLowerBound(builder, loc, dataExv, dim, one);
mlir::Value extent = fir::factory::readExtent(builder, loc, dataExv, dim);
mlir::Value ext = fir::factory::readExtent(builder, loc, dataExv, dim);
mlir::Value lb =
baseLb == one ? builder.createIntegerConstant(loc, idxTy, 0) : baseLb;

// ub = baseLb + extent - 1
mlir::Value lbExt = builder.create<mlir::arith::AddIOp>(loc, ext, baseLb);
mlir::Value ub = builder.create<mlir::arith::SubIOp>(loc, lbExt, one);
mlir::Value bound = builder.create<mlir::acc::DataBoundsOp>(
loc, boundTy, mlir::Value(), mlir::Value(), extent, mlir::Value(),
false, startIdx);
loc, boundTy, lb, ub, mlir::Value(), one, false, baseLb);
bounds.push_back(bound);
}
return bounds;
Expand All @@ -167,18 +173,19 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type boundTy = builder.getType<mlir::acc::DataBoundsType>();
llvm::SmallVector<mlir::Value> bounds;

mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
for (const auto &subscript : subscripts) {
if (const auto *triplet{
std::get_if<Fortran::parser::SubscriptTriplet>(&subscript.u)}) {
if (dimension != 0)
asFortran << ',';
mlir::Value lbound, ubound, extent;
std::optional<std::int64_t> lval, uval;
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
mlir::Value baseLb =
fir::factory::readLowerBound(builder, loc, dataExv, dimension, one);
bool defaultLb = baseLb == one;
mlir::Value stride;
mlir::Value stride = one;
bool strideInBytes = false;

if (fir::unwrapRefType(baseAddr.getType()).isa<fir::BaseBoxType>()) {
Expand Down Expand Up @@ -209,6 +216,8 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
lbound = builder.create<mlir::arith::SubIOp>(loc, lb, baseLb);
asFortran << lexpr->AsFortran();
}
} else {
lbound = defaultLb ? zero : baseLb;
}
asFortran << ':';
const auto &upper{std::get<1>(triplet->t)};
Expand Down Expand Up @@ -245,10 +254,13 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
}
}
}
// ub = baseLb + extent - 1
if (!ubound) {
extent = fir::factory::readExtent(builder, loc, dataExv, dimension);
if (lbound)
extent = builder.create<mlir::arith::SubIOp>(loc, extent, lbound);
mlir::Value ext =
fir::factory::readExtent(builder, loc, dataExv, dimension);
mlir::Value lbExt =
builder.create<mlir::arith::AddIOp>(loc, ext, baseLb);
ubound = builder.create<mlir::arith::SubIOp>(loc, lbExt, one);
}
mlir::Value bound = builder.create<mlir::acc::DataBoundsOp>(
loc, boundTy, lbound, ubound, extent, stride, strideInBytes, baseLb);
Expand Down Expand Up @@ -378,8 +390,11 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
// genExprAddr will be the result of a fir.box_addr operation.
// Retrieve the box so we handle it like other descriptor.
if (auto boxAddrOp = mlir::dyn_cast_or_null<fir::BoxAddrOp>(
addr.getDefiningOp()))
addr.getDefiningOp())) {
addr = boxAddrOp.getVal();
bounds = genBoundsOpsFromBox(builder, operandLocation,
converter, compExv, addr);
}

createOpAndAddOperand(addr, (*expr).AsFortran(),
operandLocation, bounds);
Expand All @@ -396,9 +411,9 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
llvm::SmallVector<mlir::Value> bounds;
if (fir::unwrapRefType(baseAddr.getType())
.isa<fir::BaseBoxType>())
bounds = genBoundsOpsFromBox(builder, operandLocation,
converter, *name.symbol,
baseAddr, (*expr).Rank());
bounds =
genBoundsOpsFromBox(builder, operandLocation,
converter, dataExv, baseAddr);
else if (fir::unwrapRefType(baseAddr.getType())
.isa<fir::SequenceType>())
bounds = genBaseBoundsOps(builder, operandLocation,
Expand Down
Loading

0 comments on commit be09327

Please sign in to comment.