Skip to content

Commit

Permalink
Fix model compression bug when fparam or aparam is not zero (#1306)
Browse files Browse the repository at this point in the history
* fix model compression bug when fparam and aparam are not zero

* Update ener.py
  • Loading branch information
denghuilu committed Nov 23, 2021
1 parent 338ba31 commit 843a3c5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
30 changes: 27 additions & 3 deletions deepmd/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from deepmd.descriptor import DescrptLocFrame
from deepmd.descriptor import DescrptSeA
from deepmd.utils.type_embed import embed_atom_type
from deepmd.utils.graph import get_fitting_net_variables
from deepmd.utils.graph import get_fitting_net_variables, load_graph_def, get_tensor_by_name_from_graph

from deepmd.env import global_cvt_2_tf_float
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION
Expand Down Expand Up @@ -502,7 +502,31 @@ def init_variables(self,
self.fitting_net_variables = get_fitting_net_variables(model_file)


def enable_mixed_precision(self, mixed_prec : dict = None) -> None:
def enable_compression(self,
model_file: str,
suffix: str = ""
) -> None:
"""
Set the fitting net attributes from the frozen model_file when fparam or aparam is not zero
Parameters
----------
model_file : str
The input frozen model file
suffix : str, optional
The suffix of the scope
"""
if self.numb_fparam > 0 or self.numb_aparam > 0:
graph, _ = load_graph_def(model_file)
if self.numb_fparam > 0:
self.fparam_avg = get_tensor_by_name_from_graph(graph, 'fitting_attr%s/t_fparam_avg' % suffix)
self.fparam_inv_std = get_tensor_by_name_from_graph(graph, 'fitting_attr%s/t_fparam_istd' % suffix)
if self.numb_aparam > 0:
self.aparam_avg = get_tensor_by_name_from_graph(graph, 'fitting_attr%s/t_aparam_avg' % suffix)
self.aparam_inv_std = get_tensor_by_name_from_graph(graph, 'fitting_attr%s/t_aparam_istd' % suffix)


def enable_mixed_precision(self, mixed_prec: dict = None) -> None:
"""
Reveive the mixed precision setting.
Expand All @@ -512,4 +536,4 @@ def enable_mixed_precision(self, mixed_prec : dict = None) -> None:
The mixed precision setting used in the embedding net
"""
self.mixed_prec = mixed_prec
self.fitting_precision = get_precision(mixed_prec['output_prec'])
self.fitting_precision = get_precision(mixed_prec['output_prec'])
3 changes: 3 additions & 0 deletions deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ def build (self,
else :
self.descrpt.enable_compression(self.model_param['compress']["min_nbor_dist"], self.model_param['compress']['model_file'], self.model_param['compress']['table_config'][0], self.model_param['compress']['table_config'][1], self.model_param['compress']['table_config'][2], self.model_param['compress']['table_config'][3])
self.fitting.init_variables(self.model_param['compress']['model_file'])
# for fparam or aparam settings in 'ener' type fitting net
if self.fitting_type == 'ener':
self.fitting.enable_compression(self.model_param['compress']['model_file'])

if self.is_compress or self.model_type == 'compressed_model':
tf.constant("compressed_model", name = 'model_type', dtype = tf.string)
Expand Down

0 comments on commit 843a3c5

Please sign in to comment.