diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td index 52fd824f65e74..a47dc9927dd8c 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td @@ -57,8 +57,8 @@ def Complex_NumberAttr : Complex_Attr<"Number", "number", ]; let extraClassDeclaration = [{ - std::complex getValue() { - return std::complex(getReal(), getImag()); + mlir::Complex getValue() { + return mlir::Complex(getReal(), getImag()); } }]; diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h index 714e664dd0f4e..7e2190dc28084 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h @@ -22,6 +22,7 @@ #include "mlir/ExecutionEngine/SparseTensor/MapRef.h" #include "mlir/ExecutionEngine/SparseTensor/Storage.h" +#include "mlir/Support/Complex.h" #include @@ -36,6 +37,9 @@ struct is_complex final : public std::false_type {}; template struct is_complex> final : public std::true_type {}; +template +struct is_complex> final : public std::true_type {}; + /// Returns an element-value of non-complex type. If `IsPattern` is true, /// then returns an arbitrary value. If `IsPattern` is false, then /// reads the value from the current line buffer beginning at `linePtr`. diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h index 1f805882db276..c7eddf44fb29b 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -10,9 +10,9 @@ #define MLIR_IR_BUILTINATTRIBUTES_H #include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/Support/Complex.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/Sequence.h" -#include #include namespace mlir { @@ -75,6 +75,8 @@ template struct is_complex_t : public std::false_type {}; template struct is_complex_t> : public std::true_type {}; +template +struct is_complex_t> : public std::true_type {}; } // namespace detail /// An attribute that represents a reference to a dense vector or tensor @@ -167,7 +169,7 @@ class DenseElementsAttr : public Attribute { /// element type of 'type'. 'type' must be a vector or tensor with static /// shape. static DenseElementsAttr get(ShapedType type, - ArrayRef> values); + ArrayRef> values); /// Constructs a dense float elements attribute from an array of APFloat /// values. Each APFloat value is expected to have the same bitwidth as the @@ -180,7 +182,7 @@ class DenseElementsAttr : public Attribute { /// element type of 'type'. 'type' must be a vector or tensor with static /// shape. static DenseElementsAttr get(ShapedType type, - ArrayRef> values); + ArrayRef> values); /// Construct a dense elements attribute for an initializer_list of values. /// Each value is expected to be the same bitwidth of the element type of @@ -298,11 +300,11 @@ class DenseElementsAttr : public Attribute { /// values. class ComplexIntElementIterator : public detail::DenseElementIndexedIteratorImpl< - ComplexIntElementIterator, std::complex, std::complex, - std::complex> { + ComplexIntElementIterator, mlir::Complex, + mlir::Complex, mlir::Complex> { public: - /// Accesses the raw std::complex value at this iterator position. - std::complex operator*() const; + /// Accesses the raw mlir::Complex value at this iterator position. + mlir::Complex operator*() const; private: friend DenseElementsAttr; @@ -339,10 +341,10 @@ class DenseElementsAttr : public Attribute { class ComplexFloatElementIterator final : public llvm::mapped_iterator_base> { + mlir::Complex> { public: /// Map the element to the iterator result type. - std::complex mapElement(const std::complex &value) const { + mlir::Complex mapElement(const mlir::Complex &value) const { return {APFloat(*smt, value.real()), APFloat(*smt, value.imag())}; } @@ -442,7 +444,7 @@ class DenseElementsAttr : public Attribute { ElementIterator(rawData, splat, getNumElements())); } - /// Try to get the held element values as a range of std::complex. + /// Try to get the held element values as a range of mlir::Complex. template using ComplexValueTemplateCheckT = std::enable_if_t::value && @@ -545,7 +547,7 @@ class DenseElementsAttr : public Attribute { /// element type of this attribute must be a complex of integer type. template using ComplexAPIntValueTemplateCheckT = - std::enable_if_t>::value>; + std::enable_if_t>::value>; template > FailureOr> tryGetValues() const { @@ -566,7 +568,7 @@ class DenseElementsAttr : public Attribute { /// element type of this attribute must be a complex of float type. template using ComplexAPFloatValueTemplateCheckT = - std::enable_if_t>::value>; + std::enable_if_t>::value>; template > FailureOr> tryGetValues() const { diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index 299200788136a..6165a24c0d34f 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -296,19 +296,19 @@ def Builtin_DenseTypedElementsAttr : Builtin_Attr< uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t, short, unsigned short, int, unsigned, long, unsigned long, - std::complex, std::complex, std::complex, - std::complex, - std::complex, std::complex, std::complex, - std::complex, + mlir::Complex, mlir::Complex, mlir::Complex, + mlir::Complex, + mlir::Complex, mlir::Complex, mlir::Complex, + mlir::Complex, // Float types. - float, double, std::complex, std::complex + float, double, mlir::Complex, mlir::Complex >; using NonContiguousIterableTypesT = std::tuple< Attribute, // Integer types. - APInt, bool, std::complex, + APInt, bool, mlir::Complex, // Float types. - APFloat, std::complex + APFloat, mlir::Complex >; /// Provide a `try_value_begin_impl` to enable iteration within @@ -931,12 +931,12 @@ def Builtin_SparseElementsAttr : Builtin_Attr< APInt, bool, uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t, short, unsigned short, int, unsigned, long, unsigned long, - std::complex, std::complex, std::complex, - std::complex, std::complex, std::complex, - std::complex, std::complex, std::complex, + mlir::Complex, mlir::Complex, mlir::Complex, + mlir::Complex, mlir::Complex, mlir::Complex, + mlir::Complex, mlir::Complex, mlir::Complex, // Float types. APFloat, float, double, - std::complex, std::complex, std::complex, + mlir::Complex, mlir::Complex, mlir::Complex, // String types. StringRef >; @@ -978,7 +978,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr< return getZeroAPInt(); } template - std::enable_if_t, T>::value, T> + std::enable_if_t, T>::value, T> getZeroValue() const { APInt intZero = getZeroAPInt(); return {intZero, intZero}; @@ -990,7 +990,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr< return getZeroAPFloat(); } template - std::enable_if_t, T>::value, T> + std::enable_if_t, T>::value, T> getZeroValue() const { APFloat floatZero = getZeroAPFloat(); return {floatZero, floatZero}; @@ -1002,8 +1002,8 @@ def Builtin_SparseElementsAttr : Builtin_Attr< DenseElementsAttr::is_valid_cpp_fp_type::value || std::is_same::value || (detail::is_complex_t::value && - !llvm::is_one_of, - std::complex>::value), + !llvm::is_one_of, + mlir::Complex>::value), T> getZeroValue() const { return T(); diff --git a/mlir/include/mlir/Support/Complex.h b/mlir/include/mlir/Support/Complex.h new file mode 100644 index 0000000000000..86c06f230b8ef --- /dev/null +++ b/mlir/include/mlir/Support/Complex.h @@ -0,0 +1,269 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of the mlir::NonFloatComplex type and +/// mlir::Complex type alias. The interface is intended to match the +/// std::complex type, and the mlir::Complex alias defers to std::complex for +/// builtin floating point types. +/// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SUPPORT_COMPLEX_H +#define MLIR_SUPPORT_COMPLEX_H + +#include +#include + +namespace mlir { + +// The copy constructors should only be implicit iff the underlying constructors +// are explicit and the conversion would not narrow. This is the case if the +// underlying destination type is copy-list-initializeable from the source type, +// so define a helper to determine if that is the case. +namespace detail { +// NOLINTBEGIN +template +auto test_copy_list_initializable(int) + -> decltype(void(std::declval() = {std::declval()}), + std::true_type{}); + +template +auto test_copy_list_initializable(...) -> std::false_type; + +template +struct is_copy_list_initializable + : std::bool_constant< + decltype(detail::test_copy_list_initializable(0))::value> { +}; + +template +constexpr bool is_copy_list_initializable_v = + is_copy_list_initializable::value; +// NOLINTEND +} // namespace detail + +template +class NonFloatComplex { +public: + using value_type = T; + +private: + T re; + T im; + +public: + constexpr NonFloatComplex(const T &re = T{}, const T &im = T{}) + : re(re), im(im) {} + + constexpr NonFloatComplex(const NonFloatComplex &other) = default; + + template >...> + constexpr NonFloatComplex(const NonFloatComplex &other) + : re{other.re}, im{other.im} {} + + template >...> + constexpr explicit NonFloatComplex(const NonFloatComplex &other) + : re(other.re), im(other.im) {} + + template >...> + constexpr NonFloatComplex(const std::complex &other) + : re{other.real()}, im{other.imag()} {} + + template >...> + constexpr explicit NonFloatComplex(const std::complex &other) + : re(other.real()), im(other.imag()) {} + + [[nodiscard]] constexpr T real() const { return re; } + constexpr void real(T value) { re = value; } + [[nodiscard]] constexpr T imag() const { return im; } + constexpr void imag(T value) { im = value; } + + constexpr NonFloatComplex &operator=(const NonFloatComplex &other) = default; + + constexpr NonFloatComplex &operator=(const T &real) { + re = real; + im = T{}; + return *this; + } + + constexpr NonFloatComplex &operator+=(const T &real) { + re += real; + return *this; + } + + constexpr NonFloatComplex &operator-=(const T &real) { + re -= real; + return *this; + } + + constexpr NonFloatComplex &operator*=(const T &real) { + re *= real; + im *= real; + return *this; + } + + constexpr NonFloatComplex &operator/=(const T &real) { + re /= real; + im /= real; + return *this; + } + + constexpr NonFloatComplex &operator+=(const NonFloatComplex &other) { + re += other.re; + im += other.im; + return *this; + } + + constexpr NonFloatComplex &operator-=(const NonFloatComplex &other) { + re -= other.re; + im -= other.im; + return *this; + } + + constexpr NonFloatComplex &operator*=(const NonFloatComplex &other) { + *this = *this * NonFloatComplex{other.re, other.im}; + return *this; + } + + constexpr NonFloatComplex &operator/=(const NonFloatComplex &other) { + *this = *this / NonFloatComplex{other.re, other.im}; + return *this; + } + + template + constexpr NonFloatComplex &operator=(const std::complex &other) { + re = other.real(); + im = other.imag(); + return *this; + } +}; + +template +[[nodiscard]] constexpr NonFloatComplex +operator+(const NonFloatComplex &x, const U &y) { + NonFloatComplex t{x}; + t += y; + return t; +} + +template +[[nodiscard]] constexpr NonFloatComplex +operator-(const NonFloatComplex &x, const U &y) { + NonFloatComplex t{x}; + t -= y; + return t; +} + +template +[[nodiscard]] constexpr NonFloatComplex +operator*(const NonFloatComplex &x, const NonFloatComplex &y) { + T a = x.real(); + T b = x.imag(); + T c = y.real(); + T d = y.imag(); + + return {(a * c) - (b * d), (a * d) + (b * c)}; +} + +template +[[nodiscard]] constexpr NonFloatComplex +operator*(const NonFloatComplex &x, const U &y) { + NonFloatComplex t{x}; + t *= y; + return t; +} + +template +[[nodiscard]] constexpr NonFloatComplex +operator/(const NonFloatComplex &x, const NonFloatComplex &y) { + T a = x.real(); + T b = x.imag(); + T c = y.real(); + T d = y.imag(); + + T denom = c * c + d * d; + return {(a * c + b * d) / denom, (b * c - a * d) / denom}; +} + +template +[[nodiscard]] constexpr NonFloatComplex +operator/(const NonFloatComplex &x, const U &y) { + NonFloatComplex t{x}; + t /= y; + return t; +} + +template +[[nodiscard]] constexpr NonFloatComplex +operator+(const NonFloatComplex &x) { + return x; +} + +template +[[nodiscard]] constexpr NonFloatComplex +operator-(const NonFloatComplex &x) { + return {-x.real(), -x.imag()}; +} + +template +[[nodiscard]] constexpr bool operator==(const NonFloatComplex &x, + const NonFloatComplex &y) { + return x.real() == y.real() && x.imag() == y.imag(); +} + +template +[[nodiscard]] constexpr bool operator==(const NonFloatComplex &x, + const U &y) { + return x == NonFloatComplex{y}; +} + +template +[[nodiscard]] constexpr bool operator==(const T &x, + const NonFloatComplex &y) { + return NonFloatComplex{x} == y; +} + +template +[[nodiscard]] constexpr bool operator!=(const NonFloatComplex &x, + const NonFloatComplex &y) { + return !(x == y); +} + +template +[[nodiscard]] constexpr bool operator!=(const NonFloatComplex &x, + const U &y) { + return !(x == y); +} + +template +[[nodiscard]] constexpr bool operator!=(const U &x, + const NonFloatComplex &y) { + return !(y == x); +} + +template +[[nodiscard]] constexpr T real(const NonFloatComplex &x) { + return x.real(); +} + +template +[[nodiscard]] constexpr T imag(const NonFloatComplex &x) { + return x.imag(); +} + +template +using Complex = std::conditional_t, std::complex, + NonFloatComplex>; +} // namespace mlir + +#endif diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index d7075b795ccb9..ca8e4ae2cecbc 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -592,7 +592,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { if (isComplex) { // If this is a complex, treat the parsed values as complex values. auto complexData = llvm::ArrayRef( - reinterpret_cast *>(intValues.data()), + reinterpret_cast *>(intValues.data()), intValues.size() / 2); return DenseElementsAttr::get(type, complexData); } @@ -606,7 +606,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { if (isComplex) { // If this is a complex, treat the parsed values as complex values. auto complexData = llvm::ArrayRef( - reinterpret_cast *>(floatValues.data()), + reinterpret_cast *>(floatValues.data()), floatValues.size() / 2); return DenseElementsAttr::get(type, complexData); } diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp index cdcb3cba55752..18b3c030090e2 100644 --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -151,8 +151,8 @@ struct ImOpConversion : public ConvertOpToLLVMPattern { }; struct BinaryComplexOperands { - std::complex lhs; - std::complex rhs; + mlir::Complex lhs; + mlir::Complex rhs; }; template diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 75008d6cc2591..ec270db189081 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2688,7 +2688,7 @@ void AsmPrinter::Impl::printDenseTypedElementsAttr(DenseTypedElementsAttr attr, // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 // and hence was replaced. if (llvm::isa(complexElementType)) { - auto valueIt = attr.value_begin>(); + auto valueIt = attr.value_begin>(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { auto complexValue = *(valueIt + index); os << "("; @@ -2698,7 +2698,7 @@ void AsmPrinter::Impl::printDenseTypedElementsAttr(DenseTypedElementsAttr attr, os << ")"; }); } else { - auto valueIt = attr.value_begin>(); + auto valueIt = attr.value_begin>(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { auto complexValue = *(valueIt + index); os << "("; diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index c06ae5b178624..9fda2ef8e5059 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -649,15 +649,15 @@ APInt DenseElementsAttr::IntElementIterator::operator*() const { DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( DenseElementsAttr attr, size_t dataIndex) - : DenseElementIndexedIteratorImpl, std::complex, - std::complex>( - attr.getRawData().data(), attr.isSplat(), dataIndex) { + : DenseElementIndexedIteratorImpl< + ComplexIntElementIterator, mlir::Complex, mlir::Complex, + mlir::Complex>(attr.getRawData().data(), attr.isSplat(), + dataIndex) { auto complexType = llvm::cast(attr.getElementType()); bitWidth = getDenseElementBitWidth(complexType.getElementType()); } -std::complex +mlir::Complex DenseElementsAttr::ComplexIntElementIterator::operator*() const { size_t storageWidth = getDenseElementStorageWidth(bitWidth); size_t offset = getDataIndex() * storageWidth * 2; @@ -922,8 +922,8 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); return DenseTypedElementsAttr::getRaw(type, storageBitWidth, values); } -DenseElementsAttr DenseElementsAttr::get(ShapedType type, - ArrayRef> values) { +DenseElementsAttr +DenseElementsAttr::get(ShapedType type, ArrayRef> values) { ComplexType complex = llvm::cast(type.getElementType()); assert(llvm::isa(complex.getElementType())); assert(hasSameNumElementsOrSplat(type, values)); @@ -945,7 +945,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, } DenseElementsAttr DenseElementsAttr::get(ShapedType type, - ArrayRef> values) { + ArrayRef> values) { ComplexType complex = llvm::cast(type.getElementType()); assert(llvm::isa(complex.getElementType())); assert(hasSameNumElementsOrSplat(type, values)); diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index 900cacabd592e..642f3dd717b05 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -191,28 +191,28 @@ TEST(DenseSplatTest, StringAttrSplat) { TEST(DenseComplexTest, ComplexFloatSplat) { MLIRContext context; ComplexType complexType = ComplexType::get(Float32Type::get(&context)); - std::complex value(10.0, 15.0); + mlir::Complex value(10.0, 15.0); testSplat(complexType, value); } TEST(DenseComplexTest, ComplexIntSplat) { MLIRContext context; ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); - std::complex value(10, 15); + mlir::Complex value(10, 15); testSplat(complexType, value); } TEST(DenseComplexTest, ComplexAPFloatSplat) { MLIRContext context; ComplexType complexType = ComplexType::get(Float32Type::get(&context)); - std::complex value(APFloat(10.0f), APFloat(15.0f)); + mlir::Complex value(APFloat(10.0f), APFloat(15.0f)); testSplat(complexType, value); } TEST(DenseComplexTest, ComplexAPIntSplat) { MLIRContext context; ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); - std::complex value(APInt(64, 10), APInt(64, 15)); + mlir::Complex value(APInt(64, 10), APInt(64, 15)); testSplat(complexType, value); } diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt index 3a6365b401d49..f4a5ab7d5b9a0 100644 --- a/mlir/unittests/Support/CMakeLists.txt +++ b/mlir/unittests/Support/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(MLIRSupportTests CyclicReplacerCacheTest.cpp + ComplexTest.cpp IndentedOstreamTest.cpp StorageUniquerTest.cpp ) diff --git a/mlir/unittests/Support/ComplexTest.cpp b/mlir/unittests/Support/ComplexTest.cpp new file mode 100644 index 0000000000000..a91199727737e --- /dev/null +++ b/mlir/unittests/Support/ComplexTest.cpp @@ -0,0 +1,256 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// This file contains the tests for the mlir::NonFloatComplex type. +/// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/Complex.h" +#include "gtest/gtest.h" + +namespace mlir { +// Provide ostream operator so that tests pretty print NonFloatComplex values +template +static std::ostream &operator<<(std::ostream &os, const NonFloatComplex c) { + os << "(" << c.real() << "," << c.imag() << ")"; + return os; +} + +} // namespace mlir + +// The majority of these tests just check that NonFloatComplex does exactly the +// same as std::complex. + +TEST(ComplexTest, Typedef) { + EXPECT_TRUE((std::is_same_v, std::complex>)); + + EXPECT_TRUE((std::is_same_v, mlir::NonFloatComplex>)); +} + +TEST(ComplexTest, DefaultConstructor) { + mlir::NonFloatComplex mc; + std::complex sc; + EXPECT_EQ(mc, sc); +} + +TEST(ComplexTest, RealConstructor) { + mlir::NonFloatComplex mc{10}; + std::complex sc{10}; + EXPECT_EQ(mc, sc); +} + +TEST(ComplexTest, MemberConstructor) { + mlir::NonFloatComplex mc{10, 20}; + std::complex sc{10, 20}; + EXPECT_EQ(mc, sc); +} + +TEST(ComplexTest, ExplicitCopyConstructor) { + std::complex sc{5, 10}; + mlir::NonFloatComplex mc{sc}; + EXPECT_EQ(mc, sc); + + // check the explicit constructors were used + EXPECT_FALSE((std::is_convertible_v)); +} + +TEST(ComplexTest, ImplicitCopyConstructor) { + std::complex sc{}; + mlir::NonFloatComplex mc = sc; + EXPECT_EQ(mc, sc); + + // check the implicit constructors were used + EXPECT_TRUE((std::is_convertible_v)); +} + +TEST(ComplexTest, RealAccessor) { + mlir::NonFloatComplex mc{5}; + std::complex sc{5}; + EXPECT_EQ(mc.real(), sc.real()); +} + +TEST(ComplexTest, RealSetter) { + mlir::NonFloatComplex mc{5}; + mc.real(7); + std::complex sc{5}; + sc.real(7); + EXPECT_EQ(mc.real(), sc.real()); +} + +TEST(ComplexTest, ImagAccessor) { + mlir::NonFloatComplex mc{2, 5}; + std::complex sc{2, 5}; + EXPECT_EQ(mc.imag(), sc.imag()); +} + +TEST(ComplexTest, ImagSetter) { + mlir::NonFloatComplex mc{2, 5}; + mc.imag(8); + std::complex sc{2, 5}; + sc.imag(8); + EXPECT_EQ(mc.imag(), sc.imag()); +} + +TEST(ComplexTest, CopyAssignment) { + mlir::NonFloatComplex mc{2, 5}; + mlir::NonFloatComplex mc2 = mc; + + EXPECT_EQ(mc, mc2); +} + +TEST(ComplexTest, StdCopyAssignment) { + std::complex sc{2, 5}; + mlir::NonFloatComplex mc = sc; + + EXPECT_EQ(mc, sc); +} + +TEST(ComplexTest, RealAssignment) { + std::complex sc = 2.f; + mlir::NonFloatComplex mc = 2.f; + + EXPECT_EQ(mc, sc); +} + +TEST(ComplexTest, PlusEqualsReal) { + mlir::NonFloatComplex mc{2, 5}; + mc += 7; + std::complex sc{2, 5}; + sc += 7; + + EXPECT_EQ(mc, sc); +} + +TEST(ComplexTest, MinusEqualsReal) { + mlir::NonFloatComplex mc{3, 6}; + mc -= 8; + std::complex sc{3, 6}; + sc -= 8; + + EXPECT_EQ(mc, sc); +} + +TEST(ComplexTest, TimesEqualsReal) { + mlir::NonFloatComplex mc{1, 4}; + mc *= 2; + std::complex sc{1, 4}; + sc *= 2; + + EXPECT_EQ(mc, sc); +} + +TEST(ComplexTest, DivideEqualsReal) { + mlir::NonFloatComplex mc{1, 4}; + mc /= 2; + std::complex sc{1, 4}; + sc /= 2; + + EXPECT_EQ(mc, sc); +} + +TEST(ComplexTest, AssignmentOp) { + mlir::NonFloatComplex mc{2, 5}; + mlir::NonFloatComplex mc2 = mc; + + EXPECT_EQ(mc, mc2); +} + +TEST(ComplexTest, StdAssignmentOp) { + + std::complex sc{2, 5}; + mlir::NonFloatComplex mc = sc; + + EXPECT_EQ(mc, sc); +} + +TEST(ComplexTest, AddOp) { + mlir::NonFloatComplex mc1{2, 5}; + mlir::NonFloatComplex mc2{3, 7}; + std::complex sc1{2, 5}; + std::complex sc2{3, 7}; + + EXPECT_EQ(mc1 + mc2, sc1 + sc2); + EXPECT_EQ(mc1 + 5.f, sc1 + 5.f); +} + +TEST(ComplexTest, MinusOp) { + mlir::NonFloatComplex mc1{2, 5}; + mlir::NonFloatComplex mc2{3, 7}; + std::complex sc1{2, 5}; + std::complex sc2{3, 7}; + + EXPECT_EQ(mc1 - mc2, sc1 - sc2); + EXPECT_EQ(mc1 - 5.f, sc1 - 5.f); +} + +TEST(ComplexTest, TimesOp) { + mlir::NonFloatComplex mc1{2, 5}; + mlir::NonFloatComplex mc2{3, 7}; + std::complex sc1{2, 5}; + std::complex sc2{3, 7}; + + EXPECT_EQ(mc1 * mc2, sc1 * sc2); + EXPECT_EQ(mc1 * 5.f, sc1 * 5.f); +} + +TEST(ComplexTest, DivideOp) { + mlir::NonFloatComplex mc1{5, 10}; + mlir::NonFloatComplex mc2{3, 4}; + std::complex sc1{5, 10}; + std::complex sc2{3, 4}; + + EXPECT_EQ(mc1 / mc2, sc1 / sc2); + EXPECT_EQ(mc1 / 5.f, sc1 / 5.f); +} + +TEST(ComplexTest, EqualityOp) { + mlir::NonFloatComplex mc1{3, 4}; + mlir::NonFloatComplex mc2{3, 4}; + + EXPECT_EQ(mc1, mc2); + EXPECT_EQ(mc2, mc1); +} + +TEST(ComplexTest, StdEqualityOp) { + mlir::NonFloatComplex mc{7, 8}; + std::complex sc{7, 8}; + + EXPECT_EQ(mc, sc); + EXPECT_EQ(sc, mc); +} + +TEST(ComplexTest, InequalityOp) { + mlir::NonFloatComplex mc1{3, 4}; + mlir::NonFloatComplex mc2{7, 8}; + + EXPECT_NE(mc1, mc2); + EXPECT_NE(mc2, mc1); +} + +TEST(ComplexTest, StdInequalityOp) { + mlir::NonFloatComplex mc{7, 8}; + std::complex sc{3, 4}; + + EXPECT_NE(mc, sc); + EXPECT_NE(sc, mc); +} + +TEST(ComplexTest, RealFn) { + mlir::NonFloatComplex mc{4, 6}; + std::complex sc{4, 6}; + + EXPECT_EQ(real(mc), real(sc)); +} + +TEST(ComplexTest, ImagFn) { + mlir::NonFloatComplex mc{4, 6}; + std::complex sc{4, 6}; + + EXPECT_EQ(imag(mc), imag(sc)); +}