Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_FUNCCONVERSIONS_H_
#define MLIR_DIALECT_FUNC_TRANSFORMS_FUNCCONVERSIONS_H_

#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"

Expand All @@ -29,7 +30,8 @@ class RewritePatternSet;
/// Add a pattern to the given pattern list to convert the operand and result
/// types of a CallOp with the given type converter.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
const TypeConverter &converter);
const TypeConverter &converter,
PatternBenefit benefit = 1);

/// Add a pattern to the given pattern list to rewrite branch operations to use
/// operands that have been legalized by the conversion framework. This can only
Expand All @@ -43,7 +45,8 @@ void populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
void populateBranchOpInterfaceTypeConversionPattern(
RewritePatternSet &patterns, const TypeConverter &converter,
function_ref<bool(BranchOpInterface branchOp, int idx)>
shouldConvertBranchOperand = nullptr);
shouldConvertBranchOperand = nullptr,
PatternBenefit benefit = 1);

/// Return true if op is a BranchOpInterface op whose operands are all legal
/// according to converter.
Expand All @@ -53,7 +56,8 @@ bool isLegalForBranchOpInterfaceTypeConversionPattern(
/// Add a pattern to the given pattern list to rewrite `return` ops to use
/// operands that have been legalized by the conversion framework.
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns,
const TypeConverter &converter);
const TypeConverter &converter,
PatternBenefit benefit = 1);

/// For ReturnLike ops (except `return`), return True. If op is a `return` &&
/// returnOpAlwaysLegal is false, legalize op according to converter. Otherwise,
Expand Down
23 changes: 15 additions & 8 deletions mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
} // namespace

void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
const TypeConverter &converter) {
patterns.add<CallOpSignatureConversion>(converter, patterns.getContext());
const TypeConverter &converter,
PatternBenefit benefit) {
patterns.add<CallOpSignatureConversion>(converter, patterns.getContext(),
benefit);
}

namespace {
Expand All @@ -81,8 +83,9 @@ class BranchOpInterfaceTypeConversion

BranchOpInterfaceTypeConversion(
const TypeConverter &typeConverter, MLIRContext *ctx,
function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand)
: OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1),
function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand,
PatternBenefit benefit)
: OpInterfaceConversionPattern(typeConverter, ctx, benefit),
shouldConvertBranchOperand(shouldConvertBranchOperand) {}

LogicalResult
Expand Down Expand Up @@ -135,9 +138,11 @@ class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {

void mlir::populateBranchOpInterfaceTypeConversionPattern(
RewritePatternSet &patterns, const TypeConverter &typeConverter,
function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) {
function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand,
PatternBenefit benefit) {
patterns.add<BranchOpInterfaceTypeConversion>(
typeConverter, patterns.getContext(), shouldConvertBranchOperand);
typeConverter, patterns.getContext(), shouldConvertBranchOperand,
benefit);
}

bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
Expand All @@ -157,8 +162,10 @@ bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
}

void mlir::populateReturnOpTypeConversionPattern(
RewritePatternSet &patterns, const TypeConverter &typeConverter) {
patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext());
RewritePatternSet &patterns, const TypeConverter &typeConverter,
PatternBenefit benefit) {
patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext(),
benefit);
}

bool mlir::isLegalForReturnOpTypeConversionPattern(
Expand Down