Skip to content

Commit

Permalink
[Flang] Change complex divide lowering
Browse files Browse the repository at this point in the history
Currently complex division is lowered to a fir.divc operation and the
fir.divc is later converted to a sequence of llvm operations to perform
complex division, however this causes issues for extreme values when
the calculations overflow.

This patch changes the lowering of complex division to use the Intrinsic
Call functionality to lower into library calls (for single, double,
extended and quad precisions) or an MLIR complex dialect division operation
(for half and bfloat precisions).

 A new wrapper function `genLibSplitComplexArgsCall` is written to handle
 the case of the arguments of the Complex Library calls being split to
its real and imaginary real components.

Note 1: If the Complex To Standard conversion of division operation
matures then we can use it for all precisions. Currently it has the
same issues as the conversion of fir.divc.
Note 2: A previous patch (D145808) did the same but during conversion of
the fir.divc operation. But using function calls at that stage leads to
ABI issues since the conversion to LLVM is not aware of the complex target
rewrite.
Note 3: If the patch is accepted, fir.divc can be removed from FIR. We
can use the complex.div operation where any transformation is required.

Reviewed By: vzakhari, PeteSteinfeld, DavidTruby, jeanPerier

Differential Revision: https://reviews.llvm.org/D149546
  • Loading branch information
kiranchandramohan committed May 11, 2023
1 parent 3db7d0d commit c3a0df1
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 13 deletions.
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/Builder/IntrinsicCall.h
Expand Up @@ -98,6 +98,11 @@ mlir::Value genMax(fir::FirOpBuilder &, mlir::Location,
mlir::Value genMin(fir::FirOpBuilder &, mlir::Location,
llvm::ArrayRef<mlir::Value> args);

/// Generate Complex divide with the given expected
/// result type.
mlir::Value genDivC(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type resultType, mlir::Value x, mlir::Value y);

/// Generate power function x**y with the given expected
/// result type.
mlir::Value genPow(fir::FirOpBuilder &, mlir::Location, mlir::Type resultType,
Expand Down
27 changes: 25 additions & 2 deletions flang/lib/Lower/ConvertExpr.cpp
Expand Up @@ -1080,7 +1080,16 @@ class ScalarExprLowering {
GENBIN(Multiply, Complex, fir::MulcOp)
GENBIN(Divide, Integer, mlir::arith::DivSIOp)
GENBIN(Divide, Real, mlir::arith::DivFOp)
GENBIN(Divide, Complex, fir::DivcOp)

template <int KIND>
ExtValue genval(const Fortran::evaluate::Divide<Fortran::evaluate::Type<
Fortran::common::TypeCategory::Complex, KIND>> &op) {
mlir::Type ty =
converter.genType(Fortran::common::TypeCategory::Complex, KIND);
mlir::Value lhs = genunbox(op.left());
mlir::Value rhs = genunbox(op.right());
return fir::genDivC(builder, getLoc(), ty, lhs, rhs);
}

template <Fortran::common::TypeCategory TC, int KIND>
ExtValue genval(
Expand Down Expand Up @@ -5082,7 +5091,21 @@ class ArrayExprLowering {
GENBIN(Multiply, Complex, fir::MulcOp)
GENBIN(Divide, Integer, mlir::arith::DivSIOp)
GENBIN(Divide, Real, mlir::arith::DivFOp)
GENBIN(Divide, Complex, fir::DivcOp)

template <int KIND>
CC genarr(const Fortran::evaluate::Divide<Fortran::evaluate::Type<
Fortran::common::TypeCategory::Complex, KIND>> &x) {
mlir::Location loc = getLoc();
mlir::Type ty =
converter.genType(Fortran::common::TypeCategory::Complex, KIND);
auto lf = genarr(x.left());
auto rf = genarr(x.right());
return [=](IterSpace iters) -> ExtValue {
mlir::Value lhs = fir::getBase(lf(iters));
mlir::Value rhs = fir::getBase(rf(iters));
return fir::genDivC(builder, loc, ty, lhs, rhs);
};
}

template <Fortran::common::TypeCategory TC, int KIND>
CC genarr(
Expand Down
17 changes: 16 additions & 1 deletion flang/lib/Lower/ConvertExprToHLFIR.cpp
Expand Up @@ -948,7 +948,22 @@ GENBIN(Multiply, Real, mlir::arith::MulFOp)
GENBIN(Multiply, Complex, fir::MulcOp)
GENBIN(Divide, Integer, mlir::arith::DivSIOp)
GENBIN(Divide, Real, mlir::arith::DivFOp)
GENBIN(Divide, Complex, fir::DivcOp)

template <int KIND>
struct BinaryOp<Fortran::evaluate::Divide<
Fortran::evaluate::Type<Fortran::common::TypeCategory::Complex, KIND>>> {
using Op = Fortran::evaluate::Divide<
Fortran::evaluate::Type<Fortran::common::TypeCategory::Complex, KIND>>;
static hlfir::EntityWithAttributes gen(mlir::Location loc,
fir::FirOpBuilder &builder, const Op &,
hlfir::Entity lhs, hlfir::Entity rhs) {
mlir::Type ty = Fortran::lower::getFIRType(
builder.getContext(), Fortran::common::TypeCategory::Complex, KIND,
/*params=*/std::nullopt);
return hlfir::EntityWithAttributes{
fir::genDivC(builder, loc, ty, lhs, rhs)};
}
};

template <Fortran::common::TypeCategory TC, int KIND>
struct BinaryOp<Fortran::evaluate::Power<Fortran::evaluate::Type<TC, KIND>>> {
Expand Down
69 changes: 69 additions & 0 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Expand Up @@ -1184,6 +1184,54 @@ static mlir::Value genLibCall(fir::FirOpBuilder &builder, mlir::Location loc,
return libCall.getResult(0);
}

static mlir::Value genLibSplitComplexArgsCall(
fir::FirOpBuilder &builder, mlir::Location loc, llvm::StringRef libFuncName,
mlir::FunctionType libFuncType, llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 2 && "Incorrect #args to genLibSplitComplexArgsCall");

auto getSplitComplexArgsType = [&builder, &args]() -> mlir::FunctionType {
mlir::Type ctype = args[0].getType();
auto fKind = ctype.cast<fir::ComplexType>().getFKind();
mlir::Type ftype;

if (fKind == 2)
ftype = builder.getF16Type();
else if (fKind == 3)
ftype = builder.getBF16Type();
else if (fKind == 4)
ftype = builder.getF32Type();
else if (fKind == 8)
ftype = builder.getF64Type();
else if (fKind == 10)
ftype = builder.getF80Type();
else if (fKind == 16)
ftype = builder.getF128Type();
else
assert(0 && "Unsupported Complex Type");

return builder.getFunctionType({ftype, ftype, ftype, ftype}, {ctype});
};

llvm::SmallVector<mlir::Value, 4> splitArgs;
mlir::Value cplx1 = args[0];
auto real1 = fir::factory::Complex{builder, loc}.extractComplexPart(
cplx1, /*isImagPart=*/false);
splitArgs.push_back(real1);
auto imag1 = fir::factory::Complex{builder, loc}.extractComplexPart(
cplx1, /*isImagPart=*/true);
splitArgs.push_back(imag1);
mlir::Value cplx2 = args[1];
auto real2 = fir::factory::Complex{builder, loc}.extractComplexPart(
cplx2, /*isImagPart=*/false);
splitArgs.push_back(real2);
auto imag2 = fir::factory::Complex{builder, loc}.extractComplexPart(
cplx2, /*isImagPart=*/true);
splitArgs.push_back(imag2);

return genLibCall(builder, loc, libFuncName, getSplitComplexArgsType(),
splitArgs);
}

template <typename T>
static mlir::Value genMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
llvm::StringRef mathLibFuncName,
Expand Down Expand Up @@ -1345,6 +1393,22 @@ static constexpr MathOperation mathOperations[] = {
{"cosh", "cosh", genF64F64FuncType, genLibCall},
{"cosh", "ccoshf", genComplexComplexFuncType<4>, genLibCall},
{"cosh", "ccosh", genComplexComplexFuncType<8>, genLibCall},
{"divc",
{},
genComplexComplexComplexFuncType<2>,
genComplexMathOp<mlir::complex::DivOp>},
{"divc",
{},
genComplexComplexComplexFuncType<3>,
genComplexMathOp<mlir::complex::DivOp>},
{"divc", "__divsc3", genComplexComplexComplexFuncType<4>,
genLibSplitComplexArgsCall},
{"divc", "__divdc3", genComplexComplexComplexFuncType<8>,
genLibSplitComplexArgsCall},
{"divc", "__divxc3", genComplexComplexComplexFuncType<10>,
genLibSplitComplexArgsCall},
{"divc", "__divtc3", genComplexComplexComplexFuncType<16>,
genLibSplitComplexArgsCall},
{"erf", "erff", genF32F32FuncType, genMathOp<mlir::math::ErfOp>},
{"erf", "erf", genF64F64FuncType, genMathOp<mlir::math::ErfOp>},
{"erfc", "erfcf", genF32F32FuncType, genLibCall},
Expand Down Expand Up @@ -5661,6 +5725,11 @@ mlir::Value fir::genMin(fir::FirOpBuilder &builder, mlir::Location loc,
args);
}

mlir::Value fir::genDivC(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value x, mlir::Value y) {
return IntrinsicLibrary{builder, loc}.genRuntimeCall("divc", type, {x, y});
}

mlir::Value fir::genPow(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value x, mlir::Value y) {
// TODO: since there is no libm version of pow with integer exponent,
Expand Down
7 changes: 5 additions & 2 deletions flang/test/Lower/HLFIR/binary-ops.f90
Expand Up @@ -131,8 +131,11 @@ subroutine complex_div(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.complex<4>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
! CHECK: %[[VAL_8:.*]] = fir.divc %[[VAL_6]], %[[VAL_7]] : !fir.complex<4>

! CHECK: %[[VAL_8:.*]] = fir.extract_value %[[VAL_6]], [0 : index] : (!fir.complex<4>) -> f32
! CHECK: %[[VAL_9:.*]] = fir.extract_value %[[VAL_6]], [1 : index] : (!fir.complex<4>) -> f32
! CHECK: %[[VAL_10:.*]] = fir.extract_value %[[VAL_7]], [0 : index] : (!fir.complex<4>) -> f32
! CHECK: %[[VAL_11:.*]] = fir.extract_value %[[VAL_7]], [1 : index] : (!fir.complex<4>) -> f32
! CHECK: %[[VAL_12:.*]] = fir.call @__divsc3(%[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]]) fastmath<contract> : (f32, f32, f32, f32) -> !fir.complex<4>

subroutine int_power(x, y, z)
integer :: x, y, z
Expand Down
6 changes: 5 additions & 1 deletion flang/test/Lower/assignment.f90
Expand Up @@ -251,7 +251,11 @@ real function divf(a, b)
! CHECK: %[[FCTRES:.*]] = fir.alloca !fir.complex<4>
! CHECK: %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[DIV:.*]] = fir.divc %[[A_VAL]], %[[B_VAL]] : !fir.complex<4>
! CHECK: %[[A_REAL:.*]] = fir.extract_value %[[A_VAL]], [0 : index] : (!fir.complex<4>) -> f32
! CHECK: %[[A_IMAG:.*]] = fir.extract_value %[[A_VAL]], [1 : index] : (!fir.complex<4>) -> f32
! CHECK: %[[B_REAL:.*]] = fir.extract_value %[[B_VAL]], [0 : index] : (!fir.complex<4>) -> f32
! CHECK: %[[B_IMAG:.*]] = fir.extract_value %[[B_VAL]], [1 : index] : (!fir.complex<4>) -> f32
! CHECK: %[[DIV:.*]] = fir.call @__divsc3(%[[A_REAL]], %[[A_IMAG]], %[[B_REAL]], %[[B_IMAG]]) fastmath<contract> : (f32, f32, f32, f32) -> !fir.complex<4>
! CHECK: fir.store %[[DIV]] to %[[FCTRES]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[RET:.*]] = fir.load %[[FCTRES]] : !fir.ref<!fir.complex<4>>
! CHECK: return %[[RET]] : !fir.complex<4>
Expand Down
93 changes: 86 additions & 7 deletions flang/test/Lower/complex-operations.f90
Expand Up @@ -27,11 +27,90 @@ subroutine mul_test(a,b,c)
a = b * c
end subroutine mul_test

! CHECK-LABEL: @_QPdiv_test
subroutine div_test(a,b,c)
complex :: a, b, c
! CHECK-NOT: fir.extract_value
! CHECK-NOT: fir.insert_value
! CHECK: fir.divc {{.*}}: !fir.complex
! CHECK-LABEL: @_QPdiv_test_half
! CHECK-SAME: %[[AREF:.*]]: !fir.ref<!fir.complex<2>> {{.*}}, %[[BREF:.*]]: !fir.ref<!fir.complex<2>> {{.*}}, %[[CREF:.*]]: !fir.ref<!fir.complex<2>> {{.*}})
! CHECK: %[[BVAL:.*]] = fir.load %[[BREF]] : !fir.ref<!fir.complex<2>>
! CHECK: %[[CVAL:.*]] = fir.load %[[CREF]] : !fir.ref<!fir.complex<2>>
! CHECK: %[[BVAL_CVT:.*]] = fir.convert %[[BVAL]] : (!fir.complex<2>) -> complex<f16>
! CHECK: %[[CVAL_CVT:.*]] = fir.convert %[[CVAL]] : (!fir.complex<2>) -> complex<f16>
! CHECK: %[[AVAL_CVT:.*]] = complex.div %[[BVAL_CVT]], %[[CVAL_CVT]] : complex<f16>
! CHECK: %[[AVAL:.*]] = fir.convert %[[AVAL_CVT]] : (complex<f16>) -> !fir.complex<2>
! CHECK: fir.store %[[AVAL]] to %[[AREF]] : !fir.ref<!fir.complex<2>>
subroutine div_test_half(a,b,c)
complex(kind=2) :: a, b, c
a = b / c
end subroutine div_test_half

! CHECK-LABEL: @_QPdiv_test_bfloat
! CHECK-SAME: %[[AREF:.*]]: !fir.ref<!fir.complex<3>> {{.*}}, %[[BREF:.*]]: !fir.ref<!fir.complex<3>> {{.*}}, %[[CREF:.*]]: !fir.ref<!fir.complex<3>> {{.*}})
! CHECK: %[[BVAL:.*]] = fir.load %[[BREF]] : !fir.ref<!fir.complex<3>>
! CHECK: %[[CVAL:.*]] = fir.load %[[CREF]] : !fir.ref<!fir.complex<3>>
! CHECK: %[[BVAL_CVT:.*]] = fir.convert %[[BVAL]] : (!fir.complex<3>) -> complex<bf16>
! CHECK: %[[CVAL_CVT:.*]] = fir.convert %[[CVAL]] : (!fir.complex<3>) -> complex<bf16>
! CHECK: %[[AVAL_CVT:.*]] = complex.div %[[BVAL_CVT]], %[[CVAL_CVT]] : complex<bf16>
! CHECK: %[[AVAL:.*]] = fir.convert %[[AVAL_CVT]] : (complex<bf16>) -> !fir.complex<3>
! CHECK: fir.store %[[AVAL]] to %[[AREF]] : !fir.ref<!fir.complex<3>>
subroutine div_test_bfloat(a,b,c)
complex(kind=3) :: a, b, c
a = b / c
end subroutine div_test_bfloat

! CHECK-LABEL: @_QPdiv_test_single
! CHECK-SAME: %[[AREF:.*]]: !fir.ref<!fir.complex<4>> {{.*}}, %[[BREF:.*]]: !fir.ref<!fir.complex<4>> {{.*}}, %[[CREF:.*]]: !fir.ref<!fir.complex<4>> {{.*}})
! CHECK: %[[BVAL:.*]] = fir.load %[[BREF]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[CVAL:.*]] = fir.load %[[CREF]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[BREAL:.*]] = fir.extract_value %[[BVAL]], [0 : index] : (!fir.complex<4>) -> f32
! CHECK: %[[BIMAG:.*]] = fir.extract_value %[[BVAL]], [1 : index] : (!fir.complex<4>) -> f32
! CHECK: %[[CREAL:.*]] = fir.extract_value %[[CVAL]], [0 : index] : (!fir.complex<4>) -> f32
! CHECK: %[[CIMAG:.*]] = fir.extract_value %[[CVAL]], [1 : index] : (!fir.complex<4>) -> f32
! CHECK: %[[AVAL:.*]] = fir.call @__divsc3(%[[BREAL]], %[[BIMAG]], %[[CREAL]], %[[CIMAG]]) fastmath<contract> : (f32, f32, f32, f32) -> !fir.complex<4>
! CHECK: fir.store %[[AVAL]] to %[[AREF]] : !fir.ref<!fir.complex<4>>
subroutine div_test_single(a,b,c)
complex(kind=4) :: a, b, c
a = b / c
end subroutine div_test_single

! CHECK-LABEL: @_QPdiv_test_double
! CHECK-SAME: %[[AREF:.*]]: !fir.ref<!fir.complex<8>> {{.*}}, %[[BREF:.*]]: !fir.ref<!fir.complex<8>> {{.*}}, %[[CREF:.*]]: !fir.ref<!fir.complex<8>> {{.*}})
! CHECK: %[[BVAL:.*]] = fir.load %[[BREF]] : !fir.ref<!fir.complex<8>>
! CHECK: %[[CVAL:.*]] = fir.load %[[CREF]] : !fir.ref<!fir.complex<8>>
! CHECK: %[[BREAL:.*]] = fir.extract_value %[[BVAL]], [0 : index] : (!fir.complex<8>) -> f64
! CHECK: %[[BIMAG:.*]] = fir.extract_value %[[BVAL]], [1 : index] : (!fir.complex<8>) -> f64
! CHECK: %[[CREAL:.*]] = fir.extract_value %[[CVAL]], [0 : index] : (!fir.complex<8>) -> f64
! CHECK: %[[CIMAG:.*]] = fir.extract_value %[[CVAL]], [1 : index] : (!fir.complex<8>) -> f64
! CHECK: %[[AVAL:.*]] = fir.call @__divdc3(%[[BREAL]], %[[BIMAG]], %[[CREAL]], %[[CIMAG]]) fastmath<contract> : (f64, f64, f64, f64) -> !fir.complex<8>
! CHECK: fir.store %[[AVAL]] to %[[AREF]] : !fir.ref<!fir.complex<8>>
subroutine div_test_double(a,b,c)
complex(kind=8) :: a, b, c
a = b / c
end subroutine div_test_double

! CHECK-LABEL: @_QPdiv_test_extended
! CHECK-SAME: %[[AREF:.*]]: !fir.ref<!fir.complex<10>> {{.*}}, %[[BREF:.*]]: !fir.ref<!fir.complex<10>> {{.*}}, %[[CREF:.*]]: !fir.ref<!fir.complex<10>> {{.*}})
! CHECK: %[[BVAL:.*]] = fir.load %[[BREF]] : !fir.ref<!fir.complex<10>>
! CHECK: %[[CVAL:.*]] = fir.load %[[CREF]] : !fir.ref<!fir.complex<10>>
! CHECK: %[[BREAL:.*]] = fir.extract_value %[[BVAL]], [0 : index] : (!fir.complex<10>) -> f80
! CHECK: %[[BIMAG:.*]] = fir.extract_value %[[BVAL]], [1 : index] : (!fir.complex<10>) -> f80
! CHECK: %[[CREAL:.*]] = fir.extract_value %[[CVAL]], [0 : index] : (!fir.complex<10>) -> f80
! CHECK: %[[CIMAG:.*]] = fir.extract_value %[[CVAL]], [1 : index] : (!fir.complex<10>) -> f80
! CHECK: %[[AVAL:.*]] = fir.call @__divxc3(%[[BREAL]], %[[BIMAG]], %[[CREAL]], %[[CIMAG]]) fastmath<contract> : (f80, f80, f80, f80) -> !fir.complex<10>
! CHECK: fir.store %[[AVAL]] to %[[AREF]] : !fir.ref<!fir.complex<10>>
subroutine div_test_extended(a,b,c)
complex(kind=10) :: a, b, c
a = b / c
end subroutine div_test_extended

! CHECK-LABEL: @_QPdiv_test_quad
! CHECK-SAME: %[[AREF:.*]]: !fir.ref<!fir.complex<16>> {{.*}}, %[[BREF:.*]]: !fir.ref<!fir.complex<16>> {{.*}}, %[[CREF:.*]]: !fir.ref<!fir.complex<16>> {{.*}})
! CHECK: %[[BVAL:.*]] = fir.load %[[BREF]] : !fir.ref<!fir.complex<16>>
! CHECK: %[[CVAL:.*]] = fir.load %[[CREF]] : !fir.ref<!fir.complex<16>>
! CHECK: %[[BREAL:.*]] = fir.extract_value %[[BVAL]], [0 : index] : (!fir.complex<16>) -> f128
! CHECK: %[[BIMAG:.*]] = fir.extract_value %[[BVAL]], [1 : index] : (!fir.complex<16>) -> f128
! CHECK: %[[CREAL:.*]] = fir.extract_value %[[CVAL]], [0 : index] : (!fir.complex<16>) -> f128
! CHECK: %[[CIMAG:.*]] = fir.extract_value %[[CVAL]], [1 : index] : (!fir.complex<16>) -> f128
! CHECK: %[[AVAL:.*]] = fir.call @__divtc3(%[[BREAL]], %[[BIMAG]], %[[CREAL]], %[[CIMAG]]) fastmath<contract> : (f128, f128, f128, f128) -> !fir.complex<16>
! CHECK: fir.store %[[AVAL]] to %[[AREF]] : !fir.ref<!fir.complex<16>>
subroutine div_test_quad(a,b,c)
complex(kind=16) :: a, b, c
a = b / c
end subroutine div_test
end subroutine div_test_quad

0 comments on commit c3a0df1

Please sign in to comment.