Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support init_frz_model for hybrid descriptor #1112

Merged
merged 2 commits into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions deepmd/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,26 +280,65 @@ def get_feed_dict(self,
feed_dict : dict[str, tf.Tensor]
The output feed_dict of current descriptor
"""
# TODO: currently only SeA has this method, but I think the method can be
# moved here as it doesn't contain anything related to a specific descriptor
raise NotImplementedError
feed_dict = {
't_coord:0' :coord_,
't_type:0' :atype_,
't_natoms:0' :natoms,
't_box:0' :box,
't_mesh:0' :mesh
}
return feed_dict

def init_variables(self,
embedding_net_variables: dict
) -> None:
model_file: str,
suffix : str = "",
) -> None:
"""
Init the embedding net variables with the given dict
Parameters
----------
embedding_net_variables
The input dict which stores the embedding net variables
model_file : str
The input model file
suffix : str, optional
The suffix of the scope
Notes
-----
This method is called by others when the descriptor supported initialization from the given variables.
"""
# TODO: currently only SeA has this method, but I think the method can be
# moved here as it doesn't contain anything related to a specific descriptor
raise NotImplementedError(
"Descriptor %s doesn't support initialization from the given variables!" % type(self).__name__)

def get_tensor_names(self, suffix : str = "") -> Tuple[str]:
"""Get names of tensors.
Parameters
----------
suffix : str
The suffix of the scope
Returns
-------
Tuple[str]
Names of tensors
"""
raise NotImplementedError("Descriptor %s doesn't support this property!" % type(self).__name__)

def pass_tensors_from_frz_model(self,
*tensors : tf.Tensor,
) -> None:
"""
Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def
Parameters
----------
*tensors : tf.Tensor
passed tensors
Notes
-----
The number of parameters in the method must be equal to the numbers of returns in
:meth:`get_tensor_names`.
"""
raise NotImplementedError("Descriptor %s doesn't support this method!" % type(self).__name__)
52 changes: 52 additions & 0 deletions deepmd/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,55 @@ def enable_compression(self,
"""
for idx, ii in enumerate(self.descrpt_list):
ii.enable_compression(min_nbor_dist, model_file, table_extrapolate, table_stride_1, table_stride_2, check_frequency, suffix=f"{suffix}_{idx}")

def init_variables(self,
model_file : str,
suffix : str = "",
) -> None:
"""
Init the embedding net variables with the given dict
Parameters
----------
model_file : str
The input frozen model file
suffix : str, optional
The suffix of the scope
"""
for idx, ii in enumerate(self.descrpt_list):
ii.init_variables(model_file, suffix=f"{suffix}_{idx}")

def get_tensor_names(self, suffix : str = "") -> Tuple[str]:
"""Get names of tensors.
Parameters
----------
suffix : str
The suffix of the scope
Returns
-------
Tuple[str]
Names of tensors
"""
tensor_names = []
for idx, ii in enumerate(self.descrpt_list):
tensor_names.extend(ii.get_tensor_names(suffix=f"{suffix}_{idx}"))
return tuple(tensor_names)

def pass_tensors_from_frz_model(self,
*tensors : tf.Tensor,
) -> None:
"""
Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def
Parameters
----------
*tensors : tf.Tensor
passed tensors
"""
jj = 0
for ii in self.descrpt_list:
n_tensors = len(ii.get_tensor_names())
ii.pass_tensors_from_frz_model(*tensors[jj:jj+n_tensors])
jj += n_tensors
78 changes: 27 additions & 51 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from deepmd.utils.tabulate import DPTabulate
from deepmd.utils.type_embed import embed_atom_type
from deepmd.utils.sess import run_sess
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph, get_embedding_net_variables
from .descriptor import Descriptor

class DescrptSeA (Descriptor):
Expand Down Expand Up @@ -433,10 +433,10 @@ def build (self,
tf.summary.histogram('nlist', self.nlist)

self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt])
self.descrpt_reshape = tf.identity(self.descrpt_reshape, name = 'o_rmat')
self.descrpt_deriv = tf.identity(self.descrpt_deriv, name = 'o_rmat_deriv')
self.rij = tf.identity(self.rij, name = 'o_rij')
self.nlist = tf.identity(self.nlist, name = 'o_nlist')
self.descrpt_reshape = tf.identity(self.descrpt_reshape, name = 'o_rmat' + suffix)
self.descrpt_deriv = tf.identity(self.descrpt_deriv, name = 'o_rmat_deriv' + suffix)
self.rij = tf.identity(self.rij, name = 'o_rij' + suffix)
self.nlist = tf.identity(self.nlist, name = 'o_nlist' + suffix)

self.dout, self.qmat = self._pass_filter(self.descrpt_reshape,
atype,
Expand All @@ -456,6 +456,21 @@ def get_rot_mat(self) -> tf.Tensor:
"""
return self.qmat

def get_tensor_names(self, suffix : str = "") -> Tuple[str]:
"""Get names of tensors.
Parameters
----------
suffix : str
The suffix of the scope
Returns
-------
Tuple[str]
Names of tensors
"""
return (f'o_rmat{suffix}:0', f'o_rmat_deriv{suffix}:0', f'o_rij{suffix}:0', f'o_nlist{suffix}:0')

def pass_tensors_from_frz_model(self,
descrpt_reshape : tf.Tensor,
descrpt_deriv : tf.Tensor,
Expand All @@ -481,60 +496,21 @@ def pass_tensors_from_frz_model(self,
self.descrpt_deriv = descrpt_deriv
self.descrpt_reshape = descrpt_reshape

def get_feed_dict(self,
coord_,
atype_,
natoms,
box,
mesh):
"""
generate the deed_dict for current descriptor
Parameters
----------
coord_
The coordinate of atoms
atype_
The type of atoms
natoms
The number of atoms. This tensor has the length of Ntypes + 2
natoms[0]: number of local atoms
natoms[1]: total number of atoms held by this processor
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
box
The box. Can be generated by deepmd.model.make_stat_input
mesh
For historical reasons, only the length of the Tensor matters.
if size of mesh == 6, pbc is assumed.
if size of mesh == 0, no-pbc is assumed.
Returns
-------
feed_dict
The output feed_dict of current descriptor
"""
feed_dict = {
't_coord:0' :coord_,
't_type:0' :atype_,
't_natoms:0' :natoms,
't_box:0' :box,
't_mesh:0' :mesh
}
return feed_dict


def init_variables(self,
embedding_net_variables: dict
model_file : str,
suffix : str = "",
) -> None:
"""
Init the embedding net variables with the given dict
Parameters
----------
embedding_net_variables
The input dict which stores the embedding net variables
model_file : str
The input frozen model file
suffix : str, optional
The suffix of the scope
"""
self.embedding_net_variables = embedding_net_variables
self.embedding_net_variables = get_embedding_net_variables(model_file, suffix = suffix)


def prod_force_virial(self,
Expand Down
7 changes: 4 additions & 3 deletions deepmd/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,11 @@ def build (self,
name = 'descrpt_attr/ntypes',
dtype = tf.int32)
feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box, mesh)
return_elements = ['o_rmat:0', 'o_rmat_deriv:0', 'o_rij:0', 'o_nlist:0', 'o_descriptor:0']
descrpt_reshape, descrpt_deriv, rij, nlist, dout \
return_elements = [*self.descrpt.get_tensor_names(), 'o_descriptor:0']
imported_tensors \
= self._import_graph_def_from_frz_model(frz_model, feed_dict, return_elements)
self.descrpt.pass_tensors_from_frz_model(descrpt_reshape, descrpt_deriv, rij, nlist)
dout = imported_tensors[-1]
self.descrpt.pass_tensors_from_frz_model(*imported_tensors[:-1])


if self.srtab is not None :
Expand Down
9 changes: 5 additions & 4 deletions deepmd/model/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from deepmd.env import tf
from deepmd.common import ClassArg
from deepmd.env import global_cvt_2_ener_float, MODEL_VERSION
from deepmd.env import global_cvt_2_ener_float, MODEL_VERSION, GLOBAL_TF_FLOAT_PRECISION
from deepmd.env import op_module
from deepmd.utils.graph import load_graph_def
from .model_stat import make_stat_input, merge_sys_stat
Expand Down Expand Up @@ -138,10 +138,11 @@ def build (self,
name = 'descrpt_attr/ntypes',
dtype = tf.int32)
feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box, mesh)
return_elements = ['o_rmat:0', 'o_rmat_deriv:0', 'o_rij:0', 'o_nlist:0', 'o_descriptor:0']
descrpt_reshape, descrpt_deriv, rij, nlist, dout \
return_elements = [*self.descrpt.get_tensor_names(), 'o_descriptor:0']
imported_tensors \
= self._import_graph_def_from_frz_model(frz_model, feed_dict, return_elements)
self.descrpt.pass_tensors_from_frz_model(descrpt_reshape, descrpt_deriv, rij, nlist)
dout = imported_tensors[-1]
self.descrpt.pass_tensors_from_frz_model(*imported_tensors[:-1])

rot_mat = self.descrpt.get_rot_mat()
rot_mat = tf.identity(rot_mat, name = 'o_rot_mat'+suffix)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def _init_from_frz_model(self):
# self.frz_model will control the self.model to import the descriptor from the given frozen model instead of building from scratch...
# initialize fitting net with the given compressed frozen model
if self.model_type == 'original_model':
self.descrpt.init_variables(get_embedding_net_variables(self.run_opt.init_frz_model))
self.descrpt.init_variables(self.run_opt.init_frz_model)
self.fitting.init_variables(get_fitting_net_variables(self.run_opt.init_frz_model))
tf.constant("original_model", name = 'model_type', dtype = tf.string)
elif self.model_type == 'compressed_model':
Expand Down
20 changes: 13 additions & 7 deletions deepmd/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_tensor_by_type(node,
elif data_type == np.float32:
tensor = np.array(node.float_val)
else:
raise RunTimeError('model compression does not support the half precision')
raise RuntimeError('model compression does not support the half precision')
return tensor


Expand Down Expand Up @@ -139,40 +139,44 @@ def get_embedding_net_nodes_from_graph_def(graph_def: tf.GraphDef, suffix: str =
return embedding_net_nodes


def get_embedding_net_nodes(model_file: str) -> Dict:
def get_embedding_net_nodes(model_file: str, suffix: str = "") -> Dict:
"""
Get the embedding net nodes with the given frozen model(model_file)
Parameters
----------
model_file
The input frozen model path
suffix : str, optional
The suffix of the scope
Returns
----------
Dict
The embedding net nodes with the given frozen model
"""
_, graph_def = load_graph_def(model_file)
return get_embedding_net_nodes_from_graph_def(graph_def)
return get_embedding_net_nodes_from_graph_def(graph_def, suffix=suffix)


def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict:
def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef, suffix: str = "") -> Dict:
"""
Get the embedding net variables with the given tf.GraphDef object
Parameters
----------
graph_def
The input tf.GraphDef object
suffix : str, optional
The suffix of the scope
Returns
----------
Dict
The embedding net variables within the given tf.GraphDef object
"""
embedding_net_variables = {}
embedding_net_nodes = get_embedding_net_nodes_from_graph_def(graph_def)
embedding_net_nodes = get_embedding_net_nodes_from_graph_def(graph_def, suffix=suffix)
for item in embedding_net_nodes:
node = embedding_net_nodes[item]
dtype = tf.as_dtype(node.dtype).as_numpy_dtype
Expand All @@ -184,22 +188,24 @@ def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict:
embedding_net_variables[item] = np.reshape(tensor_value, tensor_shape)
return embedding_net_variables

def get_embedding_net_variables(model_file : str) -> Dict:
def get_embedding_net_variables(model_file : str, suffix: str = "") -> Dict:
"""
Get the embedding net variables with the given frozen model(model_file)
Parameters
----------
model_file
The input frozen model path
suffix : str, optional
The suffix of the scope
Returns
----------
Dict
The embedding net variables within the given frozen model
"""
_, graph_def = load_graph_def(model_file)
return get_embedding_net_variables_from_graph_def(graph_def)
return get_embedding_net_variables_from_graph_def(graph_def, suffix=suffix)


def get_fitting_net_nodes_from_graph_def(graph_def: tf.GraphDef) -> Dict:
Expand Down