-
Notifications
You must be signed in to change notification settings - Fork 499
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add ABC for descriptors I'm going to add abstract base classes for different object, where a list of methods and attributes is defined to normalize classes and their external call by other classes. It's also useful to develop and extend new classes. The first one I did is the descriptor. * TYPE_CHECKING doesn't work in python 3.6 * fix warnings
- Loading branch information
Showing
8 changed files
with
296 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,282 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Dict, List, Tuple | ||
|
||
import numpy as np | ||
from deepmd.env import tf | ||
|
||
|
||
class Descriptor(ABC): | ||
r"""The abstract class for descriptors. All specific descriptors should | ||
be based on this class. | ||
The descriptor :math:`\mathcal{D}` describes the environment of an atom, | ||
which should be a function of coordinates and types of its neighbour atoms. | ||
Notes | ||
----- | ||
Only methods and attributes defined in this class are generally public, | ||
that can be called by other classes. | ||
""" | ||
|
||
@abstractmethod | ||
def get_rcut(self) -> float: | ||
""" | ||
Returns the cut-off radius. | ||
Returns | ||
------- | ||
float | ||
the cut-off radius | ||
Notes | ||
----- | ||
This method must be implemented, as it's called by other classes. | ||
""" | ||
|
||
@abstractmethod | ||
def get_ntypes(self) -> int: | ||
""" | ||
Returns the number of atom types. | ||
Returns | ||
------- | ||
int | ||
the number of atom types | ||
Notes | ||
----- | ||
This method must be implemented, as it's called by other classes. | ||
""" | ||
|
||
@abstractmethod | ||
def get_dim_out(self) -> int: | ||
""" | ||
Returns the output dimension of this descriptor. | ||
Returns | ||
------- | ||
int | ||
the output dimension of this descriptor | ||
Notes | ||
----- | ||
This method must be implemented, as it's called by other classes. | ||
""" | ||
|
||
def get_dim_rot_mat_1(self) -> int: | ||
""" | ||
Returns the first dimension of the rotation matrix. The rotation is of shape | ||
dim_1 x 3 | ||
Returns | ||
------- | ||
int | ||
the first dimension of the rotation matrix | ||
""" | ||
# TODO: I think this method should be implemented as it's called by dipole and | ||
# polar fitting network. However, currently not all descriptors have this | ||
# method. | ||
raise NotImplementedError | ||
|
||
def get_nlist(self) -> Tuple[tf.Tensor, tf.Tensor, List[int], List[int]]: | ||
""" | ||
Returns neighbor information. | ||
Returns | ||
------- | ||
nlist : tf.Tensor | ||
Neighbor list | ||
rij : tf.Tensor | ||
The relative distance between the neighbor and the center atom. | ||
sel_a : list[int] | ||
The number of neighbors with full information | ||
sel_r : list[int] | ||
The number of neighbors with only radial information | ||
""" | ||
# TODO: I think this method should be implemented as it's called by energy | ||
# model. However, se_ar and hybrid doesn't have this method. | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def compute_input_stats(self, | ||
data_coord: List[np.ndarray], | ||
data_box: List[np.ndarray], | ||
data_atype: List[np.ndarray], | ||
natoms_vec: List[np.ndarray], | ||
mesh: List[np.ndarray], | ||
input_dict: Dict[str, List[np.ndarray]] | ||
) -> None: | ||
""" | ||
Compute the statisitcs (avg and std) of the training data. The input will be | ||
normalized by the statistics. | ||
Parameters | ||
---------- | ||
data_coord : list[np.ndarray] | ||
The coordinates. Can be generated by | ||
:meth:`deepmd.model.model_stat.make_stat_input` | ||
data_box : list[np.ndarray] | ||
The box. Can be generated by | ||
:meth:`deepmd.model.model_stat.make_stat_input` | ||
data_atype : list[np.ndarray] | ||
The atom types. Can be generated by :meth:`deepmd.model.model_stat.make_stat_input` | ||
natoms_vec : list[np.ndarray] | ||
The vector for the number of atoms of the system and different types of | ||
atoms. Can be generated by :meth:`deepmd.model.model_stat.make_stat_input` | ||
mesh : list[np.ndarray] | ||
The mesh for neighbor searching. Can be generated by | ||
:meth:`deepmd.model.model_stat.make_stat_input` | ||
input_dict : dict[str, list[np.ndarray]] | ||
Dictionary for additional input | ||
Notes | ||
----- | ||
This method must be implemented, as it's called by other classes. | ||
""" | ||
|
||
@abstractmethod | ||
def build(self, | ||
coord_: tf.Tensor, | ||
atype_: tf.Tensor, | ||
natoms: tf.Tensor, | ||
box_: tf.Tensor, | ||
mesh: tf.Tensor, | ||
input_dict: Dict[str, Any], | ||
reuse: bool = None, | ||
suffix: str = '', | ||
) -> tf.Tensor: | ||
""" | ||
Build the computational graph for the descriptor. | ||
Parameters | ||
---------- | ||
coord_ : tf.Tensor | ||
The coordinate of atoms | ||
atype_ : tf.Tensor | ||
The type of atoms | ||
natoms : tf.Tensor | ||
The number of atoms. This tensor has the length of Ntypes + 2 | ||
natoms[0]: number of local atoms | ||
natoms[1]: total number of atoms held by this processor | ||
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms | ||
box : tf.Tensor | ||
The box of frames | ||
mesh : tf.Tensor | ||
For historical reasons, only the length of the Tensor matters. | ||
if size of mesh == 6, pbc is assumed. | ||
if size of mesh == 0, no-pbc is assumed. | ||
input_dict : dict[str, Any] | ||
Dictionary for additional inputs | ||
reuse : bool, optional | ||
The weights in the networks should be reused when get the variable. | ||
suffix : str, optional | ||
Name suffix to identify this descriptor | ||
Returns | ||
------- | ||
descriptor: tf.Tensor | ||
The output descriptor | ||
Notes | ||
----- | ||
This method must be implemented, as it's called by other classes. | ||
""" | ||
|
||
def enable_compression(self, | ||
min_nbor_dist: float, | ||
model_file: str = 'frozon_model.pb', | ||
table_extrapolate: float = 5., | ||
table_stride_1: float = 0.01, | ||
table_stride_2: float = 0.1, | ||
check_frequency: int = -1 | ||
) -> None: | ||
""" | ||
Reveive the statisitcs (distance, max_nbor_size and env_mat_range) of the | ||
training data. | ||
Parameters | ||
---------- | ||
min_nbor_dist : float | ||
The nearest distance between atoms | ||
model_file : str, default: 'frozon_model.pb' | ||
The original frozen model, which will be compressed by the program | ||
table_extrapolate : float, default: 5. | ||
The scale of model extrapolation | ||
table_stride_1 : float, default: 0.01 | ||
The uniform stride of the first table | ||
table_stride_2 : float, default: 0.1 | ||
The uniform stride of the second table | ||
check_frequency : int, default: -1 | ||
The overflow check frequency | ||
Notes | ||
----- | ||
This method is called by others when the descriptor supported compression. | ||
""" | ||
raise NotImplementedError( | ||
"Descriptor %s doesn't support compression!" % self.__name__) | ||
|
||
@abstractmethod | ||
def prod_force_virial(self, | ||
atom_ener: tf.Tensor, | ||
natoms: tf.Tensor | ||
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: | ||
""" | ||
Compute force and virial. | ||
Parameters | ||
---------- | ||
atom_ener : tf.Tensor | ||
The atomic energy | ||
natoms : tf.Tensor | ||
The number of atoms. This tensor has the length of Ntypes + 2 | ||
natoms[0]: number of local atoms | ||
natoms[1]: total number of atoms held by this processor | ||
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms | ||
Returns | ||
------- | ||
force : tf.Tensor | ||
The force on atoms | ||
virial : tf.Tensor | ||
The total virial | ||
atom_virial : tf.Tensor | ||
The atomic virial | ||
""" | ||
|
||
def get_feed_dict(self, | ||
coord_: tf.Tensor, | ||
atype_: tf.Tensor, | ||
natoms: tf.Tensor, | ||
box: tf.Tensor, | ||
mesh: tf.Tensor | ||
) -> Dict[str, tf.Tensor]: | ||
""" | ||
Generate the feed_dict for current descriptor | ||
Parameters | ||
---------- | ||
coord_ : tf.Tensor | ||
The coordinate of atoms | ||
atype_ : tf.Tensor | ||
The type of atoms | ||
natoms : tf.Tensor | ||
The number of atoms. This tensor has the length of Ntypes + 2 | ||
natoms[0]: number of local atoms | ||
natoms[1]: total number of atoms held by this processor | ||
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms | ||
box : tf.Tensor | ||
The box. Can be generated by deepmd.model.make_stat_input | ||
mesh : tf.Tensor | ||
For historical reasons, only the length of the Tensor matters. | ||
if size of mesh == 6, pbc is assumed. | ||
if size of mesh == 0, no-pbc is assumed. | ||
Returns | ||
------- | ||
feed_dict : dict[str, tf.Tensor] | ||
The output feed_dict of current descriptor | ||
""" | ||
# TODO: currently only SeA has this method, but I think the method can be | ||
# moved here as it doesn't contain anything related to a specific descriptor | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters