Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

symeig utility functions - 1 #3756

Merged
merged 8 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepchem/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@
from deepchem.utils.pytorch_utils import get_memory
from deepchem.utils.pytorch_utils import gaussian_integral
from deepchem.utils.pytorch_utils import TensorNonTensorSeparator
from deepchem.utils.pytorch_utils import tallqr
from deepchem.utils.pytorch_utils import to_fortran_order

from deepchem.utils.safeops_utils import safepow
from deepchem.utils.safeops_utils import safenorm
Expand Down
155 changes: 155 additions & 0 deletions deepchem/utils/differentiation_utils/symeig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import torch
from typing import Optional, Sequence
from deepchem.utils.differentiation_utils import LinearOperator
import functools
from deepchem.utils.pytorch_utils import tallqr


def _set_initial_v(vinit_type: str,
dtype: torch.dtype,
device: torch.device,
batch_dims: Sequence,
na: int,
nguess: int,
M: Optional[LinearOperator] = None) -> torch.Tensor:
"""Set the initial guess for the eigenvectors.

Examples
--------
>>> import torch
>>> vinit_type = "eye"
>>> dtype = torch.float64
>>> device = torch.device("cpu")
>>> batch_dims = (2, 3)
>>> na = 4
>>> nguess = 2
>>> M = None
>>> V = _set_initial_v(vinit_type, dtype, device, batch_dims, na, nguess, M)
>>> V
tensor([[[[1., 0.],
[0., 1.],
[0., 0.],
[0., 0.]],
<BLANKLINE>
[[1., 0.],
[0., 1.],
[0., 0.],
[0., 0.]],
<BLANKLINE>
[[1., 0.],
[0., 1.],
[0., 0.],
[0., 0.]]],
<BLANKLINE>
<BLANKLINE>
[[[1., 0.],
[0., 1.],
[0., 0.],
[0., 0.]],
<BLANKLINE>
[[1., 0.],
[0., 1.],
[0., 0.],
[0., 0.]],
<BLANKLINE>
[[1., 0.],
[0., 1.],
[0., 0.],
[0., 0.]]]], dtype=torch.float64)

Parameters
----------
vinit_type: str
Mode of the initial guess (``"randn"``, ``"rand"``, ``"eye"``)
dtype: torch.dtype
Data type of the initial guess.
device: torch.device
Device of the initial guess.
batch_dims: Sequence
Batch dimensions of the initial guess.
na: int
Number of basis functions.
nguess: int
Number of initial guesses.
M: Optional[LinearOperator] (default None)
The overlap matrix. If None, identity matrix is used.

Returns
-------
V: torch.Tensor
Initial guess for the eigenvectors.

"""

torch.manual_seed(12421)
if vinit_type == "eye":
nbatch = functools.reduce(lambda x, y: x * y, batch_dims, 1)
V = torch.eye(na, nguess, dtype=dtype,
device=device).unsqueeze(0).repeat(nbatch, 1, 1).reshape(
*batch_dims, na, nguess)
elif vinit_type == "randn":
V = torch.randn((*batch_dims, na, nguess), dtype=dtype, device=device)
elif vinit_type == "random" or vinit_type == "rand":
V = torch.rand((*batch_dims, na, nguess), dtype=dtype, device=device)
else:
raise ValueError("Unknown v_init type: %s" % vinit_type)

# orthogonalize V
if isinstance(M, LinearOperator):
V, R = tallqr(V, MV=M.mm(V))
else:
V, R = tallqr(V)
return V


def _take_eigpairs(eival: torch.Tensor, eivec: torch.Tensor, neig: int,
mode: str):
"""Take the eigenpairs from the eigendecomposition.

Examples
--------
>>> import torch
>>> eival = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
>>> eivec = torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
... [[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]])
>>> neig = 2
>>> mode = "lowest"
>>> eival, eivec = _take_eigpairs(eival, eivec, neig, mode)
>>> eival
tensor([[1., 2.],
[4., 5.]])
>>> eivec
tensor([[[1., 2.],
[4., 5.],
[7., 8.]],
<BLANKLINE>
[[1., 2.],
[4., 5.],
[7., 8.]]])

Parameters
----------
eival: torch.Tensor
Eigenvalues of the linear operator. Shape: ``(*BV, na)``.
eivec: torch.Tensor
Eigenvectors of the linear operator. Shape: ``(*BV, na, na)``.
neig: int
Number of eigenvalues and eigenvectors to be calculated.
mode: str
Mode of the eigenvalues to be calculated (``"lowest"``, ``"uppest"``)

Returns
-------
eival: torch.Tensor
Eigenvalues of the linear operator.
eivec: torch.Tensor
Eigenvectors of the linear operator.

"""
if mode == "lowest":
eival = eival[..., :neig]
eivec = eivec[..., :neig]
else:
eival = eival[..., -neig:]
eivec = eivec[..., -neig:]
return eival, eivec
89 changes: 89 additions & 0 deletions deepchem/utils/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,92 @@ def reconstruct_params(self, tensor_params, nontensor_params=None):
for idx, p in zip(self.tensor_idxs, tensor_params):
params[idx] = p
return params


def tallqr(V, MV=None):
"""QR decomposition for tall and skinny matrix.

Examples
--------
>>> import torch
>>> from deepchem.utils.pytorch_utils import tallqr
>>> V = torch.randn(3, 2)
>>> Q, R = tallqr(V)
>>> Q.shape
torch.Size([3, 2])
>>> R.shape
torch.Size([2, 2])
>>> torch.allclose(Q @ R, V)
True

Parameters
----------
V: torch.Tensor
V is a matrix to be decomposed. (*BV, na, nguess)
MV: torch.Tensor
(*BM, na, nguess) where M is the basis to make Q M-orthogonal
if MV is None, then MV=V (default=None)

Returns
-------
Q: torch.Tensor
The Orthogonal Part. Shape: (*BV, na, nguess)
R: torch.Tensor
The (*BM, nguess, nguess) where M is the basis to make Q M-orthogonal

"""
if MV is None:
MV = V
VTV = torch.matmul(V.transpose(-2, -1), MV) # (*BMV, nguess, nguess)
R = torch.linalg.cholesky(VTV.transpose(-2, -1).conj()).transpose(
-2, -1).conj() # (*BMV, nguess, nguess)
Rinv = torch.inverse(R) # (*BMV, nguess, nguess)
Q = torch.matmul(V, Rinv)
return Q, R


def to_fortran_order(V):
"""Convert a tensor to Fortran order. (The last two dimensions are made Fortran order.)
Fortran order/ array is a special case in which all elements of an array are stored in
column-major order.

Examples
--------
>>> import torch
>>> from deepchem.utils.pytorch_utils import to_fortran_order
>>> V = torch.randn(3, 2)
>>> V.is_contiguous()
True
>>> V = to_fortran_order(V)
>>> V.is_contiguous()
False
>>> V.shape
torch.Size([3, 2])
>>> V = torch.randn(3, 2).transpose(-2, -1)
>>> V.is_contiguous()
False
>>> V = to_fortran_order(V)
>>> V.is_contiguous()
False
>>> V.shape
torch.Size([2, 3])

Parameters
----------
V: torch.Tensor
V is a matrix to be converted. (*BV, na, nguess)

Returns
-------
outV: torch.Tensor
(*BV, nguess, na)

"""
if V.is_contiguous():
# return V.set_(V.storage(), V.storage_offset(), V.size(), tuple(reversed(V.stride())))
return V.transpose(-2, -1).contiguous().transpose(-2, -1)
elif V.transpose(-2, -1).is_contiguous():
return V
else:
raise RuntimeError(
"Only the last two dimensions can be made Fortran order.")
16 changes: 16 additions & 0 deletions deepchem/utils/test/test_differentiation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,3 +500,19 @@ def test_assert_runtime():
assert_runtime(False, "This should fail")
except RuntimeError:
pass


@pytest.mark.torch
def test_take_eigpairs():
from deepchem.utils.differentiation_utils.symeig import _take_eigpairs
eival = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
eivec = torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]])
neig = 2
mode = "lowest"
eival, eivec = _take_eigpairs(eival, eivec, neig, mode)
assert torch.allclose(eival, torch.tensor([[1., 2.], [4., 5.]]))
assert torch.allclose(
eivec,
torch.tensor([[[1., 2.], [4., 5.], [7., 8.]],
[[1., 2.], [4., 5.], [7., 8.]]]))
24 changes: 22 additions & 2 deletions deepchem/utils/test/test_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,25 @@ def test_TensorNonTensorSeparator():
params = [a, b, c]
separator = dc.utils.pytorch_utils.TensorNonTensorSeparator(params)
tensor_params = separator.get_tensor_params()
torch.allclose(tensor_params[0],
torch.tensor([5., 6., 7.], requires_grad=True))
assert torch.allclose(tensor_params[0],
torch.tensor([5., 6., 7.], requires_grad=True))


@pytest.mark.torch
def test_tallqr():
V = torch.randn(3, 2)
Q, R = dc.utils.pytorch_utils.tallqr(V)
assert Q.shape == torch.Size([3, 2])
assert R.shape == torch.Size([2, 2])
assert torch.allclose(Q @ R, V)


@pytest.mark.torch
def test_to_fortran_order():
V = torch.randn(3, 2)
if V.is_contiguous() is False:
assert False
V = dc.utils.pytorch_utils.to_fortran_order(V)
if V.is_contiguous() is True:
assert False
assert V.shape == torch.Size([3, 2])
8 changes: 8 additions & 0 deletions docs/source/api_reference/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,10 @@ The utilites here are used to create an object that contains information about a

.. autofunction:: deepchem.utils.differentiation_utils.assert_runtime

.. autofunction:: deepchem.utils.differentiation_utils._set_initial_v

.. autofunction:: deepchem.utils.differentiation_utils._take_eigpairs

Attribute Utilities
-------------------

Expand Down Expand Up @@ -413,6 +417,10 @@ Pytorch Utilities

.. autofunction:: deepchem.utils.pytorch_utils.TensorNonTensorSeparator

.. autofunction:: deepchem.utils.pytorch_utils.tallqr

.. autofunction:: deepchem.utils.pytorch_utils.to_fortran_order

Batch Utilities
---------------

Expand Down
Loading