-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Implement mts.learn.nn.ReLU * Implement mts.learn.nn.SiLU * Update change log * Update activation function docstrings
- Loading branch information
Showing
10 changed files
with
391 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from typing import List, Optional | ||
|
||
import torch | ||
from torch.nn import Module | ||
|
||
from .._backend import Labels, TensorMap | ||
from .module_map import ModuleMap | ||
|
||
|
||
class ReLU(Module): | ||
""" | ||
Module similar to :py:class:`torch.nn.ReLU` that works with | ||
:py:class:`metatensor.torch.TensorMap` objects. | ||
Applies a rectified linear unit transformation transformation to each block of a | ||
:py:class:`TensorMap` passed to its forward method, indexed by :param in_keys:. | ||
Refer to the :py:class`torch.nn.ReLU` documentation for a more detailed description | ||
of the parameters. | ||
:param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input | ||
tensor map in the :py:meth:`forward` method. | ||
:param out_properties: list of :py:class`Labels` (optional), the properties labels | ||
of the output. By default the output properties are relabeled using | ||
Labels.range. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_keys: Labels, | ||
out_properties: Optional[Labels] = None, | ||
*, | ||
in_place: bool = False, | ||
) -> None: | ||
super().__init__() | ||
modules: List[Module] = [torch.nn.ReLU() for i in range(len(in_keys))] | ||
self.module_map = ModuleMap(in_keys, modules, out_properties) | ||
|
||
def forward(self, tensor: TensorMap) -> TensorMap: | ||
""" | ||
Apply the transformation to the input tensor map `tensor`. | ||
Note: currently not supporting gradients. | ||
:param tensor: :py:class:`TensorMap` with the input tensor to be transformed. | ||
:return: :py:class:`TensorMap` | ||
""" | ||
# Currently not supporting gradients | ||
if len(tensor[0].gradients_list()) != 0: | ||
raise ValueError( | ||
"Gradients not supported. Please use metatensor.remove_gradients()" | ||
" before using this module" | ||
) | ||
return self.module_map(tensor) | ||
|
||
|
||
class InvariantReLU(torch.nn.Module): | ||
""" | ||
Module similar to :py:class:`torch.nn.ReLU` that works with | ||
:py:class:`metatensor.torch.TensorMap` objects, applying the transformation only to | ||
the invariant blocks. | ||
Applies a rectified linear unit transformation to each invariant block of a | ||
:py:class:`TensorMap` passed to its :py:meth:`forward` method. These are indexed by | ||
the keys in :param in_keys: at numeric indices passed in :param invariant_key_idxs:. | ||
Refer to the :py:class`torch.nn.ReLU` documentation for a more detailed description | ||
of the parameters. | ||
:param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input | ||
tensor map in the :py:meth:`forward` method. | ||
:param invariant_key_idxs: list of int, the indices of the invariant keys present in | ||
`in_keys` in the input :py:class:`TensorMap`. Only blocks for these keys will | ||
have the ReLU transformation applied. Covariant blocks will have the identity | ||
operator applied. | ||
:param out_properties: list of :py:class`Labels` (optional), the properties labels | ||
of the output. By default the output properties are relabeled using | ||
Labels.range. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_keys: Labels, | ||
invariant_key_idxs: List[int], | ||
out_properties: Optional[Labels] = None, | ||
*, | ||
in_place: bool = False, | ||
) -> None: | ||
super().__init__() | ||
modules: List[Module] = [] | ||
for i in range(len(in_keys)): | ||
if i in invariant_key_idxs: # Invariant block: apply ReLU | ||
module = torch.nn.ReLU() | ||
else: # Covariant block: apply identity operator | ||
module = torch.nn.Identity() | ||
modules.append(module) | ||
self.module_map: ModuleMap = ModuleMap(in_keys, modules, out_properties) | ||
|
||
def forward(self, tensor: TensorMap) -> TensorMap: | ||
""" | ||
Apply the transformation to the input tensor map `tensor`. | ||
Note: currently not supporting gradients. | ||
:param tensor: :py:class:`TensorMap` with the input tensor to be transformed. | ||
:return: :py:class:`TensorMap` | ||
""" | ||
# Currently not supporting gradients | ||
if len(tensor[0].gradients_list()) != 0: | ||
raise ValueError( | ||
"Gradients not supported. Please use metatensor.remove_gradients()" | ||
" before using this module" | ||
) | ||
return self.module_map(tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from typing import List, Optional | ||
|
||
import torch | ||
from torch.nn import Module | ||
|
||
from .._backend import Labels, TensorMap | ||
from .module_map import ModuleMap | ||
|
||
|
||
class SiLU(Module): | ||
""" | ||
Module similar to :py:class:`torch.nn.SiLU` that works with | ||
:py:class:`metatensor.torch.TensorMap` objects. | ||
Applies a sigmoid linear unit transformation transformation to each block of a | ||
:py:class:`TensorMap` passed to its forward method, indexed by :param in_keys:. | ||
Refer to the :py:class`torch.nn.SiLU` documentation for a more detailed description | ||
of the parameters. | ||
:param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input | ||
tensor map in the :py:meth:`forward` method. | ||
:param out_properties: list of :py:class`Labels` (optional), the properties labels | ||
of the output. By default the output properties are relabeled using | ||
Labels.range. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_keys: Labels, | ||
out_properties: Optional[Labels] = None, | ||
*, | ||
in_place: bool = False, | ||
) -> None: | ||
super().__init__() | ||
modules: List[Module] = [torch.nn.SiLU() for i in range(len(in_keys))] | ||
self.module_map = ModuleMap(in_keys, modules, out_properties) | ||
|
||
def forward(self, tensor: TensorMap) -> TensorMap: | ||
""" | ||
Apply the transformation to the input tensor map `tensor`. | ||
Note: currently not supporting gradients. | ||
:param tensor: :py:class:`TensorMap` with the input tensor to be transformed. | ||
:return: :py:class:`TensorMap` | ||
""" | ||
# Currently not supporting gradients | ||
if len(tensor[0].gradients_list()) != 0: | ||
raise ValueError( | ||
"Gradients not supported. Please use metatensor.remove_gradients()" | ||
" before using this module" | ||
) | ||
return self.module_map(tensor) | ||
|
||
|
||
class InvariantSiLU(torch.nn.Module): | ||
""" | ||
Module similar to :py:class:`torch.nn.SiLU` that works with | ||
:py:class:`metatensor.torch.TensorMap` objects, applying the transformation only to | ||
the invariant blocks. | ||
Applies a sigmoid linear unit transformation to each invariant block of a | ||
:py:class:`TensorMap` passed to its :py:meth:`forward` method. These are indexed by | ||
the keys in :param in_keys: at numeric indices passed in :param invariant_key_idxs:. | ||
Refer to the :py:class`torch.nn.SiLU` documentation for a more detailed description | ||
of the parameters. | ||
:param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input | ||
tensor map in the :py:meth:`forward` method. | ||
:param invariant_key_idxs: list of int, the indices of the invariant keys present in | ||
`in_keys` in the input :py:class:`TensorMap`. Only blocks for these keys will | ||
have the SiLU transformation applied. Covariant blocks will have the identity | ||
operator applied. | ||
:param out_properties: list of :py:class`Labels` (optional), the properties labels | ||
of the output. By default the output properties are relabeled using | ||
Labels.range. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_keys: Labels, | ||
invariant_key_idxs: List[int], | ||
out_properties: Optional[Labels] = None, | ||
*, | ||
in_place: bool = False, | ||
) -> None: | ||
super().__init__() | ||
modules: List[Module] = [] | ||
for i in range(len(in_keys)): | ||
if i in invariant_key_idxs: # Invariant block: apply SiLU | ||
module = torch.nn.SiLU() | ||
else: # Covariant block: apply identity operator | ||
module = torch.nn.Identity() | ||
modules.append(module) | ||
self.module_map: ModuleMap = ModuleMap(in_keys, modules, out_properties) | ||
|
||
def forward(self, tensor: TensorMap) -> TensorMap: | ||
""" | ||
Apply the transformation to the input tensor map `tensor`. | ||
Note: currently not supporting gradients. | ||
:param tensor: :py:class:`TensorMap` with the input tensor to be transformed. | ||
:return: :py:class:`TensorMap` | ||
""" | ||
# Currently not supporting gradients | ||
if len(tensor[0].gradients_list()) != 0: | ||
raise ValueError( | ||
"Gradients not supported. Please use metatensor.remove_gradients()" | ||
" before using this module" | ||
) | ||
return self.module_map(tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import pytest | ||
|
||
import metatensor | ||
|
||
|
||
torch = pytest.importorskip("torch") | ||
|
||
from metatensor.learn.nn.relu import InvariantReLU # noqa: E402 | ||
|
||
from ._rotation_utils import WignerDReal # noqa: E402 | ||
|
||
|
||
@pytest.fixture | ||
def tensor(): | ||
tensor = metatensor.load( | ||
"../metatensor-operations/tests/data/qm7-spherical-expansion.npz", | ||
use_numpy=True, | ||
).to(arrays="torch") | ||
tensor = metatensor.remove_gradients(tensor) | ||
return tensor | ||
|
||
|
||
@pytest.fixture | ||
def wigner_d_real(): | ||
return WignerDReal(lmax=4, angles=(0.87641, 1.8729, 0.9187)) | ||
|
||
|
||
def test_equivariance(tensor, wigner_d_real): | ||
""" | ||
Tests that application of an invariant ReLU layer is equivariant to O3 | ||
transformation of the input. | ||
""" | ||
# Define input and rotated input | ||
x = tensor | ||
Rx = wigner_d_real.transform_tensormap_o3(x) | ||
|
||
# Define the EquiLayerNorm module | ||
f = InvariantReLU( | ||
in_keys=x.keys, | ||
invariant_key_idxs=[i for i, key in enumerate(x.keys) if key["o3_lambda"] == 0], | ||
) | ||
|
||
# Pass both through the linear layer | ||
Rfx = wigner_d_real.transform_tensormap_o3(f(x)) # R . f(x) | ||
fRx = f(Rx) # f(R . x) | ||
|
||
assert metatensor.allclose(fRx, Rfx, atol=1e-10, rtol=1e-10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.