Skip to content

Commit

Permalink
Fix compress training bug within the dp train --init-frz-model interf…
Browse files Browse the repository at this point in the history
…ace (#1233)

* fix compress training bug within the dp train --init-frz-model interface

* address comments

* rename _transfer_graph_def function within freeze.py
  • Loading branch information
denghuilu committed Oct 26, 2021
1 parent 6c41aa3 commit 1a8fd73
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 32 deletions.
41 changes: 39 additions & 2 deletions deepmd/entrypoints/freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
"""

import logging
from deepmd.env import tf
from deepmd.env import op_module
from deepmd.env import tf, FITTING_NET_PATTERN
from deepmd.utils.sess import run_sess
from deepmd.utils.graph import get_pattern_nodes_from_graph_def
from os.path import abspath

# load grad of force module
Expand All @@ -21,6 +21,36 @@

log = logging.getLogger(__name__)

def _transfer_fitting_net_trainable_variables(sess, old_graph_def, raw_graph_def):
old_pattern = FITTING_NET_PATTERN
raw_pattern = FITTING_NET_PATTERN\
.replace('idt', 'idt+_\d+')\
.replace('bias', 'bias+_\d+')\
.replace('matrix', 'matrix+_\d+')
old_graph_nodes = get_pattern_nodes_from_graph_def(
old_graph_def,
old_pattern
)
try :
raw_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
raw_graph_def, # The graph_def is used to retrieve the nodes
[n + '_1' for n in old_graph_nodes], # The output node names are used to select the usefull nodes
)
except AssertionError:
# if there's no additional nodes
return old_graph_def

raw_graph_nodes = get_pattern_nodes_from_graph_def(
raw_graph_def,
raw_pattern
)
for node in old_graph_def.node:
if node.name not in old_graph_nodes.keys():
continue
tensor = tf.make_ndarray(raw_graph_nodes[node.name + '_1'])
node.attr["value"].tensor.tensor_content = tensor.tostring()
return old_graph_def

def _make_node_names(model_type: str, modifier_type: Optional[str] = None) -> List[str]:
"""Get node names based on model type.
Expand Down Expand Up @@ -205,6 +235,13 @@ def freeze(
output_node_list, # The output node names are used to select the usefull nodes
)

# If we need to transfer the fitting net variables
output_graph_def = _transfer_fitting_net_trainable_variables(
sess,
output_graph_def,
input_graph_def
)

# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
Expand Down
21 changes: 2 additions & 19 deletions deepmd/entrypoints/transfer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module used for transfering parameters between models."""

from typing import Dict, Optional, Sequence, Tuple
from deepmd.env import tf
from deepmd.env import tf, TRANSFER_PATTERN
import re
import numpy as np
import logging
Expand Down Expand Up @@ -225,24 +225,7 @@ def load_transform_node(graph: tf.Graph) -> Dict[str, tf.Tensor]:
Dict[str, tf.Tensor]
mapping on graph node names and corresponding tensors
"""
transform_node_pattern = re.compile(
r"filter_type_\d+/matrix_\d+_\d+|"
r"filter_type_\d+/bias_\d+_\d+|"
r"filter_type_\d+/idt_\d+_\d+|"
r"layer_\d+_type_\d+/matrix|"
r"layer_\d+_type_\d+/bias|"
r"layer_\d+_type_\d+/idt|"
r"final_layer_type_\d+/matrix|"
r"descrpt_attr/t_avg|"
r"descrpt_attr/t_std|"
r"final_layer_type_\d+/bias|"
r"fitting_attr/t_fparam_avg|"
r"fitting_attr/t_fparam_istd|"
r"fitting_attr/t_aparam_avg|"
r"fitting_attr/t_aparam_istd|"
r"model_attr/t_tab_info|"
r"model_attr/t_tab_data|"
)
transform_node_pattern = re.compile(TRANSFER_PATTERN)

transform_node = {}
for node in graph.node:
Expand Down
36 changes: 36 additions & 0 deletions deepmd/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import os
import re
import platform
from configparser import ConfigParser
from imp import reload
Expand Down Expand Up @@ -35,10 +36,45 @@
"reset_default_tf_session_config",
"op_module",
"op_grads_module",
"TRANSFER_PATTERN",
"FITTING_NET_PATTERN",
"EMBEDDING_NET_PATTERN",
]

SHARED_LIB_MODULE = "op"

EMBEDDING_NET_PATTERN = str(
r"filter_type_\d+/matrix_\d+_\d+|"
r"filter_type_\d+/bias_\d+_\d+|"
r"filter_type_\d+/idt_\d+_\d+|"
r"filter_type_all/matrix_\d+_\d+|"
r"filter_type_all/matrix_\d+_\d+_\d+|"
r"filter_type_all/bias_\d+_\d+|"
r"filter_type_all/bias_\d+_\d+_\d+|"
r"filter_type_all/idt_\d+_\d+|"
)

FITTING_NET_PATTERN = str(
r"layer_\d+_type_\d+/matrix|"
r"layer_\d+_type_\d+/bias|"
r"layer_\d+_type_\d+/idt|"
r"final_layer_type_\d+/matrix|"
r"final_layer_type_\d+/bias|"
)

TRANSFER_PATTERN = \
EMBEDDING_NET_PATTERN + \
FITTING_NET_PATTERN + \
str(
r"descrpt_attr/t_avg|"
r"descrpt_attr/t_std|"
r"fitting_attr/t_fparam_avg|"
r"fitting_attr/t_fparam_istd|"
r"fitting_attr/t_aparam_avg|"
r"fitting_attr/t_aparam_istd|"
r"model_attr/t_tab_info|"
r"model_attr/t_tab_data|"
)

def set_env_if_empty(key: str, value: str, verbose: bool = True):
"""Set environment variable only if it is empty.
Expand Down
47 changes: 36 additions & 11 deletions deepmd/utils/graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import numpy as np
from typing import Tuple, Dict
from deepmd.env import tf
from deepmd.env import tf, EMBEDDING_NET_PATTERN, FITTING_NET_PATTERN
from deepmd.utils.sess import run_sess
from deepmd.utils.errors import GraphWithoutTensorError

Expand Down Expand Up @@ -112,6 +112,30 @@ def get_tensor_by_type(node,
return tensor


def get_pattern_nodes_from_graph_def(graph_def: tf.GraphDef, pattern: str) -> Dict:
"""
Get the pattern nodes with the given tf.GraphDef object
Parameters
----------
graph_def
The input tf.GraphDef object
pattern
The node pattern within the graph_def
Returns
----------
Dict
The fitting net nodes within the given tf.GraphDef object
"""
nodes = {}
pattern = re.compile(pattern)
for node in graph_def.node:
if re.fullmatch(pattern, node.name) != None:
nodes[node.name] = node.attr["value"].tensor
return nodes


def get_embedding_net_nodes_from_graph_def(graph_def: tf.GraphDef, suffix: str = "") -> Dict:
"""
Get the embedding net nodes with the given tf.GraphDef object
Expand All @@ -128,11 +152,16 @@ def get_embedding_net_nodes_from_graph_def(graph_def: tf.GraphDef, suffix: str =
Dict
The embedding net nodes within the given tf.GraphDef object
"""
embedding_net_nodes = {}
embedding_net_pattern = f"filter_type_\d+{suffix}/matrix_\d+_\d+|filter_type_\d+{suffix}/bias_\d+_\d+|filter_type_\d+{suffix}/idt_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+_\d+|filter_type_all{suffix}/idt_\d+_\d+"
for node in graph_def.node:
if re.fullmatch(embedding_net_pattern, node.name) != None:
embedding_net_nodes[node.name] = node.attr["value"].tensor
# embedding_net_pattern = f"filter_type_\d+{suffix}/matrix_\d+_\d+|filter_type_\d+{suffix}/bias_\d+_\d+|filter_type_\d+{suffix}/idt_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+_\d+|filter_type_all{suffix}/idt_\d+_\d+"
if suffix is not "":
embedding_net_pattern = EMBEDDING_NET_PATTERN\
.replace('/idt', suffix + '/idt')\
.replace('/bias', suffix + '/bias')\
.replace('/matrix', suffix + '/matrix')
else:
embedding_net_pattern = EMBEDDING_NET_PATTERN

embedding_net_nodes = get_pattern_nodes_from_graph_def(graph_def, embedding_net_pattern)
for key in embedding_net_nodes.keys():
assert key.find('bias') > 0 or key.find(
'matrix') > 0, "currently, only support weight matrix and bias matrix at the tabulation op!"
Expand Down Expand Up @@ -222,11 +251,7 @@ def get_fitting_net_nodes_from_graph_def(graph_def: tf.GraphDef) -> Dict:
Dict
The fitting net nodes within the given tf.GraphDef object
"""
fitting_net_nodes = {}
fitting_net_pattern = "layer_\d+_type_\d+/matrix+|layer_\d+_type_\d+/bias+|layer_\d+_type_\d+/idt+|final_layer_type_\d+/matrix+|final_layer_type_\d+/bias"
for node in graph_def.node:
if re.fullmatch(fitting_net_pattern, node.name) != None:
fitting_net_nodes[node.name] = node.attr["value"].tensor
fitting_net_nodes = get_pattern_nodes_from_graph_def(graph_def, FITTING_NET_PATTERN)
for key in fitting_net_nodes.keys():
assert key.find('bias') > 0 or key.find('matrix') > 0 or key.find(
'idt') > 0, "currently, only support weight matrix, bias and idt at the model compression process!"
Expand Down

0 comments on commit 1a8fd73

Please sign in to comment.