In [1]:
import torch
import dqc
import dqc.xc
import dqc.utils

In [2]:
class MyLDAX(dqc.xc.CustomXC):
    def __init__(self, a, p):
        super().__init__()
        self.a = a
        self.p = p
        self.number_of_parameters = 2

    @property
    def family(self):
        # 1 for LDA, 2 for GGA, 4 for MGGA
        return 1

    def get_edensityxc(self, densinfo):
        # densinfo has up and down components
        if isinstance(densinfo, dqc.utils.SpinParam):
            # spin-scaling of the exchange energy
            return 0.5 * (self.get_edensityxc(densinfo.u * 2) + self.get_edensityxc(densinfo.d * 2))
        else:
            rho = densinfo.value.abs() + 1e-15  # safeguarding from nan
            return self.a * rho ** self.p
        
    def get_edensityxc_derivative(self, densinfo, parameter_number):
        # densinfo has up and down components
        if isinstance(densinfo, dqc.utils.SpinParam):
            # spin-scaling of the exchange energy
            return 0.5 * (self.get_edensityxc_derivative(densinfo.u * 2, number_of_parameter) 
                          + self.get_edensityxc_derivative(densinfo.d * 2, number_of_parameter))
        else:
            rho = densinfo.value.abs() + 1e-15  # safeguarding from nan
            if parameter_number == 0: # parameter a
                return rho ** self.p
            elif parameter_number == 1: # parameter p
                return self.a * rho ** (self.p - 1)

In [3]:
a = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.double))
p = torch.nn.Parameter(torch.tensor(2.0, dtype=torch.double))
myxc = MyLDAX(a, p)

In [4]:
mol = dqc.Mol(moldesc="H -1 0 0; H 1 0 0", basis="3-21G")
qc = dqc.KS(mol, xc=myxc).run()
ene = qc.energy()
print(ene)

tensor(-0.4645, dtype=torch.float64, grad_fn=<AddBackward0>)


In [5]:
dm = qc._dm.detach().clone() # density matrix
dm

tensor([[ 4.3848e-31, -1.4465e-16, -6.4918e-31, -9.2522e-16],
        [-1.4465e-16,  4.7720e-02,  2.1416e-16,  3.0523e-01],
        [-6.4918e-31,  2.1416e-16,  9.6112e-31,  1.3698e-15],
        [-9.2522e-16,  3.0523e-01,  1.3698e-15,  1.9523e+00]],
       dtype=torch.float64)

We have two types of equations.
First, orthonomality equations:
$$ r_{a}(\textbf{C}) = \sum_i\sum_jC_{ai}\delta_{ij}C_{aj}- 1 = \sum_iC_{ai}C_{ai}- 1 = 0$$
Obviosly, they do not depend on $\vec{\theta}$, so
$$\frac{\partial r_{a}(\textbf{C})}{\partial\vec{\theta}}= 0$$ 

In [42]:
nao = qc.get_system().get_hamiltonian().nao
norb = qc._engine.norb

derivative_of_normalization_first = torch.zeros((1, norb))
derivative_of_normalization_second = torch.zeros((1, norb))

print(derivative_of_normalization_first)
print(derivative_of_normalization_second)

tensor([[0.]])
tensor([[0.]])


Second, Roothan equations:
$$r_{ia}(\textbf{C};\;\vec{\theta}) = \sum_j (F_{ij}[\rho](\vec{\theta})C_{aj} - \epsilon_aS_{ij}C_{aj}) = 0$$
Derivative:
$$\frac{\partial r_{ia}(\textbf{C};\;\vec{\theta})}{\partial \vec{\theta}} =
\sum_j C_{aj}\int b_i(\vec{r}) \frac{\partial V_{XC}[\rho](\vec{r};\;\vec{\theta})}{\partial \vec{\theta}} b_j(\vec{r})d\vec{r}$$
As we know, $$V_{XC}[\rho](\vec{r};\;\vec{\theta}) = \frac{\partial E_{XC}[\rho]}{\partial\rho(\vec{r})} = \frac{\int \rho(\vec{r})\epsilon_{XC}[\rho](\vec{r};\;\vec{\theta})d\vec{r}}{\partial\rho(\vec{r})}$$
So
$$\frac{\partial r_{ia}(\textbf{C};\;\vec{\theta})}{\partial \vec{\theta}} =
\sum_j C_{aj}\int b_i(\vec{r}) \frac{\int \rho(\vec{r})\epsilon_{XC}[\rho](\vec{r};\;\vec{\theta})d\vec{r}}{\partial \vec{\theta}\partial\rho(\vec{r})} b_j(\vec{r})d\vec{r} = 
\sum_j C_{aj}\int b_i(\vec{r}) \frac{\int \rho(\vec{r}) \frac{\partial \epsilon_{XC}[\rho](\vec{r};\;\vec{\theta})}{\partial \vec{\theta}}d\vec{r} }{\partial\rho(\vec{r})} b_j(\vec{r})d\vec{r}
$$
It means, we can use $\frac{\partial \epsilon_{XC}[\rho](\vec{r};\;\vec{\theta})}{\partial \vec{\theta}}$ instead of $\epsilon_{XC}[\rho](\vec{r};\;\vec{\theta})$ in DQC function`get_vxc()`and get suitable result. This function takes densinfo and, I guess, somehow takes functional derivative with respect to density.

In [27]:
from dqc.utils.datastruct import ValGrad, SpinParam

def get_vxc_derivative(xc, densinfo, number_of_parameter):
    """
    Returns the ValGrad for the xc potential given the density info
    for unpolarized case.
    """
    # This is the default implementation of vxc if there is no implementation
    # in the specific class of XC.

    # densinfo.value & lapl: (*BD, nr)
    # densinfo.grad: (*BD, ndim, nr)
    # return:
    # potentialinfo.value & lapl: (*BD, nr)
    # potentialinfo.grad: (*BD, ndim, nr)

    # mark the densinfo components as requiring grads
    with xc._enable_grad_densinfo(densinfo):
        with torch.enable_grad():
            edensity_derivative = xc.get_edensityxc_derivative(densinfo, number_of_parameter)  # (*BD, nr)
        grad_outputs = torch.ones_like(edensity_derivative)
        grad_enabled = torch.is_grad_enabled()

        if not isinstance(densinfo, ValGrad):  # polarized case
            if xc.family == 1:  # LDA
                params = (densinfo.u.value, densinfo.d.value)
                dedn_u, dedn_d = torch.autograd.grad(
                    edensity_derivative, params, create_graph=grad_enabled, grad_outputs=grad_outputs)

                return SpinParam(u=ValGrad(value=dedn_u), d=ValGrad(value=dedn_d))

            elif xc.family == 2:  # GGA
                params = (densinfo.u.value, densinfo.d.value, densinfo.u.grad, densinfo.d.grad)
                dedn_u, dedn_d, dedg_u, dedg_d = torch.autograd.grad(
                    edensity_derivative, params, create_graph=grad_enabled, grad_outputs=grad_outputs)

                return SpinParam(
                    u=ValGrad(value=dedn_u, grad=dedg_u),
                    d=ValGrad(value=dedn_d, grad=dedg_d))

            else:
                raise NotImplementedError(
                    "Default polarized vxc for family %s is not implemented" % self.family)
        else:  # unpolarized case
            if xc.family == 1:  # LDA
                dedn, = torch.autograd.grad(
                    edensity_derivative, densinfo.value, create_graph=grad_enabled,
                    grad_outputs=grad_outputs)

                return ValGrad(value=dedn)

            elif xc.family == 2:  # GGA
                dedn, dedg = torch.autograd.grad(
                    edensity_derivative, (densinfo.value, densinfo.grad), create_graph=grad_enabled,
                    grad_outputs=grad_outputs)

                return ValGrad(value=dedn, grad=dedg)

            else:
                raise NotImplementedError("Default vxc for family %d is not implemented" % self.family)

In [22]:
def hamitonian_get_vxc_derivative(hamiltonian, dm, number_of_parameter):
        # dm: (*BD, nao, nao)
    assert hamiltonian.xc is not None, "Please call .setup_grid with the xc object"

    densinfo = SpinParam.apply_fcn(
        lambda dm_: hamiltonian._dm2densinfo(dm_), dm)  # value: (*BD, nr)
    potinfo = get_vxc_derivative(hamiltonian.xc, densinfo, number_of_parameter) # value: (*BD, nr)
    vxc_linop = SpinParam.apply_fcn(
        lambda potinfo_: hamiltonian._get_vxc_from_potinfo(potinfo_), potinfo)
    return vxc_linop

In [29]:
def dm2fock_derivative(engine, dm, number_of_parameter):
    vxc_derivative = hamitonian_get_vxc_derivative(engine.hamilton, dm, number_of_parameter)  # spin param or tensor (..., nao, nao)
    return SpinParam.apply_fcn(lambda vxc_: vxc_, vxc_derivative)

In [19]:
from typing import Optional, Dict, Any, List, Union, overload, Tuple
from dqc.utils.datastruct import SpinParam

def dm2scp_derivative(engine, dm: Union[torch.Tensor, SpinParam[torch.Tensor]], number_of_parameter) -> torch.Tensor:
    # convert from density matrix to a self-consistent parameter (scp)
    if isinstance(dm, torch.Tensor):  # unpolarized
        # scp is the fock matrix
        return dm2fock_derivative(engine, dm, number_of_parameter).fullmatrix()
    else:  # polarized
        # scp is the concatenated fock matrix
        fock = self.dm2fock_derivative(dm)
        mat_u = fock.u.fullmatrix().unsqueeze(0)
        mat_d = fock.d.fullmatrix().unsqueeze(0)
        return torch.cat((mat_u, mat_d), dim=0)

In [38]:
fock_derivative_first = dm2scp_derivative(qc._engine, dm, 0)
fock_derivative_first

tensor([[ 4.1203e-02, -1.8399e-16,  4.9318e-02,  9.4651e-17],
        [-1.8399e-16,  8.3114e-02, -1.5266e-16, -5.9360e-02],
        [ 4.9318e-02, -1.5266e-16,  1.1935e-01,  5.4644e-17],
        [ 9.4651e-17, -5.9360e-02,  5.4644e-17,  8.9399e-02]],
       dtype=torch.float64, grad_fn=<MulBackward0>)

In [37]:
fock_derivative_second = dm2scp_derivative(qc._engine, dm, 1)
fock_derivative_second

tensor([[ 9.9999e-01, -7.0777e-16,  3.0187e-06,  3.4694e-17],
        [-7.0777e-16,  1.0000e+00, -5.3429e-16, -1.2964e-06],
        [ 3.0187e-06, -5.3429e-16,  1.0000e+00, -1.2490e-16],
        [ 3.4694e-17, -1.2964e-06, -1.2490e-16,  1.0000e+00]],
       dtype=torch.float64, grad_fn=<MulBackward0>)

In [32]:
coeff = qc.get_system().get_hamiltonian().dm2ao_orb_params(dm, norb) # full C matrix
coeff

tensor([[ 4.6823e-16],
        [-1.5447e-01],
        [-6.9322e-16],
        [-9.8800e-01]], dtype=torch.float64)

In [39]:
equation_derivative_wrt_first = torch.matmul(coeff.t(), fock_derivative_first)
equation_derivative_wrt_first

tensor([[-7.9990e-17,  4.5810e-02, -9.0049e-17, -7.9157e-02]],
       dtype=torch.float64, grad_fn=<MmBackward0>)

In [40]:
equation_derivative_wrt_second = torch.matmul(coeff.t(), fock_derivative_second)
equation_derivative_wrt_second

tensor([[ 5.4328e-16, -1.5447e-01, -4.8729e-16, -9.8800e-01]],
       dtype=torch.float64, grad_fn=<MmBackward0>)