diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp index 5fdf9928b244b3..4603c6d0e256d2 100644 --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -24,7 +24,7 @@ // Return explicit extents. If the base is a fir.box, this won't read it to // return the extents and will instead return an empty vector. static llvm::SmallVector -getExplicitExtentsFromShape(mlir::Value shape) { +getExplicitExtentsFromShape(mlir::Value shape, fir::FirOpBuilder &builder) { llvm::SmallVector result; auto *shapeOp = shape.getDefiningOp(); if (auto s = mlir::dyn_cast_or_null(shapeOp)) { @@ -35,15 +35,23 @@ getExplicitExtentsFromShape(mlir::Value shape) { result.append(e.begin(), e.end()); } else if (mlir::dyn_cast_or_null(shapeOp)) { return {}; + } else if (auto s = mlir::dyn_cast_or_null(shapeOp)) { + fir::ShapeType shapeTy = shape.getType().cast(); + result.reserve(shapeTy.getRank()); + for (unsigned i = 0; i < shapeTy.getRank(); ++i) { + auto op = builder.create(shape.getLoc(), shape, i); + result.emplace_back(op.getResult()); + } } else { TODO(shape.getLoc(), "read fir.shape to get extents"); } return result; } static llvm::SmallVector -getExplicitExtents(fir::FortranVariableOpInterface var) { +getExplicitExtents(fir::FortranVariableOpInterface var, + fir::FirOpBuilder &builder) { if (mlir::Value shape = var.getShape()) - return getExplicitExtentsFromShape(var.getShape()); + return getExplicitExtentsFromShape(var.getShape(), builder); return {}; } @@ -385,7 +393,7 @@ hlfir::genBounds(mlir::Location loc, fir::FirOpBuilder &builder, assert((shape.getType().isa() || shape.getType().isa()) && "shape must contain extents"); - auto extents = getExplicitExtentsFromShape(shape); + auto extents = getExplicitExtentsFromShape(shape, builder); auto lowers = getExplicitLboundsFromShape(shape); assert(lowers.empty() || lowers.size() == extents.size()); mlir::Type idxTy = builder.getIndexType(); @@ -440,7 +448,7 @@ llvm::SmallVector getVariableExtents(mlir::Location loc, llvm::SmallVector extents; if (fir::FortranVariableOpInterface varIface = variable.getIfVariableInterface()) { - extents = getExplicitExtents(varIface); + extents = getExplicitExtents(varIface, builder); if (!extents.empty()) return extents; } @@ -493,7 +501,8 @@ mlir::Value hlfir::genShape(mlir::Location loc, fir::FirOpBuilder &builder, llvm::SmallVector hlfir::getIndexExtents(mlir::Location loc, fir::FirOpBuilder &builder, mlir::Value shape) { - llvm::SmallVector extents = getExplicitExtentsFromShape(shape); + llvm::SmallVector extents = + getExplicitExtentsFromShape(shape, builder); mlir::Type indexType = builder.getIndexType(); for (auto &extent : extents) extent = builder.createConvert(loc, indexType, extent); @@ -504,7 +513,7 @@ mlir::Value hlfir::genExtent(mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity entity, unsigned dim) { entity = followShapeInducingSource(entity); if (auto shape = tryRetrievingShapeOrShift(entity)) { - auto extents = getExplicitExtentsFromShape(shape); + auto extents = getExplicitExtentsFromShape(shape, builder); if (!extents.empty()) { assert(extents.size() > dim && "bad inquiry"); return extents[dim]; diff --git a/flang/test/HLFIR/extents-of-shape-of.f90 b/flang/test/HLFIR/extents-of-shape-of.f90 new file mode 100644 index 00000000000000..ff1a657dc0ea52 --- /dev/null +++ b/flang/test/HLFIR/extents-of-shape-of.f90 @@ -0,0 +1,31 @@ +! RUN: bbc -emit-fir -hlfir %s -o - | FileCheck %s +subroutine foo(a, b) + real :: a(:, :), b(:, :) + interface + elemental subroutine elem_sub(x) + real, intent(in) :: x + end subroutine + end interface + call elem_sub(matmul(a, b)) +end subroutine +! CHECK-LABEL: func.func @_QPfoo +! CHECK: %[[A_ARG:.*]]: !fir.box> {fir.bindc_name = "a"} +! CHECK: %[[B_ARG:.*]]: !fir.box> {fir.bindc_name = "b"} +! CHECK-DAG: %[[A_VAR:.*]]:2 = hlfir.declare %[[A_ARG]] +! CHECK-DAG: %[[B_VAR:.*]]:2 = hlfir.declare %[[B_ARG]] +! CHECK-NEXT: %[[MUL:.*]] = hlfir.matmul %[[A_VAR]]#0 %[[B_VAR]]#0 +! CHECK-NEXT: %[[SHAPE:.*]] = hlfir.shape_of %[[MUL]] : (!hlfir.expr) -> !fir.shape<2> +! CHECK-NEXT: %[[EXT0:.*]] = hlfir.get_extent %[[SHAPE]] {dim = 0 : index} : (!fir.shape<2>) -> index +! CHECK-NEXT: %[[EXT1:.*]] = hlfir.get_extent %[[SHAPE]] {dim = 1 : index} : (!fir.shape<2>) -> index +! CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +! CHECK-NEXT: fir.do_loop %[[ARG2:.*]] = %[[C1]] to %[[EXT1]] step %[[C1]] { +! CHECK-NEXT: fir.do_loop %[[ARG3:.*]] = %[[C1]] to %[[EXT0]] step %[[C1]] { +! CHECK-NEXT: %[[ELE:.*]] = hlfir.apply %[[MUL]], %[[ARG3]], %[[ARG2]] : (!hlfir.expr, index, index) -> f32 +! CHECK-NEXT: %[[ASSOC:.*]]:3 = hlfir.associate %[[ELE]] {uniq_name = "adapt.valuebyref"} : (f32) -> (!fir.ref, !fir.ref, i1) +! CHECK-NEXT: fir.call +! CHECK-NEXT: hlfir.end_associate +! CHECK-NEXT: } +! CHECK-NEXT: } +! CHECK-NEXT: hlfir.destroy %[[MUL]] +! CHECK-NEXT: return +! CHECK-NEXT: }