Skip to content

Commit

Permalink
Fix bug in scripted SOAP-BPNN without LayerNorm (#166)
Browse files Browse the repository at this point in the history
---------
Co-authored-by: Sanggyu Chong <schong1215@gmail.com>
  • Loading branch information
frostedoyster committed Apr 4, 2024
1 parent 7badc61 commit 0fb52b9
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 6 deletions.
10 changes: 9 additions & 1 deletion src/metatensor/models/experimental/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
DEFAULT_MODEL_HYPERS = DEFAULT_HYPERS["model"]


class Identity(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: TensorMap) -> TensorMap:
return x


class MLPMap(torch.nn.Module):
def __init__(self, all_species: List[int], hypers: dict) -> None:
super().__init__()
Expand Down Expand Up @@ -248,7 +256,7 @@ def __init__(
if hypers_bpnn["layernorm"]:
self.layernorm = LayerNormMap(self.all_species, soap_size)
else:
self.layernorm = torch.nn.Identity()
self.layernorm = Identity()

self.bpnn = MLPMap(self.all_species, hypers_bpnn)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import copy

import ase
import torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput, systems_to_torch

from metatensor.models.experimental.soap_bpnn import DEFAULT_HYPERS, Model

Expand All @@ -18,7 +21,44 @@ def test_torchscript():
},
)
soap_bpnn = Model(capabilities, DEFAULT_HYPERS["model"])
torch.jit.script(soap_bpnn, {"energy": soap_bpnn.capabilities.outputs["energy"]})
soap_bpnn = torch.jit.script(soap_bpnn)

system = ase.Atoms(
"OHCN",
positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]],
)
soap_bpnn(
[systems_to_torch(system)],
{"energy": soap_bpnn.capabilities.outputs["energy"]},
)


def test_torchscript_with_identity():
"""Tests that the model can be jitted."""

capabilities = ModelCapabilities(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
outputs={
"energy": ModelOutput(
quantity="energy",
unit="eV",
)
},
)
hypers = copy.deepcopy(DEFAULT_HYPERS["model"])
hypers["bpnn"]["layernorm"] = False
soap_bpnn = Model(capabilities, hypers)
soap_bpnn = torch.jit.script(soap_bpnn)

system = ase.Atoms(
"OHCN",
positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]],
)
soap_bpnn(
[systems_to_torch(system)],
{"energy": soap_bpnn.capabilities.outputs["energy"]},
)


def test_torchscript_save():
Expand All @@ -36,8 +76,6 @@ def test_torchscript_save():
)
soap_bpnn = Model(capabilities, DEFAULT_HYPERS["model"])
torch.jit.save(
torch.jit.script(
soap_bpnn, {"energy": soap_bpnn.capabilities.outputs["energy"]}
),
torch.jit.script(soap_bpnn),
"soap_bpnn.pt",
)

0 comments on commit 0fb52b9

Please sign in to comment.