Skip to content

Commit

Permalink
Apply fix
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Apr 4, 2024
1 parent bf44a2b commit 18c40b5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
10 changes: 9 additions & 1 deletion src/metatensor/models/experimental/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
DEFAULT_MODEL_HYPERS = DEFAULT_HYPERS["model"]


class Identity(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: TensorMap) -> TensorMap:
return x


class MLPMap(torch.nn.Module):
def __init__(self, all_species: List[int], hypers: dict) -> None:
super().__init__()
Expand Down Expand Up @@ -248,7 +256,7 @@ def __init__(
if hypers_bpnn["layernorm"]:
self.layernorm = LayerNormMap(self.all_species, soap_size)
else:
self.layernorm = torch.nn.Identity()
self.layernorm = Identity()

self.bpnn = MLPMap(self.all_species, hypers_bpnn)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_torchscript_with_identity():
},
)
hypers = copy.deepcopy(DEFAULT_HYPERS["model"])
hypers["model"]["bpnn"]["layernorm"] = False
hypers["bpnn"]["layernorm"] = False
soap_bpnn = Model(capabilities, hypers)
soap_bpnn = torch.jit.script(soap_bpnn)

Expand Down

0 comments on commit 18c40b5

Please sign in to comment.