Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion source/module_base/complexarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ ComplexArray &ComplexArray::operator=(const ComplexArray & cd)
}

// Assignment of scalar: all entries set to c
inline void ComplexArray::operator=(const std::complex < double> c)
void ComplexArray::operator=(const std::complex < double> c)
{
const int size = this->getSize();
for (int i = 0; i < size; i++)
Expand Down Expand Up @@ -239,6 +239,50 @@ void ComplexArray::operator*=(double r)
ptr[i] *= r;
}

/* Judge if two ComplexArray is equal */
bool ComplexArray::operator==(const ComplexArray &cd2)const
{
const int size1 = this->getSize();
const int size2 = cd2.getSize();
const int b11 = this->getBound1();
const int b12 = this->getBound2();
const int b13 = this->getBound3();
const int b14 = this->getBound4();
const int b21 = cd2.getBound1();
const int b22 = cd2.getBound2();
const int b23 = cd2.getBound3();
const int b24 = cd2.getBound4();
if (size1 != size2) {return false;}
if (b11 != b21) {return false;}
if (b12 != b22) {return false;}
if (b13 != b23) {return false;}
if (b14 != b24) {return false;}
for ( int i = 0;i <size1;++i) {if (this->ptr[i] != cd2.ptr[i]) {return false;} }
return true;
}

/* Judge if two ComplexArray is not equal */
bool ComplexArray::operator!=(const ComplexArray &cd2)const
{
const int size1 = this->getSize();
const int size2 = cd2.getSize();
const int b11 = this->getBound1();
const int b12 = this->getBound2();
const int b13 = this->getBound3();
const int b14 = this->getBound4();
const int b21 = cd2.getBound1();
const int b22 = cd2.getBound2();
const int b23 = cd2.getBound3();
const int b24 = cd2.getBound4();
if (size1 != size2) {return true;}
if (b11 != b21) {return true;}
if (b12 != b22) {return true;}
if (b13 != b23) {return true;}
if (b14 != b24) {return true;}
for ( int i = 0;i <size1;++i) {if (this->ptr[i] != cd2.ptr[i]) {return true;} }
return false;
}

/////////////////////////////////////////////
// //
// MEMBER FUNCTIONS: //
Expand Down
4 changes: 3 additions & 1 deletion source/module_base/complexarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class ComplexArray
// void release_temp();

ComplexArray &operator=(const ComplexArray &cd);
inline void operator=(std::complex <double> c);
void operator=(std::complex <double> c);
// inline std::complex < double> &operator()(int i, int j, int k)const
//{return d[(i * bound2 + j) * bound3 +k];}
// inline void operator=(std::complex < double> c);
Expand All @@ -47,6 +47,8 @@ class ComplexArray
void operator*=(const std::complex < double> c);

void operator*=(const ComplexArray &in);
bool operator== (const ComplexArray &cd2)const;
bool operator!= (const ComplexArray &cd2)const;

// subscript operator
std::complex < double> &operator()
Expand Down
6 changes: 4 additions & 2 deletions source/module_base/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
AddTest(
TARGET base_matrix3
LIBS base parallel
SOURCES matrix3_test.cpp
)
AddTest(
TARGET base_blas_connector
LIBS base parallel
SOURCES blas_connector_test.cpp
)
AddTest(
TARGET base_complexarray
SOURCES complexarray_test.cpp
)
169 changes: 169 additions & 0 deletions source/module_base/test/complexarray_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#include"../complexarray.h"
#include"gtest/gtest.h"

/************************************************
* unit test of class ComplexArray
***********************************************/

/**
* - Tested Function:
* - operator "=":
* - assign a complex to all elements of a ComplexArray
* - assign a ComplexArray to another ComplexArray
*
* - operator "+":
* - one ComplexArray plus another ComplexArray that has
* the same dimension.
* - throw error when one ComplexArray plus another ComplexArray
* that has different dimension.
*
* - operator "+=":
* - one ComplexArray plus another ComplexArray, and assign the
* value to first ComplexArray
* - throw error when one ComplexArray plus another ComplexArray
* that has different dimension.
*
* - operator "-":
* - one ComplexArray minus another ComplexArray that has
* the same dimension.
* - throw error when one ComplexArray minus another ComplexArray
* that has different dimension.
*
* - operator "-=":
* - one ComplexArray minus another ComplexArray, and assign the
* value to first ComplexArray
* - throw error when one ComplexArray minus another ComplexArray
* that has different dimension.
*
* - operator "*":
* - one ComplexArray is multiplied by a double
* - one ComplexArray is multiplied by a complex
* - one ComplexArray is multiplied by another ComplexArray that has same dimension
* - throw error when one ComplexArray is miltiplied by another ComplexArray
* that has different dimension.
*
* - operator "*=":
* similar as "*"
*
* - oprator "()":
* - access the element
*/


class ComplexArray_test : public testing::Test
{
protected:
ModuleBase::ComplexArray a2,a4,b2,b4,c2,c4,d2;
ModuleBase::ComplexArray a2_plus_b2,a2_minus_b2,a2_mult_b2,a2_mult_3;
std::complex<double> com1 {1.0,2.0};
std::complex<double> com2 {3.0,4.0};

void SetUp()
{
a2 = ModuleBase::ComplexArray(2,1,1,1); // if define this class as a matrix
b2 = ModuleBase::ComplexArray(2,1,1,1); // of 4 dimesions,
c2 = ModuleBase::ComplexArray(2,1,1,1); // is a2 +/-/* d2 allowed or not???
d2 = ModuleBase::ComplexArray(1,2,1,1); // it does not matter, this situation will not appear in ABACUS
a2_plus_b2 = ModuleBase::ComplexArray(2,1,1,1);
a2_minus_b2 = ModuleBase::ComplexArray(2,1,1,1);
a2_mult_b2 = ModuleBase::ComplexArray(2,1,1,1);
a2_mult_3 = ModuleBase::ComplexArray(2,1,1,1);
a4 = ModuleBase::ComplexArray(2,2,1,1);
b4 = ModuleBase::ComplexArray(2,2,1,1);
c4 = ModuleBase::ComplexArray(2,2,1,1);
a2 = com1;
b2 = com2;
a4 = com1;
b4 = com2;
d2 = com1;
a2_plus_b2 = com1 + com2;
a2_minus_b2 = com1 - com2;
a2_mult_b2 = com1 * com2;
a2_mult_3 = com1 * 3.0;
}

};

TEST_F(ComplexArray_test,operator_equal)
{
c2 = a2; //c2 is just constructed as a CompleArray, but not assigned by any value
b2 = a2; //b2 is constructed as a ComplexArray, and is also assigned by a complex
EXPECT_EQ(c2,a2);
EXPECT_EQ(b2,a2);
EXPECT_NE(c2,a4);
EXPECT_DEATH(c2=a4,"");
}

TEST_F(ComplexArray_test,operator_plus)
{
c2 = a2 + b2;
EXPECT_EQ(c2,a2_plus_b2);
EXPECT_DEATH(a2+a4,"");
//EXPECT_DEATH(a2+d2,"");
}

TEST_F(ComplexArray_test,operator_plus_equal)
{
a2 += b2;
EXPECT_EQ(a2,a2_plus_b2);
EXPECT_DEATH(a2+=a4,"");
//EXPECT_DEATH(a2+=d2,"");
}

TEST_F(ComplexArray_test,operator_minus)
{
c2 = a2 - b2;
EXPECT_EQ(c2,a2_minus_b2);
EXPECT_DEATH(a2-a4,"");
//EXPECT_DEATH(a2-d2,"");
}

TEST_F(ComplexArray_test,operator_minus_equal)
{
a2 -= b2;
EXPECT_EQ(a2,a2_minus_b2);
EXPECT_DEATH(a2-=a4,"");
//EXPECT_DEATH(a2-=d2,"");
}


TEST_F(ComplexArray_test,operator_multiply)
{
c2 = a2 * com2;
EXPECT_EQ(c2,a2_mult_b2);

c2 = a2*3.0;
EXPECT_EQ(c2,a2_mult_3);
// EXPECT_ANY_THROW(a2*a4);
// EXPECT_ANY_THROW(a2*d2);
}

TEST_F(ComplexArray_test,operator_multiply_equal)
{
c2 = a2;
c2 *= b2;
EXPECT_EQ(c2,a2_mult_b2);

c2 = a2;
c2 *= 3.0;
EXPECT_EQ(c2,a2_mult_3);

c2 = a2;
c2 *= com2;
EXPECT_EQ(c2,a2_mult_b2);

EXPECT_DEATH(a2*=a4,"");
// EXPECT_DEATH(a2*=d2,"");
}

TEST_F(ComplexArray_test,operator_parentheses)
{
c2 = a2;
EXPECT_EQ(c2(0,0,0,0), com1);

c2(1,0,0,0) = com2;
EXPECT_NE(c2,a2);

EXPECT_DEATH(a2(1,1,1,0),"");
}