diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp index 4b631b2f99a5d..21fe2d9f89372 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -27,8 +27,9 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -#include +#include "llvm/ADT/TypeSwitch.h" namespace hlfir { #define GEN_PASS_DEF_BUFFERIZEHLFIR @@ -169,6 +170,38 @@ struct AsExprOpConversion : public mlir::OpConversionPattern { } }; +struct ShapeOfOpConversion + : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(hlfir::ShapeOfOp shapeOf, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = shapeOf.getLoc(); + mlir::ModuleOp mod = shapeOf->getParentOfType(); + fir::FirOpBuilder builder(rewriter, fir::getKindMapping(mod)); + + mlir::Value shape; + hlfir::Entity bufferizedExpr{getBufferizedExprStorage(adaptor.getExpr())}; + if (bufferizedExpr.isVariable()) { + shape = hlfir::genShape(loc, builder, bufferizedExpr); + } else { + // everything else failed so try to create a shape from static type info + hlfir::ExprType exprTy = + adaptor.getExpr().getType().dyn_cast_or_null(); + if (exprTy) + shape = hlfir::genExprShape(builder, loc, exprTy); + } + // expected to never happen + if (!shape) + return emitError(loc, + "Unresolvable hlfir.shape_of where extents are unknown"); + + rewriter.replaceOp(shapeOf, shape); + return mlir::success(); + } +}; + struct ApplyOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; explicit ApplyOpConversion(mlir::MLIRContext *ctx) @@ -529,11 +562,11 @@ class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase { auto module = this->getOperation(); auto *context = &getContext(); mlir::RewritePatternSet patterns(context); - patterns - .insert(context); + patterns.insert(context); mlir::ConversionTarget target(*context); target.addIllegalOp>>) -> !fir.shape<1> { + %c0 = arith.constant 0 : index + %59:3 = fir.box_dims %arg0, %c0 : (!fir.box>>, index) -> (index, index, index) + %60 = fir.box_addr %arg0 : (!fir.box>>) -> !fir.heap> + %61 = fir.shape_shift %59#0, %59#1 : (index, index) -> !fir.shapeshift<1> + %62:2 = hlfir.declare %60(%61) {uniq_name = ".tmp.intrinsic_result"} : (!fir.heap>, !fir.shapeshift<1>) -> (!fir.box>, !fir.heap>) + %true = arith.constant true + %63 = hlfir.as_expr %62#0 move %true : (!fir.box>, i1) -> !hlfir.expr + %64 = hlfir.shape_of %63 : (!hlfir.expr) -> !fir.shape<1> + return %64 : !fir.shape<1> +} +// CHECK-LABEL: @shapeof_asexpr +// CHECK: %[[ARG0:.*]]: !fir.box>> +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 +// CHECK-NEXT: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0]] +// CHECK-NEXT: %[[BOX_ADDR:.*]] = fir.box_addr %[[ARG0]] +// CHECK-NEXT: %[[SHPE_SHFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1 +// CHECK-NEXT: %[[VAR:.*]]:2 = hlfir.declare %[[BOX_ADDR]](%[[SHPE_SHFT]]) +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[TUPLE0:.*]] = fir.undefined tuple +// CHECK-NEXT: %[[TUPLE1:.*]] = fir.insert_value %[[TUPLE0]], %[[TRUE]] +// CHECK-NEXT: %[[TUPLE2:.*]] = fir.insert_value %[[TUPLE1]], %[[VAR]]#0 +// CHECK-NEXT: %[[SHAPE:.*]] = fir.shape %[[BOX_DIMS]]#1 +// CHECK-NEXT: return %[[SHAPE]] + +func.func @shapeof_elemental() -> !fir.shape<1> { + %c1 = arith.constant 1 : index + %0 = fir.shape %c1 : (index) -> !fir.shape<1> + %1 = hlfir.elemental %0 : (!fir.shape<1>) -> !hlfir.expr { + ^bb0(%arg3: index): + hlfir.yield_element %arg3 : index + } + %2 = hlfir.shape_of %1 : (!hlfir.expr) -> !fir.shape<1> + return %2 : !fir.shape<1> +} +// CHECK-LABEL: @shapeof_elemental +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[SHAPE:.*]] = fir.shape %[[C1]] +// CHECK: fir.do_loop %{{.*}} = %{{.*}} to %[[C1:.*]] +// CHECK: return %[[SHAPE]] + +func.func @shapeof_fallback(%arg0: !hlfir.expr<1x2x3xi32>) -> !fir.shape<3> { + %shape = hlfir.shape_of %arg0 : (!hlfir.expr<1x2x3xi32>) -> !fir.shape<3> + return %shape : !fir.shape<3> +} +// CHECK-LABEL: @shapeof_fallback +// CHECK: %[[EXPR:.*]]: !hlfir.expr<1x2x3xi32> +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : index +// CHECK-NEXT: %[[C3:.*]] = arith.constant 3 : index +// CHECK-NEXT: %[[SHAPE:.*]] = fir.shape %[[C1]], %[[C2]], %[[C3]] : +// CHECK-NEXT: return %[[SHAPE]]