Skip to content

Commit

Permalink
add ABC for descriptors (#1081)
Browse files Browse the repository at this point in the history
* add ABC for descriptors

I'm going to add abstract base classes for different object, where a list of methods and attributes is defined to normalize classes and their external call by other classes. It's also useful to develop and extend new classes.

The first one I did is the descriptor.

* TYPE_CHECKING doesn't work in python 3.6

* fix warnings
  • Loading branch information
njzjz committed Sep 2, 2021
1 parent da5f688 commit c824ff6
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 7 deletions.
282 changes: 282 additions & 0 deletions deepmd/descriptor/descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple

import numpy as np
from deepmd.env import tf


class Descriptor(ABC):
r"""The abstract class for descriptors. All specific descriptors should
be based on this class.
The descriptor :math:`\mathcal{D}` describes the environment of an atom,
which should be a function of coordinates and types of its neighbour atoms.
Notes
-----
Only methods and attributes defined in this class are generally public,
that can be called by other classes.
"""

@abstractmethod
def get_rcut(self) -> float:
"""
Returns the cut-off radius.
Returns
-------
float
the cut-off radius
Notes
-----
This method must be implemented, as it's called by other classes.
"""

@abstractmethod
def get_ntypes(self) -> int:
"""
Returns the number of atom types.
Returns
-------
int
the number of atom types
Notes
-----
This method must be implemented, as it's called by other classes.
"""

@abstractmethod
def get_dim_out(self) -> int:
"""
Returns the output dimension of this descriptor.
Returns
-------
int
the output dimension of this descriptor
Notes
-----
This method must be implemented, as it's called by other classes.
"""

def get_dim_rot_mat_1(self) -> int:
"""
Returns the first dimension of the rotation matrix. The rotation is of shape
dim_1 x 3
Returns
-------
int
the first dimension of the rotation matrix
"""
# TODO: I think this method should be implemented as it's called by dipole and
# polar fitting network. However, currently not all descriptors have this
# method.
raise NotImplementedError

def get_nlist(self) -> Tuple[tf.Tensor, tf.Tensor, List[int], List[int]]:
"""
Returns neighbor information.
Returns
-------
nlist : tf.Tensor
Neighbor list
rij : tf.Tensor
The relative distance between the neighbor and the center atom.
sel_a : list[int]
The number of neighbors with full information
sel_r : list[int]
The number of neighbors with only radial information
"""
# TODO: I think this method should be implemented as it's called by energy
# model. However, se_ar and hybrid doesn't have this method.
raise NotImplementedError

@abstractmethod
def compute_input_stats(self,
data_coord: List[np.ndarray],
data_box: List[np.ndarray],
data_atype: List[np.ndarray],
natoms_vec: List[np.ndarray],
mesh: List[np.ndarray],
input_dict: Dict[str, List[np.ndarray]]
) -> None:
"""
Compute the statisitcs (avg and std) of the training data. The input will be
normalized by the statistics.
Parameters
----------
data_coord : list[np.ndarray]
The coordinates. Can be generated by
:meth:`deepmd.model.model_stat.make_stat_input`
data_box : list[np.ndarray]
The box. Can be generated by
:meth:`deepmd.model.model_stat.make_stat_input`
data_atype : list[np.ndarray]
The atom types. Can be generated by :meth:`deepmd.model.model_stat.make_stat_input`
natoms_vec : list[np.ndarray]
The vector for the number of atoms of the system and different types of
atoms. Can be generated by :meth:`deepmd.model.model_stat.make_stat_input`
mesh : list[np.ndarray]
The mesh for neighbor searching. Can be generated by
:meth:`deepmd.model.model_stat.make_stat_input`
input_dict : dict[str, list[np.ndarray]]
Dictionary for additional input
Notes
-----
This method must be implemented, as it's called by other classes.
"""

@abstractmethod
def build(self,
coord_: tf.Tensor,
atype_: tf.Tensor,
natoms: tf.Tensor,
box_: tf.Tensor,
mesh: tf.Tensor,
input_dict: Dict[str, Any],
reuse: bool = None,
suffix: str = '',
) -> tf.Tensor:
"""
Build the computational graph for the descriptor.
Parameters
----------
coord_ : tf.Tensor
The coordinate of atoms
atype_ : tf.Tensor
The type of atoms
natoms : tf.Tensor
The number of atoms. This tensor has the length of Ntypes + 2
natoms[0]: number of local atoms
natoms[1]: total number of atoms held by this processor
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
box : tf.Tensor
The box of frames
mesh : tf.Tensor
For historical reasons, only the length of the Tensor matters.
if size of mesh == 6, pbc is assumed.
if size of mesh == 0, no-pbc is assumed.
input_dict : dict[str, Any]
Dictionary for additional inputs
reuse : bool, optional
The weights in the networks should be reused when get the variable.
suffix : str, optional
Name suffix to identify this descriptor
Returns
-------
descriptor: tf.Tensor
The output descriptor
Notes
-----
This method must be implemented, as it's called by other classes.
"""

def enable_compression(self,
min_nbor_dist: float,
model_file: str = 'frozon_model.pb',
table_extrapolate: float = 5.,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1
) -> None:
"""
Reveive the statisitcs (distance, max_nbor_size and env_mat_range) of the
training data.
Parameters
----------
min_nbor_dist : float
The nearest distance between atoms
model_file : str, default: 'frozon_model.pb'
The original frozen model, which will be compressed by the program
table_extrapolate : float, default: 5.
The scale of model extrapolation
table_stride_1 : float, default: 0.01
The uniform stride of the first table
table_stride_2 : float, default: 0.1
The uniform stride of the second table
check_frequency : int, default: -1
The overflow check frequency
Notes
-----
This method is called by others when the descriptor supported compression.
"""
raise NotImplementedError(
"Descriptor %s doesn't support compression!" % self.__name__)

@abstractmethod
def prod_force_virial(self,
atom_ener: tf.Tensor,
natoms: tf.Tensor
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""
Compute force and virial.
Parameters
----------
atom_ener : tf.Tensor
The atomic energy
natoms : tf.Tensor
The number of atoms. This tensor has the length of Ntypes + 2
natoms[0]: number of local atoms
natoms[1]: total number of atoms held by this processor
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
Returns
-------
force : tf.Tensor
The force on atoms
virial : tf.Tensor
The total virial
atom_virial : tf.Tensor
The atomic virial
"""

def get_feed_dict(self,
coord_: tf.Tensor,
atype_: tf.Tensor,
natoms: tf.Tensor,
box: tf.Tensor,
mesh: tf.Tensor
) -> Dict[str, tf.Tensor]:
"""
Generate the feed_dict for current descriptor
Parameters
----------
coord_ : tf.Tensor
The coordinate of atoms
atype_ : tf.Tensor
The type of atoms
natoms : tf.Tensor
The number of atoms. This tensor has the length of Ntypes + 2
natoms[0]: number of local atoms
natoms[1]: total number of atoms held by this processor
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
box : tf.Tensor
The box. Can be generated by deepmd.model.make_stat_input
mesh : tf.Tensor
For historical reasons, only the length of the Tensor matters.
if size of mesh == 6, pbc is assumed.
if size of mesh == 0, no-pbc is assumed.
Returns
-------
feed_dict : dict[str, tf.Tensor]
The output feed_dict of current descriptor
"""
# TODO: currently only SeA has this method, but I think the method can be
# moved here as it doesn't contain anything related to a specific descriptor
raise NotImplementedError
3 changes: 2 additions & 1 deletion deepmd/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# from deepmd.descriptor import DescrptSeAEbd
# from deepmd.descriptor import DescrptSeAEf
# from deepmd.descriptor import DescrptSeR
from .descriptor import Descriptor
from .se_a import DescrptSeA
from .se_r import DescrptSeR
from .se_ar import DescrptSeAR
Expand All @@ -20,7 +21,7 @@
from .se_a_ef import DescrptSeAEf
from .loc_frame import DescrptLocFrame

class DescrptHybrid ():
class DescrptHybrid (Descriptor):
"""Concate a list of descriptors to form a new descriptor.
Parameters
Expand Down
3 changes: 2 additions & 1 deletion deepmd/descriptor/loc_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from deepmd.env import op_module
from deepmd.env import default_tf_session_config
from deepmd.utils.sess import run_sess
from .descriptor import Descriptor

class DescrptLocFrame () :
class DescrptLocFrame (Descriptor) :
"""Defines a local frame at each atom, and the compute the descriptor as local
coordinates under this frame.
Expand Down
3 changes: 2 additions & 1 deletion deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from deepmd.utils.type_embed import embed_atom_type
from deepmd.utils.sess import run_sess
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph
from .descriptor import Descriptor

class DescrptSeA ():
class DescrptSeA (Descriptor):
r"""DeepPot-SE constructed from all information (both angular and radial) of
atomic configurations. The embedding takes the distance between atoms as input.
Expand Down
3 changes: 2 additions & 1 deletion deepmd/descriptor/se_a_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from deepmd.env import op_module
from deepmd.env import default_tf_session_config
from .se_a import DescrptSeA
from .descriptor import Descriptor

class DescrptSeAEf ():
class DescrptSeAEf (Descriptor):
"""
Parameters
Expand Down
3 changes: 2 additions & 1 deletion deepmd/descriptor/se_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from .se_a import DescrptSeA
from .se_r import DescrptSeR
from deepmd.env import op_module
from .descriptor import Descriptor

class DescrptSeAR ():
class DescrptSeAR (Descriptor):
def __init__ (self, jdata):
args = ClassArg()\
.add('a', dict, must = True) \
Expand Down
3 changes: 2 additions & 1 deletion deepmd/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from deepmd.env import default_tf_session_config
from deepmd.utils.network import embedding_net, embedding_net_rand_seed_shift
from deepmd.utils.sess import run_sess
from .descriptor import Descriptor

class DescrptSeR ():
class DescrptSeR (Descriptor):
"""DeepPot-SE constructed from radial information of atomic configurations.
The embedding takes the distance between atoms as input.
Expand Down
3 changes: 2 additions & 1 deletion deepmd/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from deepmd.env import default_tf_session_config
from deepmd.utils.network import embedding_net, embedding_net_rand_seed_shift
from deepmd.utils.sess import run_sess
from .descriptor import Descriptor

class DescrptSeT ():
class DescrptSeT (Descriptor):
"""DeepPot-SE constructed from all information (both angular and radial) of atomic
configurations.
Expand Down

0 comments on commit c824ff6

Please sign in to comment.