In [1]:
import argparse
import copy
import gc
from itertools import product
from pathlib import Path
from timeit import default_timer as timer

import pyscf
import torch
import numpy as np
import pandas as pd
import opt_einsum as oe
from torch.utils.data import DataLoader

from cadft import CC_DFT_DATA, add_args, gen_logger
from cadft.utils import ModelDict
from cadft.utils import Mol
from cadft.utils import MAIN_PATH, DATA_PATH

AU2KCALMOL = 627.5094740631


class DIIS:
    """
    DIIS for the Fock matrix.
    """

    def __init__(self, nao, n=50):
        self.n = n
        self.errors = np.zeros((n, nao, nao))
        self.mat_fock = np.zeros((n, nao, nao))
        self.step = 0

    def add(self, mat_fock_, error):
        """
        Add the new Fock matrix and error.
        """
        # rolling back [_, _, 1, 2, 3] -> [_, 1, 2, 3, _]
        self.mat_fock = np.roll(self.mat_fock, -1, axis=0)
        self.errors = np.roll(self.errors, -1, axis=0)
        self.mat_fock[-1, :, :] = mat_fock_
        self.errors[-1, :, :] = error

    def hybrid(self):
        """
        Return the hybrid Fock matrix.
        """
        self.step += 1
        mat = np.zeros((self.n + 1, self.n + 1))
        mat[:-1, :-1] = np.einsum("inm,jnm->ij", self.errors, self.errors)
        mat[-1, :] = -1
        mat[:, -1] = -1
        mat[-1, -1] = 0

        b = np.zeros(self.n + 1)
        b[-1] = -1

        if self.step < self.n:
            c = np.linalg.solve(
                mat[-(self.step + 1) :, -(self.step + 1) :], b[-(self.step + 1) :]
            )
            mat_fock = np.sum(
                c[:-1, np.newaxis, np.newaxis] * self.mat_fock[-self.step :], axis=0
            )
            return mat_fock
        else:
            c = np.linalg.solve(mat, b)
            mat_fock = np.sum(c[:-1, np.newaxis, np.newaxis] * self.mat_fock, axis=0)
            return mat_fock

MAIN_PATH: /home/dhem/workspace/2024.1
DATA_PATH: /home/dhem/workspace/2024.1/data/grids_mrks
DATA_SAVE_PATH: /home/dhem/workspace/2024.1/data/grids_mrks/saved_data
DATA_CC_PATH: /home/dhem/workspace/2024.1/data/test


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
name_mol = "methane"
extend_atom = 0
extend_xyz = 0
distance = 0
basis = "cc-pCVTZ"

molecular = copy.deepcopy(Mol[name_mol])
name = f"{name_mol}_{basis}_{extend_atom}_{extend_xyz}_{distance:.4f}"

# 1. Init the model
modeldict = ModelDict(
    "2024-08-01-11-28-15",
    1,
    64,
    1,
    3,
    0,
    device,
    "float32",
    if_mkdir=False,
)
modeldict.load_model()
# modeldict.eval()
modeldict.train()

Loading from /home/dhem/workspace/2024.1/checkpoints/checkpoint-ccdft_2024-08-01-11-28-15_64_3_0
Model loaded from /home/dhem/workspace/2024.1/checkpoints/checkpoint-ccdft_2024-08-01-11-28-15_64_3_0/1-4300.pth
Model loaded from /home/dhem/workspace/2024.1/checkpoints/checkpoint-ccdft_2024-08-01-11-28-15_64_3_0/2-4300.pth


In [19]:
dft2cc = CC_DFT_DATA(
    molecular,
    name=name,
    basis=basis,
    if_basis_str=True,
)
dft2cc.test_mol()
nocc = dft2cc.mol.nelec[0]
mdft = pyscf.scf.RKS(dft2cc.mol)

dm1_scf = dft2cc.dm1_hf
inv_r_3 = pyscf.dft.numint.eval_rho(dft2cc.mol, dft2cc.ao_1, dm1_scf, xctype="GGA")
input_mat = dft2cc.grids.vector_to_matrix(inv_r_3[0, :])
input_mat = (
    torch.tensor(input_mat[:, np.newaxis, :, :], dtype=modeldict.dtype)
    .to("cuda")
)
input_mat = input_mat.requires_grad_(True)
middle_mat = modeldict.model_dict["2"](input_mat)
grad_dms = torch.autograd.grad(torch.sum(middle_mat), input_mat)[0].requires_grad_(True)

loss_fn1 = torch.nn.MSELoss()

loss_ = loss_fn1(grad_dms, grad_dms)
loss_.backward()

[['C', np.float64(-1.955894763843074e-06), np.float64(-7.0707938248870635e-06), np.float64(-5.012649493400673e-06)], ['H', np.float64(-0.06585756260255843), np.float64(0.7778665572064299), np.float64(-0.7638602855427739)], ['H', np.float64(-0.036415348375505), np.float64(-0.9821334371267781), np.float64(-0.4763801541101622)], ['H', np.float64(-0.8375850064161591), np.float64(0.09935611191229492), np.float64(0.6939025614565855)], ['H', np.float64(0.939881224528902), np.float64(0.10499502609136857), np.float64(0.5463976107029269)]]
Load data from /home/dhem/workspace/2024.1/data/test/data_methane_cc-pCVTZ_0_0_0.0000.npz


In [3]:
dft2cc = CC_DFT_DATA(
    molecular,
    name=name,
    basis=basis,
    if_basis_str=True,
)
dft2cc.test_mol()
nocc = dft2cc.mol.nelec[0]
mdft = pyscf.scf.RKS(dft2cc.mol)

dm1_scf = dft2cc.dm1_hf
oe_fock = oe.contract_expression(
    "p,p,pa,pb->ab",
    np.shape(dft2cc.ao_0[:, 0]),
    np.shape(dft2cc.ao_0[:, 0]),
    dft2cc.ao_0,
    dft2cc.ao_0,
    constants=[2, 3],
    optimize="optimal",
)

if modeldict.dtype == torch.float32:
    max_error_scf = 1e-4
else:
    max_error_scf = 1e-8

diis = DIIS(dft2cc.mol.nao, n=10)

for i in range(100):
    inv_r_3 = pyscf.dft.numint.eval_rho(
        dft2cc.mol, dft2cc.ao_1, dm1_scf, xctype="GGA"
    )
    input_mat = dft2cc.grids.vector_to_matrix(inv_r_3[0, :])
    input_mat = torch.tensor(
        input_mat[:, np.newaxis, :, :], dtype=modeldict.dtype
    ).to("cuda")
  
    with torch.no_grad():
        middle_mat = modeldict.model_dict["1"](input_mat).detach().cpu().numpy()
    middle_mat = middle_mat[:, 0, :, :]

    vxc_scf = dft2cc.grids.matrix_to_vector(middle_mat)
    exc_b3lyp = pyscf.dft.libxc.eval_xc("b3lyp", inv_r_3)[0]
    vxc_scf += exc_b3lyp

    vxc_mat = oe_fock(
        vxc_scf,
        dft2cc.grids.weights,
    )
    vj_scf = mdft.get_j(dft2cc.mol, dm1_scf)
    mat_fock = dft2cc.h1e + vj_scf + vxc_mat

    diis.add(
        mat_fock,
        dft2cc.mat_s @ dm1_scf @ mat_fock - mat_fock @ dm1_scf @ dft2cc.mat_s,
    )
    mat_fock = diis.hybrid()

    _, mo_scf = np.linalg.eigh(dft2cc.mat_hs @ mat_fock @ dft2cc.mat_hs)
    mo_scf = dft2cc.mat_hs @ mo_scf

    dm1_scf_old = dm1_scf.copy()
    dm1_scf = 2 * mo_scf[:, :nocc] @ mo_scf[:, :nocc].T
    error_dm1 = np.linalg.norm(dm1_scf - dm1_scf_old)

    if i % 10 == 0:
        print(
            f"step:{i:<8}",
            f"dm: {error_dm1::<10.5e}",
        )
    if (i > 0) and (error_dm1 < max_error_scf):
        print(
            f"step:{i:<8}",
            f"dm: {error_dm1::<10.5e}",
        )
        dm1_scf = dm1_scf_old.copy()
        break

# 2.2 check the difference of density (on grids) and dipole
scf_rho_r = pyscf.dft.numint.eval_rho(
    dft2cc.mol,
    dft2cc.ao_0,
    dm1_scf,
)
cc_rho_r = pyscf.dft.numint.eval_rho(
    dft2cc.mol,
    dft2cc.ao_0,
    dft2cc.dm1_cc,
)
dft_rho_r = pyscf.dft.numint.eval_rho(
    dft2cc.mol,
    dft2cc.ao_0,
    dft2cc.dm1_dft,
)
error_scf_rho_r = np.sum(np.abs(scf_rho_r - cc_rho_r) * dft2cc.grids.weights)
error_dft_rho_r = np.sum(np.abs(dft_rho_r - cc_rho_r) * dft2cc.grids.weights)
print(
    f"error_scf_rho_r: {error_scf_rho_r:.2e}, error_dft_rho_r: {error_dft_rho_r:.2e}",
    flush=True,
)

output_mat = modeldict.model_dict["2"](input_mat).detach().cpu().numpy()
output_mat = output_mat[:, 0, :, :]
output_mat_exc = output_mat * dft2cc.grids.vector_to_matrix(
    scf_rho_r * dft2cc.grids.weights
)

inv_r_3 = pyscf.dft.numint.eval_rho(
    dft2cc.mol, dft2cc.ao_1, dm1_scf, xctype="GGA"
)
exc_b3lyp = pyscf.dft.libxc.eval_xc("b3lyp", inv_r_3)[0]
b3lyp_ene = np.sum(exc_b3lyp * scf_rho_r * dft2cc.grids.weights)

ene_scf = (
    oe.contract("ij,ji->", dft2cc.h1e, dm1_scf)
    + 0.5 * oe.contract("ij,ji->", vj_scf, dm1_scf)
    + dft2cc.mol.energy_nuc()
    + np.sum(output_mat_exc)
    + b3lyp_ene
)
error_ene_scf = AU2KCALMOL * (ene_scf - dft2cc.e_cc)
error_ene_dft = AU2KCALMOL * (dft2cc.e_dft - dft2cc.e_cc)
print(
    f"error_scf_ene: {error_ene_scf:.2e}, error_dft_ene: {error_ene_dft:.2e}",
    flush=True,
)

[['C', np.float64(-1.9558947639107437e-06), np.float64(-7.070793824880598e-06), np.float64(-5.012649493434181e-06)], ['H', np.float64(-0.06585756258607793), np.float64(0.7778665572081223), np.float64(-0.7638602855424723)], ['H', np.float64(-0.036415348387424544), np.float64(-0.9821334371257942), np.float64(-0.47638015411128126)], ['H', np.float64(-0.8375850064189369), np.float64(0.09935611192433486), np.float64(0.6939025614515092)], ['H', np.float64(0.9398812245271196), np.float64(0.10499502607665197), np.float64(0.546397610708821)]]
Load data from /home/dhem/workspace/2024.1/data/test/data_methane_cc-pCVTZ_0_0_0.0000.npz




step:0        dm: 4.28580e-01
step:5        dm: 8.38501e-05
error_scf_rho_r: 2.17e-02, error_dft_rho_r: 8.93e-02
error_scf_ene: 9.10e-01, error_dft_ene: -3.57e+01
