-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More metatensor.learn.nn
building blocks
#513
Conversation
Here is a pre-built version of the code in this pull request: wheels.zip, you can install it locally by unzipping |
As I mentioned today, I think with similar effort you could implement an EquiLayerNorm that also works on non-invariant layers. The only important thing is that the same scaling is applied to all |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add to docs/src/learn/reference/nn/index.rst
.. autoclass:: metatensor.learn.nn.LayerNorm
:members:
.. autoclass:: metatensor.learn.nn.EquiLayerNorm
:members:
.. autoclass:: metatensor.learn.nn.Sequential
:members:
Helps to check how the API looks in the doc. Not important, but useful.
It might be useful at some later point (not in this PR!) if the Equi* modules share more things in common to extract a subclass out of this. But for the moment this only creates code noise.
class EquiModuleMap(ModuleMap):
def __init__(
self,
in_keys: Labels,
invariant_key_idxs: Union[int, List[int]]
modules: List[Module],
out_properties: Optional[List[Labels]] = None,
):
self._invariant_key_idxs = invariant_key_idxs
super().__init__(in_keys, modules, out_properties)
@classmethod
def from_module(
cls,
in_keys: Labels,
invariant_key_idxs: Union[int, List[int]]
module: Module,
out_properties: Optional[List[Labels]] = None,
):
self._invariant_key_idxs = invariant_key_idxs
ModuleMap.from_module(in_keys, module, out_properties)
def forward(self, tensor: TensorMap) -> TensorMap:
out_blocks: List[TensorBlock] = []
for i in range(len(tensor)):
key = tensor.keys.entry(i)
block = tensor.keys.block(i)
if i in self._invariant_key_idxs:
out_block = self.forward_invariant(key, block)
else:
out_block = self.forward_equivariant(key, block)
for parameter, gradient in block.gradients():
if len(gradient.gradients_list()) != 0:
raise NotImplementedError(
"gradients of gradients are not supported"
)
out_block.add_gradient(
parameter=parameter,
gradient=self.forward_block(key, gradient),
)
out_blocks.append(out_block)
return TensorMap(tensor.keys, out_blocks)
def forward_invariant_block(self, block: TensorBlock) -> TensorBlock:
return self._forward_block(block)
def forward_equivariant_block(self, block: TensorBlock) -> TensorBlock:
return block
class EquiLayerNorm(EquiModuleMap):
def __init__(
self,
in_keys: Labels,
invariant_key_idxs: List[int],
normalized_shape: Union[
Union[int, List[int], torch.Size],
Union[List[int], List[List[int]], List[torch.Size]],
],
eps: Union[float, List[float]] = 1e-5,
elementwise_affine: Union[bool, List[bool]] = True,
*,
bias: Union[bool, List[bool]] = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
out_properties: Optional[List[Labels]] = None,
):
# do all the checks that you do in this PR
modules = ...
super().__init__(in_keys, invariant_key_idxs, modules, out_properties)
#def forward_equivariant_block(self, block: TensorBlock) -> TensorBlock:
# # ... do what Michele suggested
# return block
metatensor.learn.nn
building blocks: LayerNorm
, EquiLayerNorm
, and Sequential
metatensor.learn.nn
building blocks: {Equi}LayerNorm
, EquiLinear
, {Equi}Tanh
, Sequential
metatensor.learn.nn
building blocks: {Equi}LayerNorm
, EquiLinear
, {Equi}Tanh
, Sequential
metatensor.learn.nn
building blocks
metatensor.learn.nn
building blocksmetatensor.learn.nn
building blocks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to be missing some tests that the code can be used with TorchScript
) | ||
|
||
prediction = model(tensor) | ||
assert metatensor.equal_metadata( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should also check that the values are the same, at least for one block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check what values are the same sorry? I have no reference data here, should I write a regression test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant writing the same sequential for normal torch tensor & sending a single block through. But with random initialization, this might be a lot of complexity and not much advantages, so we can maybe do without
) | ||
|
||
prediction = model(tensor) | ||
assert metatensor.equal_metadata( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant writing the same sequential for normal torch tensor & sending a single block through. But with random initialization, this might be a lot of complexity and not much advantages, so we can maybe do without
11c89eb
to
f72174b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cleaned up the set of commits, this should be good to go!
Description
More
metatensor.learn.nn
building blocks:LayerNorm
: applies a layer norm to all blocksInvariantLayerNorm
: applies a layer norm only to invariant blocks at specified key indices. Applies the identity operation to all others, thus maintaining equivariance.EquiLinear
: specification ofbias
to only be applied to invariant blocksTanh
: tanh transformations on each blockInvariantTanh
: only on invariant blocksSequential
: build sequential models, accepting anymetatensor.learn.nn.ModuleMap
as args.Docstrings in place and unit tests to come, but feedback welcome in the meantime @agoscinski @abmazitov @Luthaf
Usage
Contributor (creator of pull-request) checklist
Reviewer checklist
📚 Documentation preview 📚: https://metatensor--513.org.readthedocs.build/en/513/