Skip to content

Commit

Permalink
Implement ReLU and SiLU module maps in mts.learn.nn (#597)
Browse files Browse the repository at this point in the history
* Implement mts.learn.nn.ReLU

* Implement mts.learn.nn.SiLU

* Update change log

* Update activation function docstrings
  • Loading branch information
jwa7 committed May 1, 2024
1 parent 4b2ad59 commit 7e8a414
Show file tree
Hide file tree
Showing 10 changed files with 391 additions and 6 deletions.
12 changes: 12 additions & 0 deletions docs/src/learn/reference/nn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ Modules
.. autoclass:: metatensor.learn.nn.InvariantTanh
:members:

.. autoclass:: metatensor.learn.nn.ReLU
:members:

.. autoclass:: metatensor.learn.nn.InvariantReLU
:members:

.. autoclass:: metatensor.learn.nn.SiLU
:members:

.. autoclass:: metatensor.learn.nn.InvariantSiLU
:members:

.. autoclass:: metatensor.learn.nn.LayerNorm
:members:

Expand Down
2 changes: 2 additions & 0 deletions python/metatensor-learn/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

### Added

- Added torch-style activation function module maps to `metatensor.learn.nn`: `ReLU`,
`InvariantReLU`, `SiLU`, and `InvariantSiLU` (#597)
- Added torch-style neural network module maps to `metatensor.learn.nn`:
`LayerNorm`, `InvariantLayerNorm`, `EquivariantLinear`, `Sequential`, `Tanh`,
and `InvariantTanh` (#513)
Expand Down
2 changes: 2 additions & 0 deletions python/metatensor-learn/metatensor/learn/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@

from .layer_norm import InvariantLayerNorm, LayerNorm # noqa: F401
from .linear import EquivariantLinear, Linear # noqa: F401
from .relu import InvariantReLU, ReLU # noqa: F401
from .sequential import Sequential # noqa: F401
from .silu import InvariantSiLU, SiLU # noqa: F401
from .tanh import InvariantTanh, Tanh # noqa: F401
114 changes: 114 additions & 0 deletions python/metatensor-learn/metatensor/learn/nn/relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import List, Optional

import torch
from torch.nn import Module

from .._backend import Labels, TensorMap
from .module_map import ModuleMap


class ReLU(Module):
"""
Module similar to :py:class:`torch.nn.ReLU` that works with
:py:class:`metatensor.torch.TensorMap` objects.
Applies a rectified linear unit transformation transformation to each block of a
:py:class:`TensorMap` passed to its forward method, indexed by :param in_keys:.
Refer to the :py:class`torch.nn.ReLU` documentation for a more detailed description
of the parameters.
:param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input
tensor map in the :py:meth:`forward` method.
:param out_properties: list of :py:class`Labels` (optional), the properties labels
of the output. By default the output properties are relabeled using
Labels.range.
"""

def __init__(
self,
in_keys: Labels,
out_properties: Optional[Labels] = None,
*,
in_place: bool = False,
) -> None:
super().__init__()
modules: List[Module] = [torch.nn.ReLU() for i in range(len(in_keys))]
self.module_map = ModuleMap(in_keys, modules, out_properties)

def forward(self, tensor: TensorMap) -> TensorMap:
"""
Apply the transformation to the input tensor map `tensor`.
Note: currently not supporting gradients.
:param tensor: :py:class:`TensorMap` with the input tensor to be transformed.
:return: :py:class:`TensorMap`
"""
# Currently not supporting gradients
if len(tensor[0].gradients_list()) != 0:
raise ValueError(
"Gradients not supported. Please use metatensor.remove_gradients()"
" before using this module"
)
return self.module_map(tensor)


class InvariantReLU(torch.nn.Module):
"""
Module similar to :py:class:`torch.nn.ReLU` that works with
:py:class:`metatensor.torch.TensorMap` objects, applying the transformation only to
the invariant blocks.
Applies a rectified linear unit transformation to each invariant block of a
:py:class:`TensorMap` passed to its :py:meth:`forward` method. These are indexed by
the keys in :param in_keys: at numeric indices passed in :param invariant_key_idxs:.
Refer to the :py:class`torch.nn.ReLU` documentation for a more detailed description
of the parameters.
:param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input
tensor map in the :py:meth:`forward` method.
:param invariant_key_idxs: list of int, the indices of the invariant keys present in
`in_keys` in the input :py:class:`TensorMap`. Only blocks for these keys will
have the ReLU transformation applied. Covariant blocks will have the identity
operator applied.
:param out_properties: list of :py:class`Labels` (optional), the properties labels
of the output. By default the output properties are relabeled using
Labels.range.
"""

def __init__(
self,
in_keys: Labels,
invariant_key_idxs: List[int],
out_properties: Optional[Labels] = None,
*,
in_place: bool = False,
) -> None:
super().__init__()
modules: List[Module] = []
for i in range(len(in_keys)):
if i in invariant_key_idxs: # Invariant block: apply ReLU
module = torch.nn.ReLU()
else: # Covariant block: apply identity operator
module = torch.nn.Identity()
modules.append(module)
self.module_map: ModuleMap = ModuleMap(in_keys, modules, out_properties)

def forward(self, tensor: TensorMap) -> TensorMap:
"""
Apply the transformation to the input tensor map `tensor`.
Note: currently not supporting gradients.
:param tensor: :py:class:`TensorMap` with the input tensor to be transformed.
:return: :py:class:`TensorMap`
"""
# Currently not supporting gradients
if len(tensor[0].gradients_list()) != 0:
raise ValueError(
"Gradients not supported. Please use metatensor.remove_gradients()"
" before using this module"
)
return self.module_map(tensor)
114 changes: 114 additions & 0 deletions python/metatensor-learn/metatensor/learn/nn/silu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import List, Optional

import torch
from torch.nn import Module

from .._backend import Labels, TensorMap
from .module_map import ModuleMap


class SiLU(Module):
"""
Module similar to :py:class:`torch.nn.SiLU` that works with
:py:class:`metatensor.torch.TensorMap` objects.
Applies a sigmoid linear unit transformation transformation to each block of a
:py:class:`TensorMap` passed to its forward method, indexed by :param in_keys:.
Refer to the :py:class`torch.nn.SiLU` documentation for a more detailed description
of the parameters.
:param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input
tensor map in the :py:meth:`forward` method.
:param out_properties: list of :py:class`Labels` (optional), the properties labels
of the output. By default the output properties are relabeled using
Labels.range.
"""

def __init__(
self,
in_keys: Labels,
out_properties: Optional[Labels] = None,
*,
in_place: bool = False,
) -> None:
super().__init__()
modules: List[Module] = [torch.nn.SiLU() for i in range(len(in_keys))]
self.module_map = ModuleMap(in_keys, modules, out_properties)

def forward(self, tensor: TensorMap) -> TensorMap:
"""
Apply the transformation to the input tensor map `tensor`.
Note: currently not supporting gradients.
:param tensor: :py:class:`TensorMap` with the input tensor to be transformed.
:return: :py:class:`TensorMap`
"""
# Currently not supporting gradients
if len(tensor[0].gradients_list()) != 0:
raise ValueError(
"Gradients not supported. Please use metatensor.remove_gradients()"
" before using this module"
)
return self.module_map(tensor)


class InvariantSiLU(torch.nn.Module):
"""
Module similar to :py:class:`torch.nn.SiLU` that works with
:py:class:`metatensor.torch.TensorMap` objects, applying the transformation only to
the invariant blocks.
Applies a sigmoid linear unit transformation to each invariant block of a
:py:class:`TensorMap` passed to its :py:meth:`forward` method. These are indexed by
the keys in :param in_keys: at numeric indices passed in :param invariant_key_idxs:.
Refer to the :py:class`torch.nn.SiLU` documentation for a more detailed description
of the parameters.
:param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input
tensor map in the :py:meth:`forward` method.
:param invariant_key_idxs: list of int, the indices of the invariant keys present in
`in_keys` in the input :py:class:`TensorMap`. Only blocks for these keys will
have the SiLU transformation applied. Covariant blocks will have the identity
operator applied.
:param out_properties: list of :py:class`Labels` (optional), the properties labels
of the output. By default the output properties are relabeled using
Labels.range.
"""

def __init__(
self,
in_keys: Labels,
invariant_key_idxs: List[int],
out_properties: Optional[Labels] = None,
*,
in_place: bool = False,
) -> None:
super().__init__()
modules: List[Module] = []
for i in range(len(in_keys)):
if i in invariant_key_idxs: # Invariant block: apply SiLU
module = torch.nn.SiLU()
else: # Covariant block: apply identity operator
module = torch.nn.Identity()
modules.append(module)
self.module_map: ModuleMap = ModuleMap(in_keys, modules, out_properties)

def forward(self, tensor: TensorMap) -> TensorMap:
"""
Apply the transformation to the input tensor map `tensor`.
Note: currently not supporting gradients.
:param tensor: :py:class:`TensorMap` with the input tensor to be transformed.
:return: :py:class:`TensorMap`
"""
# Currently not supporting gradients
if len(tensor[0].gradients_list()) != 0:
raise ValueError(
"Gradients not supported. Please use metatensor.remove_gradients()"
" before using this module"
)
return self.module_map(tensor)
8 changes: 2 additions & 6 deletions python/metatensor-learn/metatensor/learn/nn/tanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ class Tanh(Module):
:param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input
tensor map in the :py:meth:`forward` method.
:param out_properties: list of :py:class`Labels` (optional), the properties labels
of the output. Used to determine the properties labels of the output. Because a
module could change the number of properties, the labels of the properties
cannot be preserved. By default the output properties are relabeled using
of the output. By default the output properties are relabeled using
Labels.range.
"""

Expand Down Expand Up @@ -72,9 +70,7 @@ class InvariantTanh(torch.nn.Module):
have the tanh transformation applied. Covariant blocks will have the identity
operator applied.
:param out_properties: list of :py:class`Labels` (optional), the properties labels
of the output. Used to determine the properties labels of the output. Because a
module could change the number of properties, the labels of the properties
cannot be preserved. By default the output properties are relabeled using
of the output. By default the output properties are relabeled using
Labels.range.
"""

Expand Down
47 changes: 47 additions & 0 deletions python/metatensor-learn/tests/relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest

import metatensor


torch = pytest.importorskip("torch")

from metatensor.learn.nn.relu import InvariantReLU # noqa: E402

from ._rotation_utils import WignerDReal # noqa: E402


@pytest.fixture
def tensor():
tensor = metatensor.load(
"../metatensor-operations/tests/data/qm7-spherical-expansion.npz",
use_numpy=True,
).to(arrays="torch")
tensor = metatensor.remove_gradients(tensor)
return tensor


@pytest.fixture
def wigner_d_real():
return WignerDReal(lmax=4, angles=(0.87641, 1.8729, 0.9187))


def test_equivariance(tensor, wigner_d_real):
"""
Tests that application of an invariant ReLU layer is equivariant to O3
transformation of the input.
"""
# Define input and rotated input
x = tensor
Rx = wigner_d_real.transform_tensormap_o3(x)

# Define the EquiLayerNorm module
f = InvariantReLU(
in_keys=x.keys,
invariant_key_idxs=[i for i, key in enumerate(x.keys) if key["o3_lambda"] == 0],
)

# Pass both through the linear layer
Rfx = wigner_d_real.transform_tensormap_o3(f(x)) # R . f(x)
fRx = f(Rx) # f(R . x)

assert metatensor.allclose(fRx, Rfx, atol=1e-10, rtol=1e-10)
19 changes: 19 additions & 0 deletions python/metatensor-learn/tests/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,19 @@ def test_sequential_mlp(tensor):
nn.Linear(
in_keys=in_keys,
in_features=4,
out_features=2,
bias=True,
dtype=torch.float64,
),
nn.ReLU(in_keys=in_keys),
nn.Linear(
in_keys=in_keys,
in_features=2,
out_features=1,
bias=True,
dtype=torch.float64,
),
nn.SiLU(in_keys=in_keys),
)

prediction = model(tensor)
Expand Down Expand Up @@ -94,10 +103,20 @@ def test_sequential_equi_mlp(tensor, wigner_d_real):
in_keys=in_keys,
invariant_key_idxs=invariant_key_idxs,
in_features=4,
out_features=2,
bias=True,
dtype=torch.float64,
),
nn.InvariantReLU(in_keys=in_keys, invariant_key_idxs=invariant_key_idxs),
nn.EquivariantLinear(
in_keys=in_keys,
invariant_key_idxs=invariant_key_idxs,
in_features=2,
out_features=1,
bias=True,
dtype=torch.float64,
),
nn.InvariantSiLU(in_keys=in_keys, invariant_key_idxs=invariant_key_idxs),
)

prediction = model(tensor)
Expand Down
Loading

0 comments on commit 7e8a414

Please sign in to comment.