-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix compress training bug within the dp train --init-frz-model interface #1233
Fix compress training bug within the dp train --init-frz-model interface #1233
Conversation
Codecov Report
@@ Coverage Diff @@
## devel #1233 +/- ##
==========================================
- Coverage 76.02% 75.99% -0.04%
==========================================
Files 91 91
Lines 7367 7389 +22
==========================================
+ Hits 5601 5615 +14
- Misses 1766 1774 +8
Continue to review full report at Codecov.
|
deepmd/entrypoints/freeze.py
Outdated
raw_graph_def, # The graph_def is used to retrieve the nodes | ||
[n + '_1' for n in old_graph_nodes], # The output node names are used to select the usefull nodes | ||
) | ||
except Exception: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any specific exception?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All fitting net variables are added the _1
suffix, we can check it by the tf.trainable_variables()
function. I think this is the default node naming method of TensorFlow: When a specific variable name is not available in the graph(due to the usage of tf.import_graph_def
), TF will automatically add a number suffix to that name. And each fitting_net node name are unique within the original graph(with a suffix matrix
, bias
or idt
), so we are fine to do so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean could you catch a specific exception (such as RuntimeError, etc) instead of general Exception?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. It's the AssertionError
.
deepmd/entrypoints/freeze.py
Outdated
@@ -21,6 +21,36 @@ | |||
|
|||
log = logging.getLogger(__name__) | |||
|
|||
def _transfer_graph_def(sess, old_graph_def, raw_graph_def): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_transfer_graph_def
is not a good name for this function. It should specified which variables are transferred
…to compress-training fix pip CI problem
The compress training code uses the
tf.import_graph_def
function to load thetf.Tensor
andtf.Operation
objects from the old graph def to the current default graph.However, this could lead to a variable name conflict during the model freeze process. And that's the reason for the issue #1194 .According to the tensorflow doc :
In this PR, the following changes are adopted to address the #1194 :
EMBEDDING_NET_PATTERN
,FITTING_NET_PATTERN
as well as theTRANSFER_PATTERN
, to thedeepmd.env
module.Note that this PR does not use the prefix parameter of the
tf.import_graph_def
function to solve the #1194 , although it is easier to do so, it will change the node name permanently. Instead this PR will not affect the graph structures as well as the node names within the graph, which is very important for the model maintenance.