Skip to content

Commit

Permalink
update cal_elem func (#3953)
Browse files Browse the repository at this point in the history
avoiding matrix transposition can computational efficiency.
  • Loading branch information
haozhihan committed Apr 11, 2024
1 parent aacf171 commit 49737e5
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 121 deletions.
165 changes: 46 additions & 119 deletions source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ void Diago_DavSubspace<T, Device>::diag_once(hamilt::Hamilt<T, Device>* phm_in,
basis,
this->hphi,
this->hcc,
this->scc,
true);
this->scc);

this->diag_zhegvx(nbase,
this->n_band,
Expand Down Expand Up @@ -175,8 +174,7 @@ void Diago_DavSubspace<T, Device>::diag_once(hamilt::Hamilt<T, Device>* phm_in,
basis,
this->hphi,
this->hcc,
this->scc,
false);
this->scc);

this->diag_zhegvx(nbase,
this->n_band,
Expand Down Expand Up @@ -399,84 +397,43 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
const psi::Psi<T, Device>& basis,
const T* hphi,
T* hcc,
T* scc,
bool init)
T* scc)
{
ModuleBase::timer::tick("Diago_DavSubspace", "cal_elem");

if (init)
{
assert(nbase == 0);
assert(this->n_band == notconv);
gemm_op<T, Device>()(this->ctx,
'C',
'N',
notconv,
notconv,
this->dim,
this->one,
&basis(0, 0),
this->dim,
hphi,
this->dim,
this->zero,
hcc,
this->nbase_x);

gemm_op<T, Device>()(this->ctx,
'C',
'N',
notconv,
notconv,
this->dim,
this->one,
&basis(0, 0),
this->dim,
&basis(0, 0),
this->dim,
this->zero,
scc,
this->nbase_x);
}
else
{
gemm_op<T, Device>()(this->ctx,
'C',
'N',
notconv,
nbase + notconv,
this->dim,
this->one,
&hphi[nbase * this->dim],
this->dim,
&basis(0, 0),
this->dim,
this->zero,
hcc + nbase,
this->nbase_x);
gemm_op<T, Device>()(this->ctx,
'C',
'N',
nbase + notconv,
notconv,
this->dim,
this->one,
&basis(0, 0),
this->dim,
&hphi[nbase * this->dim],
this->dim,
this->zero,
&hcc[nbase * this->nbase_x],
this->nbase_x);

gemm_op<T, Device>()(this->ctx,
'C',
'N',
notconv,
nbase + notconv,
this->dim,
this->one,
&basis(nbase, 0),
this->dim,
&basis(0, 0),
this->dim,
this->zero,
scc + nbase,
this->nbase_x);
}
gemm_op<T, Device>()(this->ctx,
'C',
'N',
nbase + notconv,
notconv,
this->dim,
this->one,
&basis(0, 0),
this->dim,
&basis(nbase, 0),
this->dim,
this->zero,
&scc[nbase * this->nbase_x],
this->nbase_x);

#ifdef __MPI
if (GlobalV::NPROC_IN_POOL > 1)
{
matrixTranspose_op<T, Device>()(this->ctx, this->nbase_x, this->nbase_x, hcc, hcc);
matrixTranspose_op<T, Device>()(this->ctx, this->nbase_x, this->nbase_x, scc, scc);

auto* swap = new T[notconv * this->nbase_x];
syncmem_complex_op()(this->ctx, this->ctx, swap, hcc + nbase * this->nbase_x, notconv * this->nbase_x);

Expand Down Expand Up @@ -532,64 +489,34 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
}
}
delete[] swap;

matrixTranspose_op<T, Device>()(this->ctx, this->nbase_x, this->nbase_x, hcc, hcc);
matrixTranspose_op<T, Device>()(this->ctx, this->nbase_x, this->nbase_x, scc, scc);
}
#endif

nbase += notconv;
int nb1 = nbase - notconv;
// reset:
if (init)
const size_t last_nbase = nbase; // init: last_nbase = 0
nbase = nbase + notconv;

for (size_t i = 0; i < nbase; i++)
{
for (size_t i = 0; i < nbase; i++)
if (i >= last_nbase)
{

hcc[i * this->nbase_x + i] = set_real_tocomplex(hcc[i * this->nbase_x + i]);
scc[i * this->nbase_x + i] = set_real_tocomplex(scc[i * this->nbase_x + i]);

for (size_t j = i + 1; j < nbase; j++)
{
hcc[j * this->nbase_x + i] = get_conj(hcc[i * this->nbase_x + j]);
scc[j * this->nbase_x + i] = get_conj(scc[i * this->nbase_x + j]);
}
}
for (size_t i = nbase; i < this->nbase_x; i++)
for (size_t j = std::max(i + 1, last_nbase); j < nbase; j++)
{
for (size_t j = nbase; j < this->nbase_x; j++)
{
hcc[i * this->nbase_x + j] = cs.zero;
scc[i * this->nbase_x + j] = cs.zero;
hcc[j * this->nbase_x + i] = cs.zero;
scc[j * this->nbase_x + i] = cs.zero;
}
hcc[i * this->nbase_x + j] = get_conj(hcc[j * this->nbase_x + i]);
scc[i * this->nbase_x + j] = get_conj(scc[j * this->nbase_x + i]);
}
}
else

for (size_t i = nbase; i < this->nbase_x; i++)
{
for (size_t i = 0; i < nbase; i++)
for (size_t j = nbase; j < this->nbase_x; j++)
{
if (i >= nb1)
{
hcc[i * this->nbase_x + i] = set_real_tocomplex(hcc[i * this->nbase_x + i]);
scc[i * this->nbase_x + i] = set_real_tocomplex(scc[i * this->nbase_x + i]);
}
for (size_t j = std::max(i + 1, (size_t)nb1); j < nbase; j++)
{
hcc[j * this->nbase_x + i] = get_conj(hcc[i * this->nbase_x + j]);
scc[j * this->nbase_x + i] = get_conj(scc[i * this->nbase_x + j]);
}
}
for (size_t i = nbase; i < this->nbase_x; i++)
{
for (size_t j = nbase; j < this->nbase_x; j++)
{
hcc[i * this->nbase_x + j] = cs.zero;
scc[i * this->nbase_x + j] = cs.zero;
hcc[j * this->nbase_x + i] = cs.zero;
scc[j * this->nbase_x + i] = cs.zero;
}
hcc[i * this->nbase_x + j] = cs.zero;
scc[i * this->nbase_x + j] = cs.zero;
hcc[j * this->nbase_x + i] = cs.zero;
scc[j * this->nbase_x + i] = cs.zero;
}
}

Expand Down
3 changes: 1 addition & 2 deletions source/module_hsolver/diago_dav_subspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ class Diago_DavSubspace : public DiagH<T, Device>
const psi::Psi<T, Device>& basis,
const T* hphi,
T* hcc,
T* scc,
bool init);
T* scc);

void refresh(const int& dim,
const int& nband,
Expand Down

0 comments on commit 49737e5

Please sign in to comment.