-
Notifications
You must be signed in to change notification settings - Fork 10.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[flang][hlfir] inline hlfir.transpose as hlfir.elemental
Inlining as a hlfir.elemental will allow the transpose to be inlined into subsequent operations in some cases. For example, y = TRANSPOSE(x) z = y * 2 Will operate in a single loop without creating a temporary for the TRANSPOSE (unlike the runtime call, which always allocates). This is in a new SimplifyHLFIRIntriniscs pass. The intention is that some day that pass might replace the FIR SimplifyIntrinsics pass. Depends On: D149060 Reviewed By: jeanPerier, vzakhari Differential Revision: https://reviews.llvm.org/D149067
- Loading branch information
Showing
10 changed files
with
236 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
114 changes: 114 additions & 0 deletions
114
flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
//===- SimplifyHLFIRIntrinsics.cpp - Simplify HLFIR Intrinsics ------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// Normally transformational intrinsics are lowered to calls to runtime | ||
// functions. However, some cases of the intrinsics are faster when inlined | ||
// into the calling function. | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "flang/Optimizer/Builder/FIRBuilder.h" | ||
#include "flang/Optimizer/Builder/HLFIRTools.h" | ||
#include "flang/Optimizer/Dialect/FIRDialect.h" | ||
#include "flang/Optimizer/Dialect/Support/KindMapping.h" | ||
#include "flang/Optimizer/HLFIR/HLFIRDialect.h" | ||
#include "flang/Optimizer/HLFIR/HLFIROps.h" | ||
#include "flang/Optimizer/HLFIR/Passes.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/IR/BuiltinDialect.h" | ||
#include "mlir/IR/Location.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
namespace hlfir { | ||
#define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS | ||
#include "flang/Optimizer/HLFIR/Passes.h.inc" | ||
} // namespace hlfir | ||
|
||
namespace { | ||
|
||
class TransposeAsElementalConversion | ||
: public mlir::OpRewritePattern<hlfir::TransposeOp> { | ||
public: | ||
using mlir::OpRewritePattern<hlfir::TransposeOp>::OpRewritePattern; | ||
|
||
mlir::LogicalResult | ||
matchAndRewrite(hlfir::TransposeOp transpose, | ||
mlir::PatternRewriter &rewriter) const override { | ||
mlir::Location loc = transpose.getLoc(); | ||
fir::KindMapping kindMapping{rewriter.getContext()}; | ||
fir::FirOpBuilder builder{rewriter, kindMapping}; | ||
hlfir::ExprType expr = transpose.getType(); | ||
mlir::Type elementType = expr.getElementType(); | ||
hlfir::Entity array = hlfir::Entity{transpose.getArray()}; | ||
mlir::Value resultShape = genResultShape(loc, builder, array); | ||
llvm::SmallVector<mlir::Value, 1> typeParams; | ||
hlfir::genLengthParameters(loc, builder, array, typeParams); | ||
|
||
auto genKernel = [&array](mlir::Location loc, fir::FirOpBuilder &builder, | ||
mlir::ValueRange inputIndices) -> hlfir::Entity { | ||
assert(inputIndices.size() == 2 && "checked in TransposeOp::validate"); | ||
mlir::ValueRange transposedIndices{{inputIndices[1], inputIndices[0]}}; | ||
hlfir::Entity element = | ||
hlfir::getElementAt(loc, builder, array, transposedIndices); | ||
hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, element); | ||
return val; | ||
}; | ||
hlfir::ElementalOp elementalOp = hlfir::genElementalOp( | ||
loc, builder, elementType, resultShape, typeParams, genKernel); | ||
|
||
rewriter.replaceOp(transpose, elementalOp.getResult()); | ||
return mlir::success(); | ||
} | ||
|
||
private: | ||
static mlir::Value genResultShape(mlir::Location loc, | ||
fir::FirOpBuilder &builder, | ||
hlfir::Entity array) { | ||
mlir::Value inShape = hlfir::genShape(loc, builder, array); | ||
llvm::SmallVector<mlir::Value> inExtents = | ||
hlfir::getExplicitExtentsFromShape(inShape, builder); | ||
if (inShape.getUses().empty()) | ||
inShape.getDefiningOp()->erase(); | ||
|
||
// transpose indices | ||
assert(inExtents.size() == 2 && "checked in TransposeOp::validate"); | ||
return builder.create<fir::ShapeOp>( | ||
loc, mlir::ValueRange{inExtents[1], inExtents[0]}); | ||
} | ||
}; | ||
|
||
class SimplifyHLFIRIntrinsics | ||
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> { | ||
public: | ||
void runOnOperation() override { | ||
mlir::func::FuncOp func = this->getOperation(); | ||
mlir::MLIRContext *context = &getContext(); | ||
mlir::RewritePatternSet patterns(context); | ||
patterns.insert<TransposeAsElementalConversion>(context); | ||
mlir::ConversionTarget target(*context); | ||
// don't transform transpose of polymorphic arrays (not currently supported | ||
// by hlfir.elemental) | ||
target.addDynamicallyLegalOp<hlfir::TransposeOp>( | ||
[](hlfir::TransposeOp transpose) { | ||
return transpose.getType().cast<hlfir::ExprType>().isPolymorphic(); | ||
}); | ||
target.markUnknownOpDynamicallyLegal( | ||
[](mlir::Operation *) { return true; }); | ||
if (mlir::failed( | ||
mlir::applyFullConversion(func, target, std::move(patterns)))) { | ||
mlir::emitError(func->getLoc(), | ||
"failure in HLFIR intrinsic simplification"); | ||
signalPassFailure(); | ||
} | ||
} | ||
}; | ||
} // namespace | ||
|
||
std::unique_ptr<mlir::Pass> hlfir::createSimplifyHLFIRIntrinsicsPass() { | ||
return std::make_unique<SimplifyHLFIRIntrinsics>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
// RUN: fir-opt --simplify-hlfir-intrinsics %s | FileCheck %s | ||
|
||
// box with known extents | ||
func.func @transpose0(%arg0: !fir.box<!fir.array<1x2xi32>>) { | ||
%res = hlfir.transpose %arg0 : (!fir.box<!fir.array<1x2xi32>>) -> !hlfir.expr<2x1xi32> | ||
return | ||
} | ||
// CHECK-LABEL: func.func @transpose0( | ||
// CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<1x2xi32>>) { | ||
// CHECK: %[[C1:.*]] = arith.constant 1 : index | ||
// CHECK: %[[C2:.*]] = arith.constant 2 : index | ||
// CHECK: %[[SHAPE:.*]] = fir.shape %[[C2]], %[[C1]] : (index, index) -> !fir.shape<2> | ||
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x1xi32> { | ||
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index): | ||
// CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0]] : (!fir.box<!fir.array<1x2xi32>>, index) -> (index, index, index) | ||
// CHECK: %[[C1_1:.*]] = arith.constant 1 : index | ||
// CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[ARG0]], %[[C1_1]] : (!fir.box<!fir.array<1x2xi32>>, index) -> (index, index, index) | ||
// CHECK: %[[C1_2:.*]] = arith.constant 1 : index | ||
// CHECK: %[[LOWER_BOUND0:.*]] = arith.subi %[[DIMS0]]#0, %[[C1_2]] : index | ||
// CHECK: %[[J_OFFSET:.*]] = arith.addi %[[J]], %[[LOWER_BOUND0]] : index | ||
// CHECK: %[[LOWER_BOUND1:.*]] = arith.subi %[[DIMS1]]#0, %[[C1_2]] : index | ||
// CHECK: %[[I_OFFSET:.*]] = arith.addi %[[I]], %[[LOWER_BOUND1]] : index | ||
// CHECK: %[[ELEMENT_REF:.*]] = hlfir.designate %[[ARG0]] (%[[J_OFFSET]], %[[I_OFFSET]]) : (!fir.box<!fir.array<1x2xi32>>, index, index) -> !fir.ref<i32> | ||
// CHECK: %[[ELEMENT:.*]] = fir.load %[[ELEMENT_REF]] : !fir.ref<i32> | ||
// CHECK: hlfir.yield_element %[[ELEMENT]] : i32 | ||
// CHECK: } | ||
// CHECK: return | ||
// CHECK: } | ||
|
||
// expr with known extents | ||
func.func @transpose1(%arg0: !hlfir.expr<1x2xi32>) { | ||
%res = hlfir.transpose %arg0 : (!hlfir.expr<1x2xi32>) -> !hlfir.expr<2x1xi32> | ||
return | ||
} | ||
// CHECK-LABEL: func.func @transpose1( | ||
// CHECK-SAME: %[[ARG0:.*]]: !hlfir.expr<1x2xi32>) { | ||
// CHECK: %[[C1:.*]] = arith.constant 1 : index | ||
// CHECK: %[[C2:.*]] = arith.constant 2 : index | ||
// CHECK: %[[SHAPE:.*]] = fir.shape %[[C2]], %[[C1]] : (index, index) -> !fir.shape<2> | ||
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x1xi32> { | ||
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index): | ||
// CHECK: %[[ELEMENT:.*]] = hlfir.apply %[[ARG0]], %[[J]], %[[I]] : (!hlfir.expr<1x2xi32>, index, index) -> i32 | ||
// CHECK: hlfir.yield_element %[[ELEMENT]] : i32 | ||
// CHECK: } | ||
// CHECK: return | ||
// CHECK: } | ||
|
||
// box with unknown extent | ||
func.func @transpose2(%arg0: !fir.box<!fir.array<?x2xi32>>) { | ||
%res = hlfir.transpose %arg0 : (!fir.box<!fir.array<?x2xi32>>) -> !hlfir.expr<2x?xi32> | ||
return | ||
} | ||
// CHECK-LABEL: func.func @transpose2( | ||
// CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?x2xi32>>) { | ||
// CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0]] : (!fir.box<!fir.array<?x2xi32>>, index) -> (index, index, index) | ||
// CHECK: %[[C2:.*]] = arith.constant 2 : index | ||
// CHECK: %[[SHAPE:.*]] = fir.shape %[[C2]], %[[DIMS0]]#1 : (index, index) -> !fir.shape<2> | ||
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x?xi32> { | ||
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index): | ||
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index | ||
// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0_1]] : (!fir.box<!fir.array<?x2xi32>>, index) -> (index, index, index) | ||
// CHECK: %[[C1_1:.*]] = arith.constant 1 : index | ||
// CHECK: %[[DIMS1_1:.*]]:3 = fir.box_dims %[[ARG0]], %[[C1_1]] : (!fir.box<!fir.array<?x2xi32>>, index) -> (index, index, index) | ||
// CHECK: %[[C1_2:.*]] = arith.constant 1 : index | ||
// CHECK: %[[LOWER_BOUND0:.*]] = arith.subi %[[DIMS0]]#0, %[[C1_2]] : index | ||
// CHECK: %[[J_OFFSET:.*]] = arith.addi %[[J]], %[[LOWER_BOUND0]] : index | ||
// CHECK: %[[LOWER_BOUND1:.*]] = arith.subi %[[DIMS1_1]]#0, %[[C1_2]] : index | ||
// CHECK: %[[I_OFFSET:.*]] = arith.addi %[[I]], %[[LOWER_BOUND1]] : index | ||
// CHECK: %[[ELE_REF:.*]] = hlfir.designate %[[ARG0]] (%[[J_OFFSET]], %[[I_OFFSET]]) : (!fir.box<!fir.array<?x2xi32>>, index, index) -> !fir.ref<i32> | ||
// CHECK: %[[ELEMENT:.*]] = fir.load %[[ELE_REF]] : !fir.ref<i32> | ||
// CHECK: hlfir.yield_element %[[ELEMENT]] : i32 | ||
// CHECK: } | ||
// CHECK: return | ||
// CHECK: } | ||
|
||
// expr with unknown extent | ||
func.func @transpose3(%arg0: !hlfir.expr<?x2xi32>) { | ||
%res = hlfir.transpose %arg0 : (!hlfir.expr<?x2xi32>) -> !hlfir.expr<2x?xi32> | ||
return | ||
} | ||
// CHECK-LABEL: func.func @transpose3( | ||
// CHECK-SAME: %[[ARG0:.*]]: !hlfir.expr<?x2xi32>) { | ||
// CHECK: %[[IN_SHAPE:.*]] = hlfir.shape_of %[[ARG0]] : (!hlfir.expr<?x2xi32>) -> !fir.shape<2> | ||
// CHECK: %[[EXTENT0:.*]] = hlfir.get_extent %[[IN_SHAPE]] {dim = 0 : index} : (!fir.shape<2>) -> index | ||
// CHECK: %[[C2:.*]] = arith.constant 2 : index | ||
// CHECK: %[[OUT_SHAPE:.*]] = fir.shape %[[C2]], %[[EXTENT0]] : (index, index) -> !fir.shape<2> | ||
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[OUT_SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x?xi32> { | ||
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index): | ||
// CHECK: %[[ELEMENT:.*]] = hlfir.apply %[[ARG0]], %[[J]], %[[I]] : (!hlfir.expr<?x2xi32>, index, index) -> i32 | ||
// CHECK: hlfir.yield_element %[[ELEMENT]] : i32 | ||
// CHECK: } | ||
// CHECK: return | ||
// CHECK: } |