Skip to content

Commit

Permalink
[flang][hlfir] get extents from hlfir.shape_of
Browse files Browse the repository at this point in the history
If the extents were known, this should have been canonicalised into a
fir.shape operation. Therefore, the extents at this point are not known at
compile time. Use hlfir.get_extents to delay resolving the real extent
until after the expression is bufferized.

Depends On: D146831

Differential Revision: https://reviews.llvm.org/D146832
  • Loading branch information
tblah committed Apr 17, 2023
1 parent 08b09d7 commit 5ab5cdc
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
23 changes: 16 additions & 7 deletions flang/lib/Optimizer/Builder/HLFIRTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
// Return explicit extents. If the base is a fir.box, this won't read it to
// return the extents and will instead return an empty vector.
static llvm::SmallVector<mlir::Value>
getExplicitExtentsFromShape(mlir::Value shape) {
getExplicitExtentsFromShape(mlir::Value shape, fir::FirOpBuilder &builder) {
llvm::SmallVector<mlir::Value> result;
auto *shapeOp = shape.getDefiningOp();
if (auto s = mlir::dyn_cast_or_null<fir::ShapeOp>(shapeOp)) {
Expand All @@ -35,15 +35,23 @@ getExplicitExtentsFromShape(mlir::Value shape) {
result.append(e.begin(), e.end());
} else if (mlir::dyn_cast_or_null<fir::ShiftOp>(shapeOp)) {
return {};
} else if (auto s = mlir::dyn_cast_or_null<hlfir::ShapeOfOp>(shapeOp)) {
fir::ShapeType shapeTy = shape.getType().cast<fir::ShapeType>();
result.reserve(shapeTy.getRank());
for (unsigned i = 0; i < shapeTy.getRank(); ++i) {
auto op = builder.create<hlfir::GetExtentOp>(shape.getLoc(), shape, i);
result.emplace_back(op.getResult());
}
} else {
TODO(shape.getLoc(), "read fir.shape to get extents");
}
return result;
}
static llvm::SmallVector<mlir::Value>
getExplicitExtents(fir::FortranVariableOpInterface var) {
getExplicitExtents(fir::FortranVariableOpInterface var,
fir::FirOpBuilder &builder) {
if (mlir::Value shape = var.getShape())
return getExplicitExtentsFromShape(var.getShape());
return getExplicitExtentsFromShape(var.getShape(), builder);
return {};
}

Expand Down Expand Up @@ -385,7 +393,7 @@ hlfir::genBounds(mlir::Location loc, fir::FirOpBuilder &builder,
assert((shape.getType().isa<fir::ShapeShiftType>() ||
shape.getType().isa<fir::ShapeType>()) &&
"shape must contain extents");
auto extents = getExplicitExtentsFromShape(shape);
auto extents = getExplicitExtentsFromShape(shape, builder);
auto lowers = getExplicitLboundsFromShape(shape);
assert(lowers.empty() || lowers.size() == extents.size());
mlir::Type idxTy = builder.getIndexType();
Expand Down Expand Up @@ -440,7 +448,7 @@ llvm::SmallVector<mlir::Value> getVariableExtents(mlir::Location loc,
llvm::SmallVector<mlir::Value> extents;
if (fir::FortranVariableOpInterface varIface =
variable.getIfVariableInterface()) {
extents = getExplicitExtents(varIface);
extents = getExplicitExtents(varIface, builder);
if (!extents.empty())
return extents;
}
Expand Down Expand Up @@ -493,7 +501,8 @@ mlir::Value hlfir::genShape(mlir::Location loc, fir::FirOpBuilder &builder,
llvm::SmallVector<mlir::Value>
hlfir::getIndexExtents(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::Value shape) {
llvm::SmallVector<mlir::Value> extents = getExplicitExtentsFromShape(shape);
llvm::SmallVector<mlir::Value> extents =
getExplicitExtentsFromShape(shape, builder);
mlir::Type indexType = builder.getIndexType();
for (auto &extent : extents)
extent = builder.createConvert(loc, indexType, extent);
Expand All @@ -504,7 +513,7 @@ mlir::Value hlfir::genExtent(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity entity, unsigned dim) {
entity = followShapeInducingSource(entity);
if (auto shape = tryRetrievingShapeOrShift(entity)) {
auto extents = getExplicitExtentsFromShape(shape);
auto extents = getExplicitExtentsFromShape(shape, builder);
if (!extents.empty()) {
assert(extents.size() > dim && "bad inquiry");
return extents[dim];
Expand Down
31 changes: 31 additions & 0 deletions flang/test/HLFIR/extents-of-shape-of.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
! RUN: bbc -emit-fir -hlfir %s -o - | FileCheck %s
subroutine foo(a, b)
real :: a(:, :), b(:, :)
interface
elemental subroutine elem_sub(x)
real, intent(in) :: x
end subroutine
end interface
call elem_sub(matmul(a, b))
end subroutine
! CHECK-LABEL: func.func @_QPfoo
! CHECK: %[[A_ARG:.*]]: !fir.box<!fir.array<?x?xf32>> {fir.bindc_name = "a"}
! CHECK: %[[B_ARG:.*]]: !fir.box<!fir.array<?x?xf32>> {fir.bindc_name = "b"}
! CHECK-DAG: %[[A_VAR:.*]]:2 = hlfir.declare %[[A_ARG]]
! CHECK-DAG: %[[B_VAR:.*]]:2 = hlfir.declare %[[B_ARG]]
! CHECK-NEXT: %[[MUL:.*]] = hlfir.matmul %[[A_VAR]]#0 %[[B_VAR]]#0
! CHECK-NEXT: %[[SHAPE:.*]] = hlfir.shape_of %[[MUL]] : (!hlfir.expr<?x?xf32>) -> !fir.shape<2>
! CHECK-NEXT: %[[EXT0:.*]] = hlfir.get_extent %[[SHAPE]] {dim = 0 : index} : (!fir.shape<2>) -> index
! CHECK-NEXT: %[[EXT1:.*]] = hlfir.get_extent %[[SHAPE]] {dim = 1 : index} : (!fir.shape<2>) -> index
! CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
! CHECK-NEXT: fir.do_loop %[[ARG2:.*]] = %[[C1]] to %[[EXT1]] step %[[C1]] {
! CHECK-NEXT: fir.do_loop %[[ARG3:.*]] = %[[C1]] to %[[EXT0]] step %[[C1]] {
! CHECK-NEXT: %[[ELE:.*]] = hlfir.apply %[[MUL]], %[[ARG3]], %[[ARG2]] : (!hlfir.expr<?x?xf32>, index, index) -> f32
! CHECK-NEXT: %[[ASSOC:.*]]:3 = hlfir.associate %[[ELE]] {uniq_name = "adapt.valuebyref"} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
! CHECK-NEXT: fir.call
! CHECK-NEXT: hlfir.end_associate
! CHECK-NEXT: }
! CHECK-NEXT: }
! CHECK-NEXT: hlfir.destroy %[[MUL]]
! CHECK-NEXT: return
! CHECK-NEXT: }

0 comments on commit 5ab5cdc

Please sign in to comment.