Skip to content

Commit

Permalink
[flang][hlfir] add hlfir.shape_of
Browse files Browse the repository at this point in the history
This is an operation which returns the fir.shape for a hlfir.expr.

A hlfir.expr can be defined by:
  - A transformational intrinsic (e.g. hlfir.matmul)
  - hlfir.as_expr
  - hlfir.elemental

hlfir.elemental is easy because there is a compulsory shape operand.
hlfir.as_expr is defined as operating on a variable (defined using a
hlfir.declare). hlfir.declare has an optional shape argument. The
transformational intrinsics do not have an associated shape.

If all extents are known at compile time, the extents for the shape can
be fetched from the hlfir.expr's type. For example, the result of a
hlfir.matmul with arguments who's extents are known at compile time will
have constant extents which can be queried from the type. In this case
the hlfir.shape_of will be canonicalised to a fir.shape operation using
those extents.

If not all extents are known at compile time, shapes have to be read
from boxes after bufferization. In the case of the transformational
intrinsics, the shape read from the result box can be queried from the
hlfir.declare operation for the buffer allocated to that hlfir.expr (via
the hlfir.as_expr).

Differential Revision: https://reviews.llvm.org/D146830
  • Loading branch information
tblah committed Apr 17, 2023
1 parent 5b22df9 commit 5c92507
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 0 deletions.
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ bool isI1Type(mlir::Type);
// scalar i1 or logical, or sequence of logical (via (boxed?) array or expr)
bool isMaskArgument(mlir::Type);

/// If an expression's extents are known at compile time, generate a fir.shape
/// for this expression. Otherwise return {}
mlir::Value genExprShape(mlir::OpBuilder &builder, const mlir::Location &loc,
const hlfir::ExprType &expr);

} // namespace hlfir

#endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
30 changes: 30 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -763,4 +763,34 @@ def hlfir_CopyOutOp : hlfir_Op<"copy_out", []> {
}];
}

def hlfir_ShapeOfOp : hlfir_Op<"shape_of", [Pure]> {
let summary = "Get the shape of a hlfir.expr";
let description = [{
Gets the runtime shape of a hlfir.expr. In lowering to FIR, the
hlfir.shape_of operation will be replaced by an fir.shape.
It is not valid to request the shape of a hlfir.expr which has no shape.
}];

let arguments = (ins hlfir_ExprType:$expr);

let results = (outs fir_ShapeType);

let hasVerifier = 1;

// If all extents are known at compile time, the hlfir.shape_of can be
// immediately folded into a fir.shape operation. This makes information
// available sooner to inform bufferization decisions
let hasCanonicalizeMethod = 1;

let extraClassDeclaration = [{
std::size_t getRank();
}];

let assemblyFormat = [{
$expr attr-dict `:` functional-type(operands, results)
}];

let builders = [OpBuilder<(ins "mlir::Value":$expr)>];
}

#endif // FORTRAN_DIALECT_HLFIR_OPS
23 changes: 23 additions & 0 deletions flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"

#include "flang/Optimizer/HLFIR/HLFIRDialect.cpp.inc"
Expand Down Expand Up @@ -158,3 +161,23 @@ bool hlfir::isMaskArgument(mlir::Type type) {
// input is a scalar, so allow i1 too
return mlir::isa<fir::LogicalType>(elementType) || isI1Type(elementType);
}

mlir::Value hlfir::genExprShape(mlir::OpBuilder &builder,
const mlir::Location &loc,
const hlfir::ExprType &expr) {
mlir::IndexType indexTy = builder.getIndexType();
llvm::SmallVector<mlir::Value> extents;
extents.reserve(expr.getRank());

for (std::int64_t extent : expr.getShape()) {
if (extent == hlfir::ExprType::getUnknownExtent())
return {};
extents.emplace_back(builder.create<mlir::arith::ConstantOp>(
loc, indexTy, builder.getIntegerAttr(indexTy, extent)));
}

fir::ShapeType shapeTy =
fir::ShapeType::get(builder.getContext(), expr.getRank());
fir::ShapeOp shape = builder.create<fir::ShapeOp>(loc, shapeTy, extents);
return shape.getResult();
}
51 changes: 51 additions & 0 deletions flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -886,5 +887,55 @@ void hlfir::CopyInOp::build(mlir::OpBuilder &builder,
var_is_present);
}

//===----------------------------------------------------------------------===//
// ShapeOfOp
//===----------------------------------------------------------------------===//

void hlfir::ShapeOfOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result, mlir::Value expr) {
hlfir::ExprType exprTy = expr.getType().cast<hlfir::ExprType>();
mlir::Type type = fir::ShapeType::get(builder.getContext(), exprTy.getRank());
build(builder, result, type, expr);
}

std::size_t hlfir::ShapeOfOp::getRank() {
mlir::Type resTy = getResult().getType();
fir::ShapeType shape = resTy.cast<fir::ShapeType>();
return shape.getRank();
}

mlir::LogicalResult hlfir::ShapeOfOp::verify() {
mlir::Value expr = getExpr();
hlfir::ExprType exprTy = expr.getType().cast<hlfir::ExprType>();
std::size_t exprRank = exprTy.getShape().size();

if (exprRank == 0)
return emitOpError("cannot get the shape of a shape-less expression");

std::size_t shapeRank = getRank();
if (shapeRank != exprRank)
return emitOpError("result rank and expr rank do not match");

return mlir::success();
}

mlir::LogicalResult
hlfir::ShapeOfOp::canonicalize(ShapeOfOp shapeOf,
mlir::PatternRewriter &rewriter) {
// if extent information is available at compile time, immediately fold the
// hlfir.shape_of into a fir.shape
mlir::Location loc = shapeOf.getLoc();
hlfir::ExprType expr = shapeOf.getExpr().getType().cast<hlfir::ExprType>();

mlir::Value shape = hlfir::genExprShape(rewriter, loc, expr);
if (!shape)
// shape information is not available at compile time
return mlir::LogicalResult::failure();

rewriter.replaceAllUsesWith(shapeOf.getResult(), shape);
rewriter.eraseOp(shapeOf);
return mlir::LogicalResult::success();
}

#define GET_OP_CLASSES
#include "flang/Optimizer/HLFIR/HLFIROps.cpp.inc"
12 changes: 12 additions & 0 deletions flang/test/HLFIR/invalid.fir
Original file line number Diff line number Diff line change
Expand Up @@ -500,3 +500,15 @@ func.func @bad_parent_comp6(%arg0: !fir.box<!fir.array<10x!fir.type<t2{i:i32,j:i
%2 = hlfir.parent_comp %arg0 shape %1 : (!fir.box<!fir.array<10x!fir.type<t2{i:i32,j:i32}>>>, !fir.shape<1>) -> !fir.ref<!fir.array<10x!fir.type<t1{i:i32}>>>
return
}

// -----
func.func @bad_shapeof(%arg0: !hlfir.expr<!fir.char<1,10>>) {
// expected-error@+1 {{'hlfir.shape_of' op cannot get the shape of a shape-less expression}}
%0 = hlfir.shape_of %arg0 : (!hlfir.expr<!fir.char<1,10>>) -> !fir.shape<1>
}

// -----
func.func @bad_shapeof2(%arg0: !hlfir.expr<10xi32>) {
// expected-error@+1 {{'hlfir.shape_of' op result rank and expr rank do not match}}
%0 = hlfir.shape_of %arg0 : (!hlfir.expr<10xi32>) -> !fir.shape<42>
}
29 changes: 29 additions & 0 deletions flang/test/HLFIR/shapeof.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Test hlfir.shape_of operation parse, verify (no errors), and unparse
// RUN: fir-opt %s | fir-opt | FileCheck --check-prefix CHECK --check-prefix CHECK-ALL %s

// Test canonicalization
// RUN: fir-opt %s --canonicalize | FileCheck --check-prefix CHECK-CANON --check-prefix CHECK-ALL %s

func.func @shapeof(%arg0: !hlfir.expr<2x2xi32>) -> !fir.shape<2> {
%shape = hlfir.shape_of %arg0 : (!hlfir.expr<2x2xi32>) -> !fir.shape<2>
return %shape : !fir.shape<2>
}
// CHECK-ALL-LABEL: func.func @shapeof
// CHECK-ALL: %[[EXPR:.*]]: !hlfir.expr<2x2xi32>

// CHECK-NEXT: %[[SHAPE:.*]] = hlfir.shape_of %[[EXPR]] : (!hlfir.expr<2x2xi32>) -> !fir.shape<2>

// CHECK-CANON-NEXT: %[[C2:.*]] = arith.constant 2 : index
// CHECK-CANON-NEXT: %[[SHAPE:.*]] = fir.shape %[[C2]], %[[C2]] : (index, index) -> !fir.shape<2>

// CHECK-ALL-NEXT: return %[[SHAPE]]

// no canonicalization of expressions with unknown extents
func.func @shapeof2(%arg0: !hlfir.expr<?x2xi32>) -> !fir.shape<2> {
%shape = hlfir.shape_of %arg0 : (!hlfir.expr<?x2xi32>) -> !fir.shape<2>
return %shape : !fir.shape<2>
}
// CHECK-ALL-LABEL: func.func @shapeof2
// CHECK-ALL: %[[EXPR:.*]]: !hlfir.expr<?x2xi32>
// CHECK-ALL-NEXT: %[[SHAPE:.*]] = hlfir.shape_of %[[EXPR]] : (!hlfir.expr<?x2xi32>) -> !fir.shape<2>
// CHECK-ALL-NEXT: return %[[SHAPE]]

0 comments on commit 5c92507

Please sign in to comment.