Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Orbparams #3626

Merged
merged 11 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepchem/utils/attribute_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_attr(obj: object, name: str):

Examples
--------
>>> from deepchem.utils import get_attr
>>> from deepchem.utils.attribute_utils import get_attr
>>> class MyClass:
... def __init__(self):
... self.a = 1
Expand Down
4 changes: 4 additions & 0 deletions deepchem/utils/dft_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

from deepchem.utils.dft_utils.datastruct import ZType

from deepchem.utils.dft_utils.hamilton.orbparams import BaseOrbParams
from deepchem.utils.dft_utils.hamilton.orbparams import QROrbParams
from deepchem.utils.dft_utils.hamilton.orbparams import MatExpOrbParams

from deepchem.utils.dft_utils.api.parser import parse_moldesc
except ModuleNotFoundError as e:
logger_.warning(
Expand Down
276 changes: 276 additions & 0 deletions deepchem/utils/dft_utils/hamilton/orbparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
"""
Derived from: https://github.com/diffqc/dqc/blob/master/dqc/hamilton/orbparams.py
"""
from typing import List
import torch

__all__ = ["BaseOrbParams", "QROrbParams", "MatExpOrbParams"]


class BaseOrbParams(object):
"""Class that provides free-parameterization of orthogonal orbitals.

Examples
--------
>>> import torch
>>> from deepchem.utils.dft_utils import BaseOrbParams
>>> class MyOrbParams(BaseOrbParams):
... @staticmethod
... def params2orb(params, coeffs, with_penalty):
... return params, coeffs
... @staticmethod
... def orb2params(orb):
... return orb, torch.tensor([0], dtype=orb.dtype, device=orb.device)
>>> params = torch.randn(3, 4, 5)
>>> coeffs = torch.randn(3, 4, 5)
>>> with_penalty = 0.1
>>> orb, penalty = MyOrbParams.params2orb(params, coeffs, with_penalty)
>>> params2, coeffs2 = MyOrbParams.orb2params(orb)
>>> torch.allclose(params, params2)
True

"""

@staticmethod
def params2orb( # type: ignore[empty-body]
params: torch.Tensor,
coeffs: torch.Tensor,
with_penalty: float = 0.0) -> List[torch.Tensor]:
"""
Convert the parameters & coefficients to the orthogonal orbitals.
``params`` is the tensor to be optimized in variational method, while
``coeffs`` is a tensor that is needed to get the orbital, but it is not
optimized in the variational method.

Parameters
----------
params: torch.Tensor
The free parameters to be optimized.
coeffs: torch.Tensor
The coefficients to get the orthogonal orbitals.
with_penalty: float (default 0.0)
If not 0.0, return the penalty term for the free parameters.

Returns
-------
orb: torch.Tensor
The orthogonal orbitals.
penalty: torch.Tensor
The penalty term for the free parameters. If ``with_penalty`` is 0.0,
this is not returned.

"""
pass

@staticmethod
def orb2params( # type: ignore[empty-body]
orb: torch.Tensor) -> List[torch.Tensor]:
"""
Get the free parameters from the orthogonal orbitals. Returns ``params``
and ``coeffs`` described in ``params2orb``.

Parameters
----------
orb: torch.Tensor
The orthogonal orbitals.

Returns
-------
params: torch.Tensor
The free parameters to be optimized.
coeffs: torch.Tensor
The coefficients to get the orthogonal orbitals.

"""
pass


class QROrbParams(BaseOrbParams):
"""
Orthogonal orbital parameterization using QR decomposition.
The orthogonal orbital is represented by:

P = QR

Where Q is the parameters defining the rotation of the orthogonal tensor,
and R is the coefficients tensor.

Examples
--------
>>> import torch
>>> from deepchem.utils.dft_utils import QROrbParams
>>> params = torch.randn(3, 3)
>>> coeffs = torch.randn(4, 3)
>>> with_penalty = 0.1
>>> orb, penalty = QROrbParams.params2orb(params, coeffs, with_penalty)
>>> params2, coeffs2 = QROrbParams.orb2params(orb)

"""

@staticmethod
def params2orb(params: torch.Tensor,
coeffs: torch.Tensor,
with_penalty: float = 0.0) -> List[torch.Tensor]:
"""
Convert the parameters & coefficients to the orthogonal orbitals.
``params`` is the tensor to be optimized in variational method, while
``coeffs`` is a tensor that is needed to get the orbital, but it is not
optimized in the variational method.

Parameters
----------
params: torch.Tensor
The free parameters to be optimized.
coeffs: torch.Tensor
The coefficients to get the orthogonal orbitals.
with_penalty: float (default 0.0)
If not 0.0, return the penalty term for the free parameters.

Returns
-------
orb: torch.Tensor
The orthogonal orbitals.
penalty: torch.Tensor
The penalty term for the free parameters. If ``with_penalty`` is 0.0,
this is not returned.

"""
orb, _ = torch.linalg.qr(params)
if with_penalty == 0.0:
return [orb]
else:
# QR decomposition's solution is not unique in a way that every column
# can be multiplied by -1 and it still a solution
# So, to remove the non-uniqueness, we will make the sign of the sum
# positive.
s1 = torch.sign(orb.sum(dim=-2, keepdim=True)) # (*BD, 1, norb)
s2 = torch.sign(params.sum(dim=-2, keepdim=True))
penalty = torch.mean((orb * s1 - params * s2)**2) * with_penalty
return [orb, penalty]

@staticmethod
def orb2params(orb: torch.Tensor) -> List[torch.Tensor]:
"""
Get the free parameters from the orthogonal orbitals. Returns ``params``
and ``coeffs`` described in ``params2orb``.

Parameters
----------
orb: torch.Tensor
The orthogonal orbitals.

Returns
-------
params: torch.Tensor
The free parameters to be optimized.
coeffs: torch.Tensor
The coefficients to get the orthogonal orbitals.

"""
coeffs = torch.tensor([0], dtype=orb.dtype, device=orb.device)
return [orb, coeffs]


class MatExpOrbParams(BaseOrbParams):
"""
Orthogonal orbital parameterization using matrix exponential.
The orthogonal orbital is represented by:

P = matrix_exp(Q) @ C

where C is an orthogonal coefficient tensor, and Q is the parameters defining
the rotation of the orthogonal tensor.

Examples
--------
>>> from deepchem.utils.dft_utils import MatExpOrbParams
>>> params = torch.randn(3, 3)
>>> coeffs = torch.randn(4, 3)
>>> with_penalty = 0.1
>>> orb, penalty = MatExpOrbParams.params2orb(params, coeffs, with_penalty)
>>> params2, coeffs2 = MatExpOrbParams.orb2params(orb)

"""

@staticmethod
def params2orb(params: torch.Tensor,
coeffs: torch.Tensor,
with_penalty: float = 0.0) -> List[torch.Tensor]:
"""
Convert the parameters & coefficients to the orthogonal orbitals.
``params`` is the tensor to be optimized in variational method, while
``coeffs`` is a tensor that is needed to get the orbital, but it is not
optimized in the variational method.

Parameters
----------
params: torch.Tensor
The free parameters to be optimized. (*, nparams)
coeffs: torch.Tensor
The coefficients to get the orthogonal orbitals. (*, nao, norb)
with_penalty: float (default 0.0)
If not 0.0, return the penalty term for the free parameters.

Returns
-------
orb: torch.Tensor
The orthogonal orbitals.
penalty: torch.Tensor
The penalty term for the free parameters. If ``with_penalty`` is 0.0,
this is not returned.

"""
nao = coeffs.shape[-2]
norb = coeffs.shape[-1] # noqa: F841
nparams = params.shape[-1]
bshape = params.shape[:-1]

# construct the rotation parameters
triu_idxs = torch.triu_indices(nao, nao, offset=1)[..., :nparams]
rotmat = torch.zeros((*bshape, nao, nao),
dtype=params.dtype,
device=params.device)
rotmat[..., triu_idxs[0], triu_idxs[1]] = params
rotmat = rotmat - rotmat.transpose(-2, -1).conj()

# calculate the orthogonal orbital
ortho_orb = torch.matrix_exp(rotmat) @ coeffs

if with_penalty != 0.0:
penalty = torch.zeros((1,),
dtype=params.dtype,
device=params.device)
return [ortho_orb, penalty]
else:
return [ortho_orb]

@staticmethod
def orb2params(orb: torch.Tensor) -> List[torch.Tensor]:
rbharath marked this conversation as resolved.
Show resolved Hide resolved
"""
Get the free parameters from the orthogonal orbitals. Returns ``params``
and ``coeffs`` described in ``params2orb``.

Parameters
----------
orb: torch.Tensor
The orthogonal orbitals.

Returns
-------
params: torch.Tensor
The free parameters to be optimized.
coeffs: torch.Tensor
The coefficients to get the orthogonal orbitals.

"""
# orb: (*, nao, norb)
nao = orb.shape[-2]
norb = orb.shape[-1]
nparams = norb * (nao - norb) + norb * (norb - 1) // 2

# the orbital becomes the coefficients while params is all zeros (no rotation)
coeffs = orb
params = torch.zeros((*orb.shape[:-2], nparams),
dtype=orb.dtype,
device=orb.device)
return [params, coeffs]
4 changes: 2 additions & 2 deletions deepchem/utils/test/test_attribute_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
Test Attribute Utils
"""
import deepchem.utils as utils


def test_get_attr():
from deepchem.utils.attribute_utils import get_attr

class MyClass:

Expand All @@ -13,7 +13,7 @@ def __init__(self):
self.b = 2

obj = MyClass()
assert utils.get_attr(obj, "a") == 1
assert get_attr(obj, "a") == 1


def test_set_attr():
Expand Down
43 changes: 43 additions & 0 deletions deepchem/utils/test/test_dft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,49 @@
import numpy as np


@pytest.mark.torch
rbharath marked this conversation as resolved.
Show resolved Hide resolved
def test_base_orb_params():
from deepchem.utils.dft_utils import BaseOrbParams

class MyOrbParams(BaseOrbParams):

@staticmethod
def params2orb(params, coeffs, with_penalty):
return params, coeffs

@staticmethod
def orb2params(orb):
return orb, torch.tensor([0], dtype=orb.dtype, device=orb.device)

params = torch.randn(3, 4, 5)
coeffs = torch.randn(3, 4, 5)
with_penalty = 0.1
orb, penalty = MyOrbParams.params2orb(params, coeffs, with_penalty)
params2, coeffs2 = MyOrbParams.orb2params(orb)
assert torch.allclose(params, params2)


@pytest.mark.torch
def test_qr_orb_params():
from deepchem.utils.dft_utils import QROrbParams
params = torch.randn(3, 3)
coeffs = torch.randn(4, 3)
with_penalty = 0.1
orb, penalty = QROrbParams.params2orb(params, coeffs, with_penalty)
params2, coeffs2 = QROrbParams.orb2params(orb)
assert torch.allclose(orb, params2)


@pytest.mark.torch
def test_mat_exp_orb_params():
from deepchem.utils.dft_utils import MatExpOrbParams
params = torch.randn(3, 3)
coeffs = torch.randn(4, 3)
orb = MatExpOrbParams.params2orb(params, coeffs)[0]
params2, coeffs2 = MatExpOrbParams.orb2params(orb)
assert coeffs2.shape == orb.shape


@pytest.mark.torch
def test_lattice():
"""Test lattice object.
Expand Down
9 changes: 9 additions & 0 deletions docs/source/api_reference/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,15 @@ The utilites here are used to create an object that contains information about a
.. autoclass:: deepchem.utils.dft_utils.config._Config
:members:

.. autoclass:: deepchem.utils.dft_utils.BaseOrbParams
:members:

.. autoclass:: deepchem.utils.dft_utils.QROrbParams
:members:

.. autoclass:: deepchem.utils.dft_utils.MatExpOrbParams
:members:

.. autoclass:: deepchem.utils.dft_utils.api.parser.parse_moldesc
:members:

Expand Down