Skip to content

Commit

Permalink
[mlir][complex] Convert complex.abs to libm
Browse files Browse the repository at this point in the history
Convert complex.abs to libm library

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D127476
  • Loading branch information
Lewuathe committed Jul 8, 2022
1 parent 8576867 commit eaba6e0
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 14 deletions.
59 changes: 45 additions & 14 deletions mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
Expand Up @@ -16,14 +16,43 @@
using namespace mlir;

namespace {
// Functor to resolve the function name corresponding to the given complex
// result type.
struct ComplexTypeResolver {
llvm::Optional<bool> operator()(Type type) const {
auto complexType = type.cast<ComplexType>();
auto elementType = complexType.getElementType();
if (!elementType.isa<Float32Type, Float64Type>())
return {};

return elementType.getIntOrFloatBitWidth() == 64;
}
};

// Functor to resolve the function name corresponding to the given float result
// type.
struct FloatTypeResolver {
llvm::Optional<bool> operator()(Type type) const {
auto elementType = type.cast<FloatType>();
if (!elementType.isa<Float32Type, Float64Type>())
return {};

return elementType.getIntOrFloatBitWidth() == 64;
}
};

// Pattern to convert scalar complex operations to calls to libm functions.
// Additionally the libm function signatures are declared.
template <typename Op>
// TypeResolver is a functor returning the libm function name according to the
// expected type double or float.
template <typename Op, typename TypeResolver = ComplexTypeResolver>
struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
public:
using OpRewritePattern<Op>::OpRewritePattern;
ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
StringRef doubleFunc, PatternBenefit benefit)
ScalarOpToLibmCall<Op, TypeResolver>(MLIRContext *context,
StringRef floatFunc,
StringRef doubleFunc,
PatternBenefit benefit)
: OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
doubleFunc(doubleFunc){};

Expand All @@ -34,18 +63,16 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
};
} // namespace

template <typename Op>
LogicalResult
ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
PatternRewriter &rewriter) const {
template <typename Op, typename TypeResolver>
LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite(
Op op, PatternRewriter &rewriter) const {
auto module = SymbolTable::getNearestSymbolTable(op);
auto type = op.getType().template cast<ComplexType>();
Type elementType = type.getElementType();
if (!elementType.isa<Float32Type, Float64Type>())
auto isDouble = TypeResolver()(op.getType());
if (!isDouble.hasValue())
return failure();

auto name =
elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
auto name = isDouble.value() ? doubleFunc : floatFunc;

auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
SymbolTable::lookupSymbolIn(module, name));
// Forward declare function if it hasn't already been
Expand All @@ -60,7 +87,8 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
}
assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));

rewriter.replaceOpWithNewOp<func::CallOp>(op, name, type, op->getOperands());
rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
op->getOperands());

return success();
}
Expand All @@ -79,6 +107,8 @@ void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns,
"csinf", "csin", benefit);
patterns.add<ScalarOpToLibmCall<complex::ConjOp>>(patterns.getContext(),
"conjf", "conj", benefit);
patterns.add<ScalarOpToLibmCall<complex::AbsOp, FloatTypeResolver>>(
patterns.getContext(), "cabsf", "cabs", benefit);
}

namespace {
Expand All @@ -96,7 +126,8 @@ void ConvertComplexToLibmPass::runOnOperation() {

ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp>();
target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp,
complex::AbsOp>();
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir
Expand Up @@ -9,6 +9,7 @@
// CHECK-DAG: @ccos(complex<f64>) -> complex<f64>
// CHECK-DAG: @csin(complex<f64>) -> complex<f64>
// CHECK-DAG: @conj(complex<f64>) -> complex<f64>
// CHECK-DAG: @cabs(complex<f64>) -> f64

// CHECK-LABEL: func @cpow_caller
// CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
Expand Down Expand Up @@ -80,4 +81,16 @@ func.func @conj_caller(%float: complex<f32>, %double: complex<f64>) -> (complex<
%double_result = complex.conj %double : complex<f64>
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
return %float_result, %double_result : complex<f32>, complex<f64>
}

// CHECK-LABEL: func @cabs_caller
// CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
// CHECK-SAME: %[[DOUBLE:.*]]: complex<f64>
func.func @cabs_caller(%float: complex<f32>, %double: complex<f64>) -> (f32, f64) {
// CHECK: %[[FLOAT_RESULT:.*]] = call @cabsf(%[[FLOAT]])
%float_result = complex.abs %float : complex<f32>
// CHECK: %[[DOUBLE_RESULT:.*]] = call @cabs(%[[DOUBLE]])
%double_result = complex.abs %double : complex<f64>
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
return %float_result, %double_result : f32, f64
}

0 comments on commit eaba6e0

Please sign in to comment.