Skip to content

Commit

Permalink
Feat: add DOS net (#3452)
Browse files Browse the repository at this point in the history
This PR provides DOS fitting net in Pytorch.

Future TODO:
- [ ] Loss implementation
  - [ ] Training/Fine-tuning test
  - [ ] Jit test
- [ ] Doc

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: anyangml <ap@aisi.com>
  • Loading branch information
3 people committed Mar 15, 2024
1 parent d61b152 commit 2caf92c
Show file tree
Hide file tree
Showing 27 changed files with 1,019 additions and 196 deletions.
4 changes: 4 additions & 0 deletions deepmd/dpmodel/fitting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from .dipole_fitting import (
DipoleFitting,
)
from .dos_fitting import (
DOSFittingNet,
)
from .ener_fitting import (
EnergyFittingNet,
)
Expand All @@ -21,4 +24,5 @@
"DipoleFitting",
"EnergyFittingNet",
"PolarFitting",
"DOSFittingNet",
]
93 changes: 93 additions & 0 deletions deepmd/dpmodel/fitting/dos_fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
TYPE_CHECKING,
List,
Optional,
Union,
)

import numpy as np

from deepmd.dpmodel.common import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.fitting.invar_fitting import (
InvarFitting,
)

if TYPE_CHECKING:
from deepmd.dpmodel.fitting.general_fitting import (
GeneralFitting,
)

from deepmd.utils.version import (
check_version_compatibility,
)


@InvarFitting.register("dos")
class DOSFittingNet(InvarFitting):
def __init__(
self,
ntypes: int,
dim_descrpt: int,
numb_dos: int = 300,
neuron: List[int] = [120, 120, 120],
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
bias_dos: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
trainable: Union[bool, List[bool]] = True,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
mixed_types: bool = False,
exclude_types: List[int] = [],
# not used
seed: Optional[int] = None,
):
if bias_dos is not None:
self.bias_dos = bias_dos
else:
self.bias_dos = np.zeros((ntypes, numb_dos), dtype=DEFAULT_PRECISION)
super().__init__(
var_name="dos",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
dim_out=numb_dos,
neuron=neuron,
resnet_dt=resnet_dt,
bias_atom=bias_dos,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
rcond=rcond,
trainable=trainable,
activation_function=activation_function,
precision=precision,
mixed_types=mixed_types,
exclude_types=exclude_types,
)

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data["numb_dos"] = data.pop("dim_out")
data.pop("tot_ener_zero", None)
data.pop("var_name", None)
data.pop("layer_name", None)
data.pop("use_aparam_as_mask", None)
data.pop("spin", None)
data.pop("atom_ener", None)
return super().deserialize(data)

def serialize(self) -> dict:
"""Serialize the fitting to dict."""
dd = {
**super().serialize(),
"type": "dos",
}
dd["@variables"]["bias_atom_e"] = self.bias_atom_e

return dd
9 changes: 8 additions & 1 deletion deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class GeneralFitting(NativeOP, BaseFitting):
The dimension of the input descriptor.
neuron
Number of neurons :math:`N` in each hidden layer of the fitting net
bias_atom_e
Average enery per atom for each element.
resnet_dt
Time-step `dt` in the resnet construction:
:math:`y = x + dt * \phi (Wx + b)`
Expand Down Expand Up @@ -85,6 +87,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
bias_atom_e: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[List[bool]] = None,
Expand Down Expand Up @@ -125,7 +128,11 @@ def __init__(

net_dim_out = self._net_out_dim()
# init constants
self.bias_atom_e = np.zeros([self.ntypes, net_dim_out])
if bias_atom_e is None:
self.bias_atom_e = np.zeros([self.ntypes, net_dim_out])
else:
assert bias_atom_e.shape == (self.ntypes, net_dim_out)
self.bias_atom_e = bias_atom_e
if self.numb_fparam > 0:
self.fparam_avg = np.zeros(self.numb_fparam)
self.fparam_inv_std = np.ones(self.numb_fparam)
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class InvarFitting(GeneralFitting):
Number of atomic parameter
rcond
The condition number for the regression of atomic energy.
bias_atom
Bias for each element.
tot_ener_zero
Force the total energy to zero. Useful for the charge fitting.
trainable
Expand Down Expand Up @@ -117,10 +119,11 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
bias_atom: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[List[bool]] = None,
atom_ener: Optional[List[float]] = [],
atom_ener: Optional[List[float]] = None,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
layer_name: Optional[List[Optional[str]]] = None,
Expand Down Expand Up @@ -152,6 +155,7 @@ def __init__(
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
rcond=rcond,
bias_atom_e=bias_atom,
tot_ener_zero=tot_ener_zero,
trainable=trainable,
activation_function=activation_function,
Expand Down
5 changes: 5 additions & 0 deletions deepmd/infer/deep_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def output_def(self) -> ModelOutputDef:
)
)

@property
def numb_dos(self) -> int:
"""Get the number of DOS."""
return self.get_numb_dos()

def eval(
self,
coords: np.ndarray,
Expand Down
3 changes: 2 additions & 1 deletion deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def get_standard_model(model_params):
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntypes()
fitting_net["mixed_types"] = descriptor.mixed_types()
fitting_net["embedding_width"] = descriptor.get_dim_emb()
if fitting_net["type"] in ["dipole", "polar"]:
fitting_net["embedding_width"] = descriptor.get_dim_emb()
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
Expand Down
80 changes: 80 additions & 0 deletions deepmd/pt/model/model/dos_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
Optional,
)

import torch

from .dp_model import (
DPModel,
)


class DOSModel(DPModel):
model_type = "dos"

def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)

def forward(
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
model_ret = self.forward_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["atom_dos"] = model_ret["dos"]
model_predict["dos"] = model_ret["dos_redu"]

if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
return model_predict

@torch.jit.export
def forward_lower(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
model_ret = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["atom_dos"] = model_ret["dos"]
model_predict["dos"] = model_ret["dos_redu"]

else:
model_predict = model_ret
return model_predict
8 changes: 8 additions & 0 deletions deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from deepmd.pt.model.task.dipole import (
DipoleFittingNet,
)
from deepmd.pt.model.task.dos import (
DOSFittingNet,
)
from deepmd.pt.model.task.ener import (
EnergyFittingNet,
EnergyFittingNetDirect,
Expand Down Expand Up @@ -45,6 +48,9 @@ def __new__(
from deepmd.pt.model.model.dipole_model import (
DipoleModel,
)
from deepmd.pt.model.model.dos_model import (
DOSModel,
)
from deepmd.pt.model.model.ener_model import (
EnergyModel,
)
Expand All @@ -68,6 +74,8 @@ def __new__(
cls = DipoleModel
elif isinstance(fitting, PolarFittingNet):
cls = PolarModel
elif isinstance(fitting, DOSFittingNet):
cls = DOSModel
# else: unknown fitting type, fall back to DPModel
return super().__new__(cls)

Expand Down

0 comments on commit 2caf92c

Please sign in to comment.