Skip to content

Commit

Permalink
multi mps trans symm
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Jun 26, 2024
1 parent 88f2efb commit d31c2bd
Show file tree
Hide file tree
Showing 10 changed files with 312 additions and 79 deletions.
5 changes: 4 additions & 1 deletion pyblock2/driver/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4288,7 +4288,10 @@ def dmrg(
assert len(proj_weights) == len(proj_mpss)
dmrg.projection_weights = bw.VectorFP(proj_weights)
dmrg.ext_mpss = bw.bs.VectorMPS(proj_mpss)
impo = self.get_identity_mpo()
if metric_mpo is None:
impo = self.get_identity_mpo()
else:
impo = metric_mpo
for ext_mps in dmrg.ext_mpss:
if ext_mps.info.tag == ket.info.tag:
raise RuntimeError("Same tag for proj_mps and ket!!")
Expand Down
12 changes: 7 additions & 5 deletions src/dmrg/moving_environment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2040,7 +2040,7 @@ template <typename S, typename FL, typename FLS> struct MovingEnvironment {
shared_ptr<SparseMatrix<S, FLS>> cket = nullptr) {
return symm_context_convert_impl(
i, mps->info, cmps->info, dot, fuse_left, mask, forward,
is_wfn, infer_info,
is_wfn, infer_info, false,
ket == nullptr && !(!forward && infer_info) ? mps->tensors[i]
: ket,
cket == nullptr && !(forward && infer_info)
Expand All @@ -2049,14 +2049,15 @@ template <typename S, typename FL, typename FLS> struct MovingEnvironment {
nullptr, nullptr)
.first;
}
static shared_ptr<SparseMatrixGroup<S, FLS>> symm_context_convert_group(
static shared_ptr<SparseMatrixGroup<S, FLS>>
symm_context_convert_perturbative(
int i, const shared_ptr<MPS<S, FLS>> &mps,
const shared_ptr<MPS<S, FLS>> &cmps, int dot, bool fuse_left, bool mask,
bool forward, bool is_wfn, bool infer_info,
const shared_ptr<SparseMatrixGroup<S, FLS>> &pket) {
return symm_context_convert_impl(i, mps->info, cmps->info, dot,
fuse_left, mask, forward, is_wfn,
infer_info, mps->tensors[i],
infer_info, true, mps->tensors[i],
cmps->tensors[i], pket, nullptr)
.second;
}
Expand All @@ -2066,7 +2067,7 @@ template <typename S, typename FL, typename FLS> struct MovingEnvironment {
symm_context_convert_impl(int i, const shared_ptr<MPSInfo<S>> &info,
const shared_ptr<MPSInfo<S>> &cinfo, int dot,
bool fuse_left, bool mask, bool forward,
bool is_wfn, bool infer_info,
bool is_wfn, bool infer_info, bool is_pert,
shared_ptr<SparseMatrix<S, FLS>> ket,
shared_ptr<SparseMatrix<S, FLS>> cket,
shared_ptr<SparseMatrixGroup<S, FLS>> pket,
Expand Down Expand Up @@ -2139,7 +2140,8 @@ template <typename S, typename FL, typename FLS> struct MovingEnvironment {
shared_ptr<SparseMatrixGroup<S, FLS>> gr_wfn =
is_group ? make_shared<SparseMatrixGroup<S, FLS>>(d_alloc)
: nullptr;
if (is_group && infer_info) {
if (is_pert) {
assert(is_group && infer_info);
// FIXME: multi will have problem
vector<S> pket_dqs;
for (int iw = 0; iw < pket->n; iw++) {
Expand Down
242 changes: 185 additions & 57 deletions src/dmrg/mps_unfused.hpp

Large diffs are not rendered by default.

75 changes: 75 additions & 0 deletions src/dmrg/state_averaged.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,4 +587,79 @@ template <typename S, typename FL> struct MultiMPS : MPS<S, FL> {
}
};

template <typename S1, typename S2, typename = void, typename = void>
struct TransMultiMPSInfo {
static shared_ptr<MultiMPSInfo<S2>>
forward(const shared_ptr<MultiMPSInfo<S1>> &si, const vector<S2> &targets) {
return TransMultiMPSInfo<S2, S1>::backward(si, targets);
}
static shared_ptr<MultiMPSInfo<S1>>
backward(const shared_ptr<MultiMPSInfo<S2>> &si,
const vector<S1> &targets) {
return TransMultiMPSInfo<S2, S1>::forward(si, targets);
}
};

template <typename S> struct TransMultiMPSInfoAnyBase {
static shared_ptr<MultiMPSInfo<S>>
transform(const shared_ptr<MultiMPSInfo<S>> &si, const vector<S> &targets) {
int n_sites = si->n_sites;
S vacuum = TransStateInfo<S, S>::forward(
make_shared<StateInfo<S>>(si->vacuum), targets[0])
->quanta[0];
vector<shared_ptr<StateInfo<S>>> basis(n_sites);
for (int i = 0; i < n_sites; i++)
basis[i] = TransStateInfo<S, S>::forward(si->basis[i], vacuum);
shared_ptr<MultiMPSInfo<S>> so =
make_shared<MultiMPSInfo<S>>(n_sites, vacuum, targets, basis);
// handle the singlet embedding case
so->left_dims_fci[0] =
TransStateInfo<S, S>::forward(si->left_dims_fci[0], vacuum);
for (int i = 0; i < n_sites; i++)
so->left_dims_fci[i + 1] =
make_shared<StateInfo<S>>(StateInfo<S>::tensor_product(
*so->left_dims_fci[i], *basis[i], S(S::invalid)));
so->right_dims_fci[n_sites] =
TransStateInfo<S, S>::forward(si->right_dims_fci[n_sites], vacuum);
for (int i = n_sites - 1; i >= 0; i--)
so->right_dims_fci[i] =
make_shared<StateInfo<S>>(StateInfo<S>::tensor_product(
*basis[i], *so->right_dims_fci[i + 1], S(S::invalid)));
for (int i = 0; i <= n_sites; i++) {
StateInfo<S>::multi_target_filter(*so->left_dims_fci[i],
*so->right_dims_fci[i], targets);
StateInfo<S>::multi_target_filter(*so->right_dims_fci[i],
*so->left_dims_fci[i], targets);
}
for (int i = 0; i <= n_sites; i++)
so->left_dims_fci[i]->collect();
for (int i = n_sites; i >= 0; i--)
so->right_dims_fci[i]->collect();
for (int i = 0; i <= n_sites; i++)
so->left_dims[i] =
TransStateInfo<S, S>::forward(si->left_dims[i], vacuum);
for (int i = n_sites; i >= 0; i--)
so->right_dims[i] =
TransStateInfo<S, S>::forward(si->right_dims[i], vacuum);
so->check_bond_dimensions();
so->bond_dim = so->get_max_bond_dimension();
so->tag = si->tag;
return so;
}
};

// Translation between SAny MultiMPSInfo
template <typename S>
struct TransMultiMPSInfo<S, S, typename S::is_sany_t, typename S::is_sany_t>
: TransMultiMPSInfoAnyBase<S> {
static shared_ptr<MultiMPSInfo<S>>
forward(const shared_ptr<MultiMPSInfo<S>> &si, const vector<S> &targets) {
return TransMultiMPSInfoAnyBase<S>::transform(si, targets);
}
static shared_ptr<MultiMPSInfo<S>>
backward(const shared_ptr<MultiMPSInfo<S>> &si, const vector<S> &targets) {
return TransMultiMPSInfoAnyBase<S>::transform(si, targets);
}
};

} // namespace block2
15 changes: 9 additions & 6 deletions src/dmrg/sweep_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ template <typename S, typename FL, typename FLS> struct DMRG {
false);
xket = context_ket;
if (pket != nullptr) {
context_pket =
MovingEnvironment<S, FL, FLS>::symm_context_convert_group(
context_pket = MovingEnvironment<S, FL, FLS>::
symm_context_convert_perturbative(
i, me->ket, context_ket, 1,
!skip_decomp ? forward : fuse_left, false, true, true,
true, pket);
Expand Down Expand Up @@ -861,10 +861,13 @@ template <typename S, typename FL, typename FLS> struct DMRG {
xold_ket = context_old_ket;
xket = context_ket;
if (pket != nullptr) {
context_pket =
MovingEnvironment<S, FL, FLS>::symm_context_convert_group(
i, me->ket, context_ket, 2, true, false, true, true,
true, pket);
context_pket = MovingEnvironment<
S, FL, FLS>::symm_context_convert_perturbative(i, me->ket,
context_ket,
2, true,
false, true,
true, true,
pket);
xpket = context_pket;
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/instantiation/block2_dmrg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,8 @@ extern template struct block2::AntiHermitianRuleQC<block2::SAny, double>;
extern template struct block2::MultiMPSInfo<block2::SAny>;
extern template struct block2::MultiMPS<block2::SAny, double>;

extern template struct block2::TransMultiMPSInfo<block2::SAny, block2::SAny>;

// sweep_algorithm.hpp
extern template struct block2::DMRG<block2::SAny, double, double>;
extern template struct block2::Linear<block2::SAny, double, double>;
Expand Down
2 changes: 2 additions & 0 deletions src/instantiation/dmrg_a/state_averaged.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@

template struct block2::MultiMPSInfo<block2::SAny>;
template struct block2::MultiMPS<block2::SAny, double>;

template struct block2::TransMultiMPSInfo<block2::SAny, block2::SAny>;
1 change: 1 addition & 0 deletions src/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ PYBIND11_MODULE(block2, m) {
#ifdef _USE_SANY
bind_dmrg<SAny, double>(m_sany, "SAny");
bind_trans_mps<SAny, SAny>(m_sany, "sany");
bind_trans_multi_mps<SAny, SAny>(m_sany, "sany");
bind_fl_trans_mps_spin_specific<SAny, SAny, double>(m_sany, "sany");
#ifdef _USE_COMPLEX
bind_dmrg<SAny, complex<double>>(m_sany_cpx, "SAny");
Expand Down
2 changes: 2 additions & 0 deletions src/pybind/dmrg_a/trans_mps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "../pybind_dmrg.hpp"

template void bind_trans_mps<SAny, SAny>(py::module &m, const string &aux_name);
template void bind_trans_multi_mps<SAny, SAny>(py::module &m,
const string &aux_name);
template auto
bind_fl_trans_mps_spin_specific<SAny, SAny, double>(py::module &m,
const string &aux_name)
Expand Down
35 changes: 25 additions & 10 deletions src/pybind/pybind_dmrg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,22 +484,28 @@ template <typename S, typename FL> void bind_fl_mps(py::module &m) {
.def_readwrite("canonical_form", &UnfusedMPS<S, FL>::canonical_form)
.def_static("forward_left_fused",
&UnfusedMPS<S, FL>::forward_left_fused, py::arg("i"),
py::arg("mps"), py::arg("wfn"))
py::arg("info"), py::arg("mat"), py::arg("wfn"))
.def_static("forward_right_fused",
&UnfusedMPS<S, FL>::forward_right_fused, py::arg("i"),
py::arg("mps"), py::arg("wfn"))
py::arg("info"), py::arg("mat"), py::arg("wfn"))
.def_static("forward_mps_tensor",
&UnfusedMPS<S, FL>::forward_mps_tensor, py::arg("i"),
py::arg("mps"))
.def_static("forward_multi_mps_tensor",
&UnfusedMPS<S, FL>::forward_multi_mps_tensor, py::arg("i"),
py::arg("mmps"))
.def_static("backward_left_fused",
&UnfusedMPS<S, FL>::backward_left_fused, py::arg("i"),
py::arg("mps"), py::arg("spt"), py::arg("wfn"))
py::arg("info"), py::arg("spt"), py::arg("wfn"))
.def_static("backward_right_fused",
&UnfusedMPS<S, FL>::backward_right_fused, py::arg("i"),
py::arg("mps"), py::arg("spt"), py::arg("wfn"))
py::arg("info"), py::arg("spt"), py::arg("wfn"))
.def_static("backward_mps_tensor",
&UnfusedMPS<S, FL>::backward_mps_tensor, py::arg("i"),
py::arg("mps"), py::arg("spt"))
.def_static("backward_multi_mps_tensor",
&UnfusedMPS<S, FL>::backward_multi_mps_tensor, py::arg("i"),
py::arg("mmps"), py::arg("spt"))
.def("initialize", &UnfusedMPS<S, FL>::initialize)
.def("finalize", &UnfusedMPS<S, FL>::finalize,
py::arg("para_rule") = nullptr)
Expand Down Expand Up @@ -1039,12 +1045,12 @@ void bind_fl_moving_environment(py::module &m, const string &name) {
py::arg("forward"), py::arg("is_wfn"),
py::arg("infer_info"), py::arg("ket") = nullptr,
py::arg("cket") = nullptr)
.def_static("symm_context_convert_group",
&MovingEnvironment<S, FL, FLS>::symm_context_convert_group,
py::arg("i"), py::arg("mps"), py::arg("cmps"),
py::arg("dot"), py::arg("fuse_left"), py::arg("mask"),
py::arg("forward"), py::arg("is_wfn"),
py::arg("infer_info"), py::arg("pket"));
.def_static(
"symm_context_convert_perturbative",
&MovingEnvironment<S, FL, FLS>::symm_context_convert_perturbative,
py::arg("i"), py::arg("mps"), py::arg("cmps"), py::arg("dot"),
py::arg("fuse_left"), py::arg("mask"), py::arg("forward"),
py::arg("is_wfn"), py::arg("infer_info"), py::arg("pket"));

py::bind_vector<vector<shared_ptr<MovingEnvironment<S, FL, FLS>>>>(
m, ("Vector" + name).c_str());
Expand Down Expand Up @@ -2225,6 +2231,13 @@ void bind_trans_mps(py::module &m, const string &aux_name) {
&TransMPSInfo<S, T>::forward);
}

template <typename S, typename T>
void bind_trans_multi_mps(py::module &m, const string &aux_name) {

m.def(("trans_multi_mps_info_to_" + aux_name).c_str(),
&TransMultiMPSInfo<S, T>::forward);
}

template <typename S, typename FL1, typename FL2>
void bind_fl_trans_mps(py::module &m, const string &aux_name) {

Expand Down Expand Up @@ -2898,6 +2911,8 @@ extern template auto bind_fl_spin_specific<SAny, double>(py::module &m)

extern template void bind_trans_mps<SAny, SAny>(py::module &m,
const string &aux_name);
extern template void bind_trans_multi_mps<SAny, SAny>(py::module &m,
const string &aux_name);
extern template auto
bind_fl_trans_mps_spin_specific<SAny, SAny, double>(py::module &m,
const string &aux_name)
Expand Down

0 comments on commit d31c2bd

Please sign in to comment.