Skip to content

Commit

Permalink
[flang][hlfir] inline hlfir.transpose as hlfir.elemental
Browse files Browse the repository at this point in the history
Inlining as a hlfir.elemental will allow the transpose to be inlined
into subsequent operations in some cases. For example,

y = TRANSPOSE(x)
z = y * 2

Will operate in a single loop without creating a temporary for the
TRANSPOSE (unlike the runtime call, which always allocates).

This is in a new SimplifyHLFIRIntriniscs pass. The intention is that some
day that pass might replace the FIR SimplifyIntrinsics pass.

Depends On: D149060

Reviewed By: jeanPerier, vzakhari

Differential Revision: https://reviews.llvm.org/D149067
  • Loading branch information
tblah committed Apr 25, 2023
1 parent 5156d1a commit 64ea60e
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 7 deletions.
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/Builder/HLFIRTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,11 @@ llvm::SmallVector<mlir::Value> getIndexExtents(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::Value shape);

/// Return explicit extents. If the base is a fir.box, this won't read it to
/// return the extents and will instead return an empty vector.
llvm::SmallVector<mlir::Value>
getExplicitExtentsFromShape(mlir::Value shape, fir::FirOpBuilder &builder);

/// Read length parameters into result if this entity has any.
void genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
Entity entity,
Expand Down
2 changes: 2 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef FORTRAN_OPTIMIZER_HLFIR_PASSES_H
#define FORTRAN_OPTIMIZER_HLFIR_PASSES_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include <memory>
Expand All @@ -24,6 +25,7 @@ namespace hlfir {
std::unique_ptr<mlir::Pass> createConvertHLFIRtoFIRPass();
std::unique_ptr<mlir::Pass> createBufferizeHLFIRPass();
std::unique_ptr<mlir::Pass> createLowerHLFIRIntrinsicsPass();
std::unique_ptr<mlir::Pass> createSimplifyHLFIRIntrinsicsPass();

#define GEN_PASS_REGISTRATION
#include "flang/Optimizer/HLFIR/Passes.h.inc"
Expand Down
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,9 @@ def LowerHLFIRIntrinsics : Pass<"lower-hlfir-intrinsics", "::mlir::ModuleOp"> {
let constructor = "hlfir::createLowerHLFIRIntrinsicsPass()";
}

def SimplifyHLFIRIntrinsics : Pass<"simplify-hlfir-intrinsics", "::mlir::func::FuncOp"> {
let summary = "Simplify HLFIR intrinsic operations that don't need to result in runtime calls";
let constructor = "hlfir::createSimplifyHLFIRIntrinsicsPass()";
}

#endif //FORTRAN_DIALECT_HLFIR_PASSES
4 changes: 3 additions & 1 deletion flang/include/flang/Tools/CLOptions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,10 @@ inline void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
/// passes pipeline
inline void createHLFIRToFIRPassPipeline(
mlir::PassManager &pm, llvm::OptimizationLevel optLevel = defaultOptLevel) {
if (optLevel.isOptimizingForSpeed())
if (optLevel.isOptimizingForSpeed()) {
addCanonicalizerPassWithoutRegionSimplification(pm);
pm.addPass(hlfir::createSimplifyHLFIRIntrinsicsPass());
}
pm.addPass(hlfir::createLowerHLFIRIntrinsicsPass());
pm.addPass(hlfir::createBufferizeHLFIRPass());
pm.addPass(hlfir::createConvertHLFIRtoFIRPass());
Expand Down
13 changes: 7 additions & 6 deletions flang/lib/Optimizer/Builder/HLFIRTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@

// Return explicit extents. If the base is a fir.box, this won't read it to
// return the extents and will instead return an empty vector.
static llvm::SmallVector<mlir::Value>
getExplicitExtentsFromShape(mlir::Value shape, fir::FirOpBuilder &builder) {
llvm::SmallVector<mlir::Value>
hlfir::getExplicitExtentsFromShape(mlir::Value shape,
fir::FirOpBuilder &builder) {
llvm::SmallVector<mlir::Value> result;
auto *shapeOp = shape.getDefiningOp();
if (auto s = mlir::dyn_cast_or_null<fir::ShapeOp>(shapeOp)) {
Expand Down Expand Up @@ -62,7 +63,7 @@ static llvm::SmallVector<mlir::Value>
getExplicitExtents(fir::FortranVariableOpInterface var,
fir::FirOpBuilder &builder) {
if (mlir::Value shape = var.getShape())
return getExplicitExtentsFromShape(var.getShape(), builder);
return hlfir::getExplicitExtentsFromShape(var.getShape(), builder);
return {};
}

Expand Down Expand Up @@ -404,7 +405,7 @@ hlfir::genBounds(mlir::Location loc, fir::FirOpBuilder &builder,
assert((shape.getType().isa<fir::ShapeShiftType>() ||
shape.getType().isa<fir::ShapeType>()) &&
"shape must contain extents");
auto extents = getExplicitExtentsFromShape(shape, builder);
auto extents = hlfir::getExplicitExtentsFromShape(shape, builder);
auto lowers = getExplicitLboundsFromShape(shape);
assert(lowers.empty() || lowers.size() == extents.size());
mlir::Type idxTy = builder.getIndexType();
Expand Down Expand Up @@ -527,7 +528,7 @@ llvm::SmallVector<mlir::Value>
hlfir::getIndexExtents(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::Value shape) {
llvm::SmallVector<mlir::Value> extents =
getExplicitExtentsFromShape(shape, builder);
hlfir::getExplicitExtentsFromShape(shape, builder);
mlir::Type indexType = builder.getIndexType();
for (auto &extent : extents)
extent = builder.createConvert(loc, indexType, extent);
Expand All @@ -538,7 +539,7 @@ mlir::Value hlfir::genExtent(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity entity, unsigned dim) {
entity = followShapeInducingSource(entity);
if (auto shape = tryRetrievingShapeOrShift(entity)) {
auto extents = getExplicitExtentsFromShape(shape, builder);
auto extents = hlfir::getExplicitExtentsFromShape(shape, builder);
if (!extents.empty()) {
assert(extents.size() > dim && "bad inquiry");
return extents[dim];
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_flang_library(HLFIRTransforms
BufferizeHLFIR.cpp
ConvertToFIR.cpp
LowerHLFIRIntrinsics.cpp
SimplifyHLFIRIntrinsics.cpp

DEPENDS
FIRDialect
Expand Down
114 changes: 114 additions & 0 deletions flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
//===- SimplifyHLFIRIntrinsics.cpp - Simplify HLFIR Intrinsics ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Normally transformational intrinsics are lowered to calls to runtime
// functions. However, some cases of the intrinsics are faster when inlined
// into the calling function.
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Location.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace hlfir {
#define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS
#include "flang/Optimizer/HLFIR/Passes.h.inc"
} // namespace hlfir

namespace {

class TransposeAsElementalConversion
: public mlir::OpRewritePattern<hlfir::TransposeOp> {
public:
using mlir::OpRewritePattern<hlfir::TransposeOp>::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(hlfir::TransposeOp transpose,
mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = transpose.getLoc();
fir::KindMapping kindMapping{rewriter.getContext()};
fir::FirOpBuilder builder{rewriter, kindMapping};
hlfir::ExprType expr = transpose.getType();
mlir::Type elementType = expr.getElementType();
hlfir::Entity array = hlfir::Entity{transpose.getArray()};
mlir::Value resultShape = genResultShape(loc, builder, array);
llvm::SmallVector<mlir::Value, 1> typeParams;
hlfir::genLengthParameters(loc, builder, array, typeParams);

auto genKernel = [&array](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange inputIndices) -> hlfir::Entity {
assert(inputIndices.size() == 2 && "checked in TransposeOp::validate");
mlir::ValueRange transposedIndices{{inputIndices[1], inputIndices[0]}};
hlfir::Entity element =
hlfir::getElementAt(loc, builder, array, transposedIndices);
hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, element);
return val;
};
hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
loc, builder, elementType, resultShape, typeParams, genKernel);

rewriter.replaceOp(transpose, elementalOp.getResult());
return mlir::success();
}

private:
static mlir::Value genResultShape(mlir::Location loc,
fir::FirOpBuilder &builder,
hlfir::Entity array) {
mlir::Value inShape = hlfir::genShape(loc, builder, array);
llvm::SmallVector<mlir::Value> inExtents =
hlfir::getExplicitExtentsFromShape(inShape, builder);
if (inShape.getUses().empty())
inShape.getDefiningOp()->erase();

// transpose indices
assert(inExtents.size() == 2 && "checked in TransposeOp::validate");
return builder.create<fir::ShapeOp>(
loc, mlir::ValueRange{inExtents[1], inExtents[0]});
}
};

class SimplifyHLFIRIntrinsics
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
public:
void runOnOperation() override {
mlir::func::FuncOp func = this->getOperation();
mlir::MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
patterns.insert<TransposeAsElementalConversion>(context);
mlir::ConversionTarget target(*context);
// don't transform transpose of polymorphic arrays (not currently supported
// by hlfir.elemental)
target.addDynamicallyLegalOp<hlfir::TransposeOp>(
[](hlfir::TransposeOp transpose) {
return transpose.getType().cast<hlfir::ExprType>().isPolymorphic();
});
target.markUnknownOpDynamicallyLegal(
[](mlir::Operation *) { return true; });
if (mlir::failed(
mlir::applyFullConversion(func, target, std::move(patterns)))) {
mlir::emitError(func->getLoc(),
"failure in HLFIR intrinsic simplification");
signalPassFailure();
}
}
};
} // namespace

std::unique_ptr<mlir::Pass> hlfir::createSimplifyHLFIRIntrinsicsPass() {
return std::make_unique<SimplifyHLFIRIntrinsics>();
}
2 changes: 2 additions & 0 deletions flang/test/Driver/mlir-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

! ALL: Fortran::lower::VerifierPass
! O2-NEXT: Canonicalizer
! O2-NEXT: 'func.func' Pipeline
! O2-NEXT: SimplifyHLFIRIntrinsics
! ALL-NEXT: LowerHLFIRIntrinsics
! ALL-NEXT: BufferizeHLFIR
! ALL-NEXT: ConvertHLFIRtoFIR
Expand Down
2 changes: 2 additions & 0 deletions flang/test/Fir/basic-program.fir
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ func.func @_QQmain() {
// PASSES: Pass statistics report

// PASSES: Canonicalizer
// PASSES-NEXT: 'func.func' Pipeline
// PASSES-NEXT: SimplifyHLFIRIntrinsics
// PASSES-NEXT: LowerHLFIRIntrinsics
// PASSES-NEXT: BufferizeHLFIR
// PASSES-NEXT: ConvertHLFIRtoFIR
Expand Down
95 changes: 95 additions & 0 deletions flang/test/HLFIR/simplify-hlfir-intrinsics.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// RUN: fir-opt --simplify-hlfir-intrinsics %s | FileCheck %s

// box with known extents
func.func @transpose0(%arg0: !fir.box<!fir.array<1x2xi32>>) {
%res = hlfir.transpose %arg0 : (!fir.box<!fir.array<1x2xi32>>) -> !hlfir.expr<2x1xi32>
return
}
// CHECK-LABEL: func.func @transpose0(
// CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<1x2xi32>>) {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[SHAPE:.*]] = fir.shape %[[C2]], %[[C1]] : (index, index) -> !fir.shape<2>
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x1xi32> {
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0]] : (!fir.box<!fir.array<1x2xi32>>, index) -> (index, index, index)
// CHECK: %[[C1_1:.*]] = arith.constant 1 : index
// CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[ARG0]], %[[C1_1]] : (!fir.box<!fir.array<1x2xi32>>, index) -> (index, index, index)
// CHECK: %[[C1_2:.*]] = arith.constant 1 : index
// CHECK: %[[LOWER_BOUND0:.*]] = arith.subi %[[DIMS0]]#0, %[[C1_2]] : index
// CHECK: %[[J_OFFSET:.*]] = arith.addi %[[J]], %[[LOWER_BOUND0]] : index
// CHECK: %[[LOWER_BOUND1:.*]] = arith.subi %[[DIMS1]]#0, %[[C1_2]] : index
// CHECK: %[[I_OFFSET:.*]] = arith.addi %[[I]], %[[LOWER_BOUND1]] : index
// CHECK: %[[ELEMENT_REF:.*]] = hlfir.designate %[[ARG0]] (%[[J_OFFSET]], %[[I_OFFSET]]) : (!fir.box<!fir.array<1x2xi32>>, index, index) -> !fir.ref<i32>
// CHECK: %[[ELEMENT:.*]] = fir.load %[[ELEMENT_REF]] : !fir.ref<i32>
// CHECK: hlfir.yield_element %[[ELEMENT]] : i32
// CHECK: }
// CHECK: return
// CHECK: }

// expr with known extents
func.func @transpose1(%arg0: !hlfir.expr<1x2xi32>) {
%res = hlfir.transpose %arg0 : (!hlfir.expr<1x2xi32>) -> !hlfir.expr<2x1xi32>
return
}
// CHECK-LABEL: func.func @transpose1(
// CHECK-SAME: %[[ARG0:.*]]: !hlfir.expr<1x2xi32>) {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[SHAPE:.*]] = fir.shape %[[C2]], %[[C1]] : (index, index) -> !fir.shape<2>
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x1xi32> {
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
// CHECK: %[[ELEMENT:.*]] = hlfir.apply %[[ARG0]], %[[J]], %[[I]] : (!hlfir.expr<1x2xi32>, index, index) -> i32
// CHECK: hlfir.yield_element %[[ELEMENT]] : i32
// CHECK: }
// CHECK: return
// CHECK: }

// box with unknown extent
func.func @transpose2(%arg0: !fir.box<!fir.array<?x2xi32>>) {
%res = hlfir.transpose %arg0 : (!fir.box<!fir.array<?x2xi32>>) -> !hlfir.expr<2x?xi32>
return
}
// CHECK-LABEL: func.func @transpose2(
// CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?x2xi32>>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0]] : (!fir.box<!fir.array<?x2xi32>>, index) -> (index, index, index)
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[SHAPE:.*]] = fir.shape %[[C2]], %[[DIMS0]]#1 : (index, index) -> !fir.shape<2>
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x?xi32> {
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0_1]] : (!fir.box<!fir.array<?x2xi32>>, index) -> (index, index, index)
// CHECK: %[[C1_1:.*]] = arith.constant 1 : index
// CHECK: %[[DIMS1_1:.*]]:3 = fir.box_dims %[[ARG0]], %[[C1_1]] : (!fir.box<!fir.array<?x2xi32>>, index) -> (index, index, index)
// CHECK: %[[C1_2:.*]] = arith.constant 1 : index
// CHECK: %[[LOWER_BOUND0:.*]] = arith.subi %[[DIMS0]]#0, %[[C1_2]] : index
// CHECK: %[[J_OFFSET:.*]] = arith.addi %[[J]], %[[LOWER_BOUND0]] : index
// CHECK: %[[LOWER_BOUND1:.*]] = arith.subi %[[DIMS1_1]]#0, %[[C1_2]] : index
// CHECK: %[[I_OFFSET:.*]] = arith.addi %[[I]], %[[LOWER_BOUND1]] : index
// CHECK: %[[ELE_REF:.*]] = hlfir.designate %[[ARG0]] (%[[J_OFFSET]], %[[I_OFFSET]]) : (!fir.box<!fir.array<?x2xi32>>, index, index) -> !fir.ref<i32>
// CHECK: %[[ELEMENT:.*]] = fir.load %[[ELE_REF]] : !fir.ref<i32>
// CHECK: hlfir.yield_element %[[ELEMENT]] : i32
// CHECK: }
// CHECK: return
// CHECK: }

// expr with unknown extent
func.func @transpose3(%arg0: !hlfir.expr<?x2xi32>) {
%res = hlfir.transpose %arg0 : (!hlfir.expr<?x2xi32>) -> !hlfir.expr<2x?xi32>
return
}
// CHECK-LABEL: func.func @transpose3(
// CHECK-SAME: %[[ARG0:.*]]: !hlfir.expr<?x2xi32>) {
// CHECK: %[[IN_SHAPE:.*]] = hlfir.shape_of %[[ARG0]] : (!hlfir.expr<?x2xi32>) -> !fir.shape<2>
// CHECK: %[[EXTENT0:.*]] = hlfir.get_extent %[[IN_SHAPE]] {dim = 0 : index} : (!fir.shape<2>) -> index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[OUT_SHAPE:.*]] = fir.shape %[[C2]], %[[EXTENT0]] : (index, index) -> !fir.shape<2>
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[OUT_SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x?xi32> {
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
// CHECK: %[[ELEMENT:.*]] = hlfir.apply %[[ARG0]], %[[J]], %[[I]] : (!hlfir.expr<?x2xi32>, index, index) -> i32
// CHECK: hlfir.yield_element %[[ELEMENT]] : i32
// CHECK: }
// CHECK: return
// CHECK: }

0 comments on commit 64ea60e

Please sign in to comment.