Skip to content

Commit

Permalink
core: optimize integral memory
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed May 1, 2023
1 parent 6574017 commit fa56a86
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 31 deletions.
76 changes: 59 additions & 17 deletions pyblock2/driver/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,8 @@ def get_qc_mpo(
cutoff=1e-20,
integral_cutoff=1e-20,
post_integral_cutoff=1e-20,
fast_cutoff=1e-20,
unpack_g2e=True,
algo_type=None,
normal_order_ref=None,
normal_order_single_ref=None,
Expand All @@ -1005,10 +1007,11 @@ def get_qc_mpo(

bw = self.bw

if isinstance(g2e, np.ndarray):
g2e = self.unpack_g2e(g2e)
elif isinstance(g2e, tuple):
g2e = tuple(self.unpack_g2e(x) for x in g2e)
if unpack_g2e:
if isinstance(g2e, np.ndarray):
g2e = self.unpack_g2e(g2e)
elif isinstance(g2e, tuple):
g2e = tuple(self.unpack_g2e(x) for x in g2e)

if SymmetryTypes.SZ in bw.symm_type:
if h1e is not None and isinstance(h1e, np.ndarray) and h1e.ndim == 2:
Expand Down Expand Up @@ -1166,23 +1169,55 @@ def get_qc_mpo(
if normal_order_ref is None:
if SymmetryTypes.SU2 in bw.symm_type:
if h1e is not None:
b.add_sum_term("(C+D)0", np.sqrt(2) * h1e)
b.add_sum_term("(C+D)0", h1e, cutoff=fast_cutoff, factor=np.sqrt(2))
if g2e is not None:
b.add_sum_term("((C+(C+D)0)1+D)0", g2e.transpose(0, 2, 3, 1))
if not unpack_g2e and g2e.ndim == 1:
b.data.exprs.append("((C+(C+D)0)1+D)0")
b.data.add_eight_fold_term(g2e, cutoff=fast_cutoff, factor=1.0)
else:
b.add_sum_term(
"((C+(C+D)0)1+D)0", g2e, cutoff=fast_cutoff, perm=[0, 2, 3, 1]
)
elif SymmetryTypes.SZ in bw.symm_type:
if h1e is not None:
b.add_sum_term("cd", h1e[0])
b.add_sum_term("CD", h1e[1])
b.add_sum_term("cd", h1e[0], cutoff=fast_cutoff)
b.add_sum_term("CD", h1e[1], cutoff=fast_cutoff)
if g2e is not None:
b.add_sum_term("ccdd", 0.5 * g2e[0].transpose(0, 2, 3, 1))
b.add_sum_term("cCDd", 0.5 * g2e[1].transpose(0, 2, 3, 1))
b.add_sum_term("CcdD", 0.5 * g2e[1].transpose(2, 0, 1, 3))
b.add_sum_term("CCDD", 0.5 * g2e[2].transpose(0, 2, 3, 1))
b.add_sum_term(
"ccdd",
g2e[0],
cutoff=fast_cutoff,
perm=[0, 2, 3, 1],
factor=0.5,
)
b.add_sum_term(
"cCDd",
g2e[1],
cutoff=fast_cutoff,
perm=[0, 2, 3, 1],
factor=0.5,
)
b.add_sum_term(
"CcdD",
g2e[1],
cutoff=fast_cutoff,
perm=[2, 0, 1, 3],
factor=0.5,
)
b.add_sum_term(
"CCDD",
g2e[2],
cutoff=fast_cutoff,
perm=[0, 2, 3, 1],
factor=0.5,
)
elif SymmetryTypes.SGF in bw.symm_type:
if h1e is not None:
b.add_sum_term("CD", h1e)
b.add_sum_term("CD", h1e, cutoff=fast_cutoff)
if g2e is not None:
b.add_sum_term("CCDD", 0.5 * g2e.transpose(0, 2, 3, 1))
b.add_sum_term(
"CCDD", g2e, cutoff=fast_cutoff, perm=[0, 2, 3, 1], factor=0.5
)
elif SymmetryTypes.SGB in bw.symm_type:
h_terms = FermionTransform.jordan_wigner(h1e, g2e)
for k, (x, v) in h_terms.items():
Expand Down Expand Up @@ -3530,18 +3565,25 @@ def add_term(self, expr, idx, val):
self.data.data.append(self.bw.VectorFL(val))
return self

def add_sum_term(self, expr, arr, cutoff=1e-12, fast=True):
def add_sum_term(self, expr, arr, cutoff=1e-12, fast=True, factor=1.0, perm=None):
import numpy as np

self.data.exprs.append(expr)
if fast:
self.data.add_sum_term(np.ascontiguousarray(arr), cutoff)
self.data.add_sum_term(
np.ascontiguousarray(arr),
cutoff,
factor,
self.bw.b.VectorUInt16([] if perm is None else perm),
)
else:
idx, dt = [], []
if perm is not None:
arr = arr.transpose(*perm)
for ix in np.ndindex(*arr.shape):
if abs(arr[ix]) > cutoff:
idx.extend(ix)
dt.append(arr[ix])
dt.append(arr[ix] * factor)
self.data.indices.append(self.bw.b.VectorUInt16(idx))
self.data.data.append(self.bw.VectorFL(dt))
return self
Expand Down
97 changes: 94 additions & 3 deletions src/dmrg/general_mpo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,105 @@ template <typename FL> struct GeneralFCIDUMP {
}
return r;
}
void add_eight_fold_term(const FL *vals, size_t len, FP cutoff = (FP)0.0,
FL factor = (FL)1.0) {
size_t n = 0, m = 0;
for (n = 1; n < len; n++) {
m = n * (n + 1) >> 1;
if ((m * (m + 1) >> 1) >= len)
break;
}
assert((m * (m + 1) >> 1) == len && m >= n);
vector<size_t> xm(m + 1, 0), gm(m + 1, 2), pm(m + 1, 0), qm(m + 1, 0);
for (size_t im = 1; im <= m; im++) {
xm[im] = xm[im - 1] + im;
if (xm[im] - 1 <= m) {
gm[xm[im] - 1] = 1;
for (size_t jm = xm[im - 1]; jm < xm[im]; jm++)
pm[jm] = im - 1, qm[jm] = jm - xm[im - 1];
}
}
assert(xm[m] == len);
int ntg = threading->activate_global();
vector<size_t> ms(ntg + 1, 0);
const size_t plm = m / ntg + !!(m % ntg);
#pragma omp parallel num_threads(ntg)
{
int tid = threading->get_thread_id();
for (size_t im = plm * tid; im < min(m, plm * (tid + 1)); im++)
for (size_t jm = 0; jm <= im; jm++)
ms[tid] += (abs(factor * vals[xm[im] + jm]) > cutoff) *
gm[im] * gm[jm] * (2 - (im == jm));
}
ms[ntg] = accumulate(&ms[0], &ms[ntg], (size_t)0);
indices.push_back(vector<uint16_t>(ms[ntg] * 4));
data.push_back(vector<FL>(ms[ntg]));
#pragma omp parallel num_threads(ntg)
{
int tid = threading->get_thread_id();
size_t istart = 0;
for (int i = 0; i < tid; i++)
istart += ms[i];
for (size_t im = plm * tid; im < min(m, plm * (tid + 1)); im++)
for (size_t jm = 0; jm <= im; jm++)
if (abs(factor * vals[xm[im] + jm]) > cutoff) {
for (size_t xxm = 0, xim = im, xjm = jm,
xs = istart * 4;
xxm < (2 - (im == jm));
xxm++, xim = jm, xjm = im) {
indices.back()[xs + 0] = pm[xim];
indices.back()[xs + 1] = pm[xjm];
indices.back()[xs + 2] = qm[xjm];
indices.back()[xs + 3] = qm[xim];
data.back()[istart] = factor * vals[xm[im] + jm];
istart++, xs += 4;
if (gm[xim] == 2) {
indices.back()[xs + 0] = qm[xim];
indices.back()[xs + 1] = pm[xjm];
indices.back()[xs + 2] = qm[xjm];
indices.back()[xs + 3] = pm[xim];
data.back()[istart] =
factor * vals[xm[im] + jm];
istart++, xs += 4;
}
if (gm[xjm] == 2) {
indices.back()[xs + 0] = pm[xim];
indices.back()[xs + 1] = qm[xjm];
indices.back()[xs + 2] = pm[xjm];
indices.back()[xs + 3] = qm[xim];
data.back()[istart] =
factor * vals[xm[im] + jm];
istart++, xs += 4;
}
if (gm[xim] == 2 && gm[xjm] == 2) {
indices.back()[xs + 0] = qm[xim];
indices.back()[xs + 1] = qm[xjm];
indices.back()[xs + 2] = pm[xjm];
indices.back()[xs + 3] = pm[xim];
data.back()[istart] =
factor * vals[xm[im] + jm];
istart++, xs += 4;
}
}
}
for (int i = 0; i < tid + 1; i++)
istart -= ms[i];
assert(istart == 0);
}
threading->activate_normal();
}
// array must have the min strides == 1
void add_sum_term(const FL *vals, size_t len, const vector<int> &shape,
const vector<size_t> &strides, FP cutoff = (FP)0.0,
FL factor = (FL)1.0,
const vector<int> &orb_sym = vector<int>()) {
const vector<int> &orb_sym = vector<int>(),
vector<uint16_t> rperm = vector<uint16_t>()) {
int ntg = threading->activate_global();
vector<size_t> lens(ntg + 1, 0);
const size_t plen = len / ntg + !!(len % ntg);
if (rperm.size() == 0)
for (size_t i = 0; i < shape.size(); i++)
rperm.push_back(i);
#pragma omp parallel num_threads(ntg)
{
int tid = threading->get_thread_id();
Expand All @@ -296,7 +387,7 @@ template <typename FL> struct GeneralFCIDUMP {
for (size_t i = plen * tid; i < min(len, plen * (tid + 1)); i++)
if (abs(factor * vals[i]) > cutoff) {
for (int j = 0; j < (int)shape.size(); j++)
indices.back()[istart * shape.size() + j] =
indices.back()[istart * shape.size() + rperm[j]] =
i / strides[j] % shape[j];
data.back()[istart] = factor * vals[i];
istart++;
Expand All @@ -306,7 +397,7 @@ template <typename FL> struct GeneralFCIDUMP {
if (abs(factor * vals[i]) > cutoff) {
int irrep = 0;
for (int j = 0; j < (int)shape.size(); j++) {
indices.back()[istart * shape.size() + j] =
indices.back()[istart * shape.size() + rperm[j]] =
i / strides[j] % shape[j];
irrep ^= orb_sym[i / strides[j] % shape[j]];
}
Expand Down
40 changes: 29 additions & 11 deletions src/pybind/pybind_dmrg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1846,17 +1846,35 @@ template <typename FL> void bind_general_fcidump(py::module &m) {
.def_readwrite("data", &GeneralFCIDUMP<FL>::data)
.def_readwrite("elem_type", &GeneralFCIDUMP<FL>::elem_type)
.def_readwrite("order_adjusted", &GeneralFCIDUMP<FL>::order_adjusted)
.def("add_sum_term",
[](GeneralFCIDUMP<FL> *self, const py::array_t<FL> &v,
typename GeneralFCIDUMP<FL>::FP cutoff) {
vector<int> shape(v.ndim());
vector<size_t> strides(v.ndim());
for (int i = 0; i < v.ndim(); i++)
shape[i] = v.shape()[i],
strides[i] = v.strides()[i] / sizeof(FL);
self->add_sum_term(v.data(), (size_t)v.size(), shape, strides,
cutoff);
})
.def(
"add_eight_fold_term",
[](GeneralFCIDUMP<FL> *self, const py::array_t<FL> &v,
typename GeneralFCIDUMP<FL>::FP cutoff, FL factor) {
self->add_eight_fold_term(v.data(), (size_t)v.size(), cutoff,
factor);
},
py::arg("v"), py::arg("cutoff"), py::arg("factor"))
.def(
"add_sum_term",
[](GeneralFCIDUMP<FL> *self, const py::array_t<FL> &v,
typename GeneralFCIDUMP<FL>::FP cutoff, FL factor,
const vector<uint16_t> &perm) {
vector<int> shape(v.ndim());
vector<size_t> strides(v.ndim());
for (int i = 0; i < v.ndim(); i++)
shape[i] = v.shape()[i],
strides[i] = v.strides()[i] / sizeof(FL);
vector<uint16_t> rperm(v.ndim());
if (perm.size() == 0)
for (int i = 0; i < v.ndim(); i++)
rperm[i] = i;
else
for (int i = 0; i < v.ndim(); i++)
rperm[perm[i]] = i;
self->add_sum_term(v.data(), (size_t)v.size(), shape, strides,
cutoff, factor, vector<int>(), rperm);
},
py::arg("v"), py::arg("cutoff"), py::arg("factor"), py::arg("perm"))
.def_static("initialize_from_qc",
&GeneralFCIDUMP<FL>::initialize_from_qc, py::arg("fcidump"),
py::arg("elem_type"),
Expand Down

0 comments on commit fa56a86

Please sign in to comment.