From 1a8fd7367d34bd939ee86b8b35cfaad5cc0eac23 Mon Sep 17 00:00:00 2001 From: Denghui Lu Date: Tue, 26 Oct 2021 08:17:32 +0800 Subject: [PATCH] Fix compress training bug within the dp train --init-frz-model interface (#1233) * fix compress training bug within the dp train --init-frz-model interface * address comments * rename _transfer_graph_def function within freeze.py --- deepmd/entrypoints/freeze.py | 41 +++++++++++++++++++++++++++-- deepmd/entrypoints/transfer.py | 21 ++------------- deepmd/env.py | 36 ++++++++++++++++++++++++++ deepmd/utils/graph.py | 47 ++++++++++++++++++++++++++-------- 4 files changed, 113 insertions(+), 32 deletions(-) diff --git a/deepmd/entrypoints/freeze.py b/deepmd/entrypoints/freeze.py index afbb7659d4..9d17786e1a 100755 --- a/deepmd/entrypoints/freeze.py +++ b/deepmd/entrypoints/freeze.py @@ -7,9 +7,9 @@ """ import logging -from deepmd.env import tf -from deepmd.env import op_module +from deepmd.env import tf, FITTING_NET_PATTERN from deepmd.utils.sess import run_sess +from deepmd.utils.graph import get_pattern_nodes_from_graph_def from os.path import abspath # load grad of force module @@ -21,6 +21,36 @@ log = logging.getLogger(__name__) +def _transfer_fitting_net_trainable_variables(sess, old_graph_def, raw_graph_def): + old_pattern = FITTING_NET_PATTERN + raw_pattern = FITTING_NET_PATTERN\ + .replace('idt', 'idt+_\d+')\ + .replace('bias', 'bias+_\d+')\ + .replace('matrix', 'matrix+_\d+') + old_graph_nodes = get_pattern_nodes_from_graph_def( + old_graph_def, + old_pattern + ) + try : + raw_graph_def = tf.graph_util.convert_variables_to_constants( + sess, # The session is used to retrieve the weights + raw_graph_def, # The graph_def is used to retrieve the nodes + [n + '_1' for n in old_graph_nodes], # The output node names are used to select the usefull nodes + ) + except AssertionError: + # if there's no additional nodes + return old_graph_def + + raw_graph_nodes = get_pattern_nodes_from_graph_def( + raw_graph_def, + raw_pattern + ) + for node in old_graph_def.node: + if node.name not in old_graph_nodes.keys(): + continue + tensor = tf.make_ndarray(raw_graph_nodes[node.name + '_1']) + node.attr["value"].tensor.tensor_content = tensor.tostring() + return old_graph_def def _make_node_names(model_type: str, modifier_type: Optional[str] = None) -> List[str]: """Get node names based on model type. @@ -205,6 +235,13 @@ def freeze( output_node_list, # The output node names are used to select the usefull nodes ) + # If we need to transfer the fitting net variables + output_graph_def = _transfer_fitting_net_trainable_variables( + sess, + output_graph_def, + input_graph_def + ) + # Finally we serialize and dump the output graph to the filesystem with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) diff --git a/deepmd/entrypoints/transfer.py b/deepmd/entrypoints/transfer.py index 3a755fd61a..47509551a8 100644 --- a/deepmd/entrypoints/transfer.py +++ b/deepmd/entrypoints/transfer.py @@ -1,7 +1,7 @@ """Module used for transfering parameters between models.""" from typing import Dict, Optional, Sequence, Tuple -from deepmd.env import tf +from deepmd.env import tf, TRANSFER_PATTERN import re import numpy as np import logging @@ -225,24 +225,7 @@ def load_transform_node(graph: tf.Graph) -> Dict[str, tf.Tensor]: Dict[str, tf.Tensor] mapping on graph node names and corresponding tensors """ - transform_node_pattern = re.compile( - r"filter_type_\d+/matrix_\d+_\d+|" - r"filter_type_\d+/bias_\d+_\d+|" - r"filter_type_\d+/idt_\d+_\d+|" - r"layer_\d+_type_\d+/matrix|" - r"layer_\d+_type_\d+/bias|" - r"layer_\d+_type_\d+/idt|" - r"final_layer_type_\d+/matrix|" - r"descrpt_attr/t_avg|" - r"descrpt_attr/t_std|" - r"final_layer_type_\d+/bias|" - r"fitting_attr/t_fparam_avg|" - r"fitting_attr/t_fparam_istd|" - r"fitting_attr/t_aparam_avg|" - r"fitting_attr/t_aparam_istd|" - r"model_attr/t_tab_info|" - r"model_attr/t_tab_data|" - ) + transform_node_pattern = re.compile(TRANSFER_PATTERN) transform_node = {} for node in graph.node: diff --git a/deepmd/env.py b/deepmd/env.py index 92287d8aa5..6e6543697e 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -2,6 +2,7 @@ import logging import os +import re import platform from configparser import ConfigParser from imp import reload @@ -35,10 +36,45 @@ "reset_default_tf_session_config", "op_module", "op_grads_module", + "TRANSFER_PATTERN", + "FITTING_NET_PATTERN", + "EMBEDDING_NET_PATTERN", ] SHARED_LIB_MODULE = "op" +EMBEDDING_NET_PATTERN = str( + r"filter_type_\d+/matrix_\d+_\d+|" + r"filter_type_\d+/bias_\d+_\d+|" + r"filter_type_\d+/idt_\d+_\d+|" + r"filter_type_all/matrix_\d+_\d+|" + r"filter_type_all/matrix_\d+_\d+_\d+|" + r"filter_type_all/bias_\d+_\d+|" + r"filter_type_all/bias_\d+_\d+_\d+|" + r"filter_type_all/idt_\d+_\d+|" +) + +FITTING_NET_PATTERN = str( + r"layer_\d+_type_\d+/matrix|" + r"layer_\d+_type_\d+/bias|" + r"layer_\d+_type_\d+/idt|" + r"final_layer_type_\d+/matrix|" + r"final_layer_type_\d+/bias|" +) + +TRANSFER_PATTERN = \ + EMBEDDING_NET_PATTERN + \ + FITTING_NET_PATTERN + \ + str( + r"descrpt_attr/t_avg|" + r"descrpt_attr/t_std|" + r"fitting_attr/t_fparam_avg|" + r"fitting_attr/t_fparam_istd|" + r"fitting_attr/t_aparam_avg|" + r"fitting_attr/t_aparam_istd|" + r"model_attr/t_tab_info|" + r"model_attr/t_tab_data|" +) def set_env_if_empty(key: str, value: str, verbose: bool = True): """Set environment variable only if it is empty. diff --git a/deepmd/utils/graph.py b/deepmd/utils/graph.py index 6766750e4b..031454e4b9 100644 --- a/deepmd/utils/graph.py +++ b/deepmd/utils/graph.py @@ -1,7 +1,7 @@ import re import numpy as np from typing import Tuple, Dict -from deepmd.env import tf +from deepmd.env import tf, EMBEDDING_NET_PATTERN, FITTING_NET_PATTERN from deepmd.utils.sess import run_sess from deepmd.utils.errors import GraphWithoutTensorError @@ -112,6 +112,30 @@ def get_tensor_by_type(node, return tensor +def get_pattern_nodes_from_graph_def(graph_def: tf.GraphDef, pattern: str) -> Dict: + """ + Get the pattern nodes with the given tf.GraphDef object + + Parameters + ---------- + graph_def + The input tf.GraphDef object + pattern + The node pattern within the graph_def + + Returns + ---------- + Dict + The fitting net nodes within the given tf.GraphDef object + """ + nodes = {} + pattern = re.compile(pattern) + for node in graph_def.node: + if re.fullmatch(pattern, node.name) != None: + nodes[node.name] = node.attr["value"].tensor + return nodes + + def get_embedding_net_nodes_from_graph_def(graph_def: tf.GraphDef, suffix: str = "") -> Dict: """ Get the embedding net nodes with the given tf.GraphDef object @@ -128,11 +152,16 @@ def get_embedding_net_nodes_from_graph_def(graph_def: tf.GraphDef, suffix: str = Dict The embedding net nodes within the given tf.GraphDef object """ - embedding_net_nodes = {} - embedding_net_pattern = f"filter_type_\d+{suffix}/matrix_\d+_\d+|filter_type_\d+{suffix}/bias_\d+_\d+|filter_type_\d+{suffix}/idt_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+_\d+|filter_type_all{suffix}/idt_\d+_\d+" - for node in graph_def.node: - if re.fullmatch(embedding_net_pattern, node.name) != None: - embedding_net_nodes[node.name] = node.attr["value"].tensor + # embedding_net_pattern = f"filter_type_\d+{suffix}/matrix_\d+_\d+|filter_type_\d+{suffix}/bias_\d+_\d+|filter_type_\d+{suffix}/idt_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+_\d+|filter_type_all{suffix}/idt_\d+_\d+" + if suffix is not "": + embedding_net_pattern = EMBEDDING_NET_PATTERN\ + .replace('/idt', suffix + '/idt')\ + .replace('/bias', suffix + '/bias')\ + .replace('/matrix', suffix + '/matrix') + else: + embedding_net_pattern = EMBEDDING_NET_PATTERN + + embedding_net_nodes = get_pattern_nodes_from_graph_def(graph_def, embedding_net_pattern) for key in embedding_net_nodes.keys(): assert key.find('bias') > 0 or key.find( 'matrix') > 0, "currently, only support weight matrix and bias matrix at the tabulation op!" @@ -222,11 +251,7 @@ def get_fitting_net_nodes_from_graph_def(graph_def: tf.GraphDef) -> Dict: Dict The fitting net nodes within the given tf.GraphDef object """ - fitting_net_nodes = {} - fitting_net_pattern = "layer_\d+_type_\d+/matrix+|layer_\d+_type_\d+/bias+|layer_\d+_type_\d+/idt+|final_layer_type_\d+/matrix+|final_layer_type_\d+/bias" - for node in graph_def.node: - if re.fullmatch(fitting_net_pattern, node.name) != None: - fitting_net_nodes[node.name] = node.attr["value"].tensor + fitting_net_nodes = get_pattern_nodes_from_graph_def(graph_def, FITTING_NET_PATTERN) for key in fitting_net_nodes.keys(): assert key.find('bias') > 0 or key.find('matrix') > 0 or key.find( 'idt') > 0, "currently, only support weight matrix, bias and idt at the model compression process!"