Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid descriptor #3365

Merged
merged 15 commits into from
Mar 1, 2024
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .hybrid import (
DescrptHybrid,
)
from .make_base_descriptor import (
make_base_descriptor,
)
Expand All @@ -12,5 +15,6 @@
__all__ = [
"DescrptSeA",
"DescrptSeR",
"DescrptHybrid",
"make_base_descriptor",
]
242 changes: 242 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Dict,
List,
Optional,
Union,
)

import numpy as np

from deepmd.dpmodel.common import (
NativeOP,
)
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.utils.nlist import (
nlist_distinguish_types,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)


@BaseDescriptor.register("hybrid")
class DescrptHybrid(BaseDescriptor, NativeOP):
"""Concate a list of descriptors to form a new descriptor.

Parameters
----------
list : list : List[Union[BaseDescriptor, Dict[str, Any]]]
Build a descriptor from the concatenation of the list of descriptors.
The descriptor can be either an object or a dictionary.
"""

def __init__(
self,
list: List[Union[BaseDescriptor, Dict[str, Any]]],
) -> None:
super().__init__()
# warning: list is conflict with built-in list
descrpt_list = list
if descrpt_list == [] or descrpt_list is None:
raise RuntimeError(

Check warning on line 48 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L48

Added line #L48 was not covered by tests
"cannot build descriptor from an empty list of descriptors."
)
formatted_descript_list = []
for ii in descrpt_list:
if isinstance(ii, BaseDescriptor):
formatted_descript_list.append(ii)
elif isinstance(ii, dict):
formatted_descript_list.append(BaseDescriptor(**ii))
else:
raise NotImplementedError

Check warning on line 58 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L58

Added line #L58 was not covered by tests
self.descrpt_list = formatted_descript_list
self.numb_descrpt = len(self.descrpt_list)
for ii in range(1, self.numb_descrpt):
assert (
self.descrpt_list[ii].get_ntypes() == self.descrpt_list[0].get_ntypes()
), f"number of atom types in {ii}th descrptor {self.descrpt_list[0].__class__.__name__} does not match others"
# if hybrid sel is larger than sub sel, the nlist needs to be cut for each type
hybrid_sel = self.get_sel()

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'hybrid_sel' is unnecessary as it is
redefined
before this value is used.
This assignment to 'hybrid_sel' is unnecessary as it is
redefined
before this value is used.
self.nlist_cut_idx: List[np.ndarray] = []
if self.mixed_types() and not all(
descrpt.mixed_types() for descrpt in self.descrpt_list
):
self.sel_no_mixed_types = np.max(

Check warning on line 71 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L71

Added line #L71 was not covered by tests
[
descrpt.get_sel()
for descrpt in self.descrpt_list
if not descrpt.mixed_types()
],
axis=0,
).tolist()
else:
self.sel_no_mixed_types = None
for ii in range(self.numb_descrpt):
if self.mixed_types() == self.descrpt_list[ii].mixed_types():
hybrid_sel = self.get_sel()
else:
assert self.sel_no_mixed_types is not None
hybrid_sel = self.sel_no_mixed_types

Check warning on line 86 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L85-L86

Added lines #L85 - L86 were not covered by tests
sub_sel = self.descrpt_list[ii].get_sel()
start_idx = np.cumsum(np.pad(hybrid_sel, (1, 0), "constant"))[:-1]
end_idx = start_idx + np.array(sub_sel)
cut_idx = np.concatenate(
[range(ss, ee) for ss, ee in zip(start_idx, end_idx)]
)
self.nlist_cut_idx.append(cut_idx)

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return np.max([descrpt.get_rcut() for descrpt in self.descrpt_list]).item()

def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
if self.mixed_types():
return [

Check warning on line 102 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L102

Added line #L102 was not covered by tests
np.max(
[descrpt.get_nsel() for descrpt in self.descrpt_list], axis=0
).item()
]
else:
return np.max(
[descrpt.get_sel() for descrpt in self.descrpt_list], axis=0
).tolist()

def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.descrpt_list[0].get_ntypes()

Check warning on line 114 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L114

Added line #L114 was not covered by tests

def get_dim_out(self) -> int:
"""Returns the output dimension."""
return np.sum([descrpt.get_dim_out() for descrpt in self.descrpt_list]).item()

Check warning on line 118 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L118

Added line #L118 was not covered by tests

def get_dim_emb(self) -> int:
"""Returns the output dimension."""
return np.sum([descrpt.get_dim_emb() for descrpt in self.descrpt_list]).item()

Check warning on line 122 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L122

Added line #L122 was not covered by tests

def mixed_types(self):
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return any(descrpt.mixed_types() for descrpt in self.descrpt_list)

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
"""Update mean and stddev for descriptor elements."""
for descrpt in self.descrpt_list:
descrpt.compute_input_stats(merged, path)

Check warning on line 133 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L132-L133

Added lines #L132 - L133 were not covered by tests

def call(
self,
coord_ext,
atype_ext,
nlist,
mapping: Optional[np.ndarray] = None,
):
"""Compute the descriptor.

Parameters
----------
coord_ext
The extended coordinates of atoms. shape: nf x (nallx3)
atype_ext
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.

Returns
-------
descriptor
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3.
g2
The rotationally invariant pair-partical representation.
h2
The rotationally equivariant pair-partical representation.
sw
The smooth switch function.
"""
out_descriptor = []
out_gr = []
out_g2 = []
out_h2 = None
out_sw = None
if self.sel_no_mixed_types is not None:
nl_distinguish_types = nlist_distinguish_types(

Check warning on line 175 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L175

Added line #L175 was not covered by tests
nlist,
atype_ext,
self.sel_no_mixed_types,
)
else:
nl_distinguish_types = None
for descrpt, nci in zip(self.descrpt_list, self.nlist_cut_idx):
# cut the nlist to the correct length
if self.mixed_types() == descrpt.mixed_types():
nl = nlist[:, :, nci]
else:
# mixed_types is True, but descrpt.mixed_types is False
assert nl_distinguish_types is not None
nl = nl_distinguish_types[:, :, nci]

Check warning on line 189 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L188-L189

Added lines #L188 - L189 were not covered by tests
odescriptor, gr, g2, h2, sw = descrpt(coord_ext, atype_ext, nl, mapping)
out_descriptor.append(odescriptor)
if gr is not None:
out_gr.append(gr)
if g2 is not None:
out_g2.append(g2)

Check warning on line 195 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L195

Added line #L195 was not covered by tests
if self.get_rcut() == descrpt.get_rcut():
out_h2 = h2
out_sw = sw

out_descriptor = np.concatenate(out_descriptor, axis=-1)
out_gr = np.concatenate(out_gr, axis=-2) if out_gr else None
out_g2 = np.concatenate(out_g2, axis=-1) if out_g2 else None
return out_descriptor, out_gr, out_g2, out_h2, out_sw

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict:
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["list"] = [

Check warning on line 217 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L216-L217

Added lines #L216 - L217 were not covered by tests
BaseDescriptor.update_sel(global_jdata, sub_jdata)
for sub_jdata in local_jdata["list"]
]
return local_jdata_cpy

Check warning on line 221 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L221

Added line #L221 was not covered by tests

def serialize(self) -> dict:
return {
"@class": "Descriptor",
"type": "hybrid",
"@version": 1,
"list": [descrpt.serialize() for descrpt in self.descrpt_list],
}

@classmethod
def deserialize(cls, data: dict) -> "DescrptHybrid":
data = data.copy()
class_name = data.pop("@class")
assert class_name == "Descriptor"
class_type = data.pop("type")
assert class_type == "hybrid"
check_version_compatibility(data.pop("@version"), 1, 1)
obj = cls(
list=[BaseDescriptor.deserialize(ii) for ii in data["list"]],
)
return obj
2 changes: 2 additions & 0 deletions deepmd/pt/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from .hybrid import (
DescrptBlockHybrid,
DescrptHybrid,
)
from .repformers import (
DescrptBlockRepformers,
Expand All @@ -39,6 +40,7 @@
"DescrptSeR",
"DescrptDPA1",
"DescrptDPA2",
"DescrptHybrid",
"prod_env_mat",
"DescrptGaussianLcc",
"DescrptBlockHybrid",
Expand Down