Skip to content

Commit

Permalink
feat(fmatrix): adds non-conjugate Dot product
Browse files Browse the repository at this point in the history
  • Loading branch information
orlandini committed Jun 28, 2024
1 parent e25bd5c commit 027fe71
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 19 deletions.
42 changes: 24 additions & 18 deletions Matrix/pzfmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2393,60 +2393,66 @@ int TPZFMatrix<TVar>::Subst_Diag( TPZFMatrix<TVar>* B ) const

/** @brief Implement dot product for matrices */
template<class TVar>
TVar Dot(const TPZFMatrix<TVar> &A, const TPZFMatrix<TVar> &B) {
TVar Dot(const TPZFMatrix<TVar> &A, const TPZFMatrix<TVar> &B,bool conj) {
int64_t size = (A.Rows())*A.Cols();
TVar result = 0.;
if(!size) return result;
const TVar *fpA = &A.g(0,0), *fpB = &B.g(0,0);
const TVar *fpLast = fpA+size;
while(fpA < fpLast)
{
const bool must_conj = conj && is_complex<TVar>::value;
if(must_conj){
if constexpr (is_complex<TVar>::value){
result += *fpA++ * std::conj(*fpB++);
//always evaluates to true, but we dont want compiler errors
while(fpA < fpLast)
{
result += *fpA++ * std::conj(*fpB++);
}
}
else{
}else{
while(fpA < fpLast)
{
result += (*fpA++ * *fpB++);
}
}

return result;
// #endif
}

template
std::complex<float> Dot(const TPZFMatrix< std::complex<float> > &A, const TPZFMatrix< std::complex<float> > &B);
std::complex<float> Dot(const TPZFMatrix< std::complex<float> > &A, const TPZFMatrix< std::complex<float> > &B, bool conj);

template
std::complex<double> Dot(const TPZFMatrix< std::complex<double> > &A, const TPZFMatrix< std::complex<double> > &B);
std::complex<double> Dot(const TPZFMatrix< std::complex<double> > &A, const TPZFMatrix< std::complex<double> > &B, bool conj);

template
std::complex<long double> Dot(const TPZFMatrix< std::complex<long double> > &A, const TPZFMatrix< std::complex<long double> > &B);
std::complex<long double> Dot(const TPZFMatrix< std::complex<long double> > &A, const TPZFMatrix< std::complex<long double> > &B, bool conj);

template
long double Dot(const TPZFMatrix<long double> &A, const TPZFMatrix<long double> &B);
long double Dot(const TPZFMatrix<long double> &A, const TPZFMatrix<long double> &B, bool conj);

template
double Dot(const TPZFMatrix<double> &A, const TPZFMatrix<double> &B);
double Dot(const TPZFMatrix<double> &A, const TPZFMatrix<double> &B, bool conj);

template
float Dot(const TPZFMatrix<float> &A, const TPZFMatrix<float> &B);
float Dot(const TPZFMatrix<float> &A, const TPZFMatrix<float> &B, bool conj);

template
int64_t Dot(const TPZFMatrix<int64_t> &A, const TPZFMatrix<int64_t> &B);
int64_t Dot(const TPZFMatrix<int64_t> &A, const TPZFMatrix<int64_t> &B, bool conj);

template
int Dot(const TPZFMatrix<int> &A, const TPZFMatrix<int> &B);
int Dot(const TPZFMatrix<int> &A, const TPZFMatrix<int> &B, bool conj);

template
Fad<float> Dot(const TPZFMatrix<Fad<float> > &A, const TPZFMatrix<Fad<float> > &B);
Fad<float> Dot(const TPZFMatrix<Fad<float> > &A, const TPZFMatrix<Fad<float> > &B, bool conj);

template
Fad<double> Dot(const TPZFMatrix<Fad<double> > &A, const TPZFMatrix<Fad<double> > &B);
Fad<double> Dot(const TPZFMatrix<Fad<double> > &A, const TPZFMatrix<Fad<double> > &B, bool conj);

template
Fad<long double> Dot(const TPZFMatrix<Fad<long double> > &A, const TPZFMatrix<Fad<long double> > &B);
Fad<long double> Dot(const TPZFMatrix<Fad<long double> > &A, const TPZFMatrix<Fad<long double> > &B, bool conj);

template
TPZFlopCounter Dot(const TPZFMatrix<TPZFlopCounter> &A, const TPZFMatrix<TPZFlopCounter> &B);
TPZFlopCounter Dot(const TPZFMatrix<TPZFlopCounter> &A, const TPZFMatrix<TPZFlopCounter> &B, bool conj);

/** @brief Increments value over all entries of the matrix A. */
template <class TVar>
Expand Down
2 changes: 1 addition & 1 deletion Matrix/pzfmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ template <class TVar> class TPZVerySparseMatrix;

/** @brief Returns a dot product to matrices */
template<class TVar>
TVar Dot(const TPZFMatrix<TVar> &A,const TPZFMatrix<TVar> &B);
TVar Dot(const TPZFMatrix<TVar> &A,const TPZFMatrix<TVar> &B,bool conj=true);

/** @brief Returns the norm of the matrix A */
template<class TVar>
Expand Down

0 comments on commit 027fe71

Please sign in to comment.