In [18]:
import numpy as np
import opt_einsum as oe

nao = 241

test = oe.contract_path(
    "ijkl,i,j,kl->",
    (nao, nao, nao, nao),
    (nao,),
    (nao,),
    (nao, nao),
    optimize="optimal",
    memory_limit="max_input",
)

print(test)

ValueError: Einstein sum subscript 'ijkl' does not contain the correct number of indices for operand 0.

In [45]:
import torch

from pyscf import gto
from pyscf import dft

_zeta = 1.5 ** np.array([2, 1, 0])
# _zeta = 1.5 ** np.array([17, 13, 10, 7, 5, 3, 2, 1, 0, -1, -2, -3])
_coef = np.diag(np.ones(_zeta.size)) - np.diag(np.ones(_zeta.size - 1), k=1)
_table = np.concatenate([_zeta.reshape(-1, 1), _coef], axis=1)
DEFAULT_BASIS = [[0, *_table.tolist()], [1, *_table.tolist()], [2, *_table.tolist()]]
DEFAULT_SYMB = "Ne"


def load_basis(basis):
    if basis is None:
        return DEFAULT_BASIS


def get_shell_sec(basis):
    if not isinstance(basis, (list, tuple)):
        basis = load_basis(basis)
    shell_sec = []
    for l, c0, *cr in basis:
        nb = c0 if isinstance(c0, int) else (len(c0) - 1)
        shell_sec.extend([2 * l + 1] * nb)
    return shell_sec


def t_make_pdm_shells(dm, ovlp_shells):
    """return projected density matrix by shell"""
    # (D^I_rl)_mm' = \sum_i < alpha^I_rlm | phi_i >< phi_i | aplha^I_rlm' >
    pdm_shells = [
        torch.einsum("rap,...rs,saq->...apq", po, dm, po) for po in ovlp_shells
    ]
    return pdm_shells


def t_make_pdm(dm, ovlp):
    """return projected density matrix by shell"""
    # (D^I_rl)_mm' = \sum_i < alpha^I_rlm | phi_i >< phi_i | aplha^I_rlm' >
    pdm = torch.einsum("rp,...rs,sq->...pq", ovlp, dm, ovlp)
    return pdm


def gen_proj_mol(mol, basis):
    mole_coords = mol.atom_coords(unit="Ang")
    mole_ele = mol.elements
    test_mol = gto.Mole()
    test_mol.atom = [
        ["X", coord]
        for coord, ele in zip(mole_coords, mole_ele)
        if not ele.startswith("X")
    ]
    test_mol.basis = basis
    test_mol.build(0, 0, unit="Ang")
    return test_mol


class NetMixin:

    def __init__(self, proj_basis=None):
        self._pbas = load_basis(proj_basis)
        # [1,1,1,...,3,3,3,...,5,5,5,...]
        self._shell_sec = get_shell_sec(self._pbas)
        # total number of projected basis per atom
        self.nproj = sum(self._shell_sec)
        print(self.nproj)
        # prepare overlap integrals used in projection
        self.prepare_integrals()

    def prepare_integrals(self):
        """overlap between origin and projected basis, reshaped"""
        # a virtual molecule to be projected on
        self._pmol = gen_proj_mol(self.mol, self._pbas)

        nao = self.mol.nao
        natm = self._pmol.natm
        pnao = self._pmol.nao
        proj = gto.intor_cross("int1e_ovlp", self.mol, self._pmol)
        # shape [nao, natom x nproj]
        proj_back = gto.intor_cross("int1e_ovlp", self._pmol, self.mol)
        # shape [nao, natom x nproj]

        self.t_proj_ovlp = torch.from_numpy(proj).double()
        self.t_proj_ovlp_back = torch.from_numpy(
            proj_back @ np.linalg.inv(proj @ proj_back)
        ).double()

        # < mol_ao | alpha^I_rlm >, shape=[nao x natom x nproj]
        t_proj_ovlp = torch.from_numpy(
            proj.reshape(
                nao,
                natm,
                pnao // natm,
            )
        ).double()
        # split the projected coeffs by shell (different r and l)
        self._t_ovlp_shells = torch.split(t_proj_ovlp, self._shell_sec, -1)

    def make_pdm(self, dm=None, flatten=False):
        """return projected density matrix by shell"""
        if dm is None:
            dm = self.make_rdm1()
        t_dm = torch.from_numpy(dm).double()
        t_pdm_shells = t_make_pdm_shells(t_dm, self._t_ovlp_shells)

        t_dm_ovlp = t_make_pdm(t_dm, self.t_proj_ovlp)
        t_dm_back = t_make_pdm(t_dm_ovlp, self.t_proj_ovlp_back)
        print(t_dm_ovlp)
        print(np.linalg.norm(t_dm - t_dm_back))

        if not flatten:
            return [s.detach().cpu().numpy() for s in t_pdm_shells]
        else:
            return (
                torch.cat([s.flatten(-2) for s in t_pdm_shells], dim=-1)
                .detach()
                .cpu()
                .numpy()
            )


class DSCF(NetMixin, dft.rks.RKS):
    """Restricted SCF solver for given NN energy model"""

    def __init__(self, mol, xc="HF", proj_basis=None):
        # base method must be initialized first
        dft.rks.RKS.__init__(self, mol, xc=xc)
        # correction mixin initialization
        NetMixin.__init__(self, proj_basis=proj_basis)
        # update keys to avoid pyscf warning
        self._keys.update(self.__dict__.keys())

In [48]:
mol = gto.Mole()
mol.verbose = -1
mol.output = None
mol.atom = [
    ["O", 1.0, 0, 0],
    ["O", -1.0, 0, 0],
]
mol.basis = "ccpvdz"
mol.build(0, 0)

dscf = DSCF(mol)
energy = dscf.kernel()
# print(np.shape(dscf.make_rdm1()))
# print((dscf._t_ovlp_shells))
print(dscf.make_pdm())

27
tensor([[ 1.4287e+00,  1.6570e-01,  6.4681e-02,  ..., -2.2801e-02,
         -2.9770e-17,  3.9532e-02],
        [ 1.6570e-01,  1.2342e+00,  1.0381e+00,  ..., -2.6712e-02,
         -2.9741e-17,  4.6310e-02],
        [ 6.4681e-02,  1.0381e+00,  8.7837e-01,  ..., -2.2651e-02,
         -1.7892e-17,  3.9267e-02],
        ...,
        [-2.2801e-02, -2.6712e-02, -2.2651e-02,  ...,  2.3279e-03,
          3.2829e-17, -4.0347e-03],
        [-2.9770e-17, -2.9741e-17, -1.7892e-17,  ...,  3.2829e-17,
          1.2324e-03, -5.6800e-17],
        [ 3.9532e-02,  4.6310e-02,  3.9267e-02,  ..., -4.0347e-03,
         -5.6800e-17,  6.9929e-03]], dtype=torch.float64)
1.9738310814951244e-14
[array([[[1.4286951]],

       [[1.4286951]]]), array([[[1.23416373]],

       [[1.23416373]]]), array([[[0.87837112]],

       [[0.87837112]]]), array([[[ 5.41967969e-01,  8.69458171e-16,  1.29249874e-14],
        [ 8.69458171e-16,  1.23375976e+00, -1.58441311e-01],
        [ 1.29249874e-14, -1.58441311e-01,  5.8087431

In [39]:
for i in range(36):
    # print(dscf._t_ovlp_shells[i].shape)
    print(dscf.make_pdm()[i].shape)

(2, 1, 1)
(2, 1, 1)
(2, 1, 1)
(2, 1, 1)
(2, 1, 1)
(2, 1, 1)
(2, 1, 1)
(2, 1, 1)
(2, 1, 1)
(2, 1, 1)
(2, 1, 1)
(2, 1, 1)
(2, 3, 3)
(2, 3, 3)
(2, 3, 3)
(2, 3, 3)
(2, 3, 3)
(2, 3, 3)
(2, 3, 3)
(2, 3, 3)
(2, 3, 3)
(2, 3, 3)
(2, 3, 3)
(2, 3, 3)
(2, 5, 5)
(2, 5, 5)
(2, 5, 5)
(2, 5, 5)
(2, 5, 5)
(2, 5, 5)
(2, 5, 5)
(2, 5, 5)
(2, 5, 5)
(2, 5, 5)
(2, 5, 5)
(2, 5, 5)


In [40]:
12 * (1**2 + 3**2 + 5**2)

420