Skip to content

Commit

Permalink
support init_frz_model for hybrid descriptor (#1112)
Browse files Browse the repository at this point in the history
* support init_frz_model for hybrid descriptor

Refactors some methods to implement it.
Also fixes some typos.

* rename `graph_def` to `model_file`

Co-authored-by: Denghui Lu <ludenghui.cs@gmail.com>

Co-authored-by: Denghui Lu <ludenghui.cs@gmail.com>
  • Loading branch information
njzjz and denghuilu committed Sep 8, 2021
1 parent 60797e0 commit 97be2f5
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 75 deletions.
57 changes: 48 additions & 9 deletions deepmd/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,26 +280,65 @@ def get_feed_dict(self,
feed_dict : dict[str, tf.Tensor]
The output feed_dict of current descriptor
"""
# TODO: currently only SeA has this method, but I think the method can be
# moved here as it doesn't contain anything related to a specific descriptor
raise NotImplementedError
feed_dict = {
't_coord:0' :coord_,
't_type:0' :atype_,
't_natoms:0' :natoms,
't_box:0' :box,
't_mesh:0' :mesh
}
return feed_dict

def init_variables(self,
embedding_net_variables: dict
) -> None:
model_file: str,
suffix : str = "",
) -> None:
"""
Init the embedding net variables with the given dict
Parameters
----------
embedding_net_variables
The input dict which stores the embedding net variables
model_file : str
The input model file
suffix : str, optional
The suffix of the scope
Notes
-----
This method is called by others when the descriptor supported initialization from the given variables.
"""
# TODO: currently only SeA has this method, but I think the method can be
# moved here as it doesn't contain anything related to a specific descriptor
raise NotImplementedError(
"Descriptor %s doesn't support initialization from the given variables!" % type(self).__name__)

def get_tensor_names(self, suffix : str = "") -> Tuple[str]:
"""Get names of tensors.
Parameters
----------
suffix : str
The suffix of the scope
Returns
-------
Tuple[str]
Names of tensors
"""
raise NotImplementedError("Descriptor %s doesn't support this property!" % type(self).__name__)

def pass_tensors_from_frz_model(self,
*tensors : tf.Tensor,
) -> None:
"""
Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def
Parameters
----------
*tensors : tf.Tensor
passed tensors
Notes
-----
The number of parameters in the method must be equal to the numbers of returns in
:meth:`get_tensor_names`.
"""
raise NotImplementedError("Descriptor %s doesn't support this method!" % type(self).__name__)
52 changes: 52 additions & 0 deletions deepmd/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,55 @@ def enable_compression(self,
"""
for idx, ii in enumerate(self.descrpt_list):
ii.enable_compression(min_nbor_dist, model_file, table_extrapolate, table_stride_1, table_stride_2, check_frequency, suffix=f"{suffix}_{idx}")

def init_variables(self,
model_file : str,
suffix : str = "",
) -> None:
"""
Init the embedding net variables with the given dict
Parameters
----------
model_file : str
The input frozen model file
suffix : str, optional
The suffix of the scope
"""
for idx, ii in enumerate(self.descrpt_list):
ii.init_variables(model_file, suffix=f"{suffix}_{idx}")

def get_tensor_names(self, suffix : str = "") -> Tuple[str]:
"""Get names of tensors.
Parameters
----------
suffix : str
The suffix of the scope
Returns
-------
Tuple[str]
Names of tensors
"""
tensor_names = []
for idx, ii in enumerate(self.descrpt_list):
tensor_names.extend(ii.get_tensor_names(suffix=f"{suffix}_{idx}"))
return tuple(tensor_names)

def pass_tensors_from_frz_model(self,
*tensors : tf.Tensor,
) -> None:
"""
Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def
Parameters
----------
*tensors : tf.Tensor
passed tensors
"""
jj = 0
for ii in self.descrpt_list:
n_tensors = len(ii.get_tensor_names())
ii.pass_tensors_from_frz_model(*tensors[jj:jj+n_tensors])
jj += n_tensors
78 changes: 27 additions & 51 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from deepmd.utils.tabulate import DPTabulate
from deepmd.utils.type_embed import embed_atom_type
from deepmd.utils.sess import run_sess
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph, get_embedding_net_variables
from .descriptor import Descriptor

class DescrptSeA (Descriptor):
Expand Down Expand Up @@ -433,10 +433,10 @@ def build (self,
tf.summary.histogram('nlist', self.nlist)

self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt])
self.descrpt_reshape = tf.identity(self.descrpt_reshape, name = 'o_rmat')
self.descrpt_deriv = tf.identity(self.descrpt_deriv, name = 'o_rmat_deriv')
self.rij = tf.identity(self.rij, name = 'o_rij')
self.nlist = tf.identity(self.nlist, name = 'o_nlist')
self.descrpt_reshape = tf.identity(self.descrpt_reshape, name = 'o_rmat' + suffix)
self.descrpt_deriv = tf.identity(self.descrpt_deriv, name = 'o_rmat_deriv' + suffix)
self.rij = tf.identity(self.rij, name = 'o_rij' + suffix)
self.nlist = tf.identity(self.nlist, name = 'o_nlist' + suffix)

self.dout, self.qmat = self._pass_filter(self.descrpt_reshape,
atype,
Expand All @@ -456,6 +456,21 @@ def get_rot_mat(self) -> tf.Tensor:
"""
return self.qmat

def get_tensor_names(self, suffix : str = "") -> Tuple[str]:
"""Get names of tensors.
Parameters
----------
suffix : str
The suffix of the scope
Returns
-------
Tuple[str]
Names of tensors
"""
return (f'o_rmat{suffix}:0', f'o_rmat_deriv{suffix}:0', f'o_rij{suffix}:0', f'o_nlist{suffix}:0')

def pass_tensors_from_frz_model(self,
descrpt_reshape : tf.Tensor,
descrpt_deriv : tf.Tensor,
Expand All @@ -481,60 +496,21 @@ def pass_tensors_from_frz_model(self,
self.descrpt_deriv = descrpt_deriv
self.descrpt_reshape = descrpt_reshape

def get_feed_dict(self,
coord_,
atype_,
natoms,
box,
mesh):
"""
generate the deed_dict for current descriptor
Parameters
----------
coord_
The coordinate of atoms
atype_
The type of atoms
natoms
The number of atoms. This tensor has the length of Ntypes + 2
natoms[0]: number of local atoms
natoms[1]: total number of atoms held by this processor
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
box
The box. Can be generated by deepmd.model.make_stat_input
mesh
For historical reasons, only the length of the Tensor matters.
if size of mesh == 6, pbc is assumed.
if size of mesh == 0, no-pbc is assumed.
Returns
-------
feed_dict
The output feed_dict of current descriptor
"""
feed_dict = {
't_coord:0' :coord_,
't_type:0' :atype_,
't_natoms:0' :natoms,
't_box:0' :box,
't_mesh:0' :mesh
}
return feed_dict


def init_variables(self,
embedding_net_variables: dict
model_file : str,
suffix : str = "",
) -> None:
"""
Init the embedding net variables with the given dict
Parameters
----------
embedding_net_variables
The input dict which stores the embedding net variables
model_file : str
The input frozen model file
suffix : str, optional
The suffix of the scope
"""
self.embedding_net_variables = embedding_net_variables
self.embedding_net_variables = get_embedding_net_variables(model_file, suffix = suffix)


def prod_force_virial(self,
Expand Down
7 changes: 4 additions & 3 deletions deepmd/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,11 @@ def build (self,
name = 'descrpt_attr/ntypes',
dtype = tf.int32)
feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box, mesh)
return_elements = ['o_rmat:0', 'o_rmat_deriv:0', 'o_rij:0', 'o_nlist:0', 'o_descriptor:0']
descrpt_reshape, descrpt_deriv, rij, nlist, dout \
return_elements = [*self.descrpt.get_tensor_names(), 'o_descriptor:0']
imported_tensors \
= self._import_graph_def_from_frz_model(frz_model, feed_dict, return_elements)
self.descrpt.pass_tensors_from_frz_model(descrpt_reshape, descrpt_deriv, rij, nlist)
dout = imported_tensors[-1]
self.descrpt.pass_tensors_from_frz_model(*imported_tensors[:-1])


if self.srtab is not None :
Expand Down
9 changes: 5 additions & 4 deletions deepmd/model/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from deepmd.env import tf
from deepmd.common import ClassArg
from deepmd.env import global_cvt_2_ener_float, MODEL_VERSION
from deepmd.env import global_cvt_2_ener_float, MODEL_VERSION, GLOBAL_TF_FLOAT_PRECISION
from deepmd.env import op_module
from deepmd.utils.graph import load_graph_def
from .model_stat import make_stat_input, merge_sys_stat
Expand Down Expand Up @@ -138,10 +138,11 @@ def build (self,
name = 'descrpt_attr/ntypes',
dtype = tf.int32)
feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box, mesh)
return_elements = ['o_rmat:0', 'o_rmat_deriv:0', 'o_rij:0', 'o_nlist:0', 'o_descriptor:0']
descrpt_reshape, descrpt_deriv, rij, nlist, dout \
return_elements = [*self.descrpt.get_tensor_names(), 'o_descriptor:0']
imported_tensors \
= self._import_graph_def_from_frz_model(frz_model, feed_dict, return_elements)
self.descrpt.pass_tensors_from_frz_model(descrpt_reshape, descrpt_deriv, rij, nlist)
dout = imported_tensors[-1]
self.descrpt.pass_tensors_from_frz_model(*imported_tensors[:-1])

rot_mat = self.descrpt.get_rot_mat()
rot_mat = tf.identity(rot_mat, name = 'o_rot_mat'+suffix)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def _init_from_frz_model(self):
# self.frz_model will control the self.model to import the descriptor from the given frozen model instead of building from scratch...
# initialize fitting net with the given compressed frozen model
if self.model_type == 'original_model':
self.descrpt.init_variables(get_embedding_net_variables(self.run_opt.init_frz_model))
self.descrpt.init_variables(self.run_opt.init_frz_model)
self.fitting.init_variables(get_fitting_net_variables(self.run_opt.init_frz_model))
tf.constant("original_model", name = 'model_type', dtype = tf.string)
elif self.model_type == 'compressed_model':
Expand Down
20 changes: 13 additions & 7 deletions deepmd/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_tensor_by_type(node,
elif data_type == np.float32:
tensor = np.array(node.float_val)
else:
raise RunTimeError('model compression does not support the half precision')
raise RuntimeError('model compression does not support the half precision')
return tensor


Expand Down Expand Up @@ -139,40 +139,44 @@ def get_embedding_net_nodes_from_graph_def(graph_def: tf.GraphDef, suffix: str =
return embedding_net_nodes


def get_embedding_net_nodes(model_file: str) -> Dict:
def get_embedding_net_nodes(model_file: str, suffix: str = "") -> Dict:
"""
Get the embedding net nodes with the given frozen model(model_file)
Parameters
----------
model_file
The input frozen model path
suffix : str, optional
The suffix of the scope
Returns
----------
Dict
The embedding net nodes with the given frozen model
"""
_, graph_def = load_graph_def(model_file)
return get_embedding_net_nodes_from_graph_def(graph_def)
return get_embedding_net_nodes_from_graph_def(graph_def, suffix=suffix)


def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict:
def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef, suffix: str = "") -> Dict:
"""
Get the embedding net variables with the given tf.GraphDef object
Parameters
----------
graph_def
The input tf.GraphDef object
suffix : str, optional
The suffix of the scope
Returns
----------
Dict
The embedding net variables within the given tf.GraphDef object
"""
embedding_net_variables = {}
embedding_net_nodes = get_embedding_net_nodes_from_graph_def(graph_def)
embedding_net_nodes = get_embedding_net_nodes_from_graph_def(graph_def, suffix=suffix)
for item in embedding_net_nodes:
node = embedding_net_nodes[item]
dtype = tf.as_dtype(node.dtype).as_numpy_dtype
Expand All @@ -184,22 +188,24 @@ def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict:
embedding_net_variables[item] = np.reshape(tensor_value, tensor_shape)
return embedding_net_variables

def get_embedding_net_variables(model_file : str) -> Dict:
def get_embedding_net_variables(model_file : str, suffix: str = "") -> Dict:
"""
Get the embedding net variables with the given frozen model(model_file)
Parameters
----------
model_file
The input frozen model path
suffix : str, optional
The suffix of the scope
Returns
----------
Dict
The embedding net variables within the given frozen model
"""
_, graph_def = load_graph_def(model_file)
return get_embedding_net_variables_from_graph_def(graph_def)
return get_embedding_net_variables_from_graph_def(graph_def, suffix=suffix)


def get_fitting_net_nodes_from_graph_def(graph_def: tf.GraphDef) -> Dict:
Expand Down

0 comments on commit 97be2f5

Please sign in to comment.