Skip to content

Commit

Permalink
Chore: refactor LinearAtomicModel serialize/deserialize (#3451)
Browse files Browse the repository at this point in the history
This PR refactors the serialization/ deserialization in
LinearEnergyAtomicModel using the plugin mechanism introduced in #3438 .

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anyangml and pre-commit-ci[bot] committed Mar 13, 2024
1 parent 36fdf53 commit dda4bc6
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 55 deletions.
1 change: 1 addition & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)


@BaseAtomicModel.register("standard")
class DPAtomicModel(BaseAtomicModel):
"""Model give atomic prediction of some physical property.
Expand Down
49 changes: 21 additions & 28 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import sys
from abc import (
abstractmethod,
)
from typing import (
Dict,
List,
Expand Down Expand Up @@ -225,40 +221,38 @@ def fitting_output_def(self) -> FittingOutputDef:
]
)

@staticmethod
def serialize(models, type_map) -> dict:
def serialize(self) -> dict:
return {
"@class": "Model",
"type": "linear",
"@version": 1,
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
"type_map": type_map,
"models": [model.serialize() for model in self.models],
"type_map": self.type_map,
}

@staticmethod
def deserialize(data) -> Tuple[List[BaseAtomicModel], List[str]]:
@classmethod
def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
model_names = data["model_name"]
type_map = data["type_map"]
type_map = data.pop("type_map")
models = [
getattr(sys.modules[__name__], name).deserialize(model)
for name, model in zip(model_names, data["models"])
BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model)
for model in data["models"]
]
return models, type_map
data.pop("models")
return cls(models, type_map, **data)

@abstractmethod
def _compute_weight(
self,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlists_: List[np.ndarray],
) -> np.ndarray:
) -> List[np.ndarray]:
"""This should be a list of user defined weights that matches the number of models to be combined."""
raise NotImplementedError
nmodels = len(self.models)
return [np.ones(1) / nmodels for _ in range(nmodels)]

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
Expand Down Expand Up @@ -335,10 +329,10 @@ def serialize(self) -> dict:
{
"@class": "Model",
"type": "zbl",
"@version": 1,
"models": LinearEnergyAtomicModel.serialize(
[self.dp_model, self.zbl_model], self.type_map
),
"@version": 2,
"models": LinearEnergyAtomicModel(
models=[self.models[0], self.models[1]], type_map=self.type_map
).serialize(),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
Expand All @@ -349,16 +343,15 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class")
data.pop("type")
sw_rmin = data.pop("sw_rmin")
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")

([dp_model, zbl_model], type_map) = LinearEnergyAtomicModel.deserialize(
data.pop("models")
)
linear_model = LinearEnergyAtomicModel.deserialize(data.pop("models"))
dp_model, zbl_model = linear_model.models
type_map = linear_model.type_map

return cls(
dp_model=dp_model,
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)


@BaseAtomicModel.register("pairtab")
class PairTabAtomicModel(BaseAtomicModel):
"""Pairwise tabulation energy model.
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
log = logging.getLogger(__name__)


@BaseAtomicModel.register("standard")
class DPAtomicModel(torch.nn.Module, BaseAtomicModel):
"""Model give atomic prediction of some physical property.
Expand Down
52 changes: 25 additions & 27 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import sys
from abc import (
abstractmethod,
)
from typing import (
Dict,
List,
Expand Down Expand Up @@ -260,35 +256,38 @@ def fitting_output_def(self) -> FittingOutputDef:
]
)

@staticmethod
def serialize(models, type_map) -> dict:
def serialize(self) -> dict:
return {
"@class": "Model",
"@version": 1,
"type": "linear",
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
"type_map": type_map,
"models": [model.serialize() for model in self.models],
"type_map": self.type_map,
}

@staticmethod
def deserialize(data) -> Tuple[List[BaseAtomicModel], List[str]]:
@classmethod
def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
model_names = data["model_name"]
type_map = data["type_map"]
data.pop("@class")
data.pop("type")
type_map = data.pop("type_map")
models = [
getattr(sys.modules[__name__], name).deserialize(model)
for name, model in zip(model_names, data["models"])
BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model)
for model in data["models"]
]
return models, type_map
data.pop("models")
return cls(models, type_map, **data)

@abstractmethod
def _compute_weight(
self, extended_coord, extended_atype, nlists_
) -> List[torch.Tensor]:
"""This should be a list of user defined weights that matches the number of models to be combined."""
raise NotImplementedError
nmodels = len(self.models)
return [
torch.ones(1, dtype=torch.float64, device=env.DEVICE) / nmodels
for _ in range(nmodels)
]

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
Expand Down Expand Up @@ -400,11 +399,11 @@ def serialize(self) -> dict:
dd.update(
{
"@class": "Model",
"@version": 1,
"@version": 2,
"type": "zbl",
"models": LinearEnergyAtomicModel.serialize(
[self.models[0], self.models[1]], self.type_map
),
"models": LinearEnergyAtomicModel(
models=[self.models[0], self.models[1]], type_map=self.type_map
).serialize(),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
Expand All @@ -415,14 +414,13 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
check_version_compatibility(data.pop("@version", 1), 2, 1)
sw_rmin = data.pop("sw_rmin")
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")

[dp_model, zbl_model], type_map = LinearEnergyAtomicModel.deserialize(
data.pop("models")
)
linear_model = LinearEnergyAtomicModel.deserialize(data.pop("models"))
dp_model, zbl_model = linear_model.models
type_map = linear_model.type_map

data.pop("@class", None)
data.pop("type", None)
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)


@BaseAtomicModel.register("pairtab")
class PairTabAtomicModel(torch.nn.Module, BaseAtomicModel):
"""Pairwise tabulation energy model.
Expand Down

0 comments on commit dda4bc6

Please sign in to comment.