Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AddLinearOperator [Linear Operator Helper Class] #3679

Merged
merged 4 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions deepchem/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

logger = logging.getLogger(__name__)

from deepchem.utils.misc_utils import indent
from deepchem.utils.misc_utils import shape2str

from deepchem.utils.batch_utils import batch_coulomb_matrix_features

from deepchem.utils.attribute_utils import set_attr
Expand Down
5 changes: 3 additions & 2 deletions deepchem/utils/differentiation_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from deepchem.utils.differentiation_utils.bcast import get_bcasted_dims
from deepchem.utils.differentiation_utils.bcast import match_dim

from deepchem.utils.differentiation_utils.linop import LinearOperator

from deepchem.utils.differentiation_utils.misc import set_default_option
from deepchem.utils.differentiation_utils.misc import get_and_pop_keys
from deepchem.utils.differentiation_utils.misc import get_method
from deepchem.utils.differentiation_utils.misc import dummy_context_manager
from deepchem.utils.differentiation_utils.misc import assert_runtime

from deepchem.utils.differentiation_utils.linop import LinearOperator
from deepchem.utils.differentiation_utils.linop import AddLinearOperator
except:
pass
293 changes: 276 additions & 17 deletions deepchem/utils/differentiation_utils/linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from abc import abstractmethod
from contextlib import contextmanager
from scipy.sparse.linalg import LinearOperator as spLinearOperator
from deepchem.utils.differentiation_utils import EditableModule
from deepchem.utils.differentiation_utils import get_bcasted_dims
from deepchem.utils.differentiation_utils import EditableModule, get_bcasted_dims
from deepchem.utils import shape2str, indent

__all__ = ["LinearOperator"]

Expand Down Expand Up @@ -189,7 +189,7 @@ def __repr__(self) -> str:

"""
return "LinearOperator (%s) with shape %s, dtype = %s, device = %s" % \
(self.__class__.__name__, _shape2str(self.shape), self.dtype, self.device)
(self.__class__.__name__, shape2str(self.shape), self.dtype, self.device)

@abstractmethod
def _getparamnames(self, prefix: str = "") -> List[str]:
Expand Down Expand Up @@ -491,7 +491,126 @@ def getparamnames(self, methodname: str, prefix: str = "") -> List[str]:
raise KeyError("getparamnames for method %s is not implemented" %
methodname)

def __rsub__(self, b):
def __add__(self, b: LinearOperator):
"""Addition with another linear operator.

Examples
--------
>>> class Operator(LinearOperator):
... def __init__(self, mat: torch.Tensor, is_hermitian: bool) -> None:
... super(Operator, self).__init__(
... shape=mat.shape,
... is_hermitian=is_hermitian,
... dtype=mat.dtype,
... device=mat.device,
... _suppress_hermit_warning=True,
... )
... self.mat = mat
... def _mv(self, x: torch.Tensor) -> torch.Tensor:
... return torch.matmul(self.mat, x.unsqueeze(-1)).squeeze(-1)
... def _mm(self, x: torch.Tensor) -> torch.Tensor:
... return torch.matmul(self.mat, x)
... def _rmv(self, x: torch.Tensor) -> torch.Tensor:
... return torch.matmul(self.mat.transpose(-3, -1).conj(), x.unsqueeze(-1)).squeeze(-1)
... def _rmm(self, x: torch.Tensor) -> torch.Tensor:
... return torch.matmul(self.mat.transpose(-2, -1).conj(), x)
... def _fullmatrix(self) -> torch.Tensor:
... return self.mat
... def _getparamnames(self, prefix: str = "") -> List[str]:
... return [prefix + "mat"]
>>> op = Operator(torch.tensor([[1, 2.],
... [3, 4]]), is_hermitian=False)
>>> x = torch.tensor([[2, 2],
... [1, 2.]])
>>> op.mm(x)
tensor([[ 4., 6.],
[10., 14.]])
>>> op2 = op + op
>>> op2.mm(x)
tensor([[ 8., 12.],
[20., 28.]])

Parameters
----------
b: LinearOperator
The linear operator to be added.

Returns
-------
LinearOperator
The result of the addition.

"""
assert isinstance(
b, LinearOperator
), "Only addition with another LinearOperator is supported"
if self.shape[-2:] != b.shape[-2:]:
raise RuntimeError("Mismatch shape of add operation: %s and %s" %
(self.shape, b.shape))
return AddLinearOperator(self, b)

def __sub__(self, b: LinearOperator):
"""Subtraction with another linear operator.

Examples
--------
>>> class Operator(LinearOperator):
... def __init__(self, mat: torch.Tensor, is_hermitian: bool) -> None:
... super(Operator, self).__init__(
... shape=mat.shape,
... is_hermitian=is_hermitian,
... dtype=mat.dtype,
... device=mat.device,
... _suppress_hermit_warning=True,
... )
... self.mat = mat
... def _mv(self, x: torch.Tensor) -> torch.Tensor:
... return torch.matmul(self.mat, x.unsqueeze(-1)).squeeze(-1)
... def _mm(self, x: torch.Tensor) -> torch.Tensor:
... return torch.matmul(self.mat, x)
... def _rmv(self, x: torch.Tensor) -> torch.Tensor:
... return torch.matmul(self.mat.transpose(-3, -1).conj(), x.unsqueeze(-1)).squeeze(-1)
... def _rmm(self, x: torch.Tensor) -> torch.Tensor:
... return torch.matmul(self.mat.transpose(-2, -1).conj(), x)
... def _fullmatrix(self) -> torch.Tensor:
... return self.mat
... def _getparamnames(self, prefix: str = "") -> List[str]:
... return [prefix + "mat"]
>>> op = Operator(torch.tensor([[1, 2.],
... [3, 4]]), is_hermitian=False)
>>> op1 = Operator(torch.tensor([[0, 1.],
... [1, 2]]), is_hermitian=False)
>>> x = torch.tensor([[2, 2],
... [1, 2.]])
>>> op.mm(x)
tensor([[ 4., 6.],
[10., 14.]])
>>> op2 = op - op1
>>> op2.mm(x)
tensor([[3., 4.],
[6., 8.]])

Parameters
----------
b: LinearOperator
The linear operator to be subtracted.

Returns
-------
LinearOperator
The result of the subtraction.

"""

assert isinstance(
b, LinearOperator
), "Only subtraction with another LinearOperator is supported"
if self.shape[-2:] != b.shape[-2:]:
raise RuntimeError("Mismatch shape of add operation: %s and %s" %
(self.shape, b.shape))
return AddLinearOperator(self, b, -1)

def __rsub__(self, b: LinearOperator):
return b.__sub__(self)

# properties
Expand Down Expand Up @@ -587,19 +706,159 @@ def _assert_if_init_executed(self):
raise RuntimeError("super().__init__ must be executed first")


def _shape2str(shape):
"""Convert the shape to string representation.
It also nicely formats the shape to be readable.
# Helper Classes
class AddLinearOperator(LinearOperator):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be added to the documentation. Can be done in future PR

"""Adds two linear operators.

Parameters
----------
shape: Sequence[int]
The shape to be converted to string representation.

Returns
-------
str
The string representation of the shape.
Examples
--------
>>> import torch
>>> seed = torch.manual_seed(100)
>>> class MyLinOp(LinearOperator):
... def __init__(self, shape):
... super(MyLinOp, self).__init__(shape)
... self.param = torch.rand(shape)
... def _getparamnames(self, prefix=""):
... return [prefix + "param"]
... def _mv(self, x):
... return torch.matmul(self.param, x)
... def _rmv(self, x):
... return torch.matmul(self.param.transpose(-2,-1).conj(), x)
... def _mm(self, x):
... return torch.matmul(self.param, x)
... def _rmm(self, x):
... return torch.matmul(self.param.transpose(-2,-1).conj(), x)
... def _fullmatrix(self):
... return self.param
>>> linop1 = MyLinOp((1,3,1,2))
>>> linop2 = MyLinOp((1,3,1,2))
>>> linop = AddLinearOperator(linop1, linop2)
>>> print(linop)
AddLinearOperator with shape (1, 3, 1, 2) of:
* LinearOperator (MyLinOp) with shape (1, 3, 1, 2), dtype = torch.float32, device = cpu
* LinearOperator (MyLinOp) with shape (1, 3, 1, 2), dtype = torch.float32, device = cpu
>>> x = torch.rand(1,3,2,2)
>>> linop.mv(x)
tensor([[[[0.6256, 1.0689]],
<BLANKLINE>
[[0.6039, 0.5380]],
<BLANKLINE>
[[0.9702, 2.1129]]]])
>>> x = torch.rand(1,3,1,1)
>>> linop.rmv(x)
tensor([[[[0.1662],
[0.3813]],
<BLANKLINE>
[[0.4460],
[0.5705]],
<BLANKLINE>
[[0.5942],
[1.1089]]]])
>>> x = torch.rand(1,2,2,1)
>>> linop.mm(x)
tensor([[[[0.7845],
[0.5439]]],
<BLANKLINE>
<BLANKLINE>
[[[0.6518],
[0.4318]]],
<BLANKLINE>
<BLANKLINE>
[[[1.4336],
[0.9796]]]])

"""
return "(%s)" % (", ".join([str(s) for s in shape]))

def __init__(self, a: LinearOperator, b: LinearOperator, mul: int = 1):
"""Initialize the ``AddLinearOperator``.

Parameters
----------
a: LinearOperator
The first linear operator to be added.
b: LinearOperator
The second linear operator to be added.
mul: int
The multiplier of the second linear operator. Default to 1.
If -1, then the second linear operator will be subtracted.

"""
shape = (*get_bcasted_dims(a.shape[:-2], b.shape[:-2]), a.shape[-2],
b.shape[-1])
is_hermitian = a.is_hermitian and b.is_hermitian
super(AddLinearOperator, self).__init__(
shape=shape,
is_hermitian=is_hermitian,
dtype=a.dtype,
device=a.device,
_suppress_hermit_warning=True,
)
self.a = a
self.b = b
assert mul == 1 or mul == -1
self.mul = mul

def __repr__(self):
"""Representation of the ``AddLinearOperator``.

Returns
-------
str
The representation of the ``AddLinearOperator``.

"""
return "AddLinearOperator with shape %s of:\n * %s\n * %s" % \
(shape2str(self.shape),
indent(self.a.__repr__(), 3),
indent(self.b.__repr__(), 3))

def _mv(self, x: torch.Tensor) -> torch.Tensor:
"""Matrix-vector multiplication.

Parameters
----------
x: torch.Tensor
The vector with shape ``(...,q)`` where the linear operation is operated on

Returns
-------
torch.Tensor
The result of the linear operation with shape ``(...,p)``

"""
return self.a._mv(x) + self.mul * self.b._mv(x)

def _rmv(self, x: torch.Tensor) -> torch.Tensor:
"""Transposed matrix-vector multiplication.

Parameters
----------
x: torch.Tensor
The vector of shape ``(...,p)`` where the adjoint linear operation is operated at.

Returns
-------
torch.Tensor
The result of the adjoint linear operation with shape ``(...,q)``

"""
return self.a.rmv(x) + self.mul * self.b.rmv(x)

def _getparamnames(self, prefix: str = "") -> List[str]:
"""Get the parameter names that affects most of the methods (i.e. mm, mv, rmm, rmv).

Parameters
----------
prefix: str
The prefix to be appended in front of the parameters name.
This usually contains the dots.

Returns
-------
List[str]
List of parameter names (including the prefix) that affecting
the ``LinearOperator``.

"""
return self.a._getparamnames(prefix=prefix + "a.") + \
self.b._getparamnames(prefix=prefix + "b.")
45 changes: 45 additions & 0 deletions deepchem/utils/misc_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
Utilities for miscellaneous tasks.
"""


def indent(s, nspace):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a future PR, please add unit tests for these utilities

"""Gives indentation of the second line and next lines.
It is used to format the string representation of an object.
Which might be containing multiples objects in it.
Usage: LinearOperator

Parameters
----------
s: str
The string to be indented.
nspace: int
The number of spaces to be indented.

Returns
-------
str
The indented string.

"""
spaces = " " * nspace
lines = [spaces + c if i > 0 else c for i, c in enumerate(s.split("\n"))]
return "\n".join(lines)


def shape2str(shape):
"""Convert the shape to string representation.
It also nicely formats the shape to be readable.

Parameters
----------
shape: Sequence[int]
The shape to be converted to string representation.

Returns
-------
str
The string representation of the shape.

"""
return "(%s)" % (", ".join([str(s) for s in shape]))