Skip to content

Commit

Permalink
[flang][openacc] Support array section with non constant bounds
Browse files Browse the repository at this point in the history
Add lowering for non constant lower and upper bounds
in array section.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D148840
  • Loading branch information
clementval committed Apr 24, 2023
1 parent 398d68f commit 9f548c1
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 25 deletions.
31 changes: 23 additions & 8 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,15 @@ genObjectList(const Fortran::parser::AccObjectList &objectList,
static llvm::SmallVector<mlir::Value>
genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext &stmtCtx,
const std::list<Fortran::parser::SectionSubscript> &subscripts,
std::stringstream &asFortran, const Fortran::parser::Name &name) {
int dimension = 0;
mlir::Type idxTy = builder.getIndexType();
mlir::Type boundTy = builder.getType<mlir::acc::DataBoundsType>();
llvm::SmallVector<mlir::Value> bounds;
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(*name.symbol);

for (const auto &subscript : subscripts) {
if (const auto *triplet{
std::get_if<Fortran::parser::SubscriptTriplet>(&subscript.u)}) {
Expand All @@ -136,7 +139,13 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
}
asFortran << *lval;
} else {
TODO(loc, "non constant lower bound in array section");
const Fortran::lower::SomeExpr *lexpr =
Fortran::semantics::GetExpr(*lower);
mlir::Value lb =
fir::getBase(converter.genExprValue(loc, *lexpr, stmtCtx));
lb = builder.createConvert(loc, baseLb.getType(), lb);
lbound = builder.create<mlir::arith::SubIOp>(loc, lb, baseLb);
asFortran << lexpr->AsFortran();
}
}
asFortran << ':';
Expand All @@ -152,7 +161,13 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
}
asFortran << *uval;
} else {
TODO(loc, "non constant upper bound in array section");
const Fortran::lower::SomeExpr *uexpr =
Fortran::semantics::GetExpr(*upper);
mlir::Value ub =
fir::getBase(converter.genExprValue(loc, *uexpr, stmtCtx));
ub = builder.createConvert(loc, baseLb.getType(), ub);
ubound = builder.create<mlir::arith::SubIOp>(loc, ub, baseLb);
asFortran << uexpr->AsFortran();
}
}
if (lower && upper) {
Expand All @@ -169,9 +184,9 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
}
}
if (!ubound) {
mlir::Value ext =
fir::factory::readExtent(builder, loc, dataExv, dimension);
extent = builder.create<mlir::arith::SubIOp>(loc, ext, baseLb);
extent = fir::factory::readExtent(builder, loc, dataExv, dimension);
if (lbound)
extent = builder.create<mlir::arith::SubIOp>(loc, extent, lbound);
}
mlir::Value empty;
mlir::Value bound = builder.create<mlir::acc::DataBoundsOp>(
Expand Down Expand Up @@ -243,9 +258,9 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
asFortran << name.ToString();
if (!arrayElement->subscripts.empty()) {
asFortran << '(';
bounds =
genBoundsOps(builder, operandLocation, converter,
arrayElement->subscripts, asFortran, name);
bounds = genBoundsOps(builder, operandLocation, converter,
stmtCtx, arrayElement->subscripts,
asFortran, name);
}
asFortran << ')';
Op op = createOpAndAddOperand(*name.symbol, asFortran.str(),
Expand Down
70 changes: 53 additions & 17 deletions flang/test/Lower/OpenACC/acc-enter-data.f90
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ subroutine acc_enter_data

!$acc enter data copyin(a(1:,1:5))
!CHECK: %[[LB1:.*]] = arith.constant 0 : index
!CHECK: %[[EXTENT:.*]] = arith.subi %[[EXTENT_C10:.*]], %c1{{.*}} : index
!CHECK: %[[EXTENT:.*]] = arith.subi %[[EXTENT_C10:.*]], %c0{{.*}} : index
!CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB1]] : index) extent(%[[EXTENT]] : index) startIdx(%c1{{.*}} : index)
!CHECK: %[[LB2:.*]] = arith.constant 0 : index
!CHECK: %[[UB2:.*]] = arith.constant 4 : index
Expand All @@ -108,32 +108,44 @@ subroutine acc_enter_data
!CHECK: acc.enter_data dataOperands(%[[COPYIN_A]] : !fir.ref<!fir.array<10x10xf32>>)

!$acc enter data copyin(a(:10,1:5))
!CHECK: %[[ONE:.*]] = arith.constant 1 : index
!CHECK: %[[UB1:.*]] = arith.constant 9 : index
!CHECK: %[[BOUND1:.*]] = acc.bounds upperbound(%[[UB1]] : index)
!CHECK: %[[LB2:.*]] = arith.constant 0 : index
!CHECK: %[[BOUND1:.*]] = acc.bounds upperbound(%[[UB1]] : index) startIdx(%[[ONE]] : index)
!CHECK: %[[ONE:.*]] = arith.constant 1 : index
!CHECK: %[[LB2:.*]] = arith.constant 0 : index
!CHECK: %[[UB2:.*]] = arith.constant 4 : index
!CHECK: %[[BOUND2:.*]] = acc.bounds lowerbound(%[[LB2]] : index) upperbound(%[[UB2]] : index) startIdx(%c1{{.*}} : index)
!CHECK: %[[COPYIN_A:.*]] = acc.copyin varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%[[BOUND1]], %[[BOUND2]]) -> !fir.ref<!fir.array<10x10xf32>> {name = "a(:10,1:5)", structured = false}
!CHECK: acc.enter_data dataOperands(%[[COPYIN_A]] : !fir.ref<!fir.array<10x10xf32>>)

!$acc enter data copyin(a(:,:))
!CHECK: %[[C1:.*]] = arith.constant 1 : index
!CHECK: %[[EXT:.*]] = arith.subi %c10{{.*}}, %[[C1]] : index
!CHECK: %[[BOUND1:.*]] = acc.bounds extent(%[[EXT]] : index) startIdx(%[[C1]] : index)
!CHECK: %[[C1:.*]] = arith.constant 1 : index
!CHECK: %[[EXT:.*]] = arith.subi %c10{{.*}}, %[[C1]] : index
!CHECK: %[[BOUND2:.*]] = acc.bounds extent(%[[EXT]] : index) startIdx(%[[C1]] : index)
!CHECK: %[[ONE:.*]] = arith.constant 1 : index
!CHECK: %[[BOUND1:.*]] = acc.bounds extent(%c10{{.*}} : index) startIdx(%[[ONE]] : index)
!CHECK: %[[ONE:.*]] = arith.constant 1 : index
!CHECK: %[[BOUND2:.*]] = acc.bounds extent(%c10{{.*}} : index) startIdx(%[[ONE]] : index)
!CHECK: %[[COPYIN_A:.*]] = acc.copyin varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%[[BOUND1]], %[[BOUND2]]) -> !fir.ref<!fir.array<10x10xf32>> {name = "a(:,:)", structured = false}
!CHECK: acc.enter_data dataOperands(%[[COPYIN_A]] : !fir.ref<!fir.array<10x10xf32>>)

end subroutine acc_enter_data


subroutine acc_enter_data_dummy(a)
subroutine acc_enter_data_dummy(a, b, n, m)
integer :: n, m
real :: a(1:10)
real :: b(n:m)

!CHECK-LABEL: func.func @_QPacc_enter_data_dummy
!CHECK-SAME: %[[A:.*]]: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "a"}
!CHECK-SAME: %[[A:.*]]: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "b"}, %[[N:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}, %[[M:.*]]: !fir.ref<i32> {fir.bindc_name = "m"}

!CHECK: %[[LOAD_N:.*]] = fir.load %[[N]] : !fir.ref<i32>
!CHECK: %[[N_I64:.*]] = fir.convert %[[LOAD_N]] : (i32) -> i64
!CHECK: %[[N_IDX:.*]] = fir.convert %[[N_I64]] : (i64) -> index
!CHECK: %[[LOAD_M:.*]] = fir.load %[[M]] : !fir.ref<i32>
!CHECK: %[[M_I64:.*]] = fir.convert %[[LOAD_M]] : (i32) -> i64
!CHECK: %[[M_IDX:.*]] = fir.convert %[[M_I64]] : (i64) -> index
!CHECK: %[[M_N:.*]] = arith.subi %[[M_IDX]], %[[N_IDX]] : index
!CHECK: %[[C1:.*]] = arith.constant 1 : index
!CHECK: %[[M_N_1:.*]] = arith.addi %[[M_N]], %[[C1]] : index
!CHECK: %[[C0:.*]] = arith.constant 0 : index
!CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[M_N_1]], %[[C0]] : index
!CHECK: %[[EXT_B:.*]] = arith.select %[[CMP]], %[[M_N_1]], %[[C0]] : index

!$acc enter data create(a(5:10))
!CHECK: %[[LB1:.*]] = arith.constant 4 : index
Expand All @@ -142,6 +154,31 @@ subroutine acc_enter_data_dummy(a)
!CHECK: %[[CREATE1:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10xf32>>) bounds(%[[BOUND1]]) -> !fir.ref<!fir.array<10xf32>> {name = "a(5:10)", structured = false}
!CHECK: acc.enter_data dataOperands(%[[CREATE1]] : !fir.ref<!fir.array<10xf32>>)

!$acc enter data create(b(n:m))
!CHECK: %[[LOAD_N:.*]] = fir.load %[[N]] : !fir.ref<i32>
!CHECK: %[[CONVERT_N:.*]] = fir.convert %[[LOAD_N]] : (i32) -> index
!CHECK: %[[LB:.*]] = arith.subi %[[CONVERT_N]], %[[N_IDX]] : index
!CHECK: %[[LOAD_M:.*]] = fir.load %[[M]] : !fir.ref<i32>
!CHECK: %[[CONVERT_M:.*]] = fir.convert %[[LOAD_M]] : (i32) -> index
!CHECK: %[[UB:.*]] = arith.subi %[[CONVERT_M]], %[[N_IDX]] : index
!CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) startIdx(%[[N_IDX]] : index)
!CHECK: %[[CREATE1:.*]] = acc.create varPtr(%[[B]] : !fir.ref<!fir.array<?xf32>>) bounds(%[[BOUND1]]) -> !fir.ref<!fir.array<?xf32>> {name = "b(n:m)", structured = false}
!CHECK: acc.enter_data dataOperands(%[[CREATE1]] : !fir.ref<!fir.array<?xf32>>)

!$acc enter data create(b(n:))
!CHECK: %[[LOAD_N:.*]] = fir.load %[[N]] : !fir.ref<i32>
!CHECK: %[[CONVERT_N:.*]] = fir.convert %[[LOAD_N]] : (i32) -> index
!CHECK: %[[LB:.*]] = arith.subi %[[CONVERT_N]], %[[N_IDX]] : index
!CHECK: %[[EXT:.*]] = arith.subi %[[EXT_B]], %[[LB]] : index
!CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB]] : index) extent(%[[EXT]] : index) startIdx(%[[N_IDX]] : index)
!CHECK: %[[CREATE1:.*]] = acc.create varPtr(%[[B]] : !fir.ref<!fir.array<?xf32>>) bounds(%[[BOUND1]]) -> !fir.ref<!fir.array<?xf32>> {name = "b(n:)", structured = false}
!CHECK: acc.enter_data dataOperands(%[[CREATE1]] : !fir.ref<!fir.array<?xf32>>)

!$acc enter data create(b(:))
!CHECK: %[[BOUND1:.*]] = acc.bounds extent(%[[EXT_B]] : index) startIdx(%[[N_IDX]] : index)
!CHECK: %[[CREATE1:.*]] = acc.create varPtr(%[[B]] : !fir.ref<!fir.array<?xf32>>) bounds(%[[BOUND1]]) -> !fir.ref<!fir.array<?xf32>> {name = "b(:)", structured = false}
!CHECK: acc.enter_data dataOperands(%[[CREATE1]] : !fir.ref<!fir.array<?xf32>>)

end subroutine

! Test lowering of array section for non default lower bound.
Expand All @@ -162,8 +199,7 @@ subroutine acc_enter_data_non_default_lb()
!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref<!fir.array<10xi32>>)

!$acc enter data create(a(:))
!CHECK: %[[EXT:.*]] = arith.subi %c10{{.*}}, %[[BASELB]] : index
!CHECK: %[[BOUND:.*]] = acc.bounds extent(%[[EXT]] : index) startIdx(%[[BASELB]] : index)
!CHECK: %[[BOUND:.*]] = acc.bounds extent(%c10{{.*}} : index) startIdx(%[[BASELB]] : index)
!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<10xi32>> {name = "a(:)", structured = false}
!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref<!fir.array<10xi32>>)

Expand All @@ -177,7 +213,7 @@ subroutine acc_enter_data_non_default_lb()
!$acc enter data create(a(4:))
!CHECK: %[[SECTIONLB:.*]] = arith.constant 4 : index
!CHECK: %[[LB:.*]] = arith.subi %[[SECTIONLB]], %[[BASELB]] : index
!CHECK: %[[EXT:.*]] = arith.subi %c10{{.*}}, %[[BASELB]] : index
!CHECK: %[[EXT:.*]] = arith.subi %c10{{.*}}, %[[LB]] : index
!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) extent(%[[EXT]] : index) startIdx(%[[BASELB]] : index)
!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<10xi32>> {name = "a(4:)", structured = false}
!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref<!fir.array<10xi32>>)
Expand Down

0 comments on commit 9f548c1

Please sign in to comment.