diff --git a/source/module_base/complexarray.cpp b/source/module_base/complexarray.cpp index f90205e253..a2708e09cf 100644 --- a/source/module_base/complexarray.cpp +++ b/source/module_base/complexarray.cpp @@ -126,7 +126,7 @@ ComplexArray &ComplexArray::operator=(const ComplexArray & cd) } // Assignment of scalar: all entries set to c -inline void ComplexArray::operator=(const std::complex < double> c) +void ComplexArray::operator=(const std::complex < double> c) { const int size = this->getSize(); for (int i = 0; i < size; i++) @@ -239,6 +239,50 @@ void ComplexArray::operator*=(double r) ptr[i] *= r; } +/* Judge if two ComplexArray is equal */ +bool ComplexArray::operator==(const ComplexArray &cd2)const +{ + const int size1 = this->getSize(); + const int size2 = cd2.getSize(); + const int b11 = this->getBound1(); + const int b12 = this->getBound2(); + const int b13 = this->getBound3(); + const int b14 = this->getBound4(); + const int b21 = cd2.getBound1(); + const int b22 = cd2.getBound2(); + const int b23 = cd2.getBound3(); + const int b24 = cd2.getBound4(); + if (size1 != size2) {return false;} + if (b11 != b21) {return false;} + if (b12 != b22) {return false;} + if (b13 != b23) {return false;} + if (b14 != b24) {return false;} + for ( int i = 0;i ptr[i] != cd2.ptr[i]) {return false;} } + return true; +} + +/* Judge if two ComplexArray is not equal */ +bool ComplexArray::operator!=(const ComplexArray &cd2)const +{ + const int size1 = this->getSize(); + const int size2 = cd2.getSize(); + const int b11 = this->getBound1(); + const int b12 = this->getBound2(); + const int b13 = this->getBound3(); + const int b14 = this->getBound4(); + const int b21 = cd2.getBound1(); + const int b22 = cd2.getBound2(); + const int b23 = cd2.getBound3(); + const int b24 = cd2.getBound4(); + if (size1 != size2) {return true;} + if (b11 != b21) {return true;} + if (b12 != b22) {return true;} + if (b13 != b23) {return true;} + if (b14 != b24) {return true;} + for ( int i = 0;i ptr[i] != cd2.ptr[i]) {return true;} } + return false; +} + ///////////////////////////////////////////// // // // MEMBER FUNCTIONS: // diff --git a/source/module_base/complexarray.h b/source/module_base/complexarray.h index 959549e983..ee2b278906 100644 --- a/source/module_base/complexarray.h +++ b/source/module_base/complexarray.h @@ -32,7 +32,7 @@ class ComplexArray // void release_temp(); ComplexArray &operator=(const ComplexArray &cd); - inline void operator=(std::complex c); + void operator=(std::complex c); // inline std::complex < double> &operator()(int i, int j, int k)const //{return d[(i * bound2 + j) * bound3 +k];} // inline void operator=(std::complex < double> c); @@ -47,6 +47,8 @@ class ComplexArray void operator*=(const std::complex < double> c); void operator*=(const ComplexArray &in); + bool operator== (const ComplexArray &cd2)const; + bool operator!= (const ComplexArray &cd2)const; // subscript operator std::complex < double> &operator() diff --git a/source/module_base/test/CMakeLists.txt b/source/module_base/test/CMakeLists.txt index 3c09e07a40..7c3188e062 100644 --- a/source/module_base/test/CMakeLists.txt +++ b/source/module_base/test/CMakeLists.txt @@ -1,10 +1,12 @@ AddTest( TARGET base_matrix3 - LIBS base parallel SOURCES matrix3_test.cpp ) AddTest( TARGET base_blas_connector - LIBS base parallel SOURCES blas_connector_test.cpp ) +AddTest( + TARGET base_complexarray + SOURCES complexarray_test.cpp +) diff --git a/source/module_base/test/complexarray_test.cpp b/source/module_base/test/complexarray_test.cpp new file mode 100644 index 0000000000..eef4016259 --- /dev/null +++ b/source/module_base/test/complexarray_test.cpp @@ -0,0 +1,169 @@ +#include"../complexarray.h" +#include"gtest/gtest.h" + +/************************************************ +* unit test of class ComplexArray +***********************************************/ + +/** +* - Tested Function: +* - operator "=": +* - assign a complex to all elements of a ComplexArray +* - assign a ComplexArray to another ComplexArray +* +* - operator "+": +* - one ComplexArray plus another ComplexArray that has +* the same dimension. +* - throw error when one ComplexArray plus another ComplexArray +* that has different dimension. +* +* - operator "+=": +* - one ComplexArray plus another ComplexArray, and assign the +* value to first ComplexArray +* - throw error when one ComplexArray plus another ComplexArray +* that has different dimension. +* +* - operator "-": +* - one ComplexArray minus another ComplexArray that has +* the same dimension. +* - throw error when one ComplexArray minus another ComplexArray +* that has different dimension. +* +* - operator "-=": +* - one ComplexArray minus another ComplexArray, and assign the +* value to first ComplexArray +* - throw error when one ComplexArray minus another ComplexArray +* that has different dimension. +* +* - operator "*": +* - one ComplexArray is multiplied by a double +* - one ComplexArray is multiplied by a complex +* - one ComplexArray is multiplied by another ComplexArray that has same dimension +* - throw error when one ComplexArray is miltiplied by another ComplexArray +* that has different dimension. +* +* - operator "*=": +* similar as "*" +* +* - oprator "()": +* - access the element +*/ + + +class ComplexArray_test : public testing::Test +{ + protected: + ModuleBase::ComplexArray a2,a4,b2,b4,c2,c4,d2; + ModuleBase::ComplexArray a2_plus_b2,a2_minus_b2,a2_mult_b2,a2_mult_3; + std::complex com1 {1.0,2.0}; + std::complex com2 {3.0,4.0}; + + void SetUp() + { + a2 = ModuleBase::ComplexArray(2,1,1,1); // if define this class as a matrix + b2 = ModuleBase::ComplexArray(2,1,1,1); // of 4 dimesions, + c2 = ModuleBase::ComplexArray(2,1,1,1); // is a2 +/-/* d2 allowed or not??? + d2 = ModuleBase::ComplexArray(1,2,1,1); // it does not matter, this situation will not appear in ABACUS + a2_plus_b2 = ModuleBase::ComplexArray(2,1,1,1); + a2_minus_b2 = ModuleBase::ComplexArray(2,1,1,1); + a2_mult_b2 = ModuleBase::ComplexArray(2,1,1,1); + a2_mult_3 = ModuleBase::ComplexArray(2,1,1,1); + a4 = ModuleBase::ComplexArray(2,2,1,1); + b4 = ModuleBase::ComplexArray(2,2,1,1); + c4 = ModuleBase::ComplexArray(2,2,1,1); + a2 = com1; + b2 = com2; + a4 = com1; + b4 = com2; + d2 = com1; + a2_plus_b2 = com1 + com2; + a2_minus_b2 = com1 - com2; + a2_mult_b2 = com1 * com2; + a2_mult_3 = com1 * 3.0; + } + +}; + +TEST_F(ComplexArray_test,operator_equal) +{ + c2 = a2; //c2 is just constructed as a CompleArray, but not assigned by any value + b2 = a2; //b2 is constructed as a ComplexArray, and is also assigned by a complex + EXPECT_EQ(c2,a2); + EXPECT_EQ(b2,a2); + EXPECT_NE(c2,a4); + EXPECT_DEATH(c2=a4,""); +} + +TEST_F(ComplexArray_test,operator_plus) +{ + c2 = a2 + b2; + EXPECT_EQ(c2,a2_plus_b2); + EXPECT_DEATH(a2+a4,""); + //EXPECT_DEATH(a2+d2,""); +} + +TEST_F(ComplexArray_test,operator_plus_equal) +{ + a2 += b2; + EXPECT_EQ(a2,a2_plus_b2); + EXPECT_DEATH(a2+=a4,""); + //EXPECT_DEATH(a2+=d2,""); +} + +TEST_F(ComplexArray_test,operator_minus) +{ + c2 = a2 - b2; + EXPECT_EQ(c2,a2_minus_b2); + EXPECT_DEATH(a2-a4,""); + //EXPECT_DEATH(a2-d2,""); +} + +TEST_F(ComplexArray_test,operator_minus_equal) +{ + a2 -= b2; + EXPECT_EQ(a2,a2_minus_b2); + EXPECT_DEATH(a2-=a4,""); + //EXPECT_DEATH(a2-=d2,""); +} + + +TEST_F(ComplexArray_test,operator_multiply) +{ + c2 = a2 * com2; + EXPECT_EQ(c2,a2_mult_b2); + + c2 = a2*3.0; + EXPECT_EQ(c2,a2_mult_3); + // EXPECT_ANY_THROW(a2*a4); + // EXPECT_ANY_THROW(a2*d2); +} + +TEST_F(ComplexArray_test,operator_multiply_equal) +{ + c2 = a2; + c2 *= b2; + EXPECT_EQ(c2,a2_mult_b2); + + c2 = a2; + c2 *= 3.0; + EXPECT_EQ(c2,a2_mult_3); + + c2 = a2; + c2 *= com2; + EXPECT_EQ(c2,a2_mult_b2); + + EXPECT_DEATH(a2*=a4,""); + // EXPECT_DEATH(a2*=d2,""); +} + +TEST_F(ComplexArray_test,operator_parentheses) +{ + c2 = a2; + EXPECT_EQ(c2(0,0,0,0), com1); + + c2(1,0,0,0) = com2; + EXPECT_NE(c2,a2); + + EXPECT_DEATH(a2(1,1,1,0),""); +} +