Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More metatensor.learn.nn building blocks #513

Merged
merged 4 commits into from
Apr 29, 2024
Merged

More metatensor.learn.nn building blocks #513

merged 4 commits into from
Apr 29, 2024

Conversation

jwa7
Copy link
Member

@jwa7 jwa7 commented Feb 13, 2024

Description

More metatensor.learn.nn building blocks:

  • LayerNorm: applies a layer norm to all blocks
  • InvariantLayerNorm: applies a layer norm only to invariant blocks at specified key indices. Applies the identity operation to all others, thus maintaining equivariance.
  • EquiLinear: specification of bias to only be applied to invariant blocks
  • Tanh: tanh transformations on each block
  • InvariantTanh: only on invariant blocks
  • Sequential: build sequential models, accepting any metatensor.learn.nn.ModuleMap as args.

Docstrings in place and unit tests to come, but feedback welcome in the meantime @agoscinski @abmazitov @Luthaf

Usage

import metatensor
from metatensor.learn import nn as nn

X = metatensor.load("path/to/X.npz")
Y = metatensor.load("path/to/Y.npz")

keys = X.keys
in_features = [len(X[key].properties) for key in keys]
out_features = [len(Y[key].properties) for key in keys]
invariant_key_idxs = [i for i, key in enumerate(keys) if key["spherical_harmonics_l"] == 0]

model = nn.Sequential(
        in_keys,
        nn.LayerNorm(
            in_keys=in_keys,
            normalized_shape=[
                len(tensor.block(key).properties) for key in invariant_key_idxs
            ],
            dtype=torch.float64,
        ),
        nn.Linear(
            in_keys=in_keys,
            in_features=in_features,
            out_features=4,
            bias=True,
            dtype=torch.float64,
        ),
        nn.Tanh(in_keys=in_keys),
        nn.Linear(
            in_keys=in_keys,
            in_features=4,
            out_features=1,
            out_properties=[Y.block(key).properties for key in in_keys],
            bias=True,
            dtype=torch.float64,
        ),
    )

model(X)
>>> TensorMap with 8 blocks
keys: spherical_harmonics_l  species_center
                0                  1
                1                  1
                      ...
                2                  14
                3                  14

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • Issue referenced (for PRs that solve an issue)?

Reviewer checklist

  • CHANGELOG updated with public API or any other important changes?

📚 Documentation preview 📚: https://metatensor--513.org.readthedocs.build/en/513/

Copy link

github-actions bot commented Feb 13, 2024

Here is a pre-built version of the code in this pull request: wheels.zip, you can install it locally by unzipping wheels.zip and using pip to install the file matching your system

@ceriottm
Copy link
Contributor

As I mentioned today, I think with similar effort you could implement an EquiLayerNorm that also works on non-invariant layers. The only important thing is that the same scaling is applied to all components for one property. I think it'd be worth doing this in this PR - let me know if I'm missing something or you want to discuss this.

Copy link
Contributor

@agoscinski agoscinski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add to docs/src/learn/reference/nn/index.rst

.. autoclass:: metatensor.learn.nn.LayerNorm
   :members:

.. autoclass:: metatensor.learn.nn.EquiLayerNorm
   :members:

.. autoclass:: metatensor.learn.nn.Sequential
   :members:

Helps to check how the API looks in the doc. Not important, but useful.


It might be useful at some later point (not in this PR!) if the Equi* modules share more things in common to extract a subclass out of this. But for the moment this only creates code noise.

class EquiModuleMap(ModuleMap):
    def __init__(
        self,
        in_keys: Labels,
        invariant_key_idxs: Union[int, List[int]]
        modules: List[Module],
        out_properties: Optional[List[Labels]] = None,
    ):
        self._invariant_key_idxs = invariant_key_idxs
        super().__init__(in_keys, modules, out_properties)

    @classmethod
    def from_module(
        cls,
        in_keys: Labels,
        invariant_key_idxs: Union[int, List[int]]
        module: Module,
        out_properties: Optional[List[Labels]] = None,
    ):
        self._invariant_key_idxs = invariant_key_idxs
        ModuleMap.from_module(in_keys, module, out_properties)

    def forward(self, tensor: TensorMap) -> TensorMap:
        out_blocks: List[TensorBlock] = []
        for i in range(len(tensor)):
            key = tensor.keys.entry(i)
            block = tensor.keys.block(i)
            if i in self._invariant_key_idxs:
                out_block = self.forward_invariant(key, block)
            else:
                out_block = self.forward_equivariant(key, block)

            for parameter, gradient in block.gradients():
                if len(gradient.gradients_list()) != 0:
                    raise NotImplementedError(
                        "gradients of gradients are not supported"
                    )
                out_block.add_gradient(
                    parameter=parameter,
                    gradient=self.forward_block(key, gradient),
                )
            out_blocks.append(out_block)

        return TensorMap(tensor.keys, out_blocks)

    def forward_invariant_block(self, block: TensorBlock) -> TensorBlock:
        return self._forward_block(block)

    def forward_equivariant_block(self, block: TensorBlock) -> TensorBlock:
        return block


class EquiLayerNorm(EquiModuleMap):
    def __init__(
        self,
        in_keys: Labels,
        invariant_key_idxs: List[int],
        normalized_shape: Union[
            Union[int, List[int], torch.Size],
            Union[List[int], List[List[int]], List[torch.Size]],
        ],
        eps: Union[float, List[float]] = 1e-5,
        elementwise_affine: Union[bool, List[bool]] = True,
        *,
        bias: Union[bool, List[bool]] = True,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
        out_properties: Optional[List[Labels]] = None,
    ):
        # do all the checks that you do in this PR
        modules = ...
        super().__init__(in_keys, invariant_key_idxs, modules, out_properties)

    #def forward_equivariant_block(self, block: TensorBlock) -> TensorBlock:
    #    # ... do what Michele suggested
    #    return block

python/metatensor-learn/metatensor/learn/nn/__init__.py Outdated Show resolved Hide resolved
python/metatensor-learn/metatensor/learn/nn/layer_norm.py Outdated Show resolved Hide resolved
python/metatensor-learn/metatensor/learn/nn/layer_norm.py Outdated Show resolved Hide resolved
@jwa7 jwa7 requested a review from Luthaf April 3, 2024 13:13
@jwa7 jwa7 changed the title metatensor.learn.nn building blocks: LayerNorm, EquiLayerNorm, and Sequential metatensor.learn.nn building blocks: {Equi}LayerNorm, EquiLinear, {Equi}Tanh, Sequential Apr 4, 2024
@jwa7 jwa7 changed the title metatensor.learn.nn building blocks: {Equi}LayerNorm, EquiLinear, {Equi}Tanh, Sequential metatensor.learn.nn building blocks Apr 11, 2024
@jwa7 jwa7 changed the title metatensor.learn.nn building blocks More metatensor.learn.nn building blocks Apr 11, 2024
@jwa7 jwa7 requested review from Luthaf and removed request for Luthaf April 12, 2024 08:36
Copy link
Contributor

@Luthaf Luthaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be missing some tests that the code can be used with TorchScript

python/metatensor-learn/metatensor/learn/nn/layer_norm.py Outdated Show resolved Hide resolved
python/metatensor-learn/metatensor/learn/nn/layer_norm.py Outdated Show resolved Hide resolved
docs/src/learn/reference/nn/index.rst Outdated Show resolved Hide resolved
python/metatensor-learn/metatensor/learn/nn/_utils.py Outdated Show resolved Hide resolved
python/metatensor-learn/metatensor/learn/nn/layer_norm.py Outdated Show resolved Hide resolved
python/metatensor-learn/metatensor/learn/nn/linear.py Outdated Show resolved Hide resolved
python/metatensor-learn/metatensor/learn/nn/linear.py Outdated Show resolved Hide resolved
python/metatensor-learn/metatensor/learn/nn/linear.py Outdated Show resolved Hide resolved
python/metatensor-learn/metatensor/learn/nn/tanh.py Outdated Show resolved Hide resolved
python/metatensor-learn/tests/sequential.py Show resolved Hide resolved
)

prediction = model(tensor)
assert metatensor.equal_metadata(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should also check that the values are the same, at least for one block.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check what values are the same sorry? I have no reference data here, should I write a regression test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant writing the same sequential for normal torch tensor & sending a single block through. But with random initialization, this might be a lot of complexity and not much advantages, so we can maybe do without

tox.ini Show resolved Hide resolved
@jwa7 jwa7 requested a review from Luthaf April 26, 2024 09:41
python/metatensor-torch/tests/learn/tanh.py Outdated Show resolved Hide resolved
)

prediction = model(tensor)
assert metatensor.equal_metadata(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant writing the same sequential for normal torch tensor & sending a single block through. But with random initialization, this might be a lot of complexity and not much advantages, so we can maybe do without

Copy link
Contributor

@Luthaf Luthaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cleaned up the set of commits, this should be good to go!

@Luthaf Luthaf merged commit 590e3ac into master Apr 29, 2024
23 of 24 checks passed
@Luthaf Luthaf deleted the learn-nn-modules branch April 29, 2024 15:54
@jwa7 jwa7 mentioned this pull request Apr 30, 2024
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants