Skip to content

Commit

Permalink
[fir] Add fir.select and fir.select_rank FIR to LLVM IR conversion pa…
Browse files Browse the repository at this point in the history
…tterns

The `fir.select` and `fir.select_rank` are lowered to llvm.switch.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D113089

Co-authored-by: Jean Perier <jperier@nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
  • Loading branch information
3 people committed Nov 5, 2021
1 parent 3a11fb5 commit 8c23990
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 3 deletions.
11 changes: 10 additions & 1 deletion flang/include/flang/Optimizer/Dialect/FIROps.td
Expand Up @@ -480,14 +480,17 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :

// The number of destination conditions that may be tested
unsigned getNumConditions() {
return (*this)->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).size();
return getCases().size();
}

// The selector is the value being tested to determine the destination
mlir::Value getSelector() { return selector(); }
mlir::Value getSelector(llvm::ArrayRef<mlir::Value> operands) {
return operands[0];
}
mlir::Value getSelector(mlir::ValueRange operands) {
return operands.front();
}

// The number of blocks that may be branched to
unsigned getNumDest() { return (*this)->getNumSuccessors(); }
Expand All @@ -498,6 +501,8 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :

llvm::Optional<llvm::ArrayRef<mlir::Value>> getSuccessorOperands(
llvm::ArrayRef<mlir::Value> operands, unsigned cond);
llvm::Optional<mlir::ValueRange> getSuccessorOperands(
mlir::ValueRange operands, unsigned cond);
using BranchOpInterfaceTrait::getSuccessorOperands;

// Helper function to deal with Optional operand forms
Expand All @@ -510,6 +515,10 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
p.printSuccessor(succ);
}

mlir::ArrayAttr getCases() {
return (*this)->getAttrOfType<mlir::ArrayAttr>(getCasesAttr());
}

unsigned targetOffsetSize();
}];
}
Expand Down
77 changes: 75 additions & 2 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Expand Up @@ -174,6 +174,78 @@ struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
}
};

template <typename OP>
void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
typename OP::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) {
unsigned conds = select.getNumConditions();
auto cases = select.getCases().getValue();
mlir::Value selector = adaptor.selector();
auto loc = select.getLoc();
assert(conds > 0 && "select must have cases");

llvm::SmallVector<mlir::Block *> destinations;
llvm::SmallVector<mlir::ValueRange> destinationsOperands;
mlir::Block *defaultDestination;
mlir::ValueRange defaultOperands;
llvm::SmallVector<int32_t> caseValues;

for (unsigned t = 0; t != conds; ++t) {
mlir::Block *dest = select.getSuccessor(t);
auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
const mlir::Attribute &attr = cases[t];
if (auto intAttr = attr.template dyn_cast<mlir::IntegerAttr>()) {
destinations.push_back(dest);
destinationsOperands.push_back(destOps.hasValue() ? *destOps
: ValueRange());
caseValues.push_back(intAttr.getInt());
continue;
}
assert(attr.template dyn_cast_or_null<mlir::UnitAttr>());
assert((t + 1 == conds) && "unit must be last");
defaultDestination = dest;
defaultOperands = destOps.hasValue() ? *destOps : ValueRange();
}

// LLVM::SwitchOp takes a i32 type for the selector.
if (select.getSelector().getType() != rewriter.getI32Type())
selector =
rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), selector);

rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
select, selector,
/*defaultDestination=*/defaultDestination,
/*defaultOperands=*/defaultOperands,
/*caseValues=*/caseValues,
/*caseDestinations=*/destinations,
/*caseOperands=*/destinationsOperands,
/*branchWeights=*/ArrayRef<int32_t>());
}

/// conversion of fir::SelectOp to an if-then-else ladder
struct SelectOpConversion : public FIROpConversion<fir::SelectOp> {
using FIROpConversion::FIROpConversion;

mlir::LogicalResult
matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, rewriter);
return success();
}
};

/// conversion of fir::SelectRankOp to an if-then-else ladder
struct SelectRankOpConversion : public FIROpConversion<fir::SelectRankOp> {
using FIROpConversion::FIROpConversion;

mlir::LogicalResult
matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
selectMatchAndRewrite<fir::SelectRankOp>(lowerTy(), op, adaptor, rewriter);
return success();
}
};

// convert to LLVM IR dialect `undef`
struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
using FIROpConversion::FIROpConversion;
Expand Down Expand Up @@ -318,8 +390,9 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
fir::LLVMTypeConverter typeConverter{getModule()};
mlir::OwningRewritePatternList pattern(context);
pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
InsertOnRangeOpConversion, UndefOpConversion,
UnreachableOpConversion, ZeroOpConversion>(typeConverter);
InsertOnRangeOpConversion, SelectOpConversion,
SelectRankOpConversion, UnreachableOpConversion,
ZeroOpConversion, UndefOpConversion>(typeConverter);
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
pattern);
Expand Down
19 changes: 19 additions & 0 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Expand Up @@ -2264,6 +2264,15 @@ fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}

llvm::Optional<mlir::ValueRange>
fir::SelectOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) {
auto a =
(*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
getOperandSegmentSizeAttr());
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}

unsigned fir::SelectOp::targetOffsetSize() {
return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
getTargetOffsetAttr()));
Expand Down Expand Up @@ -2557,6 +2566,16 @@ fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}

llvm::Optional<mlir::ValueRange>
fir::SelectRankOp::getSuccessorOperands(mlir::ValueRange operands,
unsigned oper) {
auto a =
(*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
getOperandSegmentSizeAttr());
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}

unsigned fir::SelectRankOp::targetOffsetSize() {
return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
getTargetOffsetAttr()));
Expand Down
92 changes: 92 additions & 0 deletions flang/test/Fir/convert-to-llvm.fir
Expand Up @@ -167,3 +167,95 @@ func @zero_test_float() {
func @test_unreachable() {
fir.unreachable
}

// -----

// Test `fir.select` operation conversion pattern.
// Check that the if-then-else ladder is correctly constructed and that we
// branch to the correct block.

func @select(%arg : index, %arg2 : i32) -> i32 {
%0 = arith.constant 1 : i32
%1 = arith.constant 2 : i32
%2 = arith.constant 3 : i32
%3 = arith.constant 4 : i32
fir.select %arg:index [ 1, ^bb1(%0:i32),
2, ^bb2(%2,%arg,%arg2:i32,index,i32),
3, ^bb3(%arg2,%2:i32,i32),
4, ^bb4(%1:i32),
unit, ^bb5 ]
^bb1(%a : i32) :
return %a : i32
^bb2(%b : i32, %b2 : index, %b3:i32) :
%castidx = arith.index_cast %b2 : index to i32
%4 = arith.addi %b, %castidx : i32
%5 = arith.addi %4, %b3 : i32
return %5 : i32
^bb3(%c:i32, %c2:i32) :
%6 = arith.addi %c, %c2 : i32
return %6 : i32
^bb4(%d : i32) :
return %d : i32
^bb5 :
%zero = arith.constant 0 : i32
return %zero : i32
}

// CHECK-LABEL: func @select(
// CHECK-SAME: %[[SELECTVALUE:.*]]: [[IDX:.*]],
// CHECK-SAME: %[[ARG1:.*]]: i32)
// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: %[[SELECTOR:.*]] = llvm.trunc %[[SELECTVALUE]] : i{{.*}} to i32
// CHECK: llvm.switch %[[SELECTOR]], ^bb5 [
// CHECK: 1: ^bb1(%[[C0]] : i32),
// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, [[IDX]], i32),
// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
// CHECK: 4: ^bb4(%[[C1]] : i32)
// CHECK: ]

// -----

// Test `fir.select_rank` operation conversion pattern.
// Check that the if-then-else ladder is correctly constructed and that we
// branch to the correct block.

func @select_rank(%arg : i32, %arg2 : i32) -> i32 {
%0 = arith.constant 1 : i32
%1 = arith.constant 2 : i32
%2 = arith.constant 3 : i32
%3 = arith.constant 4 : i32
fir.select_rank %arg:i32 [ 1, ^bb1(%0:i32),
2, ^bb2(%2,%arg,%arg2:i32,i32,i32),
3, ^bb3(%arg2,%2:i32,i32),
4, ^bb4(%1:i32),
unit, ^bb5 ]
^bb1(%a : i32) :
return %a : i32
^bb2(%b : i32, %b2 : i32, %b3:i32) :
%4 = arith.addi %b, %b2 : i32
%5 = arith.addi %4, %b3 : i32
return %5 : i32
^bb3(%c:i32, %c2:i32) :
%6 = arith.addi %c, %c2 : i32
return %6 : i32
^bb4(%d : i32) :
return %d : i32
^bb5 :
%zero = arith.constant 0 : i32
return %zero : i32
}

// CHECK-LABEL: func @select_rank(
// CHECK-SAME: %[[SELECTVALUE:.*]]: i32,
// CHECK-SAME: %[[ARG1:.*]]: i32)
// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: llvm.switch %[[SELECTVALUE]], ^bb5 [
// CHECK: 1: ^bb1(%[[C0]] : i32),
// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, i32, i32),
// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
// CHECK: 4: ^bb4(%[[C1]] : i32)
// CHECK: ]

0 comments on commit 8c23990

Please sign in to comment.