From 843a3c5ab34930c948ba4687c85cedef86f65176 Mon Sep 17 00:00:00 2001 From: Denghui Lu Date: Tue, 23 Nov 2021 09:28:10 +0800 Subject: [PATCH] Fix model compression bug when fparam or aparam is not zero (#1306) * fix model compression bug when fparam and aparam are not zero * Update ener.py --- deepmd/fit/ener.py | 30 +++++++++++++++++++++++++++--- deepmd/train/trainer.py | 3 +++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index e6b0d0a763..f82b68f41b 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -9,7 +9,7 @@ from deepmd.descriptor import DescrptLocFrame from deepmd.descriptor import DescrptSeA from deepmd.utils.type_embed import embed_atom_type -from deepmd.utils.graph import get_fitting_net_variables +from deepmd.utils.graph import get_fitting_net_variables, load_graph_def, get_tensor_by_name_from_graph from deepmd.env import global_cvt_2_tf_float from deepmd.env import GLOBAL_TF_FLOAT_PRECISION @@ -502,7 +502,31 @@ def init_variables(self, self.fitting_net_variables = get_fitting_net_variables(model_file) - def enable_mixed_precision(self, mixed_prec : dict = None) -> None: + def enable_compression(self, + model_file: str, + suffix: str = "" + ) -> None: + """ + Set the fitting net attributes from the frozen model_file when fparam or aparam is not zero + + Parameters + ---------- + model_file : str + The input frozen model file + suffix : str, optional + The suffix of the scope + """ + if self.numb_fparam > 0 or self.numb_aparam > 0: + graph, _ = load_graph_def(model_file) + if self.numb_fparam > 0: + self.fparam_avg = get_tensor_by_name_from_graph(graph, 'fitting_attr%s/t_fparam_avg' % suffix) + self.fparam_inv_std = get_tensor_by_name_from_graph(graph, 'fitting_attr%s/t_fparam_istd' % suffix) + if self.numb_aparam > 0: + self.aparam_avg = get_tensor_by_name_from_graph(graph, 'fitting_attr%s/t_aparam_avg' % suffix) + self.aparam_inv_std = get_tensor_by_name_from_graph(graph, 'fitting_attr%s/t_aparam_istd' % suffix) + + + def enable_mixed_precision(self, mixed_prec: dict = None) -> None: """ Reveive the mixed precision setting. @@ -512,4 +536,4 @@ def enable_mixed_precision(self, mixed_prec : dict = None) -> None: The mixed precision setting used in the embedding net """ self.mixed_prec = mixed_prec - self.fitting_precision = get_precision(mixed_prec['output_prec']) \ No newline at end of file + self.fitting_precision = get_precision(mixed_prec['output_prec']) diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 0e1ed6d3e8..ea75b30bd1 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -293,6 +293,9 @@ def build (self, else : self.descrpt.enable_compression(self.model_param['compress']["min_nbor_dist"], self.model_param['compress']['model_file'], self.model_param['compress']['table_config'][0], self.model_param['compress']['table_config'][1], self.model_param['compress']['table_config'][2], self.model_param['compress']['table_config'][3]) self.fitting.init_variables(self.model_param['compress']['model_file']) + # for fparam or aparam settings in 'ener' type fitting net + if self.fitting_type == 'ener': + self.fitting.enable_compression(self.model_param['compress']['model_file']) if self.is_compress or self.model_type == 'compressed_model': tf.constant("compressed_model", name = 'model_type', dtype = tf.string)