Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable mixed precision support for deepmd-kit #1285

Merged
merged 17 commits into from
Nov 23, 2021

Conversation

denghuilu
Copy link
Member

This PR has enabled the mixed-precision training as well as the mixed precision inference process for deepmd-kit. Without any change of the input script, one can easily enable the mixed precision training by simply setting the environment variable DP_ENABLE_MIXED_PREC to fp16.

Main changes:

  • add DP_ENABLE_MIXED_PREC environmental variable for the control of mixed precision training. Note currently only tf.float16 precision is enabled with the mixed precision setting.
  • set the default embedding-net and fitting-net precision at argcheck.py according to the environment variable DP_INTERFACE_PREC.
  • use dynamic loss scale for gradients update.
  • add doc for mixed precision suppport.

According to our example water benchmark system, with TF-2.6.0, CUDA-11.0 and NVIDIA-V100 GPU environment, the speed of the dp training process decreased slightly, while the inference process with 12288 atoms has gained a speedup by a factor of 3.

It is strongly recommended to enable the mixed precision settings with CUDA-11.0 or above CUDA-toolkit.

@njzjz
Copy link
Member

njzjz commented Nov 15, 2021

@wanghan-iapcm an import error is caught in the latest dpdata

@codecov-commenter
Copy link

codecov-commenter commented Nov 15, 2021

Codecov Report

Merging #1285 (e7d357b) into devel (4af4ea5) will not change coverage.
The diff coverage is n/a.

❗ Current head e7d357b differs from pull request most recent head 6fa19c9. Consider uploading reports for the commit 6fa19c9 to get more accurate results
Impacted file tree graph

@@           Coverage Diff           @@
##            devel    #1285   +/-   ##
=======================================
  Coverage   64.28%   64.28%           
=======================================
  Files           5        5           
  Lines          14       14           
=======================================
  Hits            9        9           
  Misses          5        5           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 4af4ea5...6fa19c9. Read the comment docs.

@@ -345,6 +348,9 @@ def _build_training(self):
optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer)
else:
optimizer = tf.train.AdamOptimizer(learning_rate = self.learning_rate)
if DP_ENABLE_MIXED_PRECISION:
# enable dynamic loss scale of the gradients
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function has been moved to tf.mixed_precision.enable_mixed_precision_graph_rewrite. https://www.tensorflow.org/api_docs/python/tf/compat/v1/mixed_precision/enable_mixed_precision_graph_rewrite What TF version do you use? Do you know if it is supported in all TF versions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function was found in Nvidia's official documentation. I have tested it with the TF-1.14.0 and TF-2.6.0 environment. Since it is a deprecated function, I will use the new tf.mixed_precision.enable_mixed_precision_graph_rewrite function.

Copy link
Member

@njzjz njzjz Nov 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method was available since v1.12 (tensorflow/tensorflow@02730dc) and then was renamed in v2.4 (tensorflow/tensorflow@0112286). We may need to raise an error for TF<1.12.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

@amcadmus
Copy link
Member

@wanghan-iapcm an import error is caught in the latest dpdata

pymatgen... could you please help fix it? thanks!

@njzjz
Copy link
Member

njzjz commented Nov 16, 2021

pymatgen... could you please help fix it? thanks!

See deepmodeling/dpdata#217.

@denghuilu
Copy link
Member Author

There are some problems in the mixed precision training on the descriptors of se_r and se_t types, which are under investigation.

@denghuilu
Copy link
Member Author

There are some problems in the mixed precision training on the descriptors of se_r and se_t types, which are under investigation.

@amcadmus @njzjz There's still some errors when training mixed precision with the se_r or se_t types. So I suggest that we merge the se_a type first.

@denghuilu denghuilu requested a review from njzjz November 21, 2021 13:40
@@ -345,6 +358,12 @@ def _build_training(self):
optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer)
else:
optimizer = tf.train.AdamOptimizer(learning_rate = self.learning_rate)
if self.mixed_prec is not None:
# check the TF_VERSION, when TF < 1.12, mixed precision is not allowed
if TF_VERSION < "1.12":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> "1.8"<"1.12"
False

@njzjz
Copy link
Member

njzjz commented Nov 21, 2021

Can you also support hybrid?

@denghuilu
Copy link
Member Author

Can you also support hybrid?

As we said, there's still some errors when using the se_r or se_t type descriptor. Hybrid is not yet ready for using.

@njzjz
Copy link
Member

njzjz commented Nov 22, 2021

It will be useful to hybrid mixed by two se_a.

Comment on lines 19 to 20
trainable = True,
trainable = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why introduce this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo for debug, I'll fix it

@@ -44,6 +49,12 @@ def one_layer(inputs,
b_initializer,
trainable = trainable)
variable_summaries(b, 'bias')

if mixed_prec is not None and outputs_size != 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like this idea.
For dipole and polar, the size of output layer is not 1, but they are using fp16, which is not what we want.

Comment on lines 79 to 80
if mixed_prec is not None and outputs_size != 1:
idt = tf.cast(idt, get_precision(mixed_prec['compute_prec']))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again outputs_size != 1 may not be a good idea.

Comment on lines +752 to +753
if self.mixed_prec is not None:
inputs = tf.cast(inputs, get_precision(self.mixed_prec['compute_prec']))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this line? the inputs are anyway cast to compute_prec in networks.one_layer or networks.embedding_net

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. There's matrix multiplication outside the embedding net, we need to cast the inputs to match the dtype of the embedding net output.
  2. Half precision slicing will be more efficient.

Comment on lines 266 to 282
def enable_mixed_precision(self, mixed_prec : dict = None) -> None:
"""
Reveive the mixed precision setting.

Parameters
----------
mixed_prec
The mixed precision setting used in the embedding net

Notes
-----
This method is called by others when the descriptor supported compression.
"""
raise NotImplementedError(
"Descriptor %s doesn't support mixed precision training!" % type(self).__name__)


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lint errors appear here

@@ -345,6 +358,15 @@ def _build_training(self):
optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer)
else:
optimizer = tf.train.AdamOptimizer(learning_rate = self.learning_rate)
if self.mixed_prec is not None:
TF_VERSION_LIST = [int(item) for item in TF_VERSION.split('.')]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int(item) will cause an error if the version is a pre-release, e.g. v2.6.0-rc1. See https://github.com/tensorflow/tensorflow/blob/ff68385595088304cf772086b9a259a65b007622/tensorflow/core/public/version.h#L35-L37

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to use a third-party class Specifiers

Comment on lines +612 to +613
Argument("output_prec", str, optional=True, default="float32", doc=doc_output_prec),
Argument("compute_prec", str, optional=False, default="float16", doc=doc_compute_prec),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default behavior is to enable mixed precision?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mixed_precision session is optional within the training session(see line 617), so it's false by default. However, when one have set the mixed_precision session, one must provide the compute_prec key.

@wanghan-iapcm wanghan-iapcm merged commit f40e14e into deepmodeling:devel Nov 23, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants