Skip to content

Commit

Permalink
[mlir] Add conversion and tests for complex.[sqrt|atan2] to Arith.
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D126799
  • Loading branch information
pifon2a committed Jun 1, 2022
1 parent 86f9cf8 commit f711785
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 1 deletion.
112 changes: 112 additions & 0 deletions mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
Expand Up @@ -44,6 +44,49 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
}
};

// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;

LogicalResult
matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto type = op.getType().cast<ComplexType>();
Type elementType = type.getElementType();

Value lhs = adaptor.getLhs();
Value rhs = adaptor.getRhs();

Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
Value rhsSquaredPlusLhsSquared =
b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
Value sqrtOfRhsSquaredPlusLhsSquared =
b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);

Value zero =
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
Value one = b.create<arith::ConstantOp>(elementType,
b.getFloatAttr(elementType, 1));
Value i = b.create<complex::CreateOp>(type, zero, one);
Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);

Value divResult =
b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
Value logResult = b.create<complex::LogOp>(divResult);

Value negativeOne = b.create<arith::ConstantOp>(
elementType, b.getFloatAttr(elementType, -1));
Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);

rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
return success();
}
};

template <typename ComparisonOp, arith::CmpFPredicate p>
struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
using OpConversionPattern<ComparisonOp>::OpConversionPattern;
Expand Down Expand Up @@ -700,6 +743,73 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
}
};

// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto type = op.getType().cast<ComplexType>();
Type elementType = type.getElementType();
Value arg = adaptor.getComplex();

Value zero =
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));

Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());

Value absLhs = b.create<math::AbsOp>(real);
Value absArg = b.create<complex::AbsOp>(elementType, arg);
Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);

Value half = b.create<arith::ConstantOp>(
elementType, b.getFloatAttr(elementType, 0.5));
Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half);
Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs);

Value realIsNegative =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
Value imagIsNegative =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);

Value resultReal = sqrtAddAbs;

Value imagDivTwoResultReal = b.create<arith::DivFOp>(
imag, b.create<arith::AddFOp>(resultReal, resultReal));

Value negativeResultReal = b.create<arith::NegFOp>(resultReal);

Value resultImag = b.create<arith::SelectOp>(
realIsNegative,
b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
resultReal),
imagDivTwoResultReal);

resultReal = b.create<arith::SelectOp>(
realIsNegative,
b.create<arith::DivFOp>(
imag, b.create<arith::AddFOp>(resultImag, resultImag)),
resultReal);

Value realIsZero =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
Value imagIsZero =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);

resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);

rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
}
};

struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
using OpConversionPattern<complex::SignOp>::OpConversionPattern;

Expand Down Expand Up @@ -782,6 +892,7 @@ void mlir::populateComplexToStandardConversionPatterns(
// clang-format off
patterns.add<
AbsOpConversion,
Atan2OpConversion,
ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
Expand All @@ -796,6 +907,7 @@ void mlir::populateComplexToStandardConversionPatterns(
NegOpConversion,
SignOpConversion,
SinOpConversion,
SqrtOpConversion,
TanOpConversion,
TanhOpConversion>(patterns.getContext());
// clang-format on
Expand Down
20 changes: 19 additions & 1 deletion mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file | FileCheck %s
// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file |\
// RUN: FileCheck %s

// CHECK-LABEL: func @complex_abs
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
Expand All @@ -16,6 +17,15 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {

// -----

// CHECK-LABEL: func @complex_atan2
func.func @complex_atan2(%lhs: complex<f32>,
%rhs: complex<f32>) -> complex<f32> {
%atan2 = complex.atan2 %lhs, %rhs : complex<f32>
return %atan2 : complex<f32>
}

// -----

// CHECK-LABEL: func @complex_add
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
Expand Down Expand Up @@ -645,3 +655,11 @@ func.func @complex_tanh(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] : f32
// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>

// -----

// CHECK-LABEL: func @complex_sqrt
func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
%sqrt = complex.sqrt %arg : complex<f32>
return %sqrt : complex<f32>
}
119 changes: 119 additions & 0 deletions mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
@@ -0,0 +1,119 @@
// RUN: mlir-opt %s \
// RUN: -func-bufferize -tensor-bufferize -arith-bufferize --canonicalize \
// RUN: -convert-scf-to-cf --convert-complex-to-standard \
// RUN: -convert-memref-to-llvm -convert-math-to-llvm -convert-math-to-libm \
// RUN: -convert-vector-to-llvm -convert-complex-to-llvm \
// RUN: -convert-func-to-llvm -reconcile-unrealized-casts |\
// RUN: mlir-cpu-runner \
// RUN: -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext |\
// RUN: FileCheck %s

func.func @test_unary(%input: tensor<?xcomplex<f32>>,
%func: (complex<f32>) -> complex<f32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%size = tensor.dim %input, %c0: tensor<?xcomplex<f32>>

scf.for %i = %c0 to %size step %c1 {
%elem = tensor.extract %input[%i]: tensor<?xcomplex<f32>>

%val = func.call_indirect %func(%elem) : (complex<f32>) -> complex<f32>
%real = complex.re %val : complex<f32>
%imag = complex.im %val: complex<f32>
vector.print %real : f32
vector.print %imag : f32
scf.yield
}
func.return
}

func.func @sqrt(%arg: complex<f32>) -> complex<f32> {
%sqrt = complex.sqrt %arg : complex<f32>
func.return %sqrt : complex<f32>
}

// %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...]
func.func @test_binary(%input: tensor<?xcomplex<f32>>,
%func: (complex<f32>, complex<f32>) -> complex<f32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%size = tensor.dim %input, %c0: tensor<?xcomplex<f32>>

scf.for %i = %c0 to %size step %c2 {
%lhs = tensor.extract %input[%i]: tensor<?xcomplex<f32>>
%i_next = arith.addi %i, %c1 : index
%rhs = tensor.extract %input[%i_next]: tensor<?xcomplex<f32>>

%val = func.call_indirect %func(%lhs, %rhs)
: (complex<f32>, complex<f32>) -> complex<f32>
%real = complex.re %val : complex<f32>
%imag = complex.im %val: complex<f32>
vector.print %real : f32
vector.print %imag : f32
scf.yield
}
func.return
}

func.func @atan2(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
%atan2 = complex.atan2 %lhs, %rhs : complex<f32>
func.return %atan2 : complex<f32>
}


func.func @entry() {
// complex.sqrt test
%sqrt_test = arith.constant dense<[
(-1.0, -1.0),
// CHECK: 0.455
// CHECK-NEXT: -1.098
(-1.0, 1.0),
// CHECK-NEXT: 0.455
// CHECK-NEXT: 1.098
(0.0, 0.0),
// CHECK-NEXT: 0
// CHECK-NEXT: 0
(0.0, 1.0),
// CHECK-NEXT: 0.707
// CHECK-NEXT: 0.707
(1.0, -1.0),
// CHECK-NEXT: 1.098
// CHECK-NEXT: -0.455
(1.0, 0.0),
// CHECK-NEXT: 1
// CHECK-NEXT: 0
(1.0, 1.0)
// CHECK-NEXT: 1.098
// CHECK-NEXT: 0.455
]> : tensor<7xcomplex<f32>>
%sqrt_test_cast = tensor.cast %sqrt_test
: tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>

%sqrt_func = func.constant @sqrt : (complex<f32>) -> complex<f32>
call @test_unary(%sqrt_test_cast, %sqrt_func)
: (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()

// complex.atan2 test
%atan2_test = arith.constant dense<[
(1.0, 2.0), (2.0, 1.0),
// CHECK: 0.785
// CHECK-NEXT: 0.346
(1.0, 1.0), (1.0, 0.0),
// CHECK-NEXT: 1.017
// CHECK-NEXT: 0.402
(1.0, 1.0), (1.0, 1.0)
// CHECK-NEXT: 0.785
// CHECK-NEXT: 0
]> : tensor<6xcomplex<f32>>
%atan2_test_cast = tensor.cast %atan2_test
: tensor<6xcomplex<f32>> to tensor<?xcomplex<f32>>

%atan2_func = func.constant @atan2 : (complex<f32>, complex<f32>)
-> complex<f32>
call @test_binary(%atan2_test_cast, %atan2_func)
: (tensor<?xcomplex<f32>>, (complex<f32>, complex<f32>)
-> complex<f32>) -> ()
func.return
}

0 comments on commit f711785

Please sign in to comment.