Skip to content

Commit

Permalink
feat: atom_ener in energy fitting (#3370)
Browse files Browse the repository at this point in the history
Also, fix the TF serialization issue (it tried to store a tensor instead
of a NumPy array).

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com>
  • Loading branch information
njzjz and wanghan-iapcm committed Mar 1, 2024
1 parent 54b14ef commit f684be8
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 5 deletions.
33 changes: 32 additions & 1 deletion deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ class GeneralFitting(NativeOP, BaseFitting):
different fitting nets for different atom types.
exclude_types: List[int]
Atomic contributions of the excluded atom types are set zero.
remove_vaccum_contribution: List[bool], optional
Remove vaccum contribution before the bias is added. The list assigned each
type. For `mixed_types` provide `[True]`, otherwise it should be a list of the same
length as `ntypes` signaling if or not removing the vaccum contribution for the atom types in the list.
"""

def __init__(
Expand All @@ -95,6 +98,7 @@ def __init__(
spin: Any = None,
mixed_types: bool = True,
exclude_types: List[int] = [],
remove_vaccum_contribution: Optional[List[bool]] = None,
):
self.var_name = var_name
self.ntypes = ntypes
Expand All @@ -119,6 +123,7 @@ def __init__(
self.exclude_types = exclude_types
if self.spin is not None:
raise NotImplementedError("spin is not supported")
self.remove_vaccum_contribution = remove_vaccum_contribution

self.emask = AtomExcludeMask(self.ntypes, self.exclude_types)

Expand Down Expand Up @@ -298,6 +303,14 @@ def _call_common(
"which is not consistent with {self.dim_descrpt}."
)
xx = descriptor
if self.remove_vaccum_contribution is not None:
# TODO: Idealy, the input for vaccum should be computed;
# we consider it as always zero for convenience.
# Needs a compute_input_stats for vaccum passed from the
# descriptor.
xx_zeros = np.zeros_like(xx)
else:
xx_zeros = None
# check fparam dim, concate to input descriptor
if self.numb_fparam > 0:
assert fparam is not None, "fparam should not be None"
Expand All @@ -312,6 +325,11 @@ def _call_common(
[xx, fparam],
axis=-1,
)
if xx_zeros is not None:
xx_zeros = np.concatenate(
[xx_zeros, fparam],
axis=-1,
)
# check aparam dim, concate to input descriptor
if self.numb_aparam > 0:
assert aparam is not None, "aparam should not be None"
Expand All @@ -326,6 +344,11 @@ def _call_common(
[xx, aparam],
axis=-1,
)
if xx_zeros is not None:
xx_zeros = np.concatenate(
[xx_zeros, aparam],
axis=-1,
)

# calcualte the prediction
if not self.mixed_types:
Expand All @@ -335,11 +358,19 @@ def _call_common(
(atype == type_i).reshape([nf, nloc, 1]), [1, 1, net_dim_out]
)
atom_property = self.nets[(type_i,)](xx)
if self.remove_vaccum_contribution is not None and not (
len(self.remove_vaccum_contribution) > type_i
and not self.remove_vaccum_contribution[type_i]
):
assert xx_zeros is not None
atom_property -= self.nets[(type_i,)](xx_zeros)
atom_property = atom_property + self.bias_atom_e[type_i]
atom_property = atom_property * mask
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
else:
outs = self.nets[()](xx) + self.bias_atom_e[atype]
if xx_zeros is not None:
outs -= self.nets[()](xx_zeros)
# nf x nloc
exclude_mask = self.emask.build_type_exclude_mask(atype)
# nf x nloc x nod
Expand Down
5 changes: 3 additions & 2 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ def __init__(
raise NotImplementedError("use_aparam_as_mask is not implemented")
if layer_name is not None:
raise NotImplementedError("layer_name is not implemented")
if atom_ener is not None and atom_ener != []:
raise NotImplementedError("atom_ener is not implemented")

self.dim_out = dim_out
self.atom_ener = atom_ener
Expand All @@ -159,6 +157,9 @@ def __init__(
spin=spin,
mixed_types=mixed_types,
exclude_types=exclude_types,
remove_vaccum_contribution=None
if atom_ener is None or len([x for x in atom_ener if x is not None]) == 0
else [x is not None for x in atom_ener],
)

def serialize(self) -> dict:
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def __init__(
rcond=rcond,
seed=seed,
exclude_types=exclude_types,
remove_vaccum_contribution=None
if atom_ener is None or len([x for x in atom_ener if x is not None]) == 0
else [x is not None for x in atom_ener],
**kwargs,
)

Expand Down
36 changes: 35 additions & 1 deletion deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,10 @@ class GeneralFitting(Fitting):
Random seed.
exclude_types: List[int]
Atomic contributions of the excluded atom types are set zero.
remove_vaccum_contribution: List[bool], optional
Remove vaccum contribution before the bias is added. The list assigned each
type. For `mixed_types` provide `[True]`, otherwise it should be a list of the same
length as `ntypes` signaling if or not removing the vaccum contribution for the atom types in the list.
"""

def __init__(
Expand All @@ -258,6 +261,7 @@ def __init__(
rcond: Optional[float] = None,
seed: Optional[int] = None,
exclude_types: List[int] = [],
remove_vaccum_contribution: Optional[List[bool]] = None,
**kwargs,
):
super().__init__()
Expand All @@ -275,6 +279,7 @@ def __init__(
self.rcond = rcond
# order matters, should be place after the assignment of ntypes
self.reinit_exclude(exclude_types)
self.remove_vaccum_contribution = remove_vaccum_contribution

net_dim_out = self._net_out_dim()
# init constants
Expand Down Expand Up @@ -479,6 +484,14 @@ def _forward_common(
aparam: Optional[torch.Tensor] = None,
):
xx = descriptor
if self.remove_vaccum_contribution is not None:
# TODO: Idealy, the input for vaccum should be computed;
# we consider it as always zero for convenience.
# Needs a compute_input_stats for vaccum passed from the
# descriptor.
xx_zeros = torch.zeros_like(xx)
else:
xx_zeros = None
nf, nloc, nd = xx.shape
net_dim_out = self._net_out_dim()

Expand Down Expand Up @@ -507,6 +520,11 @@ def _forward_common(
[xx, fparam],
dim=-1,
)
if xx_zeros is not None:
xx_zeros = torch.cat(
[xx_zeros, fparam],
dim=-1,
)
# check aparam dim, concate to input descriptor
if self.numb_aparam > 0:
assert aparam is not None, "aparam should not be None"
Expand All @@ -526,6 +544,11 @@ def _forward_common(
[xx, aparam],
dim=-1,
)
if xx_zeros is not None:
xx_zeros = torch.cat(
[xx_zeros, aparam],
dim=-1,
)

outs = torch.zeros(
(nf, nloc, net_dim_out),
Expand All @@ -534,6 +557,7 @@ def _forward_common(
) # jit assertion
if self.old_impl:
assert self.filter_layers_old is not None
assert xx_zeros is None
if self.mixed_types:
atom_property = self.filter_layers_old[0](xx) + self.bias_atom_e[atype]
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
Expand All @@ -549,6 +573,8 @@ def _forward_common(
atom_property = (
self.filter_layers.networks[0](xx) + self.bias_atom_e[atype]
)
if xx_zeros is not None:
atom_property -= self.filter_layers.networks[0](xx_zeros)
outs = (
outs + atom_property
) # Shape is [nframes, natoms[0], net_dim_out]
Expand All @@ -557,6 +583,14 @@ def _forward_common(
mask = (atype == type_i).unsqueeze(-1)
mask = torch.tile(mask, (1, 1, net_dim_out))
atom_property = ll(xx)
if xx_zeros is not None:
# must assert, otherwise jit is not happy
assert self.remove_vaccum_contribution is not None
if not (
len(self.remove_vaccum_contribution) > type_i
and not self.remove_vaccum_contribution[type_i]
):
atom_property -= ll(xx_zeros)
atom_property = atom_property + self.bias_atom_e[type_i]
atom_property = atom_property * mask
outs = (
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def serialize(self, suffix: str = "") -> dict:
"rcond": self.rcond,
"tot_ener_zero": self.tot_ener_zero,
"trainable": self.trainable,
"atom_ener": self.atom_ener,
"atom_ener": self.atom_ener_v,
"activation_function": self.activation_function_name,
"precision": self.fitting_precision.name,
"layer_name": self.layer_name,
Expand Down
11 changes: 11 additions & 0 deletions source/tests/consistent/fitting/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
("float64", "float32"), # precision
(True, False), # mixed_types
(0, 1), # numb_fparam
([], [-12345.6, None]), # atom_ener
)
class TestEner(CommonTest, FittingTest, unittest.TestCase):
@property
Expand All @@ -52,13 +53,15 @@ def data(self) -> dict:
precision,
mixed_types,
numb_fparam,
atom_ener,
) = self.param
return {
"neuron": [5, 5, 5],
"resnet_dt": resnet_dt,
"precision": precision,
"numb_fparam": numb_fparam,
"seed": 20240217,
"atom_ener": atom_ener,
}

@property
Expand All @@ -68,6 +71,7 @@ def skip_tf(self) -> bool:
precision,
mixed_types,
numb_fparam,
atom_ener,
) = self.param
# TODO: mixed_types
return mixed_types or CommonTest.skip_pt
Expand All @@ -79,6 +83,7 @@ def skip_pt(self) -> bool:
precision,
mixed_types,
numb_fparam,
atom_ener,
) = self.param
return CommonTest.skip_pt

Expand All @@ -105,6 +110,7 @@ def addtional_data(self) -> dict:
precision,
mixed_types,
numb_fparam,
atom_ener,
) = self.param
return {
"ntypes": self.ntypes,
Expand All @@ -118,6 +124,7 @@ def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]:
precision,
mixed_types,
numb_fparam,
atom_ener,
) = self.param
return self.build_tf_fitting(
obj,
Expand All @@ -134,6 +141,7 @@ def eval_pt(self, pt_obj: Any) -> Any:
precision,
mixed_types,
numb_fparam,
atom_ener,
) = self.param
return (
pt_obj(
Expand All @@ -154,6 +162,7 @@ def eval_dp(self, dp_obj: Any) -> Any:
precision,
mixed_types,
numb_fparam,
atom_ener,
) = self.param
return dp_obj(
self.inputs,
Expand All @@ -175,6 +184,7 @@ def rtol(self) -> float:
precision,
mixed_types,
numb_fparam,
atom_ener,
) = self.param
if precision == "float64":
return 1e-10
Expand All @@ -191,6 +201,7 @@ def atol(self) -> float:
precision,
mixed_types,
numb_fparam,
atom_ener,
) = self.param
if precision == "float64":
return 1e-10
Expand Down

0 comments on commit f684be8

Please sign in to comment.