From eaeaf2f4e96136367cb15ba18fd93ae493eff40f Mon Sep 17 00:00:00 2001 From: s-watanabe314 Date: Tue, 28 Jan 2025 13:59:17 +0900 Subject: [PATCH] [mlir][complex] Add complex-range option and select complex division algorithm This patch adds the `complex-range` option and two calculation methods for complex number division (algebraic method and Smith's algorithm) to both the `ComplexToLLVM` and `ComplexToStandard` passes, allowing the calculation method to be controlled by the option. See also the discussion in the following discourse post. https://discourse.llvm.org/t/question-and-proposal-regarding-complex-number-division-algorithm-in-the-complex-dialect/83772 --- .../ComplexCommon/DivisionConverter.h | 48 ++ .../Conversion/ComplexToLLVM/ComplexToLLVM.h | 8 +- .../ComplexToStandard/ComplexToStandard.h | 9 +- mlir/include/mlir/Conversion/Passes.td | 22 + .../mlir/Dialect/Complex/IR/CMakeLists.txt | 2 + .../include/mlir/Dialect/Complex/IR/Complex.h | 6 + .../mlir/Dialect/Complex/IR/ComplexBase.td | 16 + mlir/lib/Conversion/CMakeLists.txt | 1 + .../Conversion/ComplexCommon/CMakeLists.txt | 12 + .../ComplexCommon/DivisionConverter.cpp | 456 ++++++++++++++++++ .../Conversion/ComplexToLLVM/CMakeLists.txt | 1 + .../ComplexToLLVM/ComplexToLLVM.cpp | 44 +- .../ComplexToStandard/CMakeLists.txt | 1 + .../ComplexToStandard/ComplexToStandard.cpp | 243 ++-------- .../ComplexToLLVM/complex-range-option.mlir | 303 ++++++++++++ .../ComplexToLLVM/convert-to-llvm.mlir | 4 +- .../ComplexToLLVM/full-conversion.mlir | 2 +- .../complex-range-option.mlir | 277 +++++++++++ 18 files changed, 1223 insertions(+), 232 deletions(-) create mode 100644 mlir/include/mlir/Conversion/ComplexCommon/DivisionConverter.h create mode 100644 mlir/lib/Conversion/ComplexCommon/CMakeLists.txt create mode 100644 mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp create mode 100644 mlir/test/Conversion/ComplexToLLVM/complex-range-option.mlir create mode 100644 mlir/test/Conversion/ComplexToStandard/complex-range-option.mlir diff --git a/mlir/include/mlir/Conversion/ComplexCommon/DivisionConverter.h b/mlir/include/mlir/Conversion/ComplexCommon/DivisionConverter.h new file mode 100644 index 0000000000000..df97dc2c4eb7d --- /dev/null +++ b/mlir/include/mlir/Conversion/ComplexCommon/DivisionConverter.h @@ -0,0 +1,48 @@ +//===- DivisionConverter.h - Complex division conversion ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_COMPLEXCOMMON_DIVISIONCONVERTER_H +#define MLIR_CONVERSION_COMPLEXCOMMON_DIVISIONCONVERTER_H + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +namespace mlir { +namespace complex { +/// convert a complex division to the LLVM dialect using algebraic method +void convertDivToLLVMUsingAlgebraic(ConversionPatternRewriter &rewriter, + Location loc, Value lhsRe, Value lhsIm, + Value rhsRe, Value rhsIm, + LLVM::FastmathFlagsAttr fmf, + Value *resultRe, Value *resultIm); + +/// convert a complex division to the arith/math dialects using algebraic method +void convertDivToStandardUsingAlgebraic(ConversionPatternRewriter &rewriter, + Location loc, Value lhsRe, Value lhsIm, + Value rhsRe, Value rhsIm, + arith::FastMathFlagsAttr fmf, + Value *resultRe, Value *resultIm); + +/// convert a complex division to the LLVM dialect using Smith's method +void convertDivToLLVMUsingRangeReduction(ConversionPatternRewriter &rewriter, + Location loc, Value lhsRe, Value lhsIm, + Value rhsRe, Value rhsIm, + LLVM::FastmathFlagsAttr fmf, + Value *resultRe, Value *resultIm); + +/// convert a complex division to the arith/math dialects using Smith's method +void convertDivToStandardUsingRangeReduction( + ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, + Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, + Value *resultIm); + +} // namespace complex +} // namespace mlir + +#endif // MLIR_CONVERSION_COMPLEXCOMMON_DIVISIONCONVERTER_H diff --git a/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h index 8266442cf5db8..1db75563fe304 100644 --- a/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h +++ b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h @@ -9,6 +9,8 @@ #define MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_ #include "mlir/Conversion/LLVMCommon/StructBuilder.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Pass/Pass.h" namespace mlir { class DialectRegistry; @@ -39,8 +41,10 @@ class ComplexStructBuilder : public StructBuilder { }; /// Populate the given list with patterns that convert from Complex to LLVM. -void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter, - RewritePatternSet &patterns); +void populateComplexToLLVMConversionPatterns( + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + mlir::complex::ComplexRangeFlags complexRange = + mlir::complex::ComplexRangeFlags::basic); void registerConvertComplexToLLVMInterface(DialectRegistry ®istry); diff --git a/mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h b/mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h index 39c4a1ae54617..30b86cac9cd4e 100644 --- a/mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h +++ b/mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h @@ -8,6 +8,8 @@ #ifndef MLIR_CONVERSION_COMPLEXTOSTANDARD_COMPLEXTOSTANDARD_H_ #define MLIR_CONVERSION_COMPLEXTOSTANDARD_COMPLEXTOSTANDARD_H_ +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Pass/Pass.h" #include namespace mlir { @@ -18,10 +20,15 @@ class Pass; #include "mlir/Conversion/Passes.h.inc" /// Populate the given list with patterns that convert from Complex to Standard. -void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns); +void populateComplexToStandardConversionPatterns( + RewritePatternSet &patterns, + mlir::complex::ComplexRangeFlags complexRange = + mlir::complex::ComplexRangeFlags::improved); /// Create a pass to convert Complex operations to the Standard dialect. std::unique_ptr createConvertComplexToStandardPass(); +std::unique_ptr +createConvertComplexToStandardPass(ConvertComplexToStandardOptions options); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index ff79a1226c047..5203838a6eb35 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -274,6 +274,17 @@ def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> { def ConvertComplexToLLVMPass : Pass<"convert-complex-to-llvm"> { let summary = "Convert Complex dialect to LLVM dialect"; let dependentDialects = ["LLVM::LLVMDialect"]; + + let options = [ + Option<"complexRange", "complex-range", "::mlir::complex::ComplexRangeFlags", + /*default=*/"::mlir::complex::ComplexRangeFlags::basic", + "Control the intermediate calculation of complex number division", + [{::llvm::cl::values( + clEnumValN(::mlir::complex::ComplexRangeFlags::improved, "improved", "improved"), + clEnumValN(::mlir::complex::ComplexRangeFlags::basic, "basic", "basic (default)"), + clEnumValN(::mlir::complex::ComplexRangeFlags::none, "none", "none") + )}]>, + ]; } //===----------------------------------------------------------------------===// @@ -308,6 +319,17 @@ def ConvertComplexToStandard : Pass<"convert-complex-to-standard"> { let summary = "Convert Complex dialect to standard dialect"; let constructor = "mlir::createConvertComplexToStandardPass()"; let dependentDialects = ["math::MathDialect"]; + + let options = [ + Option<"complexRange", "complex-range", "::mlir::complex::ComplexRangeFlags", + /*default=*/"::mlir::complex::ComplexRangeFlags::improved", + "Control the intermediate calculation of complex number division", + [{::llvm::cl::values( + clEnumValN(::mlir::complex::ComplexRangeFlags::improved, "improved", "improved (default)"), + clEnumValN(::mlir::complex::ComplexRangeFlags::basic, "basic", "basic"), + clEnumValN(::mlir::complex::ComplexRangeFlags::none, "none", "none") + )}]>, + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt index f41888d01a2fd..837664e25b3c2 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt @@ -2,6 +2,8 @@ add_mlir_dialect(ComplexOps complex) add_mlir_doc(ComplexOps ComplexOps Dialects/ -gen-dialect-doc -dialect=complex) set(LLVM_TARGET_DEFINITIONS ComplexAttributes.td) +mlir_tablegen(ComplexEnums.h.inc -gen-enum-decls) +mlir_tablegen(ComplexEnums.cpp.inc -gen-enum-defs) mlir_tablegen(ComplexAttributes.h.inc -gen-attrdef-decls) mlir_tablegen(ComplexAttributes.cpp.inc -gen-attrdef-defs) add_public_tablegen_target(MLIRComplexAttributesIncGen) diff --git a/mlir/include/mlir/Dialect/Complex/IR/Complex.h b/mlir/include/mlir/Dialect/Complex/IR/Complex.h index fb024fa2e951e..be7e50d656385 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/Complex.h +++ b/mlir/include/mlir/Dialect/Complex/IR/Complex.h @@ -22,6 +22,12 @@ #include "mlir/Dialect/Complex/IR/ComplexOpsDialect.h.inc" +//===----------------------------------------------------------------------===// +// Complex Dialect Enums +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Complex/IR/ComplexEnums.h.inc" + //===----------------------------------------------------------------------===// // Complex Dialect Operations //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td index 31135fc8c8ce7..c8af498f44829 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td @@ -9,6 +9,7 @@ #ifndef COMPLEX_BASE #define COMPLEX_BASE +include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" def Complex_Dialect : Dialect { @@ -24,4 +25,19 @@ def Complex_Dialect : Dialect { let useDefaultAttributePrinterParser = 1; } +//===----------------------------------------------------------------------===// +// Complex_ComplexRangeFlags +//===----------------------------------------------------------------------===// + +def Complex_CRF_improved : I32BitEnumAttrCaseBit<"improved", 0>; +def Complex_CRF_basic : I32BitEnumAttrCaseBit<"basic", 1>; +def Complex_CRF_none : I32BitEnumAttrCaseBit<"none", 2>; + +def Complex_ComplexRangeFlags : I32BitEnumAttr< + "ComplexRangeFlags", + "Complex range flags", + [Complex_CRF_improved, Complex_CRF_basic, Complex_CRF_none]> { + let cppNamespace = "::mlir::complex"; +} + #endif // COMPLEX_BASE diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 0bd08ec6333e6..fa904a33ebf96 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -11,6 +11,7 @@ add_subdirectory(ArmSMEToSCF) add_subdirectory(ArmSMEToLLVM) add_subdirectory(AsyncToLLVM) add_subdirectory(BufferizationToMemRef) +add_subdirectory(ComplexCommon) add_subdirectory(ComplexToLibm) add_subdirectory(ComplexToLLVM) add_subdirectory(ComplexToSPIRV) diff --git a/mlir/lib/Conversion/ComplexCommon/CMakeLists.txt b/mlir/lib/Conversion/ComplexCommon/CMakeLists.txt new file mode 100644 index 0000000000000..2560a4a5631f4 --- /dev/null +++ b/mlir/lib/Conversion/ComplexCommon/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_conversion_library(MLIRComplexDivisionConversion + DivisionConverter.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRComplexDialect + MLIRLLVMDialect + MLIRMathDialect + ) diff --git a/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp b/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp new file mode 100644 index 0000000000000..cce9cc77c3a4c --- /dev/null +++ b/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp @@ -0,0 +1,456 @@ +//===- DivisionConverter.cpp - Complex division conversion ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements functions for two different complex number division +// algorithms, the `algebraic formula` and `Smith's range reduction method`. +// These are used in two conversions: `ComplexToLLVM` and `ComplexToStandard`. +// When modifying the algorithms, both `ToLLVM` and `ToStandard` must be +// changed. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ComplexCommon/DivisionConverter.h" +#include "mlir/Dialect/Math/IR/Math.h" + +using namespace mlir; + +void mlir::complex::convertDivToLLVMUsingAlgebraic( + ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, + Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, + Value *resultIm) { + Value rhsSqNorm = rewriter.create( + loc, rewriter.create(loc, rhsRe, rhsRe, fmf), + rewriter.create(loc, rhsIm, rhsIm, fmf), fmf); + + Value realNumerator = rewriter.create( + loc, rewriter.create(loc, lhsRe, rhsRe, fmf), + rewriter.create(loc, lhsIm, rhsIm, fmf), fmf); + + Value imagNumerator = rewriter.create( + loc, rewriter.create(loc, lhsIm, rhsRe, fmf), + rewriter.create(loc, lhsRe, rhsIm, fmf), fmf); + + *resultRe = rewriter.create(loc, realNumerator, rhsSqNorm, fmf); + *resultIm = rewriter.create(loc, imagNumerator, rhsSqNorm, fmf); +} + +void mlir::complex::convertDivToStandardUsingAlgebraic( + ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, + Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, + Value *resultIm) { + Value rhsSqNorm = rewriter.create( + loc, rewriter.create(loc, rhsRe, rhsRe, fmf), + rewriter.create(loc, rhsIm, rhsIm, fmf), fmf); + + Value realNumerator = rewriter.create( + loc, rewriter.create(loc, lhsRe, rhsRe, fmf), + rewriter.create(loc, lhsIm, rhsIm, fmf), fmf); + Value imagNumerator = rewriter.create( + loc, rewriter.create(loc, lhsIm, rhsRe, fmf), + rewriter.create(loc, lhsRe, rhsIm, fmf), fmf); + + *resultRe = + rewriter.create(loc, realNumerator, rhsSqNorm, fmf); + *resultIm = + rewriter.create(loc, imagNumerator, rhsSqNorm, fmf); +}; + +// Smith's algorithm to divide complex numbers. It is just a bit smarter +// way to compute the following algebraic formula: +// (lhsRe + lhsIm * i) / (rhsRe + rhsIm * i) +// = (lhsRe + lhsIm * i) (rhsRe - rhsIm * i) / +// ((rhsRe + rhsIm * i)(rhsRe - rhsIm * i)) +// = ((lhsRe * rhsRe + lhsIm * rhsIm) + +// (lhsIm * rhsRe - lhsRe * rhsIm) * i) / ||rhs||^2 +// +// Depending on whether |rhsRe| < |rhsIm| we compute either +// rhsRealImagRatio = rhsRe / rhsIm +// rhsRealImagDenom = rhsIm + rhsRe * rhsRealImagRatio +// resultRe = (lhsRe * rhsRealImagRatio + lhsIm) / +// rhsRealImagDenom +// resultIm = (lhsIm * rhsRealImagRatio - lhsRe) / +// rhsRealImagDenom +// +// or +// +// rhsImagRealRatio = rhsIm / rhsRe +// rhsImagRealDenom = rhsRe + rhsIm * rhsImagRealRatio +// resultRe = (lhsRe + lhsIm * rhsImagRealRatio) / +// rhsImagRealDenom +// resultIm = (lhsIm - lhsRe * rhsImagRealRatio) / +// rhsImagRealDenom +// +// See https://dl.acm.org/citation.cfm?id=368661 for more details. + +void mlir::complex::convertDivToLLVMUsingRangeReduction( + ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, + Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, + Value *resultIm) { + auto elementType = cast(rhsRe.getType()); + + Value rhsRealImagRatio = + rewriter.create(loc, rhsRe, rhsIm, fmf); + Value rhsRealImagDenom = rewriter.create( + loc, rhsIm, + rewriter.create(loc, rhsRealImagRatio, rhsRe, fmf), fmf); + Value realNumerator1 = rewriter.create( + loc, rewriter.create(loc, lhsRe, rhsRealImagRatio, fmf), + lhsIm, fmf); + Value resultReal1 = + rewriter.create(loc, realNumerator1, rhsRealImagDenom, fmf); + Value imagNumerator1 = rewriter.create( + loc, rewriter.create(loc, lhsIm, rhsRealImagRatio, fmf), + lhsRe, fmf); + Value resultImag1 = + rewriter.create(loc, imagNumerator1, rhsRealImagDenom, fmf); + + Value rhsImagRealRatio = + rewriter.create(loc, rhsIm, rhsRe, fmf); + Value rhsImagRealDenom = rewriter.create( + loc, rhsRe, + rewriter.create(loc, rhsImagRealRatio, rhsIm, fmf), fmf); + Value realNumerator2 = rewriter.create( + loc, lhsRe, + rewriter.create(loc, lhsIm, rhsImagRealRatio, fmf), fmf); + Value resultReal2 = + rewriter.create(loc, realNumerator2, rhsImagRealDenom, fmf); + Value imagNumerator2 = rewriter.create( + loc, lhsIm, + rewriter.create(loc, lhsRe, rhsImagRealRatio, fmf), fmf); + Value resultImag2 = + rewriter.create(loc, imagNumerator2, rhsImagRealDenom, fmf); + + // Consider corner cases. + // Case 1. Zero denominator, numerator contains at most one NaN value. + Value zero = rewriter.create( + loc, elementType, rewriter.getZeroAttr(elementType)); + Value rhsRealAbs = rewriter.create(loc, rhsRe, fmf); + Value rhsRealIsZero = rewriter.create( + loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, zero); + Value rhsImagAbs = rewriter.create(loc, rhsIm, fmf); + Value rhsImagIsZero = rewriter.create( + loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, zero); + Value lhsRealIsNotNaN = + rewriter.create(loc, LLVM::FCmpPredicate::ord, lhsRe, zero); + Value lhsImagIsNotNaN = + rewriter.create(loc, LLVM::FCmpPredicate::ord, lhsIm, zero); + Value lhsContainsNotNaNValue = + rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); + Value resultIsInfinity = rewriter.create( + loc, lhsContainsNotNaNValue, + rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); + Value inf = rewriter.create( + loc, elementType, + rewriter.getFloatAttr(elementType, + APFloat::getInf(elementType.getFloatSemantics()))); + Value infWithSignOfrhsReal = + rewriter.create(loc, inf, rhsRe); + Value infinityResultReal = + rewriter.create(loc, infWithSignOfrhsReal, lhsRe, fmf); + Value infinityResultImag = + rewriter.create(loc, infWithSignOfrhsReal, lhsIm, fmf); + + // Case 2. Infinite numerator, finite denominator. + Value rhsRealFinite = rewriter.create( + loc, LLVM::FCmpPredicate::one, rhsRealAbs, inf); + Value rhsImagFinite = rewriter.create( + loc, LLVM::FCmpPredicate::one, rhsImagAbs, inf); + Value rhsFinite = + rewriter.create(loc, rhsRealFinite, rhsImagFinite); + Value lhsRealAbs = rewriter.create(loc, lhsRe, fmf); + Value lhsRealInfinite = rewriter.create( + loc, LLVM::FCmpPredicate::oeq, lhsRealAbs, inf); + Value lhsImagAbs = rewriter.create(loc, lhsIm, fmf); + Value lhsImagInfinite = rewriter.create( + loc, LLVM::FCmpPredicate::oeq, lhsImagAbs, inf); + Value lhsInfinite = + rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); + Value infNumFiniteDenom = + rewriter.create(loc, lhsInfinite, rhsFinite); + Value one = rewriter.create( + loc, elementType, rewriter.getFloatAttr(elementType, 1)); + Value lhsRealIsInfWithSign = rewriter.create( + loc, rewriter.create(loc, lhsRealInfinite, one, zero), + lhsRe); + Value lhsImagIsInfWithSign = rewriter.create( + loc, rewriter.create(loc, lhsImagInfinite, one, zero), + lhsIm); + Value lhsRealIsInfWithSignTimesrhsReal = + rewriter.create(loc, lhsRealIsInfWithSign, rhsRe, fmf); + Value lhsImagIsInfWithSignTimesrhsImag = + rewriter.create(loc, lhsImagIsInfWithSign, rhsIm, fmf); + Value resultReal3 = rewriter.create( + loc, inf, + rewriter.create(loc, lhsRealIsInfWithSignTimesrhsReal, + lhsImagIsInfWithSignTimesrhsImag, fmf), + fmf); + Value lhsRealIsInfWithSignTimesrhsImag = + rewriter.create(loc, lhsRealIsInfWithSign, rhsIm, fmf); + Value lhsImagIsInfWithSignTimesrhsReal = + rewriter.create(loc, lhsImagIsInfWithSign, rhsRe, fmf); + Value resultImag3 = rewriter.create( + loc, inf, + rewriter.create(loc, lhsImagIsInfWithSignTimesrhsReal, + lhsRealIsInfWithSignTimesrhsImag, fmf), + fmf); + + // Case 3: Finite numerator, infinite denominator. + Value lhsRealFinite = rewriter.create( + loc, LLVM::FCmpPredicate::one, lhsRealAbs, inf); + Value lhsImagFinite = rewriter.create( + loc, LLVM::FCmpPredicate::one, lhsImagAbs, inf); + Value lhsFinite = + rewriter.create(loc, lhsRealFinite, lhsImagFinite); + Value rhsRealInfinite = rewriter.create( + loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, inf); + Value rhsImagInfinite = rewriter.create( + loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, inf); + Value rhsInfinite = + rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); + Value finiteNumInfiniteDenom = + rewriter.create(loc, lhsFinite, rhsInfinite); + Value rhsRealIsInfWithSign = rewriter.create( + loc, rewriter.create(loc, rhsRealInfinite, one, zero), + rhsRe); + Value rhsImagIsInfWithSign = rewriter.create( + loc, rewriter.create(loc, rhsImagInfinite, one, zero), + rhsIm); + Value rhsRealIsInfWithSignTimeslhsReal = + rewriter.create(loc, lhsRe, rhsRealIsInfWithSign, fmf); + Value rhsImagIsInfWithSignTimeslhsImag = + rewriter.create(loc, lhsIm, rhsImagIsInfWithSign, fmf); + Value resultReal4 = rewriter.create( + loc, zero, + rewriter.create(loc, rhsRealIsInfWithSignTimeslhsReal, + rhsImagIsInfWithSignTimeslhsImag, fmf), + fmf); + Value rhsRealIsInfWithSignTimeslhsImag = + rewriter.create(loc, lhsIm, rhsRealIsInfWithSign, fmf); + Value rhsImagIsInfWithSignTimeslhsReal = + rewriter.create(loc, lhsRe, rhsImagIsInfWithSign, fmf); + Value resultImag4 = rewriter.create( + loc, zero, + rewriter.create(loc, rhsRealIsInfWithSignTimeslhsImag, + rhsImagIsInfWithSignTimeslhsReal, fmf), + fmf); + + Value realAbsSmallerThanImagAbs = rewriter.create( + loc, LLVM::FCmpPredicate::olt, rhsRealAbs, rhsImagAbs); + Value resultReal5 = rewriter.create( + loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); + Value resultImag5 = rewriter.create( + loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); + Value resultRealSpecialCase3 = rewriter.create( + loc, finiteNumInfiniteDenom, resultReal4, resultReal5); + Value resultImagSpecialCase3 = rewriter.create( + loc, finiteNumInfiniteDenom, resultImag4, resultImag5); + Value resultRealSpecialCase2 = rewriter.create( + loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); + Value resultImagSpecialCase2 = rewriter.create( + loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); + Value resultRealSpecialCase1 = rewriter.create( + loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); + Value resultImagSpecialCase1 = rewriter.create( + loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); + + Value resultRealIsNaN = rewriter.create( + loc, LLVM::FCmpPredicate::uno, resultReal5, zero); + Value resultImagIsNaN = rewriter.create( + loc, LLVM::FCmpPredicate::uno, resultImag5, zero); + Value resultIsNaN = + rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); + + *resultRe = rewriter.create( + loc, resultIsNaN, resultRealSpecialCase1, resultReal5); + *resultIm = rewriter.create( + loc, resultIsNaN, resultImagSpecialCase1, resultImag5); +} + +void mlir::complex::convertDivToStandardUsingRangeReduction( + ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, + Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, + Value *resultIm) { + auto elementType = cast(rhsRe.getType()); + + Value rhsRealImagRatio = + rewriter.create(loc, rhsRe, rhsIm, fmf); + Value rhsRealImagDenom = rewriter.create( + loc, rhsIm, + rewriter.create(loc, rhsRealImagRatio, rhsRe, fmf), fmf); + Value realNumerator1 = rewriter.create( + loc, rewriter.create(loc, lhsRe, rhsRealImagRatio, fmf), + lhsIm, fmf); + Value resultReal1 = rewriter.create(loc, realNumerator1, + rhsRealImagDenom, fmf); + Value imagNumerator1 = rewriter.create( + loc, rewriter.create(loc, lhsIm, rhsRealImagRatio, fmf), + lhsRe, fmf); + Value resultImag1 = rewriter.create(loc, imagNumerator1, + rhsRealImagDenom, fmf); + + Value rhsImagRealRatio = + rewriter.create(loc, rhsIm, rhsRe, fmf); + Value rhsImagRealDenom = rewriter.create( + loc, rhsRe, + rewriter.create(loc, rhsImagRealRatio, rhsIm, fmf), fmf); + Value realNumerator2 = rewriter.create( + loc, lhsRe, + rewriter.create(loc, lhsIm, rhsImagRealRatio, fmf), fmf); + Value resultReal2 = rewriter.create(loc, realNumerator2, + rhsImagRealDenom, fmf); + Value imagNumerator2 = rewriter.create( + loc, lhsIm, + rewriter.create(loc, lhsRe, rhsImagRealRatio, fmf), fmf); + Value resultImag2 = rewriter.create(loc, imagNumerator2, + rhsImagRealDenom, fmf); + + // Consider corner cases. + // Case 1. Zero denominator, numerator contains at most one NaN value. + Value zero = rewriter.create( + loc, elementType, rewriter.getZeroAttr(elementType)); + Value rhsRealAbs = rewriter.create(loc, rhsRe, fmf); + Value rhsRealIsZero = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); + Value rhsImagAbs = rewriter.create(loc, rhsIm, fmf); + Value rhsImagIsZero = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); + Value lhsRealIsNotNaN = rewriter.create( + loc, arith::CmpFPredicate::ORD, lhsRe, zero); + Value lhsImagIsNotNaN = rewriter.create( + loc, arith::CmpFPredicate::ORD, lhsIm, zero); + Value lhsContainsNotNaNValue = + rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); + Value resultIsInfinity = rewriter.create( + loc, lhsContainsNotNaNValue, + rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); + Value inf = rewriter.create( + loc, elementType, + rewriter.getFloatAttr(elementType, + APFloat::getInf(elementType.getFloatSemantics()))); + Value infWithSignOfRhsReal = + rewriter.create(loc, inf, rhsRe); + Value infinityResultReal = + rewriter.create(loc, infWithSignOfRhsReal, lhsRe, fmf); + Value infinityResultImag = + rewriter.create(loc, infWithSignOfRhsReal, lhsIm, fmf); + + // Case 2. Infinite numerator, finite denominator. + Value rhsRealFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); + Value rhsImagFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); + Value rhsFinite = + rewriter.create(loc, rhsRealFinite, rhsImagFinite); + Value lhsRealAbs = rewriter.create(loc, lhsRe, fmf); + Value lhsRealInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); + Value lhsImagAbs = rewriter.create(loc, lhsIm, fmf); + Value lhsImagInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); + Value lhsInfinite = + rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); + Value infNumFiniteDenom = + rewriter.create(loc, lhsInfinite, rhsFinite); + Value one = rewriter.create( + loc, elementType, rewriter.getFloatAttr(elementType, 1)); + Value lhsRealIsInfWithSign = rewriter.create( + loc, rewriter.create(loc, lhsRealInfinite, one, zero), + lhsRe); + Value lhsImagIsInfWithSign = rewriter.create( + loc, rewriter.create(loc, lhsImagInfinite, one, zero), + lhsIm); + Value lhsRealIsInfWithSignTimesRhsReal = + rewriter.create(loc, lhsRealIsInfWithSign, rhsRe, fmf); + Value lhsImagIsInfWithSignTimesRhsImag = + rewriter.create(loc, lhsImagIsInfWithSign, rhsIm, fmf); + Value resultReal3 = rewriter.create( + loc, inf, + rewriter.create(loc, lhsRealIsInfWithSignTimesRhsReal, + lhsImagIsInfWithSignTimesRhsImag, fmf), + fmf); + Value lhsRealIsInfWithSignTimesRhsImag = + rewriter.create(loc, lhsRealIsInfWithSign, rhsIm, fmf); + Value lhsImagIsInfWithSignTimesRhsReal = + rewriter.create(loc, lhsImagIsInfWithSign, rhsRe, fmf); + Value resultImag3 = rewriter.create( + loc, inf, + rewriter.create(loc, lhsImagIsInfWithSignTimesRhsReal, + lhsRealIsInfWithSignTimesRhsImag, fmf), + fmf); + + // Case 3: Finite numerator, infinite denominator. + Value lhsRealFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); + Value lhsImagFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); + Value lhsFinite = + rewriter.create(loc, lhsRealFinite, lhsImagFinite); + Value rhsRealInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); + Value rhsImagInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); + Value rhsInfinite = + rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); + Value finiteNumInfiniteDenom = + rewriter.create(loc, lhsFinite, rhsInfinite); + Value rhsRealIsInfWithSign = rewriter.create( + loc, rewriter.create(loc, rhsRealInfinite, one, zero), + rhsRe); + Value rhsImagIsInfWithSign = rewriter.create( + loc, rewriter.create(loc, rhsImagInfinite, one, zero), + rhsIm); + Value rhsRealIsInfWithSignTimesLhsReal = + rewriter.create(loc, lhsRe, rhsRealIsInfWithSign, fmf); + Value rhsImagIsInfWithSignTimesLhsImag = + rewriter.create(loc, lhsIm, rhsImagIsInfWithSign, fmf); + Value resultReal4 = rewriter.create( + loc, zero, + rewriter.create(loc, rhsRealIsInfWithSignTimesLhsReal, + rhsImagIsInfWithSignTimesLhsImag, fmf), + fmf); + Value rhsRealIsInfWithSignTimesLhsImag = + rewriter.create(loc, lhsIm, rhsRealIsInfWithSign, fmf); + Value rhsImagIsInfWithSignTimesLhsReal = + rewriter.create(loc, lhsRe, rhsImagIsInfWithSign, fmf); + Value resultImag4 = rewriter.create( + loc, zero, + rewriter.create(loc, rhsRealIsInfWithSignTimesLhsImag, + rhsImagIsInfWithSignTimesLhsReal, fmf), + fmf); + + Value realAbsSmallerThanImagAbs = rewriter.create( + loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); + Value resultReal5 = rewriter.create( + loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); + Value resultImag5 = rewriter.create( + loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); + Value resultRealSpecialCase3 = rewriter.create( + loc, finiteNumInfiniteDenom, resultReal4, resultReal5); + Value resultImagSpecialCase3 = rewriter.create( + loc, finiteNumInfiniteDenom, resultImag4, resultImag5); + Value resultRealSpecialCase2 = rewriter.create( + loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); + Value resultImagSpecialCase2 = rewriter.create( + loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); + Value resultRealSpecialCase1 = rewriter.create( + loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); + Value resultImagSpecialCase1 = rewriter.create( + loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); + + Value resultRealIsNaN = rewriter.create( + loc, arith::CmpFPredicate::UNO, resultReal5, zero); + Value resultImagIsNaN = rewriter.create( + loc, arith::CmpFPredicate::UNO, resultImag5, zero); + Value resultIsNaN = + rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); + + *resultRe = rewriter.create( + loc, resultIsNaN, resultRealSpecialCase1, resultReal5); + *resultIm = rewriter.create( + loc, resultIsNaN, resultImagSpecialCase1, resultImag5); +} diff --git a/mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt index d3a5bf2aa2f05..074bc4bc3f0d7 100644 --- a/mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRComplexToLLVM LINK_LIBS PUBLIC MLIRArithAttrToLLVMConversion + MLIRComplexDivisionConversion MLIRComplexDialect MLIRLLVMCommonConversion MLIRLLVMDialect diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp index ad86fe362076b..6956c81f2a2d3 100644 --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" +#include "mlir/Conversion/ComplexCommon/DivisionConverter.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" @@ -204,6 +205,11 @@ struct AddOpConversion : public ConvertOpToLLVMPattern { }; struct DivOpConversion : public ConvertOpToLLVMPattern { + DivOpConversion(const LLVMTypeConverter &converter, + complex::ComplexRangeFlags target) + : ConvertOpToLLVMPattern(converter), + complexRange(target) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult @@ -227,28 +233,26 @@ struct DivOpConversion : public ConvertOpToLLVMPattern { Value lhsRe = arg.lhs.real(); Value lhsIm = arg.lhs.imag(); - Value rhsSqNorm = rewriter.create( - loc, rewriter.create(loc, rhsRe, rhsRe, fmf), - rewriter.create(loc, rhsIm, rhsIm, fmf), fmf); - - Value resultReal = rewriter.create( - loc, rewriter.create(loc, lhsRe, rhsRe, fmf), - rewriter.create(loc, lhsIm, rhsIm, fmf), fmf); + Value resultRe, resultIm; - Value resultImag = rewriter.create( - loc, rewriter.create(loc, lhsIm, rhsRe, fmf), - rewriter.create(loc, lhsRe, rhsIm, fmf), fmf); + if (complexRange == complex::ComplexRangeFlags::basic || + complexRange == complex::ComplexRangeFlags::none) { + mlir::complex::convertDivToLLVMUsingAlgebraic( + rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm); + } else if (complexRange == complex::ComplexRangeFlags::improved) { + mlir::complex::convertDivToLLVMUsingRangeReduction( + rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm); + } - result.setReal( - rewriter, loc, - rewriter.create(loc, resultReal, rhsSqNorm, fmf)); - result.setImaginary( - rewriter, loc, - rewriter.create(loc, resultImag, rhsSqNorm, fmf)); + result.setReal(rewriter, loc, resultRe); + result.setImaginary(rewriter, loc, resultIm); rewriter.replaceOp(op, {result}); return success(); } + +private: + complex::ComplexRangeFlags complexRange; }; struct MulOpConversion : public ConvertOpToLLVMPattern { @@ -324,19 +328,21 @@ struct SubOpConversion : public ConvertOpToLLVMPattern { } // namespace void mlir::populateComplexToLLVMConversionPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns) { + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + complex::ComplexRangeFlags complexRange) { // clang-format off patterns.add< AbsOpConversion, AddOpConversion, ConstantOpLowering, CreateOpConversion, - DivOpConversion, ImOpConversion, MulOpConversion, ReOpConversion, SubOpConversion >(converter); + + patterns.add(converter, complexRange); // clang-format on } @@ -353,7 +359,7 @@ void ConvertComplexToLLVMPass::runOnOperation() { // Convert to the LLVM IR dialect using the converter defined above. RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); - populateComplexToLLVMConversionPatterns(converter, patterns); + populateComplexToLLVMConversionPatterns(converter, patterns, complexRange); LLVMConversionTarget target(getContext()); target.addIllegalDialect(); diff --git a/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt b/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt index e74c212b9a4f1..c3cf92f57b48d 100644 --- a/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_conversion_library(MLIRComplexToStandard LINK_LIBS PUBLIC MLIRArithDialect + MLIRComplexDivisionConversion MLIRComplexDialect MLIRIR MLIRMathDialect diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 473b1da4f701c..3df8ad47e4d33 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -8,6 +8,7 @@ #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" +#include "mlir/Conversion/ComplexCommon/DivisionConverter.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -262,6 +263,9 @@ struct CosOpConversion : public TrigonometricOpConversion { }; struct DivOpConversion : public OpConversionPattern { + DivOpConversion(MLIRContext *context, complex::ComplexRangeFlags target) + : OpConversionPattern(context), complexRange(target) {} + using OpConversionPattern::OpConversionPattern; LogicalResult @@ -281,214 +285,27 @@ struct DivOpConversion : public OpConversionPattern { Value rhsImag = rewriter.create(loc, elementType, adaptor.getRhs()); - // Smith's algorithm to divide complex numbers. It is just a bit smarter - // way to compute the following formula: - // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) - // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / - // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) - // = ((lhsReal * rhsReal + lhsImag * rhsImag) + - // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 - // - // Depending on whether |rhsReal| < |rhsImag| we compute either - // rhsRealImagRatio = rhsReal / rhsImag - // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio - // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom - // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom - // - // or - // - // rhsImagRealRatio = rhsImag / rhsReal - // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio - // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom - // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom - // - // See https://dl.acm.org/citation.cfm?id=368661 for more details. - Value rhsRealImagRatio = - rewriter.create(loc, rhsReal, rhsImag, fmf); - Value rhsRealImagDenom = rewriter.create( - loc, rhsImag, - rewriter.create(loc, rhsRealImagRatio, rhsReal, fmf), - fmf); - Value realNumerator1 = rewriter.create( - loc, - rewriter.create(loc, lhsReal, rhsRealImagRatio, fmf), - lhsImag, fmf); - Value resultReal1 = rewriter.create(loc, realNumerator1, - rhsRealImagDenom, fmf); - Value imagNumerator1 = rewriter.create( - loc, - rewriter.create(loc, lhsImag, rhsRealImagRatio, fmf), - lhsReal, fmf); - Value resultImag1 = rewriter.create(loc, imagNumerator1, - rhsRealImagDenom, fmf); - - Value rhsImagRealRatio = - rewriter.create(loc, rhsImag, rhsReal, fmf); - Value rhsImagRealDenom = rewriter.create( - loc, rhsReal, - rewriter.create(loc, rhsImagRealRatio, rhsImag, fmf), - fmf); - Value realNumerator2 = rewriter.create( - loc, lhsReal, - rewriter.create(loc, lhsImag, rhsImagRealRatio, fmf), - fmf); - Value resultReal2 = rewriter.create(loc, realNumerator2, - rhsImagRealDenom, fmf); - Value imagNumerator2 = rewriter.create( - loc, lhsImag, - rewriter.create(loc, lhsReal, rhsImagRealRatio, fmf), - fmf); - Value resultImag2 = rewriter.create(loc, imagNumerator2, - rhsImagRealDenom, fmf); - - // Consider corner cases. - // Case 1. Zero denominator, numerator contains at most one NaN value. - Value zero = rewriter.create( - loc, elementType, rewriter.getZeroAttr(elementType)); - Value rhsRealAbs = rewriter.create(loc, rhsReal, fmf); - Value rhsRealIsZero = rewriter.create( - loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); - Value rhsImagAbs = rewriter.create(loc, rhsImag, fmf); - Value rhsImagIsZero = rewriter.create( - loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); - Value lhsRealIsNotNaN = rewriter.create( - loc, arith::CmpFPredicate::ORD, lhsReal, zero); - Value lhsImagIsNotNaN = rewriter.create( - loc, arith::CmpFPredicate::ORD, lhsImag, zero); - Value lhsContainsNotNaNValue = - rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); - Value resultIsInfinity = rewriter.create( - loc, lhsContainsNotNaNValue, - rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); - Value inf = rewriter.create( - loc, elementType, - rewriter.getFloatAttr( - elementType, APFloat::getInf(elementType.getFloatSemantics()))); - Value infWithSignOfRhsReal = - rewriter.create(loc, inf, rhsReal); - Value infinityResultReal = - rewriter.create(loc, infWithSignOfRhsReal, lhsReal, fmf); - Value infinityResultImag = - rewriter.create(loc, infWithSignOfRhsReal, lhsImag, fmf); - - // Case 2. Infinite numerator, finite denominator. - Value rhsRealFinite = rewriter.create( - loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); - Value rhsImagFinite = rewriter.create( - loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); - Value rhsFinite = - rewriter.create(loc, rhsRealFinite, rhsImagFinite); - Value lhsRealAbs = rewriter.create(loc, lhsReal, fmf); - Value lhsRealInfinite = rewriter.create( - loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); - Value lhsImagAbs = rewriter.create(loc, lhsImag, fmf); - Value lhsImagInfinite = rewriter.create( - loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); - Value lhsInfinite = - rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); - Value infNumFiniteDenom = - rewriter.create(loc, lhsInfinite, rhsFinite); - Value one = rewriter.create( - loc, elementType, rewriter.getFloatAttr(elementType, 1)); - Value lhsRealIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, lhsRealInfinite, one, zero), - lhsReal); - Value lhsImagIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, lhsImagInfinite, one, zero), - lhsImag); - Value lhsRealIsInfWithSignTimesRhsReal = - rewriter.create(loc, lhsRealIsInfWithSign, rhsReal, fmf); - Value lhsImagIsInfWithSignTimesRhsImag = - rewriter.create(loc, lhsImagIsInfWithSign, rhsImag, fmf); - Value resultReal3 = rewriter.create( - loc, inf, - rewriter.create(loc, lhsRealIsInfWithSignTimesRhsReal, - lhsImagIsInfWithSignTimesRhsImag, fmf), - fmf); - Value lhsRealIsInfWithSignTimesRhsImag = - rewriter.create(loc, lhsRealIsInfWithSign, rhsImag, fmf); - Value lhsImagIsInfWithSignTimesRhsReal = - rewriter.create(loc, lhsImagIsInfWithSign, rhsReal, fmf); - Value resultImag3 = rewriter.create( - loc, inf, - rewriter.create(loc, lhsImagIsInfWithSignTimesRhsReal, - lhsRealIsInfWithSignTimesRhsImag, fmf), - fmf); + Value resultReal, resultImag; + + if (complexRange == complex::ComplexRangeFlags::basic || + complexRange == complex::ComplexRangeFlags::none) { + mlir::complex::convertDivToStandardUsingAlgebraic( + rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal, + &resultImag); + } else if (complexRange == complex::ComplexRangeFlags::improved) { + mlir::complex::convertDivToStandardUsingRangeReduction( + rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal, + &resultImag); + } - // Case 3: Finite numerator, infinite denominator. - Value lhsRealFinite = rewriter.create( - loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); - Value lhsImagFinite = rewriter.create( - loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); - Value lhsFinite = - rewriter.create(loc, lhsRealFinite, lhsImagFinite); - Value rhsRealInfinite = rewriter.create( - loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); - Value rhsImagInfinite = rewriter.create( - loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); - Value rhsInfinite = - rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); - Value finiteNumInfiniteDenom = - rewriter.create(loc, lhsFinite, rhsInfinite); - Value rhsRealIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, rhsRealInfinite, one, zero), - rhsReal); - Value rhsImagIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, rhsImagInfinite, one, zero), - rhsImag); - Value rhsRealIsInfWithSignTimesLhsReal = - rewriter.create(loc, lhsReal, rhsRealIsInfWithSign, fmf); - Value rhsImagIsInfWithSignTimesLhsImag = - rewriter.create(loc, lhsImag, rhsImagIsInfWithSign, fmf); - Value resultReal4 = rewriter.create( - loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimesLhsReal, - rhsImagIsInfWithSignTimesLhsImag, fmf), - fmf); - Value rhsRealIsInfWithSignTimesLhsImag = - rewriter.create(loc, lhsImag, rhsRealIsInfWithSign, fmf); - Value rhsImagIsInfWithSignTimesLhsReal = - rewriter.create(loc, lhsReal, rhsImagIsInfWithSign, fmf); - Value resultImag4 = rewriter.create( - loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimesLhsImag, - rhsImagIsInfWithSignTimesLhsReal, fmf), - fmf); + rewriter.replaceOpWithNewOp(op, type, resultReal, + resultImag); - Value realAbsSmallerThanImagAbs = rewriter.create( - loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); - Value resultReal = rewriter.create( - loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); - Value resultImag = rewriter.create( - loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); - Value resultRealSpecialCase3 = rewriter.create( - loc, finiteNumInfiniteDenom, resultReal4, resultReal); - Value resultImagSpecialCase3 = rewriter.create( - loc, finiteNumInfiniteDenom, resultImag4, resultImag); - Value resultRealSpecialCase2 = rewriter.create( - loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); - Value resultImagSpecialCase2 = rewriter.create( - loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); - Value resultRealSpecialCase1 = rewriter.create( - loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); - Value resultImagSpecialCase1 = rewriter.create( - loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); - - Value resultRealIsNaN = rewriter.create( - loc, arith::CmpFPredicate::UNO, resultReal, zero); - Value resultImagIsNaN = rewriter.create( - loc, arith::CmpFPredicate::UNO, resultImag, zero); - Value resultIsNaN = - rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); - Value resultRealWithSpecialCases = rewriter.create( - loc, resultIsNaN, resultRealSpecialCase1, resultReal); - Value resultImagWithSpecialCases = rewriter.create( - loc, resultIsNaN, resultImagSpecialCase1, resultImag); - - rewriter.replaceOpWithNewOp( - op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); return success(); } + +private: + complex::ComplexRangeFlags complexRange; }; struct ExpOpConversion : public OpConversionPattern { @@ -1219,7 +1036,7 @@ struct AngleOpConversion : public OpConversionPattern { } // namespace void mlir::populateComplexToStandardConversionPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, complex::ComplexRangeFlags complexRange) { // clang-format off patterns.add< AbsOpConversion, @@ -1231,7 +1048,6 @@ void mlir::populateComplexToStandardConversionPatterns( ComparisonOpConversion, ConjOpConversion, CosOpConversion, - DivOpConversion, ExpOpConversion, Expm1OpConversion, Log1pOpConversion, @@ -1246,19 +1062,27 @@ void mlir::populateComplexToStandardConversionPatterns( PowOpConversion, RsqrtOpConversion >(patterns.getContext()); + + patterns.add(patterns.getContext(), complexRange); + // clang-format on } namespace { struct ConvertComplexToStandardPass : public impl::ConvertComplexToStandardBase { + ConvertComplexToStandardPass() = default; + ConvertComplexToStandardPass(const ConvertComplexToStandardOptions &options) + : impl::ConvertComplexToStandardBase( + options) {} + void runOnOperation() override; }; void ConvertComplexToStandardPass::runOnOperation() { // Convert to the Standard dialect using the converter defined above. RewritePatternSet patterns(&getContext()); - populateComplexToStandardConversionPatterns(patterns); + populateComplexToStandardConversionPatterns(patterns, complexRange); ConversionTarget target(getContext()); target.addLegalDialect(); @@ -1272,3 +1096,8 @@ void ConvertComplexToStandardPass::runOnOperation() { std::unique_ptr mlir::createConvertComplexToStandardPass() { return std::make_unique(); } + +std::unique_ptr mlir::createConvertComplexToStandardPass( + ConvertComplexToStandardOptions options) { + return std::make_unique(std::move(options)); +} diff --git a/mlir/test/Conversion/ComplexToLLVM/complex-range-option.mlir b/mlir/test/Conversion/ComplexToLLVM/complex-range-option.mlir new file mode 100644 index 0000000000000..78e8db795788a --- /dev/null +++ b/mlir/test/Conversion/ComplexToLLVM/complex-range-option.mlir @@ -0,0 +1,303 @@ +// RUN: mlir-opt %s -convert-complex-to-llvm=complex-range=improved | FileCheck %s --check-prefix=DIV-SMITH +// RUN: mlir-opt %s -convert-complex-to-llvm=complex-range=basic | FileCheck %s --check-prefix=DIV-ALGEBRAIC +// RUN: mlir-opt %s -convert-complex-to-llvm=complex-range=none | FileCheck %s --check-prefix=DIV-ALGEBRAIC + + +func.func @complex_div(%lhs: complex, %rhs: complex) -> complex { + %div = complex.div %lhs, %rhs : complex + return %div : complex +} +// DIV-SMITH-LABEL: func @complex_div +// DIV-SMITH-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex +// DIV-SMITH-DAG: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex to ![[C_TY:.*>]] +// DIV-SMITH-DAG: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex to ![[C_TY]] + +// DIV-SMITH: %[[LHS_REAL:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]] +// DIV-SMITH: %[[LHS_IMAG:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]] +// DIV-SMITH: %[[RHS_REAL:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]] +// DIV-SMITH: %[[RHS_IMAG:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]] + +// DIV-SMITH: %[[RESULT_0:.*]] = llvm.mlir.poison : ![[C_TY]] + +// DIV-SMITH: %[[RHS_REAL_IMAG_RATIO:.*]] = llvm.fdiv %[[RHS_REAL]], %[[RHS_IMAG]] : f32 +// DIV-SMITH: %[[RHS_REAL_TIMES_RHS_REAL_IMAG_RATIO:.*]] = llvm.fmul %[[RHS_REAL_IMAG_RATIO]], %[[RHS_REAL]] : f32 +// DIV-SMITH: %[[RHS_REAL_IMAG_DENOM:.*]] = llvm.fadd %[[RHS_IMAG]], %[[RHS_REAL_TIMES_RHS_REAL_IMAG_RATIO]] : f32 +// DIV-SMITH: %[[LHS_REAL_TIMES_RHS_REAL_IMAG_RATIO:.*]] = llvm.fmul %[[LHS_REAL]], %[[RHS_REAL_IMAG_RATIO]] : f32 +// DIV-SMITH: %[[REAL_NUMERATOR_1:.*]] = llvm.fadd %[[LHS_REAL_TIMES_RHS_REAL_IMAG_RATIO]], %[[LHS_IMAG]] : f32 +// DIV-SMITH: %[[RESULT_REAL_1:.*]] = llvm.fdiv %[[REAL_NUMERATOR_1]], %[[RHS_REAL_IMAG_DENOM]] : f32 +// DIV-SMITH: %[[LHS_IMAG_TIMES_RHS_REAL_IMAG_RATIO:.*]] = llvm.fmul %[[LHS_IMAG]], %[[RHS_REAL_IMAG_RATIO]] : f32 +// DIV-SMITH: %[[IMAG_NUMERATOR_1:.*]] = llvm.fsub %[[LHS_IMAG_TIMES_RHS_REAL_IMAG_RATIO]], %[[LHS_REAL]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_1:.*]] = llvm.fdiv %[[IMAG_NUMERATOR_1]], %[[RHS_REAL_IMAG_DENOM]] : f32 + +// DIV-SMITH: %[[RHS_IMAG_REAL_RATIO:.*]] = llvm.fdiv %[[RHS_IMAG]], %[[RHS_REAL]] : f32 +// DIV-SMITH: %[[RHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO:.*]] = llvm.fmul %[[RHS_IMAG_REAL_RATIO]], %[[RHS_IMAG]] : f32 +// DIV-SMITH: %[[RHS_IMAG_REAL_DENOM:.*]] = llvm.fadd %[[RHS_REAL]], %[[RHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO]] : f32 +// DIV-SMITH: %[[LHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO:.*]] = llvm.fmul %[[LHS_IMAG]], %[[RHS_IMAG_REAL_RATIO]] : f32 +// DIV-SMITH: %[[REAL_NUMERATOR_2:.*]] = llvm.fadd %[[LHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO]] : f32 +// DIV-SMITH: %[[RESULT_REAL_2:.*]] = llvm.fdiv %[[REAL_NUMERATOR_2]], %[[RHS_IMAG_REAL_DENOM]] : f32 +// DIV-SMITH: %[[LHS_REAL_TIMES_RHS_IMAG_REAL_RATIO:.*]] = llvm.fmul %[[LHS_REAL]], %[[RHS_IMAG_REAL_RATIO]] : f32 +// DIV-SMITH: %[[IMAG_NUMERATOR_2:.*]] = llvm.fsub %[[LHS_IMAG]], %[[LHS_REAL_TIMES_RHS_IMAG_REAL_RATIO]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_2:.*]] = llvm.fdiv %[[IMAG_NUMERATOR_2]], %[[RHS_IMAG_REAL_DENOM]] : f32 + +// Case 1. Zero denominator, numerator contains at most one NaN value. +// DIV-SMITH: %[[ZERO:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 +// DIV-SMITH: %[[RHS_REAL_ABS:.*]] = llvm.intr.fabs(%[[RHS_REAL]]) : (f32) -> f32 +// DIV-SMITH: %[[RHS_REAL_ABS_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[RHS_REAL_ABS]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RHS_IMAG_ABS:.*]] = llvm.intr.fabs(%[[RHS_IMAG]]) : (f32) -> f32 +// DIV-SMITH: %[[RHS_IMAG_ABS_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[RHS_IMAG_ABS]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_REAL_IS_NOT_NAN:.*]] = llvm.fcmp "ord" %[[LHS_REAL]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_NOT_NAN:.*]] = llvm.fcmp "ord" %[[LHS_IMAG]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_CONTAINS_NOT_NAN_VALUE:.*]] = llvm.or %[[LHS_REAL_IS_NOT_NAN]], %[[LHS_IMAG_IS_NOT_NAN]] : i1 +// DIV-SMITH: %[[RHS_IS_ZERO:.*]] = llvm.and %[[RHS_REAL_ABS_IS_ZERO]], %[[RHS_IMAG_ABS_IS_ZERO]] : i1 +// DIV-SMITH: %[[RESULT_IS_INFINITY:.*]] = llvm.and %[[LHS_CONTAINS_NOT_NAN_VALUE]], %[[RHS_IS_ZERO]] : i1 +// DIV-SMITH: %[[INF:.*]] = llvm.mlir.constant(0x7F800000 : f32) : f32 +// DIV-SMITH: %[[INF_WITH_SIGN_OF_RHS_REAL:.*]] = llvm.intr.copysign(%[[INF]], %[[RHS_REAL]]) : (f32, f32) -> f32 +// DIV-SMITH: %[[INFINITY_RESULT_REAL:.*]] = llvm.fmul %[[INF_WITH_SIGN_OF_RHS_REAL]], %[[LHS_REAL]] : f32 +// DIV-SMITH: %[[INFINITY_RESULT_IMAG:.*]] = llvm.fmul %[[INF_WITH_SIGN_OF_RHS_REAL]], %[[LHS_IMAG]] : f32 + +// Case 2. Infinite numerator, finite denominator. +// DIV-SMITH: %[[RHS_REAL_FINITE:.*]] = llvm.fcmp "one" %[[RHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IMAG_FINITE:.*]] = llvm.fcmp "one" %[[RHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IS_FINITE:.*]] = llvm.and %[[RHS_REAL_FINITE]], %[[RHS_IMAG_FINITE]] : i1 +// DIV-SMITH: %[[LHS_REAL_ABS:.*]] = llvm.intr.fabs(%[[LHS_REAL]]) : (f32) -> f32 +// DIV-SMITH: %[[LHS_REAL_INFINITE:.*]] = llvm.fcmp "oeq" %[[LHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IMAG_ABS:.*]] = llvm.intr.fabs(%[[LHS_IMAG]]) : (f32) -> f32 +// DIV-SMITH: %[[LHS_IMAG_INFINITE:.*]] = llvm.fcmp "oeq" %[[LHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IS_INFINITE:.*]] = llvm.or %[[LHS_REAL_INFINITE]], %[[LHS_IMAG_INFINITE]] : i1 +// DIV-SMITH: %[[INF_NUM_FINITE_DENOM:.*]] = llvm.and %[[LHS_IS_INFINITE]], %[[RHS_IS_FINITE]] : i1 +// DIV-SMITH: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF:.*]] = llvm.select %[[LHS_REAL_INFINITE]], %[[ONE]], %[[ZERO]] : i1, f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF_WITH_SIGN:.*]] = llvm.intr.copysign(%[[LHS_REAL_IS_INF]], %[[LHS_REAL]]) : (f32, f32) -> f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF:.*]] = llvm.select %[[LHS_IMAG_INFINITE]], %[[ONE]], %[[ZERO]] : i1, f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF_WITH_SIGN:.*]] = llvm.intr.copysign(%[[LHS_IMAG_IS_INF]], %[[LHS_IMAG]]) : (f32, f32) -> f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_REAL:.*]] = llvm.fmul %[[LHS_REAL_IS_INF_WITH_SIGN]], %[[RHS_REAL]] : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_IMAG:.*]] = llvm.fmul %[[LHS_IMAG_IS_INF_WITH_SIGN]], %[[RHS_IMAG]] : f32 +// DIV-SMITH: %[[INF_MULTIPLICATOR_1:.*]] = llvm.fadd %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_REAL]], %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_IMAG]] : f32 +// DIV-SMITH: %[[RESULT_REAL_3:.*]] = llvm.fmul %[[INF]], %[[INF_MULTIPLICATOR_1]] : f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_IMAG:.*]] = llvm.fmul %[[LHS_REAL_IS_INF_WITH_SIGN]], %[[RHS_IMAG]] : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_REAL:.*]] = llvm.fmul %[[LHS_IMAG_IS_INF_WITH_SIGN]], %[[RHS_REAL]] : f32 +// DIV-SMITH: %[[INF_MULTIPLICATOR_2:.*]] = llvm.fsub %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_REAL]], %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_IMAG]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_3:.*]] = llvm.fmul %[[INF]], %[[INF_MULTIPLICATOR_2]] : f32 + +// Case 3. Finite numerator, infinite denominator. +// DIV-SMITH: %[[LHS_REAL_FINITE:.*]] = llvm.fcmp "one" %[[LHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IMAG_FINITE:.*]] = llvm.fcmp "one" %[[LHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IS_FINITE:.*]] = llvm.and %[[LHS_REAL_FINITE]], %[[LHS_IMAG_FINITE]] : i1 +// DIV-SMITH: %[[RHS_REAL_INFINITE:.*]] = llvm.fcmp "oeq" %[[RHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IMAG_INFINITE:.*]] = llvm.fcmp "oeq" %[[RHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IS_INFINITE:.*]] = llvm.or %[[RHS_REAL_INFINITE]], %[[RHS_IMAG_INFINITE]] : i1 +// DIV-SMITH: %[[FINITE_NUM_INFINITE_DENOM:.*]] = llvm.and %[[LHS_IS_FINITE]], %[[RHS_IS_INFINITE]] : i1 +// DIV-SMITH: %[[RHS_REAL_IS_INF:.*]] = llvm.select %[[RHS_REAL_INFINITE]], %[[ONE]], %[[ZERO]] : i1, f32 +// DIV-SMITH: %[[RHS_REAL_IS_INF_WITH_SIGN:.*]] = llvm.intr.copysign(%[[RHS_REAL_IS_INF]], %[[RHS_REAL]]) : (f32, f32) -> f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF:.*]] = llvm.select %[[RHS_IMAG_INFINITE]], %[[ONE]], %[[ZERO]] : i1, f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF_WITH_SIGN:.*]] = llvm.intr.copysign(%[[RHS_IMAG_IS_INF]], %[[RHS_IMAG]]) : (f32, f32) -> f32 +// DIV-SMITH: %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_REAL:.*]] = llvm.fmul %[[LHS_REAL]], %[[RHS_REAL_IS_INF_WITH_SIGN]] : f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_IMAG:.*]] = llvm.fmul %[[LHS_IMAG]], %[[RHS_IMAG_IS_INF_WITH_SIGN]] : f32 +// DIV-SMITH: %[[ZERO_MULTIPLICATOR_1:.*]] = llvm.fadd %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_REAL]], %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_IMAG]] : f32 +// DIV-SMITH: %[[RESULT_REAL_4:.*]] = llvm.fmul %[[ZERO]], %[[ZERO_MULTIPLICATOR_1]] : f32 +// DIV-SMITH: %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_IMAG:.*]] = llvm.fmul %[[LHS_IMAG]], %[[RHS_REAL_IS_INF_WITH_SIGN]] : f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_REAL:.*]] = llvm.fmul %[[LHS_REAL]], %[[RHS_IMAG_IS_INF_WITH_SIGN]] : f32 +// DIV-SMITH: %[[ZERO_MULTIPLICATOR_2:.*]] = llvm.fsub %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_IMAG]], %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_REAL]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_4:.*]] = llvm.fmul %[[ZERO]], %[[ZERO_MULTIPLICATOR_2]] : f32 + +// DIV-SMITH: %[[REAL_ABS_SMALLER_THAN_IMAG_ABS:.*]] = llvm.fcmp "olt" %[[RHS_REAL_ABS]], %[[RHS_IMAG_ABS]] : f32 +// DIV-SMITH: %[[RESULT_REAL:.*]] = llvm.select %[[REAL_ABS_SMALLER_THAN_IMAG_ABS]], %[[RESULT_REAL_1]], %[[RESULT_REAL_2]] : i1, f32 +// DIV-SMITH: %[[RESULT_IMAG:.*]] = llvm.select %[[REAL_ABS_SMALLER_THAN_IMAG_ABS]], %[[RESULT_IMAG_1]], %[[RESULT_IMAG_2]] : i1, f32 +// DIV-SMITH: %[[RESULT_REAL_SPECIAL_CASE_3:.*]] = llvm.select %[[FINITE_NUM_INFINITE_DENOM]], %[[RESULT_REAL_4]], %[[RESULT_REAL]] : i1, f32 +// DIV-SMITH: %[[RESULT_IMAG_SPECIAL_CASE_3:.*]] = llvm.select %[[FINITE_NUM_INFINITE_DENOM]], %[[RESULT_IMAG_4]], %[[RESULT_IMAG]] : i1, f32 +// DIV-SMITH: %[[RESULT_REAL_SPECIAL_CASE_2:.*]] = llvm.select %[[INF_NUM_FINITE_DENOM]], %[[RESULT_REAL_3]], %[[RESULT_REAL_SPECIAL_CASE_3]] : i1, f32 +// DIV-SMITH: %[[RESULT_IMAG_SPECIAL_CASE_2:.*]] = llvm.select %[[INF_NUM_FINITE_DENOM]], %[[RESULT_IMAG_3]], %[[RESULT_IMAG_SPECIAL_CASE_3]] : i1, f32 +// DIV-SMITH: %[[RESULT_REAL_SPECIAL_CASE_1:.*]] = llvm.select %[[RESULT_IS_INFINITY]], %[[INFINITY_RESULT_REAL]], %[[RESULT_REAL_SPECIAL_CASE_2]] : i1, f32 +// DIV-SMITH: %[[RESULT_IMAG_SPECIAL_CASE_1:.*]] = llvm.select %[[RESULT_IS_INFINITY]], %[[INFINITY_RESULT_IMAG]], %[[RESULT_IMAG_SPECIAL_CASE_2]] : i1, f32 +// DIV-SMITH: %[[RESULT_REAL_IS_NAN:.*]] = llvm.fcmp "uno" %[[RESULT_REAL]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_IS_NAN:.*]] = llvm.fcmp "uno" %[[RESULT_IMAG]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RESULT_IS_NAN:.*]] = llvm.and %[[RESULT_REAL_IS_NAN]], %[[RESULT_IMAG_IS_NAN]] : i1 +// DIV-SMITH: %[[RESULT_REAL_WITH_SPECIAL_CASES:.*]] = llvm.select %[[RESULT_IS_NAN]], %[[RESULT_REAL_SPECIAL_CASE_1]], %[[RESULT_REAL]] : i1, f32 +// DIV-SMITH: %[[RESULT_IMAG_WITH_SPECIAL_CASES:.*]] = llvm.select %[[RESULT_IS_NAN]], %[[RESULT_IMAG_SPECIAL_CASE_1]], %[[RESULT_IMAG]] : i1, f32 +// DIV-SMITH: %[[RESULT_1:.*]] = llvm.insertvalue %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_0]][0] : ![[C_TY]] +// DIV-SMITH: %[[RESULT_2:.*]] = llvm.insertvalue %[[RESULT_IMAG_WITH_SPECIAL_CASES]], %[[RESULT_1]][1] : ![[C_TY]] +// DIV-SMITH: %[[CASTED_RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_2]] : ![[C_TY]] to complex +// DIV-SMITH: return %[[CASTED_RESULT]] : complex + + +// DIV-ALGEBRAIC-LABEL: func @complex_div +// DIV-ALGEBRAIC-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex +// DIV-ALGEBRAIC-DAG: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex to ![[C_TY:.*>]] +// DIV-ALGEBRAIC-DAG: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex to ![[C_TY]] + +// DIV-ALGEBRAIC: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]] +// DIV-ALGEBRAIC: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]] +// DIV-ALGEBRAIC: %[[RHS_RE:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]] +// DIV-ALGEBRAIC: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]] + +// DIV-ALGEBRAIC: %[[RESULT_0:.*]] = llvm.mlir.poison : ![[C_TY]] + +// DIV-ALGEBRAIC-DAG: %[[RHS_RE_SQ:.*]] = llvm.fmul %[[RHS_RE]], %[[RHS_RE]] : f32 +// DIV-ALGEBRAIC-DAG: %[[RHS_IM_SQ:.*]] = llvm.fmul %[[RHS_IM]], %[[RHS_IM]] : f32 +// DIV-ALGEBRAIC: %[[SQ_NORM:.*]] = llvm.fadd %[[RHS_RE_SQ]], %[[RHS_IM_SQ]] : f32 + +// DIV-ALGEBRAIC-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_RE]] : f32 +// DIV-ALGEBRAIC-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_IM]] : f32 +// DIV-ALGEBRAIC: %[[REAL_TMP_2:.*]] = llvm.fadd %[[REAL_TMP_0]], %[[REAL_TMP_1]] : f32 + +// DIV-ALGEBRAIC-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] : f32 +// DIV-ALGEBRAIC-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] : f32 +// DIV-ALGEBRAIC: %[[IMAG_TMP_2:.*]] = llvm.fsub %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] : f32 + +// DIV-ALGEBRAIC: %[[REAL:.*]] = llvm.fdiv %[[REAL_TMP_2]], %[[SQ_NORM]] : f32 +// DIV-ALGEBRAIC: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]] : f32 +// DIV-ALGEBRAIC: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]] +// DIV-ALGEBRAIC: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]] +// +// DIV-ALGEBRAIC: %[[CASTED_RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_2]] : ![[C_TY]] to complex +// DIV-ALGEBRAIC: return %[[CASTED_RESULT]] : complex + + +func.func @complex_div_with_fmf(%lhs: complex, %rhs: complex) -> complex { + %div = complex.div %lhs, %rhs fastmath : complex + return %div : complex +} +// DIV-SMITH-LABEL: func @complex_div_with_fmf +// DIV-SMITH-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex +// DIV-SMITH-DAG: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex to ![[C_TY:.*>]] +// DIV-SMITH-DAG: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex to ![[C_TY]] + +// DIV-SMITH: %[[LHS_REAL:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]] +// DIV-SMITH: %[[LHS_IMAG:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]] +// DIV-SMITH: %[[RHS_REAL:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]] +// DIV-SMITH: %[[RHS_IMAG:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]] + +// DIV-SMITH: %[[RESULT_0:.*]] = llvm.mlir.poison : ![[C_TY]] + +// DIV-SMITH: %[[RHS_REAL_IMAG_RATIO:.*]] = llvm.fdiv %[[RHS_REAL]], %[[RHS_IMAG]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[RHS_REAL_TIMES_RHS_REAL_IMAG_RATIO:.*]] = llvm.fmul %[[RHS_REAL_IMAG_RATIO]], %[[RHS_REAL]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[RHS_REAL_IMAG_DENOM:.*]] = llvm.fadd %[[RHS_IMAG]], %[[RHS_REAL_TIMES_RHS_REAL_IMAG_RATIO]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[LHS_REAL_TIMES_RHS_REAL_IMAG_RATIO:.*]] = llvm.fmul %[[LHS_REAL]], %[[RHS_REAL_IMAG_RATIO]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[REAL_NUMERATOR_1:.*]] = llvm.fadd %[[LHS_REAL_TIMES_RHS_REAL_IMAG_RATIO]], %[[LHS_IMAG]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[RESULT_REAL_1:.*]] = llvm.fdiv %[[REAL_NUMERATOR_1]], %[[RHS_REAL_IMAG_DENOM]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[LHS_IMAG_TIMES_RHS_REAL_IMAG_RATIO:.*]] = llvm.fmul %[[LHS_IMAG]], %[[RHS_REAL_IMAG_RATIO]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[IMAG_NUMERATOR_1:.*]] = llvm.fsub %[[LHS_IMAG_TIMES_RHS_REAL_IMAG_RATIO]], %[[LHS_REAL]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[RESULT_IMAG_1:.*]] = llvm.fdiv %[[IMAG_NUMERATOR_1]], %[[RHS_REAL_IMAG_DENOM]] {fastmathFlags = #llvm.fastmath} : f32 + +// DIV-SMITH: %[[RHS_IMAG_REAL_RATIO:.*]] = llvm.fdiv %[[RHS_IMAG]], %[[RHS_REAL]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[RHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO:.*]] = llvm.fmul %[[RHS_IMAG_REAL_RATIO]], %[[RHS_IMAG]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[RHS_IMAG_REAL_DENOM:.*]] = llvm.fadd %[[RHS_REAL]], %[[RHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[LHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO:.*]] = llvm.fmul %[[LHS_IMAG]], %[[RHS_IMAG_REAL_RATIO]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[REAL_NUMERATOR_2:.*]] = llvm.fadd %[[LHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[RESULT_REAL_2:.*]] = llvm.fdiv %[[REAL_NUMERATOR_2]], %[[RHS_IMAG_REAL_DENOM]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[LHS_REAL_TIMES_RHS_IMAG_REAL_RATIO:.*]] = llvm.fmul %[[LHS_REAL]], %[[RHS_IMAG_REAL_RATIO]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[IMAG_NUMERATOR_2:.*]] = llvm.fsub %[[LHS_IMAG]], %[[LHS_REAL_TIMES_RHS_IMAG_REAL_RATIO]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[RESULT_IMAG_2:.*]] = llvm.fdiv %[[IMAG_NUMERATOR_2]], %[[RHS_IMAG_REAL_DENOM]] {fastmathFlags = #llvm.fastmath} : f32 + +// Case 1. Zero denominator, numerator contains at most one NaN value. +// DIV-SMITH: %[[ZERO:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 +// DIV-SMITH: %[[RHS_REAL_ABS:.*]] = llvm.intr.fabs(%[[RHS_REAL]]) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 +// DIV-SMITH: %[[RHS_REAL_ABS_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[RHS_REAL_ABS]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RHS_IMAG_ABS:.*]] = llvm.intr.fabs(%[[RHS_IMAG]]) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 +// DIV-SMITH: %[[RHS_IMAG_ABS_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[RHS_IMAG_ABS]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_REAL_IS_NOT_NAN:.*]] = llvm.fcmp "ord" %[[LHS_REAL]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_NOT_NAN:.*]] = llvm.fcmp "ord" %[[LHS_IMAG]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_CONTAINS_NOT_NAN_VALUE:.*]] = llvm.or %[[LHS_REAL_IS_NOT_NAN]], %[[LHS_IMAG_IS_NOT_NAN]] : i1 +// DIV-SMITH: %[[RHS_IS_ZERO:.*]] = llvm.and %[[RHS_REAL_ABS_IS_ZERO]], %[[RHS_IMAG_ABS_IS_ZERO]] : i1 +// DIV-SMITH: %[[RESULT_IS_INFINITY:.*]] = llvm.and %[[LHS_CONTAINS_NOT_NAN_VALUE]], %[[RHS_IS_ZERO]] : i1 +// DIV-SMITH: %[[INF:.*]] = llvm.mlir.constant(0x7F800000 : f32) : f32 +// DIV-SMITH: %[[INF_WITH_SIGN_OF_RHS_REAL:.*]] = llvm.intr.copysign(%[[INF]], %[[RHS_REAL]]) : (f32, f32) -> f32 +// DIV-SMITH: %[[INFINITY_RESULT_REAL:.*]] = llvm.fmul %[[INF_WITH_SIGN_OF_RHS_REAL]], %[[LHS_REAL]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[INFINITY_RESULT_IMAG:.*]] = llvm.fmul %[[INF_WITH_SIGN_OF_RHS_REAL]], %[[LHS_IMAG]] {fastmathFlags = #llvm.fastmath} : f32 + +// Case 2. Infinite numerator, finite denominator. +// DIV-SMITH: %[[RHS_REAL_FINITE:.*]] = llvm.fcmp "one" %[[RHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IMAG_FINITE:.*]] = llvm.fcmp "one" %[[RHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IS_FINITE:.*]] = llvm.and %[[RHS_REAL_FINITE]], %[[RHS_IMAG_FINITE]] : i1 +// DIV-SMITH: %[[LHS_REAL_ABS:.*]] = llvm.intr.fabs(%[[LHS_REAL]]) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 +// DIV-SMITH: %[[LHS_REAL_INFINITE:.*]] = llvm.fcmp "oeq" %[[LHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IMAG_ABS:.*]] = llvm.intr.fabs(%[[LHS_IMAG]]) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 +// DIV-SMITH: %[[LHS_IMAG_INFINITE:.*]] = llvm.fcmp "oeq" %[[LHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IS_INFINITE:.*]] = llvm.or %[[LHS_REAL_INFINITE]], %[[LHS_IMAG_INFINITE]] : i1 +// DIV-SMITH: %[[INF_NUM_FINITE_DENOM:.*]] = llvm.and %[[LHS_IS_INFINITE]], %[[RHS_IS_FINITE]] : i1 +// DIV-SMITH: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF:.*]] = llvm.select %[[LHS_REAL_INFINITE]], %[[ONE]], %[[ZERO]] : i1, f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF_WITH_SIGN:.*]] = llvm.intr.copysign(%[[LHS_REAL_IS_INF]], %[[LHS_REAL]]) : (f32, f32) -> f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF:.*]] = llvm.select %[[LHS_IMAG_INFINITE]], %[[ONE]], %[[ZERO]] : i1, f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF_WITH_SIGN:.*]] = llvm.intr.copysign(%[[LHS_IMAG_IS_INF]], %[[LHS_IMAG]]) : (f32, f32) -> f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_REAL:.*]] = llvm.fmul %[[LHS_REAL_IS_INF_WITH_SIGN]], %[[RHS_REAL]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_IMAG:.*]] = llvm.fmul %[[LHS_IMAG_IS_INF_WITH_SIGN]], %[[RHS_IMAG]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[INF_MULTIPLICATOR_1:.*]] = llvm.fadd %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_REAL]], %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_IMAG]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[RESULT_REAL_3:.*]] = llvm.fmul %[[INF]], %[[INF_MULTIPLICATOR_1]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_IMAG:.*]] = llvm.fmul %[[LHS_REAL_IS_INF_WITH_SIGN]], %[[RHS_IMAG]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_REAL:.*]] = llvm.fmul %[[LHS_IMAG_IS_INF_WITH_SIGN]], %[[RHS_REAL]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[INF_MULTIPLICATOR_2:.*]] = llvm.fsub %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_REAL]], %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_IMAG]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[RESULT_IMAG_3:.*]] = llvm.fmul %[[INF]], %[[INF_MULTIPLICATOR_2]] {fastmathFlags = #llvm.fastmath} : f32 + +// Case 3. Finite numerator, infinite denominator. +// DIV-SMITH: %[[LHS_REAL_FINITE:.*]] = llvm.fcmp "one" %[[LHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IMAG_FINITE:.*]] = llvm.fcmp "one" %[[LHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IS_FINITE:.*]] = llvm.and %[[LHS_REAL_FINITE]], %[[LHS_IMAG_FINITE]] : i1 +// DIV-SMITH: %[[RHS_REAL_INFINITE:.*]] = llvm.fcmp "oeq" %[[RHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IMAG_INFINITE:.*]] = llvm.fcmp "oeq" %[[RHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IS_INFINITE:.*]] = llvm.or %[[RHS_REAL_INFINITE]], %[[RHS_IMAG_INFINITE]] : i1 +// DIV-SMITH: %[[FINITE_NUM_INFINITE_DENOM:.*]] = llvm.and %[[LHS_IS_FINITE]], %[[RHS_IS_INFINITE]] : i1 +// DIV-SMITH: %[[RHS_REAL_IS_INF:.*]] = llvm.select %[[RHS_REAL_INFINITE]], %[[ONE]], %[[ZERO]] : i1, f32 +// DIV-SMITH: %[[RHS_REAL_IS_INF_WITH_SIGN:.*]] = llvm.intr.copysign(%[[RHS_REAL_IS_INF]], %[[RHS_REAL]]) : (f32, f32) -> f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF:.*]] = llvm.select %[[RHS_IMAG_INFINITE]], %[[ONE]], %[[ZERO]] : i1, f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF_WITH_SIGN:.*]] = llvm.intr.copysign(%[[RHS_IMAG_IS_INF]], %[[RHS_IMAG]]) : (f32, f32) -> f32 +// DIV-SMITH: %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_REAL:.*]] = llvm.fmul %[[LHS_REAL]], %[[RHS_REAL_IS_INF_WITH_SIGN]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_IMAG:.*]] = llvm.fmul %[[LHS_IMAG]], %[[RHS_IMAG_IS_INF_WITH_SIGN]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[ZERO_MULTIPLICATOR_1:.*]] = llvm.fadd %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_REAL]], %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_IMAG]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[RESULT_REAL_4:.*]] = llvm.fmul %[[ZERO]], %[[ZERO_MULTIPLICATOR_1]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_IMAG:.*]] = llvm.fmul %[[LHS_IMAG]], %[[RHS_REAL_IS_INF_WITH_SIGN]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_REAL:.*]] = llvm.fmul %[[LHS_REAL]], %[[RHS_IMAG_IS_INF_WITH_SIGN]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[ZERO_MULTIPLICATOR_2:.*]] = llvm.fsub %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_IMAG]], %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_REAL]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-SMITH: %[[RESULT_IMAG_4:.*]] = llvm.fmul %[[ZERO]], %[[ZERO_MULTIPLICATOR_2]] {fastmathFlags = #llvm.fastmath} : f32 + +// DIV-SMITH: %[[REAL_ABS_SMALLER_THAN_IMAG_ABS:.*]] = llvm.fcmp "olt" %[[RHS_REAL_ABS]], %[[RHS_IMAG_ABS]] : f32 +// DIV-SMITH: %[[RESULT_REAL:.*]] = llvm.select %[[REAL_ABS_SMALLER_THAN_IMAG_ABS]], %[[RESULT_REAL_1]], %[[RESULT_REAL_2]] : i1, f32 +// DIV-SMITH: %[[RESULT_IMAG:.*]] = llvm.select %[[REAL_ABS_SMALLER_THAN_IMAG_ABS]], %[[RESULT_IMAG_1]], %[[RESULT_IMAG_2]] : i1, f32 +// DIV-SMITH: %[[RESULT_REAL_SPECIAL_CASE_3:.*]] = llvm.select %[[FINITE_NUM_INFINITE_DENOM]], %[[RESULT_REAL_4]], %[[RESULT_REAL]] : i1, f32 +// DIV-SMITH: %[[RESULT_IMAG_SPECIAL_CASE_3:.*]] = llvm.select %[[FINITE_NUM_INFINITE_DENOM]], %[[RESULT_IMAG_4]], %[[RESULT_IMAG]] : i1, f32 +// DIV-SMITH: %[[RESULT_REAL_SPECIAL_CASE_2:.*]] = llvm.select %[[INF_NUM_FINITE_DENOM]], %[[RESULT_REAL_3]], %[[RESULT_REAL_SPECIAL_CASE_3]] : i1, f32 +// DIV-SMITH: %[[RESULT_IMAG_SPECIAL_CASE_2:.*]] = llvm.select %[[INF_NUM_FINITE_DENOM]], %[[RESULT_IMAG_3]], %[[RESULT_IMAG_SPECIAL_CASE_3]] : i1, f32 +// DIV-SMITH: %[[RESULT_REAL_SPECIAL_CASE_1:.*]] = llvm.select %[[RESULT_IS_INFINITY]], %[[INFINITY_RESULT_REAL]], %[[RESULT_REAL_SPECIAL_CASE_2]] : i1, f32 +// DIV-SMITH: %[[RESULT_IMAG_SPECIAL_CASE_1:.*]] = llvm.select %[[RESULT_IS_INFINITY]], %[[INFINITY_RESULT_IMAG]], %[[RESULT_IMAG_SPECIAL_CASE_2]] : i1, f32 +// DIV-SMITH: %[[RESULT_REAL_IS_NAN:.*]] = llvm.fcmp "uno" %[[RESULT_REAL]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_IS_NAN:.*]] = llvm.fcmp "uno" %[[RESULT_IMAG]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RESULT_IS_NAN:.*]] = llvm.and %[[RESULT_REAL_IS_NAN]], %[[RESULT_IMAG_IS_NAN]] : i1 +// DIV-SMITH: %[[RESULT_REAL_WITH_SPECIAL_CASES:.*]] = llvm.select %[[RESULT_IS_NAN]], %[[RESULT_REAL_SPECIAL_CASE_1]], %[[RESULT_REAL]] : i1, f32 +// DIV-SMITH: %[[RESULT_IMAG_WITH_SPECIAL_CASES:.*]] = llvm.select %[[RESULT_IS_NAN]], %[[RESULT_IMAG_SPECIAL_CASE_1]], %[[RESULT_IMAG]] : i1, f32 +// DIV-SMITH: %[[RESULT_1:.*]] = llvm.insertvalue %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_0]][0] : ![[C_TY]] +// DIV-SMITH: %[[RESULT_2:.*]] = llvm.insertvalue %[[RESULT_IMAG_WITH_SPECIAL_CASES]], %[[RESULT_1]][1] : ![[C_TY]] +// DIV-SMITH: %[[CASTED_RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_2]] : ![[C_TY]] to complex +// DIV-SMITH: return %[[CASTED_RESULT]] : complex + + +// DIV-ALGEBRAIC-LABEL: func @complex_div_with_fmf +// DIV-ALGEBRAIC-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex +// DIV-ALGEBRAIC-DAG: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex to ![[C_TY:.*>]] +// DIV-ALGEBRAIC-DAG: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex to ![[C_TY]] + +// DIV-ALGEBRAIC: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]] +// DIV-ALGEBRAIC: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]] +// DIV-ALGEBRAIC: %[[RHS_RE:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]] +// DIV-ALGEBRAIC: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]] + +// DIV-ALGEBRAIC: %[[RESULT_0:.*]] = llvm.mlir.poison : ![[C_TY]] + +// DIV-ALGEBRAIC-DAG: %[[RHS_RE_SQ:.*]] = llvm.fmul %[[RHS_RE]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-ALGEBRAIC-DAG: %[[RHS_IM_SQ:.*]] = llvm.fmul %[[RHS_IM]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-ALGEBRAIC: %[[SQ_NORM:.*]] = llvm.fadd %[[RHS_RE_SQ]], %[[RHS_IM_SQ]] {fastmathFlags = #llvm.fastmath} : f32 + +// DIV-ALGEBRAIC-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-ALGEBRAIC-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-ALGEBRAIC: %[[REAL_TMP_2:.*]] = llvm.fadd %[[REAL_TMP_0]], %[[REAL_TMP_1]] {fastmathFlags = #llvm.fastmath} : f32 + +// DIV-ALGEBRAIC-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-ALGEBRAIC-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-ALGEBRAIC: %[[IMAG_TMP_2:.*]] = llvm.fsub %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] {fastmathFlags = #llvm.fastmath} : f32 + +// DIV-ALGEBRAIC: %[[REAL:.*]] = llvm.fdiv %[[REAL_TMP_2]], %[[SQ_NORM]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-ALGEBRAIC: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]] {fastmathFlags = #llvm.fastmath} : f32 +// DIV-ALGEBRAIC: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]] +// DIV-ALGEBRAIC: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]] +// +// DIV-ALGEBRAIC: %[[CASTED_RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_2]] : ![[C_TY]] to complex +// DIV-ALGEBRAIC: return %[[CASTED_RESULT]] : complex diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir index 40f8af7de44aa..ad1b6658fbe78 100644 --- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir @@ -103,8 +103,8 @@ func.func @complex_div(%lhs: complex, %rhs: complex) -> complex { // CHECK: %[[IMAG_TMP_2:.*]] = llvm.fsub %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] : f32 // CHECK: %[[REAL:.*]] = llvm.fdiv %[[REAL_TMP_2]], %[[SQ_NORM]] : f32 -// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]] // CHECK: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]] : f32 +// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]] // CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]] // // CHECK: %[[CASTED_RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_2]] : ![[C_TY]] to complex @@ -221,8 +221,8 @@ func.func @complex_substraction_with_fmf() { // CHECK: %[[IMAG_TMP_2:.*]] = llvm.fsub %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] {fastmathFlags = #llvm.fastmath} : f32 // CHECK: %[[REAL:.*]] = llvm.fdiv %[[REAL_TMP_2]], %[[SQ_NORM]] {fastmathFlags = #llvm.fastmath} : f32 -// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]] // CHECK: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]] // CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]] // // CHECK: %[[CASTED_RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_2]] : ![[C_TY]] to complex diff --git a/mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir b/mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir index deae4f618f789..2e27d694d1e71 100644 --- a/mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir +++ b/mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir @@ -26,8 +26,8 @@ func.func @complex_div(%lhs: complex, %rhs: complex) -> complex { // CHECK: %[[IMAG_TMP_2:.*]] = llvm.fsub %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] : f32 // CHECK: %[[REAL:.*]] = llvm.fdiv %[[REAL_TMP_2]], %[[SQ_NORM]] : f32 -// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]] // CHECK: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]] : f32 +// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]] // CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]] // CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]] diff --git a/mlir/test/Conversion/ComplexToStandard/complex-range-option.mlir b/mlir/test/Conversion/ComplexToStandard/complex-range-option.mlir new file mode 100644 index 0000000000000..97f37d8ebe77e --- /dev/null +++ b/mlir/test/Conversion/ComplexToStandard/complex-range-option.mlir @@ -0,0 +1,277 @@ +// RUN: mlir-opt %s -convert-complex-to-standard=complex-range=improved | FileCheck %s --check-prefix=DIV-SMITH +// RUN: mlir-opt %s -convert-complex-to-standard=complex-range=basic | FileCheck %s --check-prefix=DIV-ALGEBRAIC +// RUN: mlir-opt %s -convert-complex-to-standard=complex-range=none | FileCheck %s --check-prefix=DIV-ALGEBRAIC + + +func.func @complex_div(%lhs: complex, %rhs: complex) -> complex { + %div = complex.div %lhs, %rhs : complex + return %div : complex +} +// DIV-SMITH-LABEL: func @complex_div +// DIV-SMITH-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex + +// DIV-SMITH: %[[LHS_REAL:.*]] = complex.re %[[LHS]] : complex +// DIV-SMITH: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex +// DIV-SMITH: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex +// DIV-SMITH: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex + +// DIV-SMITH: %[[RHS_REAL_IMAG_RATIO:.*]] = arith.divf %[[RHS_REAL]], %[[RHS_IMAG]] : f32 +// DIV-SMITH: %[[RHS_REAL_TIMES_RHS_REAL_IMAG_RATIO:.*]] = arith.mulf %[[RHS_REAL_IMAG_RATIO]], %[[RHS_REAL]] : f32 +// DIV-SMITH: %[[RHS_REAL_IMAG_DENOM:.*]] = arith.addf %[[RHS_IMAG]], %[[RHS_REAL_TIMES_RHS_REAL_IMAG_RATIO]] : f32 +// DIV-SMITH: %[[LHS_REAL_TIMES_RHS_REAL_IMAG_RATIO:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL_IMAG_RATIO]] : f32 +// DIV-SMITH: %[[REAL_NUMERATOR_1:.*]] = arith.addf %[[LHS_REAL_TIMES_RHS_REAL_IMAG_RATIO]], %[[LHS_IMAG]] : f32 +// DIV-SMITH: %[[RESULT_REAL_1:.*]] = arith.divf %[[REAL_NUMERATOR_1]], %[[RHS_REAL_IMAG_DENOM]] : f32 +// DIV-SMITH: %[[LHS_IMAG_TIMES_RHS_REAL_IMAG_RATIO:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL_IMAG_RATIO]] : f32 +// DIV-SMITH: %[[IMAG_NUMERATOR_1:.*]] = arith.subf %[[LHS_IMAG_TIMES_RHS_REAL_IMAG_RATIO]], %[[LHS_REAL]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_1:.*]] = arith.divf %[[IMAG_NUMERATOR_1]], %[[RHS_REAL_IMAG_DENOM]] : f32 + +// DIV-SMITH: %[[RHS_IMAG_REAL_RATIO:.*]] = arith.divf %[[RHS_IMAG]], %[[RHS_REAL]] : f32 +// DIV-SMITH: %[[RHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO:.*]] = arith.mulf %[[RHS_IMAG_REAL_RATIO]], %[[RHS_IMAG]] : f32 +// DIV-SMITH: %[[RHS_IMAG_REAL_DENOM:.*]] = arith.addf %[[RHS_REAL]], %[[RHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO]] : f32 +// DIV-SMITH: %[[LHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG_REAL_RATIO]] : f32 +// DIV-SMITH: %[[REAL_NUMERATOR_2:.*]] = arith.addf %[[LHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO]] : f32 +// DIV-SMITH: %[[RESULT_REAL_2:.*]] = arith.divf %[[REAL_NUMERATOR_2]], %[[RHS_IMAG_REAL_DENOM]] : f32 +// DIV-SMITH: %[[LHS_REAL_TIMES_RHS_IMAG_REAL_RATIO:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG_REAL_RATIO]] : f32 +// DIV-SMITH: %[[IMAG_NUMERATOR_2:.*]] = arith.subf %[[LHS_IMAG]], %[[LHS_REAL_TIMES_RHS_IMAG_REAL_RATIO]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_2:.*]] = arith.divf %[[IMAG_NUMERATOR_2]], %[[RHS_IMAG_REAL_DENOM]] : f32 + +// Case 1. Zero denominator, numerator contains at most one NaN value. +// DIV-SMITH: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// DIV-SMITH: %[[RHS_REAL_ABS:.*]] = math.absf %[[RHS_REAL]] : f32 +// DIV-SMITH: %[[RHS_REAL_ABS_IS_ZERO:.*]] = arith.cmpf oeq, %[[RHS_REAL_ABS]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RHS_IMAG_ABS:.*]] = math.absf %[[RHS_IMAG]] : f32 +// DIV-SMITH: %[[RHS_IMAG_ABS_IS_ZERO:.*]] = arith.cmpf oeq, %[[RHS_IMAG_ABS]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_REAL_IS_NOT_NAN:.*]] = arith.cmpf ord, %[[LHS_REAL]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_NOT_NAN:.*]] = arith.cmpf ord, %[[LHS_IMAG]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_CONTAINS_NOT_NAN_VALUE:.*]] = arith.ori %[[LHS_REAL_IS_NOT_NAN]], %[[LHS_IMAG_IS_NOT_NAN]] : i1 +// DIV-SMITH: %[[RHS_IS_ZERO:.*]] = arith.andi %[[RHS_REAL_ABS_IS_ZERO]], %[[RHS_IMAG_ABS_IS_ZERO]] : i1 +// DIV-SMITH: %[[RESULT_IS_INFINITY:.*]] = arith.andi %[[LHS_CONTAINS_NOT_NAN_VALUE]], %[[RHS_IS_ZERO]] : i1 +// DIV-SMITH: %[[INF:.*]] = arith.constant 0x7F800000 : f32 +// DIV-SMITH: %[[INF_WITH_SIGN_OF_RHS_REAL:.*]] = math.copysign %[[INF]], %[[RHS_REAL]] : f32 +// DIV-SMITH: %[[INFINITY_RESULT_REAL:.*]] = arith.mulf %[[INF_WITH_SIGN_OF_RHS_REAL]], %[[LHS_REAL]] : f32 +// DIV-SMITH: %[[INFINITY_RESULT_IMAG:.*]] = arith.mulf %[[INF_WITH_SIGN_OF_RHS_REAL]], %[[LHS_IMAG]] : f32 + +// Case 2. Infinite numerator, finite denominator. +// DIV-SMITH: %[[RHS_REAL_FINITE:.*]] = arith.cmpf one, %[[RHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IMAG_FINITE:.*]] = arith.cmpf one, %[[RHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IS_FINITE:.*]] = arith.andi %[[RHS_REAL_FINITE]], %[[RHS_IMAG_FINITE]] : i1 +// DIV-SMITH: %[[LHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL]] : f32 +// DIV-SMITH: %[[LHS_REAL_INFINITE:.*]] = arith.cmpf oeq, %[[LHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG]] : f32 +// DIV-SMITH: %[[LHS_IMAG_INFINITE:.*]] = arith.cmpf oeq, %[[LHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IS_INFINITE:.*]] = arith.ori %[[LHS_REAL_INFINITE]], %[[LHS_IMAG_INFINITE]] : i1 +// DIV-SMITH: %[[INF_NUM_FINITE_DENOM:.*]] = arith.andi %[[LHS_IS_INFINITE]], %[[RHS_IS_FINITE]] : i1 +// DIV-SMITH: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF:.*]] = arith.select %[[LHS_REAL_INFINITE]], %[[ONE]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF_WITH_SIGN:.*]] = math.copysign %[[LHS_REAL_IS_INF]], %[[LHS_REAL]] : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF:.*]] = arith.select %[[LHS_IMAG_INFINITE]], %[[ONE]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF_WITH_SIGN:.*]] = math.copysign %[[LHS_IMAG_IS_INF]], %[[LHS_IMAG]] : f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL_IS_INF_WITH_SIGN]], %[[RHS_REAL]] : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG_IS_INF_WITH_SIGN]], %[[RHS_IMAG]] : f32 +// DIV-SMITH: %[[INF_MULTIPLICATOR_1:.*]] = arith.addf %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_REAL]], %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_IMAG]] : f32 +// DIV-SMITH: %[[RESULT_REAL_3:.*]] = arith.mulf %[[INF]], %[[INF_MULTIPLICATOR_1]] : f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL_IS_INF_WITH_SIGN]], %[[RHS_IMAG]] : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG_IS_INF_WITH_SIGN]], %[[RHS_REAL]] : f32 +// DIV-SMITH: %[[INF_MULTIPLICATOR_2:.*]] = arith.subf %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_REAL]], %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_IMAG]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_3:.*]] = arith.mulf %[[INF]], %[[INF_MULTIPLICATOR_2]] : f32 + +// Case 3. Finite numerator, infinite denominator. +// DIV-SMITH: %[[LHS_REAL_FINITE:.*]] = arith.cmpf one, %[[LHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IMAG_FINITE:.*]] = arith.cmpf one, %[[LHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IS_FINITE:.*]] = arith.andi %[[LHS_REAL_FINITE]], %[[LHS_IMAG_FINITE]] : i1 +// DIV-SMITH: %[[RHS_REAL_INFINITE:.*]] = arith.cmpf oeq, %[[RHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IMAG_INFINITE:.*]] = arith.cmpf oeq, %[[RHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IS_INFINITE:.*]] = arith.ori %[[RHS_REAL_INFINITE]], %[[RHS_IMAG_INFINITE]] : i1 +// DIV-SMITH: %[[FINITE_NUM_INFINITE_DENOM:.*]] = arith.andi %[[LHS_IS_FINITE]], %[[RHS_IS_INFINITE]] : i1 +// DIV-SMITH: %[[RHS_REAL_IS_INF:.*]] = arith.select %[[RHS_REAL_INFINITE]], %[[ONE]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RHS_REAL_IS_INF_WITH_SIGN:.*]] = math.copysign %[[RHS_REAL_IS_INF]], %[[RHS_REAL]] : f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF:.*]] = arith.select %[[RHS_IMAG_INFINITE]], %[[ONE]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF_WITH_SIGN:.*]] = math.copysign %[[RHS_IMAG_IS_INF]], %[[RHS_IMAG]] : f32 +// DIV-SMITH: %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL_IS_INF_WITH_SIGN]] : f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG_IS_INF_WITH_SIGN]] : f32 +// DIV-SMITH: %[[ZERO_MULTIPLICATOR_1:.*]] = arith.addf %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_REAL]], %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_IMAG]] : f32 +// DIV-SMITH: %[[RESULT_REAL_4:.*]] = arith.mulf %[[ZERO]], %[[ZERO_MULTIPLICATOR_1]] : f32 +// DIV-SMITH: %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL_IS_INF_WITH_SIGN]] : f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG_IS_INF_WITH_SIGN]] : f32 +// DIV-SMITH: %[[ZERO_MULTIPLICATOR_2:.*]] = arith.subf %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_IMAG]], %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_REAL]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_4:.*]] = arith.mulf %[[ZERO]], %[[ZERO_MULTIPLICATOR_2]] : f32 + +// DIV-SMITH: %[[REAL_ABS_SMALLER_THAN_IMAG_ABS:.*]] = arith.cmpf olt, %[[RHS_REAL_ABS]], %[[RHS_IMAG_ABS]] : f32 +// DIV-SMITH: %[[RESULT_REAL:.*]] = arith.select %[[REAL_ABS_SMALLER_THAN_IMAG_ABS]], %[[RESULT_REAL_1]], %[[RESULT_REAL_2]] : f32 +// DIV-SMITH: %[[RESULT_IMAG:.*]] = arith.select %[[REAL_ABS_SMALLER_THAN_IMAG_ABS]], %[[RESULT_IMAG_1]], %[[RESULT_IMAG_2]] : f32 +// DIV-SMITH: %[[RESULT_REAL_SPECIAL_CASE_3:.*]] = arith.select %[[FINITE_NUM_INFINITE_DENOM]], %[[RESULT_REAL_4]], %[[RESULT_REAL]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_SPECIAL_CASE_3:.*]] = arith.select %[[FINITE_NUM_INFINITE_DENOM]], %[[RESULT_IMAG_4]], %[[RESULT_IMAG]] : f32 +// DIV-SMITH: %[[RESULT_REAL_SPECIAL_CASE_2:.*]] = arith.select %[[INF_NUM_FINITE_DENOM]], %[[RESULT_REAL_3]], %[[RESULT_REAL_SPECIAL_CASE_3]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_SPECIAL_CASE_2:.*]] = arith.select %[[INF_NUM_FINITE_DENOM]], %[[RESULT_IMAG_3]], %[[RESULT_IMAG_SPECIAL_CASE_3]] : f32 +// DIV-SMITH: %[[RESULT_REAL_SPECIAL_CASE_1:.*]] = arith.select %[[RESULT_IS_INFINITY]], %[[INFINITY_RESULT_REAL]], %[[RESULT_REAL_SPECIAL_CASE_2]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_SPECIAL_CASE_1:.*]] = arith.select %[[RESULT_IS_INFINITY]], %[[INFINITY_RESULT_IMAG]], %[[RESULT_IMAG_SPECIAL_CASE_2]] : f32 +// DIV-SMITH: %[[RESULT_REAL_IS_NAN:.*]] = arith.cmpf uno, %[[RESULT_REAL]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[RESULT_IMAG]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RESULT_IS_NAN:.*]] = arith.andi %[[RESULT_REAL_IS_NAN]], %[[RESULT_IMAG_IS_NAN]] : i1 +// DIV-SMITH: %[[RESULT_REAL_WITH_SPECIAL_CASES:.*]] = arith.select %[[RESULT_IS_NAN]], %[[RESULT_REAL_SPECIAL_CASE_1]], %[[RESULT_REAL]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_WITH_SPECIAL_CASES:.*]] = arith.select %[[RESULT_IS_NAN]], %[[RESULT_IMAG_SPECIAL_CASE_1]], %[[RESULT_IMAG]] : f32 +// DIV-SMITH: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex +// DIV-SMITH: return %[[RESULT]] : complex + + +// DIV-ALGEBRAIC-LABEL: func @complex_div +// DIV-ALGEBRAIC-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex + +// DIV-ALGEBRAIC: %[[LHS_RE:.*]] = complex.re %[[LHS]] : complex +// DIV-ALGEBRAIC: %[[LHS_IM:.*]] = complex.im %[[LHS]] : complex +// DIV-ALGEBRAIC: %[[RHS_RE:.*]] = complex.re %[[RHS]] : complex +// DIV-ALGEBRAIC: %[[RHS_IM:.*]] = complex.im %[[RHS]] : complex + +// DIV-ALGEBRAIC-DAG: %[[RHS_RE_SQ:.*]] = arith.mulf %[[RHS_RE]], %[[RHS_RE]] : f32 +// DIV-ALGEBRAIC-DAG: %[[RHS_IM_SQ:.*]] = arith.mulf %[[RHS_IM]], %[[RHS_IM]] : f32 +// DIV-ALGEBRAIC: %[[SQ_NORM:.*]] = arith.addf %[[RHS_RE_SQ]], %[[RHS_IM_SQ]] : f32 + +// DIV-ALGEBRAIC-DAG: %[[REAL_TMP_0:.*]] = arith.mulf %[[LHS_RE]], %[[RHS_RE]] : f32 +// DIV-ALGEBRAIC-DAG: %[[REAL_TMP_1:.*]] = arith.mulf %[[LHS_IM]], %[[RHS_IM]] : f32 +// DIV-ALGEBRAIC: %[[REAL_TMP_2:.*]] = arith.addf %[[REAL_TMP_0]], %[[REAL_TMP_1]] : f32 + +// DIV-ALGEBRAIC-DAG: %[[IMAG_TMP_0:.*]] = arith.mulf %[[LHS_IM]], %[[RHS_RE]] : f32 +// DIV-ALGEBRAIC-DAG: %[[IMAG_TMP_1:.*]] = arith.mulf %[[LHS_RE]], %[[RHS_IM]] : f32 +// DIV-ALGEBRAIC: %[[IMAG_TMP_2:.*]] = arith.subf %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] : f32 + +// DIV-ALGEBRAIC: %[[REAL:.*]] = arith.divf %[[REAL_TMP_2]], %[[SQ_NORM]] : f32 +// DIV-ALGEBRAIC: %[[IMAG:.*]] = arith.divf %[[IMAG_TMP_2]], %[[SQ_NORM]] : f32 +// DIV-ALGEBRAIC: %[[RESULT:.*]] = complex.create %[[REAL]], %[[IMAG]] : complex +// DIV-ALGEBRAIC: return %[[RESULT]] : complex + + +func.func @complex_div_with_fmf(%lhs: complex, %rhs: complex) -> complex { + %div = complex.div %lhs, %rhs fastmath : complex + return %div : complex +} +// DIV-SMITH-LABEL: func @complex_div_with_fmf +// DIV-SMITH-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex + +// DIV-SMITH: %[[LHS_REAL:.*]] = complex.re %[[LHS]] : complex +// DIV-SMITH: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex +// DIV-SMITH: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex +// DIV-SMITH: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex + +// DIV-SMITH: %[[RHS_REAL_IMAG_RATIO:.*]] = arith.divf %[[RHS_REAL]], %[[RHS_IMAG]] fastmath : f32 +// DIV-SMITH: %[[RHS_REAL_TIMES_RHS_REAL_IMAG_RATIO:.*]] = arith.mulf %[[RHS_REAL_IMAG_RATIO]], %[[RHS_REAL]] fastmath : f32 +// DIV-SMITH: %[[RHS_REAL_IMAG_DENOM:.*]] = arith.addf %[[RHS_IMAG]], %[[RHS_REAL_TIMES_RHS_REAL_IMAG_RATIO]] fastmath : f32 +// DIV-SMITH: %[[LHS_REAL_TIMES_RHS_REAL_IMAG_RATIO:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL_IMAG_RATIO]] fastmath : f32 +// DIV-SMITH: %[[REAL_NUMERATOR_1:.*]] = arith.addf %[[LHS_REAL_TIMES_RHS_REAL_IMAG_RATIO]], %[[LHS_IMAG]] fastmath : f32 +// DIV-SMITH: %[[RESULT_REAL_1:.*]] = arith.divf %[[REAL_NUMERATOR_1]], %[[RHS_REAL_IMAG_DENOM]] fastmath : f32 +// DIV-SMITH: %[[LHS_IMAG_TIMES_RHS_REAL_IMAG_RATIO:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL_IMAG_RATIO]] fastmath : f32 +// DIV-SMITH: %[[IMAG_NUMERATOR_1:.*]] = arith.subf %[[LHS_IMAG_TIMES_RHS_REAL_IMAG_RATIO]], %[[LHS_REAL]] fastmath : f32 +// DIV-SMITH: %[[RESULT_IMAG_1:.*]] = arith.divf %[[IMAG_NUMERATOR_1]], %[[RHS_REAL_IMAG_DENOM]] fastmath : f32 + +// DIV-SMITH: %[[RHS_IMAG_REAL_RATIO:.*]] = arith.divf %[[RHS_IMAG]], %[[RHS_REAL]] fastmath : f32 +// DIV-SMITH: %[[RHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO:.*]] = arith.mulf %[[RHS_IMAG_REAL_RATIO]], %[[RHS_IMAG]] fastmath : f32 +// DIV-SMITH: %[[RHS_IMAG_REAL_DENOM:.*]] = arith.addf %[[RHS_REAL]], %[[RHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO]] fastmath : f32 +// DIV-SMITH: %[[LHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG_REAL_RATIO]] fastmath : f32 +// DIV-SMITH: %[[REAL_NUMERATOR_2:.*]] = arith.addf %[[LHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO]] fastmath : f32 +// DIV-SMITH: %[[RESULT_REAL_2:.*]] = arith.divf %[[REAL_NUMERATOR_2]], %[[RHS_IMAG_REAL_DENOM]] fastmath : f32 +// DIV-SMITH: %[[LHS_REAL_TIMES_RHS_IMAG_REAL_RATIO:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG_REAL_RATIO]] fastmath : f32 +// DIV-SMITH: %[[IMAG_NUMERATOR_2:.*]] = arith.subf %[[LHS_IMAG]], %[[LHS_REAL_TIMES_RHS_IMAG_REAL_RATIO]] fastmath : f32 +// DIV-SMITH: %[[RESULT_IMAG_2:.*]] = arith.divf %[[IMAG_NUMERATOR_2]], %[[RHS_IMAG_REAL_DENOM]] fastmath : f32 + +// Case 1. Zero denominator, numerator contains at most one NaN value. +// DIV-SMITH: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// DIV-SMITH: %[[RHS_REAL_ABS:.*]] = math.absf %[[RHS_REAL]] fastmath : f32 +// DIV-SMITH: %[[RHS_REAL_ABS_IS_ZERO:.*]] = arith.cmpf oeq, %[[RHS_REAL_ABS]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RHS_IMAG_ABS:.*]] = math.absf %[[RHS_IMAG]] fastmath : f32 +// DIV-SMITH: %[[RHS_IMAG_ABS_IS_ZERO:.*]] = arith.cmpf oeq, %[[RHS_IMAG_ABS]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_REAL_IS_NOT_NAN:.*]] = arith.cmpf ord, %[[LHS_REAL]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_NOT_NAN:.*]] = arith.cmpf ord, %[[LHS_IMAG]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_CONTAINS_NOT_NAN_VALUE:.*]] = arith.ori %[[LHS_REAL_IS_NOT_NAN]], %[[LHS_IMAG_IS_NOT_NAN]] : i1 +// DIV-SMITH: %[[RHS_IS_ZERO:.*]] = arith.andi %[[RHS_REAL_ABS_IS_ZERO]], %[[RHS_IMAG_ABS_IS_ZERO]] : i1 +// DIV-SMITH: %[[RESULT_IS_INFINITY:.*]] = arith.andi %[[LHS_CONTAINS_NOT_NAN_VALUE]], %[[RHS_IS_ZERO]] : i1 +// DIV-SMITH: %[[INF:.*]] = arith.constant 0x7F800000 : f32 +// DIV-SMITH: %[[INF_WITH_SIGN_OF_RHS_REAL:.*]] = math.copysign %[[INF]], %[[RHS_REAL]] : f32 +// DIV-SMITH: %[[INFINITY_RESULT_REAL:.*]] = arith.mulf %[[INF_WITH_SIGN_OF_RHS_REAL]], %[[LHS_REAL]] fastmath : f32 +// DIV-SMITH: %[[INFINITY_RESULT_IMAG:.*]] = arith.mulf %[[INF_WITH_SIGN_OF_RHS_REAL]], %[[LHS_IMAG]] fastmath : f32 + +// Case 2. Infinite numerator, finite denominator. +// DIV-SMITH: %[[RHS_REAL_FINITE:.*]] = arith.cmpf one, %[[RHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IMAG_FINITE:.*]] = arith.cmpf one, %[[RHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IS_FINITE:.*]] = arith.andi %[[RHS_REAL_FINITE]], %[[RHS_IMAG_FINITE]] : i1 +// DIV-SMITH: %[[LHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL]] fastmath : f32 +// DIV-SMITH: %[[LHS_REAL_INFINITE:.*]] = arith.cmpf oeq, %[[LHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG]] fastmath : f32 +// DIV-SMITH: %[[LHS_IMAG_INFINITE:.*]] = arith.cmpf oeq, %[[LHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IS_INFINITE:.*]] = arith.ori %[[LHS_REAL_INFINITE]], %[[LHS_IMAG_INFINITE]] : i1 +// DIV-SMITH: %[[INF_NUM_FINITE_DENOM:.*]] = arith.andi %[[LHS_IS_INFINITE]], %[[RHS_IS_FINITE]] : i1 +// DIV-SMITH: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF:.*]] = arith.select %[[LHS_REAL_INFINITE]], %[[ONE]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF_WITH_SIGN:.*]] = math.copysign %[[LHS_REAL_IS_INF]], %[[LHS_REAL]] : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF:.*]] = arith.select %[[LHS_IMAG_INFINITE]], %[[ONE]], %[[ZERO]] : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF_WITH_SIGN:.*]] = math.copysign %[[LHS_IMAG_IS_INF]], %[[LHS_IMAG]] : f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL_IS_INF_WITH_SIGN]], %[[RHS_REAL]] fastmath : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG_IS_INF_WITH_SIGN]], %[[RHS_IMAG]] fastmath : f32 +// DIV-SMITH: %[[INF_MULTIPLICATOR_1:.*]] = arith.addf %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_REAL]], %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_IMAG]] fastmath : f32 +// DIV-SMITH: %[[RESULT_REAL_3:.*]] = arith.mulf %[[INF]], %[[INF_MULTIPLICATOR_1]] fastmath : f32 +// DIV-SMITH: %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL_IS_INF_WITH_SIGN]], %[[RHS_IMAG]] fastmath : f32 +// DIV-SMITH: %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG_IS_INF_WITH_SIGN]], %[[RHS_REAL]] fastmath : f32 +// DIV-SMITH: %[[INF_MULTIPLICATOR_2:.*]] = arith.subf %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_REAL]], %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_IMAG]] fastmath : f32 +// DIV-SMITH: %[[RESULT_IMAG_3:.*]] = arith.mulf %[[INF]], %[[INF_MULTIPLICATOR_2]] fastmath : f32 + +// Case 3. Finite numerator, infinite denominator. +// DIV-SMITH: %[[LHS_REAL_FINITE:.*]] = arith.cmpf one, %[[LHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IMAG_FINITE:.*]] = arith.cmpf one, %[[LHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[LHS_IS_FINITE:.*]] = arith.andi %[[LHS_REAL_FINITE]], %[[LHS_IMAG_FINITE]] : i1 +// DIV-SMITH: %[[RHS_REAL_INFINITE:.*]] = arith.cmpf oeq, %[[RHS_REAL_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IMAG_INFINITE:.*]] = arith.cmpf oeq, %[[RHS_IMAG_ABS]], %[[INF]] : f32 +// DIV-SMITH: %[[RHS_IS_INFINITE:.*]] = arith.ori %[[RHS_REAL_INFINITE]], %[[RHS_IMAG_INFINITE]] : i1 +// DIV-SMITH: %[[FINITE_NUM_INFINITE_DENOM:.*]] = arith.andi %[[LHS_IS_FINITE]], %[[RHS_IS_INFINITE]] : i1 +// DIV-SMITH: %[[RHS_REAL_IS_INF:.*]] = arith.select %[[RHS_REAL_INFINITE]], %[[ONE]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RHS_REAL_IS_INF_WITH_SIGN:.*]] = math.copysign %[[RHS_REAL_IS_INF]], %[[RHS_REAL]] : f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF:.*]] = arith.select %[[RHS_IMAG_INFINITE]], %[[ONE]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF_WITH_SIGN:.*]] = math.copysign %[[RHS_IMAG_IS_INF]], %[[RHS_IMAG]] : f32 +// DIV-SMITH: %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL_IS_INF_WITH_SIGN]] fastmath : f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG_IS_INF_WITH_SIGN]] fastmath : f32 +// DIV-SMITH: %[[ZERO_MULTIPLICATOR_1:.*]] = arith.addf %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_REAL]], %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_IMAG]] fastmath : f32 +// DIV-SMITH: %[[RESULT_REAL_4:.*]] = arith.mulf %[[ZERO]], %[[ZERO_MULTIPLICATOR_1]] fastmath : f32 +// DIV-SMITH: %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL_IS_INF_WITH_SIGN]] fastmath : f32 +// DIV-SMITH: %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG_IS_INF_WITH_SIGN]] fastmath : f32 +// DIV-SMITH: %[[ZERO_MULTIPLICATOR_2:.*]] = arith.subf %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_IMAG]], %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_REAL]] fastmath : f32 +// DIV-SMITH: %[[RESULT_IMAG_4:.*]] = arith.mulf %[[ZERO]], %[[ZERO_MULTIPLICATOR_2]] fastmath : f32 + +// DIV-SMITH: %[[REAL_ABS_SMALLER_THAN_IMAG_ABS:.*]] = arith.cmpf olt, %[[RHS_REAL_ABS]], %[[RHS_IMAG_ABS]] : f32 +// DIV-SMITH: %[[RESULT_REAL:.*]] = arith.select %[[REAL_ABS_SMALLER_THAN_IMAG_ABS]], %[[RESULT_REAL_1]], %[[RESULT_REAL_2]] : f32 +// DIV-SMITH: %[[RESULT_IMAG:.*]] = arith.select %[[REAL_ABS_SMALLER_THAN_IMAG_ABS]], %[[RESULT_IMAG_1]], %[[RESULT_IMAG_2]] : f32 +// DIV-SMITH: %[[RESULT_REAL_SPECIAL_CASE_3:.*]] = arith.select %[[FINITE_NUM_INFINITE_DENOM]], %[[RESULT_REAL_4]], %[[RESULT_REAL]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_SPECIAL_CASE_3:.*]] = arith.select %[[FINITE_NUM_INFINITE_DENOM]], %[[RESULT_IMAG_4]], %[[RESULT_IMAG]] : f32 +// DIV-SMITH: %[[RESULT_REAL_SPECIAL_CASE_2:.*]] = arith.select %[[INF_NUM_FINITE_DENOM]], %[[RESULT_REAL_3]], %[[RESULT_REAL_SPECIAL_CASE_3]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_SPECIAL_CASE_2:.*]] = arith.select %[[INF_NUM_FINITE_DENOM]], %[[RESULT_IMAG_3]], %[[RESULT_IMAG_SPECIAL_CASE_3]] : f32 +// DIV-SMITH: %[[RESULT_REAL_SPECIAL_CASE_1:.*]] = arith.select %[[RESULT_IS_INFINITY]], %[[INFINITY_RESULT_REAL]], %[[RESULT_REAL_SPECIAL_CASE_2]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_SPECIAL_CASE_1:.*]] = arith.select %[[RESULT_IS_INFINITY]], %[[INFINITY_RESULT_IMAG]], %[[RESULT_IMAG_SPECIAL_CASE_2]] : f32 +// DIV-SMITH: %[[RESULT_REAL_IS_NAN:.*]] = arith.cmpf uno, %[[RESULT_REAL]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[RESULT_IMAG]], %[[ZERO]] : f32 +// DIV-SMITH: %[[RESULT_IS_NAN:.*]] = arith.andi %[[RESULT_REAL_IS_NAN]], %[[RESULT_IMAG_IS_NAN]] : i1 +// DIV-SMITH: %[[RESULT_REAL_WITH_SPECIAL_CASES:.*]] = arith.select %[[RESULT_IS_NAN]], %[[RESULT_REAL_SPECIAL_CASE_1]], %[[RESULT_REAL]] : f32 +// DIV-SMITH: %[[RESULT_IMAG_WITH_SPECIAL_CASES:.*]] = arith.select %[[RESULT_IS_NAN]], %[[RESULT_IMAG_SPECIAL_CASE_1]], %[[RESULT_IMAG]] : f32 +// DIV-SMITH: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex +// DIV-SMITH: return %[[RESULT]] : complex + + +// DIV-ALGEBRAIC-LABEL: func @complex_div_with_fmf +// DIV-ALGEBRAIC-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex + +// DIV-ALGEBRAIC: %[[LHS_RE:.*]] = complex.re %[[LHS]] : complex +// DIV-ALGEBRAIC: %[[LHS_IM:.*]] = complex.im %[[LHS]] : complex +// DIV-ALGEBRAIC: %[[RHS_RE:.*]] = complex.re %[[RHS]] : complex +// DIV-ALGEBRAIC: %[[RHS_IM:.*]] = complex.im %[[RHS]] : complex + +// DIV-ALGEBRAIC-DAG: %[[RHS_RE_SQ:.*]] = arith.mulf %[[RHS_RE]], %[[RHS_RE]] fastmath : f32 +// DIV-ALGEBRAIC-DAG: %[[RHS_IM_SQ:.*]] = arith.mulf %[[RHS_IM]], %[[RHS_IM]] fastmath : f32 +// DIV-ALGEBRAIC: %[[SQ_NORM:.*]] = arith.addf %[[RHS_RE_SQ]], %[[RHS_IM_SQ]] fastmath : f32 + +// DIV-ALGEBRAIC-DAG: %[[REAL_TMP_0:.*]] = arith.mulf %[[LHS_RE]], %[[RHS_RE]] fastmath : f32 +// DIV-ALGEBRAIC-DAG: %[[REAL_TMP_1:.*]] = arith.mulf %[[LHS_IM]], %[[RHS_IM]] fastmath : f32 +// DIV-ALGEBRAIC: %[[REAL_TMP_2:.*]] = arith.addf %[[REAL_TMP_0]], %[[REAL_TMP_1]] fastmath : f32 + +// DIV-ALGEBRAIC-DAG: %[[IMAG_TMP_0:.*]] = arith.mulf %[[LHS_IM]], %[[RHS_RE]] fastmath : f32 +// DIV-ALGEBRAIC-DAG: %[[IMAG_TMP_1:.*]] = arith.mulf %[[LHS_RE]], %[[RHS_IM]] fastmath : f32 +// DIV-ALGEBRAIC: %[[IMAG_TMP_2:.*]] = arith.subf %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] fastmath : f32 + +// DIV-ALGEBRAIC: %[[REAL:.*]] = arith.divf %[[REAL_TMP_2]], %[[SQ_NORM]] fastmath : f32 +// DIV-ALGEBRAIC: %[[IMAG:.*]] = arith.divf %[[IMAG_TMP_2]], %[[SQ_NORM]] fastmath : f32 +// DIV-ALGEBRAIC: %[[RESULT:.*]] = complex.create %[[REAL]], %[[IMAG]] : complex +// DIV-ALGEBRAIC: return %[[RESULT]] : complex